In [46]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import PIL
import glob
import shutil
import random
from tqdm import tqdm
from pandas import read_excel
import torch.nn.functional as F
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

PyTorch Version:  1.5.0
Torchvision Version:  0.6.0


In [144]:
dataset = datasets.ImageFolder('/hddraid5/data/colin/cell_classification/data/train/')

In [148]:
cell_types = dataset.classes

In [38]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [39]:
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0
    
    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size

In [208]:
model_ft, input_size = initialize_model('densenet', 9, True)
model_ft.load_state_dict(torch.load('wbc_dense.pt'))
model_ft.eval()
cutoff = 800000

In [234]:
input_size = 224
data_transforms = transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ColorJitter(brightness=0.10, contrast=0.20, saturation=0.20, hue=0.20),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

In [210]:
base_path = '/hddraid5/data/colin/covid-data/'
label_files = glob.glob(os.path.join(base_path, '*.xlsx'))
orders = []
test_results = []
for label_file in label_files:
    table = read_excel(label_file)
    table_orders = list(table['Order #'])
    table_test_results = list(table['Covid Test result'])
    orders = orders + table_orders
    test_results = test_results + table_test_results

In [211]:
# lets compile a DB
all_images = {}
for order, test_result in zip(orders, test_results):
    try:
        label = 'positive' in test_result.lower()
        np.int(order)
    except (TypeError, AttributeError):
        continue
    all_image_paths = glob.glob(os.path.join(base_path, 'COVID Research Images','**', str(order), '*.jpg'), recursive=True)
    image_paths = [image_path for image_path in all_image_paths if (os.path.getsize(image_path) < cutoff and os.path.getsize(image_path) > 100)]
    all_images[str(order)] = image_paths

In [None]:
batch_size = 8
all_cell_classes = {}
for patient, image_paths in all_images.items():
    test_images = [data_transforms(PIL.Image.open(image_path)) for image_path in image_paths]
    for i in np.arange(0, len(test_images), batch_size):
        tensors = torch.stack(test_images[i:i+batch_size])
        class_probs = F.softmax(model_ft(tensors), dim=-1)
        cell_classes = torch.argmax(class_probs, dim=-1).tolist()
        if i == 0:
            all_cell_classes[patient] = cell_classes
        else:
            all_cell_classes[patient] += cell_classes

In [None]:
image_paths = all_images[patient]
images = [PIL.Image.open(ip) for ip in image_paths]
cell_indices = all_cell_classes[patient]
cell_labels = np.array(cell_types)[cell_indices]

In [None]:
index = 100
print(cell_labels[index])
images[index]    

In [202]:
import csv

In [205]:
all_images[patient]

[]

In [220]:
import json

In [None]:
patient_data = {}
for patient in all_images.keys():
    patient_data[patient] = {
        
    }
    try:
        for image_path, cell_type in zip(all_images[patient], all_cell_classes[patient]):
            patient_data[patient][os.path.basename(image_path)] = cell_labels[cell_type]
    except KeyError:
        pass
with open('wbc_classes_v1_jitter.json', 'w') as fp:
    json.dump(patient_data, fp)

In [None]:
with open("wbc_classes_v1.json", 'wb') as fp:
    
    writer = csv.writer(csvfile, delimiter=',')
    writer.writerow(['Patient ID', 'filename', 'Cell Type'])
    for patient in all_images.keys():
        try:
            for image_path, cell_type in zip(all_images[patient], all_cell_classes[patient]):
                writer.writerow([patient, os.path.basename(image_path), cell_labels[cell_type]])
        except KeyError:
            pass

In [214]:
with open("wbc_classes_v1.json", 'wb') as fp:
    writer = csv.writer(csvfile, delimiter=',')
    writer.writerow(['Patient ID', 'filename', 'Cell Type'])
    for patient in all_images.keys():
        try:
            for image_path, cell_type in zip(all_images[patient], all_cell_classes[patient]):
                writer.writerow([patient, os.path.basename(image_path), cell_labels[cell_type]])
        except KeyError:
            pass