In [169]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import json

from PIL import Image
import torchvision
from torchvision import datasets, models, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from sklearn.metrics import *
import time
from datetime import datetime
import os
from torch.utils import data
import random
import copy
import itertools
import io
import uuid
from sklearn.model_selection import KFold, train_test_split

import warnings
warnings.filterwarnings('ignore')

import wandb
wandb_username = 'andreasabo'
local_username = 'andreasabo'

In [170]:
if torch.cuda.is_available():
    device = torch.device('cuda:1') 
else:
    device = torch.device('cpu')
print(device)

cuda:1


In [171]:
# root directory
root_dir = "/home/andreasabo/Documents/HNProject/"
split_file_base = "/home/andreasabo/Documents/HNUltra/"

# data directory on current machine: abhishekmoturu, andreasabo, denizjafari, navidkorhani
data_dir = "/home/" + local_username + "/Documents/HNProject/all_label_img/"

# read target df
csv_path = os.path.join(root_dir, "all_splits_1000000.csv")
data_df = pd.read_csv(csv_path, usecols=['subj_id', 'scan_num', 'view_label', 'image_ids', 'reflux_label', 'function_label', 'surgery_label', 'outcome_train'])

In [172]:
print(data_df.head())
len(data_df)

  function_label image_ids reflux_label surgery_label view_label  subj_id  \
0        Missing  1323_2_1      Missing       Missing    Missing     1323   
1        Missing  1323_2_2      Missing       Missing    Missing     1323   
2        Missing  1323_2_3      Missing       Missing    Missing     1323   
3        Missing  1323_2_4      Missing       Missing    Missing     1323   
4        Missing  1323_2_5      Missing       Missing    Missing     1323   

   scan_num  outcome_train  
0         2            NaN  
1         2            NaN  
2         2            NaN  
3         2            NaN  
4         2            NaN  


72459

### **Reading Data Indicies and Labels**

In [173]:
# Drop the images for which we do not have view labels or the label is "other"
data_df = data_df[data_df.view_label != "Missing"]
data_df = data_df[data_df.view_label != "Other"]
train_df = data_df[data_df.outcome_train == 1]
test_df = data_df[data_df.outcome_train == 0]

print(f"We have {len(test_df) + len(train_df)} images ({len(train_df)} train and {len(test_df)} test)")
unique_subj = train_df.subj_id.unique()

# Create the splits for 5-fold cross validation based on subj_id
data_split_file = split_file_base + 'data_splits_outcome.json'
if not os.path.isfile(data_split_file):

    kf = KFold(n_splits=5, random_state=0, shuffle=True)
    fold = 0
    all_folds = {}
    for train_subj, val_subj in kf.split(unique_subj):
        train_ids  = unique_subj[train_subj]
        val_ids = unique_subj[val_subj]
        
        # Save the image names
        train_images = train_df[train_df.subj_id.isin(train_ids)].image_ids.tolist()
        val_images = train_df[train_df.subj_id.isin(val_ids)].image_ids.tolist()
        
        # Save the scan number
        train_scan = train_df[train_df.subj_id.isin(train_ids)].scan_num.tolist()
        val_scan = train_df[train_df.subj_id.isin(val_ids)].scan_num.tolist()
        
        # Save the view 
        train_views = train_df[train_df.subj_id.isin(train_ids)].view_label.tolist()
        val_views = train_df[train_df.subj_id.isin(val_ids)].view_label.tolist()
        
        
        
        # Save the outcome labels
        train_function = train_df[train_df.subj_id.isin(train_ids)].function_label.tolist()
        val_function = train_df[train_df.subj_id.isin(val_ids)].function_label.tolist()
        
        train_reflux = train_df[train_df.subj_id.isin(train_ids)].reflux_label.tolist()
        val_reflux = train_df[train_df.subj_id.isin(val_ids)].reflux_label.tolist() 
        
        train_surgery = train_df[train_df.subj_id.isin(train_ids)].surgery_label.tolist()
        val_surgery = train_df[train_df.subj_id.isin(val_ids)].surgery_label.tolist()
        
        
        val_labels = train_df[train_df.subj_id.isin(val_ids)].view_label.tolist()
        cur_fold = {'train_images': train_images, 'val_images': val_images, 'train_reflux': train_reflux, 
                    'val_reflux': val_reflux, 'train_function': train_function, 'val_function': val_function, 
                    'train_surgery': train_surgery, 'val_surgery': val_surgery, 'train_scan': train_scan,
                    'val_scan': val_scan, 'train_views': train_views, 'val_views': val_views}
        
        all_folds[fold] = cur_fold
        fold += 1

    print("Saving data splits")
    with open(data_split_file, 'w') as f:
        json.dump(all_folds, f)
        
else: # just load from file
    print("Reading splits from file")
    with open(data_split_file, 'r') as f:
        all_folds = json.load(f)

We have 9581 images (7230 train and 2351 test)
Reading splits from file


In [178]:
# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception, viewnet]
model_name = "viewnet"

# Number of classes in the dataset: right_sag, right_trav, left_sag, left_trav, bladder, other
num_classes = 6

# Batch size for training (change depending on how much memory you have)
batch_size = 100

# Number of epochs to train for
num_epochs = 200

# Flag for feature extracting. When False, we finetune the whole model; when True we only update the reshaped layer params
feature_extract = False

# Flag for whether or not to use pretrained model
pretrain = False

# Flag for whether or not to sample sets of images by scan (ensures that all scans are seen the same number of times)
sample_by_scan = True

In [175]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [176]:
# Code from: https://gist.github.com/stefanonardo/693d96ceb2f531fa05db530f3e21517d
class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=True):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if np.isnan(metrics):
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)

In [205]:
import random
class ScanDataset(data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, image_list, reflux_labels, surgery_labels, function_labels, view_labels, scan_labels, binarize_labels = True, transformations=None, sample_by_scan=True):
        'Initialization'
        self.image_list = image_list
        self.subj_id = np.asarray([int(s[0:4]) for s in image_list])

        self.reflux_labels = reflux_labels        
        self.surgery_labels = surgery_labels
        self.function_labels = function_labels       
        self.view_labels = view_labels
        self.scan_labels = np.asarray(scan_labels)
        self.binarize_labels = binarize_labels
        self.transformations = transformations
        self.sample_by_scan = sample_by_scan
        self.image_return_order = ['Saggital_Right', 'Transverse_Right', 'Saggital_Left', 'Transverse_Left', 'Bladder']
        
        # Create a list of indices that we will access the images in
        random.seed(0)
        list_of_inds = [num for num in range(len(image_list))]
        
        # UPDATE: don't need this since the dataloader shuffles for us
        #random.shuffle(list_of_inds) #shuffle method

        self.index_order = list_of_inds
        self.all_view_names = list(set((view_labels)))
        
        # If we are going to group by scans, create the scan list
        self.all_scan_ids = np.asarray([(str(self.subj_id[i]) + "_" + str(self.scan_labels[i])) for i in range(len(self.scan_labels))])
        self.unique_scans = list(set(self.all_scan_ids))
        
  def __len__(self):
        'Denotes the total number of image_list or scan_list'
        if self.sample_by_scan:
            return len(self.unique_scans)
        else:
            return len(self.image_list)

  def __getitem__(self, ind):
        'Generates one sample of data'
        if self.sample_by_scan:
            cur_scan = self.unique_scans[ind]
            images_of_scan = (self.all_scan_ids == cur_scan)
            indexes_in_scan = [i for i, x in enumerate(images_of_scan) if x]
            index = indexes_in_scan[0]
            
        else:
            # Select sample based on current images
            index = self.index_order[ind]
            image_id = self.image_list[index]

            # We want to use the current image, as well as images from the same scan (from the same patient)
            # To fill images of the other 4 classes of images
            # First finds all of the images in this scan
            images_of_scan = (self.subj_id == self.subj_id[index]) & (self.scan_labels == self.scan_labels[index])
            indexes_in_scan = [i for i, x in enumerate(images_of_scan) if x]
        
        # Group images in the scan by view and select one of each (if available)
        dict_of_scans_by_view = {}
        # looping through views
        for inds in indexes_in_scan:
            if self.view_labels[inds] not in dict_of_scans_by_view:
                dict_of_scans_by_view[self.view_labels[inds]] = [inds]
            else:
                dict_of_scans_by_view[self.view_labels[inds]].append(inds)

        # If we directly picked an image, make sure this one is the only one for the view
        if not self.sample_by_scan:
            dict_of_scans_by_view[self.view_labels[index]] = [index]

                
        output_image_path_dict = {}
        # Now loop through all the views (other than the one for this image) and randomly select an image.  
        for view in dict_of_scans_by_view:
                random_view_from_scan_ind = random.choice(dict_of_scans_by_view[view])
                output_image_path_dict[view] = data_dir + self.image_list[random_view_from_scan_ind] + '.jpg'
                
        output_images = []
        for view in self.image_return_order:
            if view not in output_image_path_dict:
                empty_im = [np.ones((256, 256))*np.nan]
                empty_im = torch.FloatTensor(empty_im)
                output_images.append( empty_im)
            else:
                image = Image.open(output_image_path_dict[view]).convert('L')
                if self.transformations:
                    image = self.transformations(image)
                image = ToTensor()(image)
                output_images.append(image)
        
        
        
        y = [self.reflux_labels[index], self.surgery_labels[index], self.function_labels[index]]
        for i, outcome in enumerate(y):
            # Convert "yes" and "no" to 0/1
            if outcome == "No":
                y[i] = 0
            elif outcome == "Yes":
                y[i] = 1
            elif outcome == "Missing":
                y[i] = np.nan

            # Should we also binarize the function labels?
            elif self.binarize_labels:
                if float(outcome) > 60 or float(outcome) < 40:
                    y[i] = 1
                else:
                    y[i] = 0
            else:
                y[i] = float(outcome)

        y = torch.FloatTensor(y)
        output_images = torch.stack(output_images)
        return output_images, y
    
# Use the first fold for now
partition = all_folds['0']

# Test out dataloaders
shuffle = True
num_workers = 0
binarize_outcomes = True
params = {'batch_size': batch_size,
          'shuffle': shuffle,
          'num_workers': num_workers}

# Tranforms
trans = transforms.Compose([transforms.RandomAffine(degrees=8, translate=(0.1, 0.1), scale=(0.95,1.25))])

# Generators
training_set = ScanDataset(partition['train_images'], partition['train_reflux'], partition['train_surgery'],
                       partition['train_function'], partition['train_views'],partition['train_scan'], binarize_outcomes, trans, sample_by_scan)
val_set = ScanDataset(partition['val_images'], partition['val_reflux'], partition['val_surgery'],
                       partition['val_function'], partition['val_views'],partition['val_scan'], binarize_outcomes, trans, sample_by_scan)

training_generator = data.DataLoader(training_set, **params)
val_generator = data.DataLoader(val_set, **params)





dict_keys(['train_images', 'val_images', 'train_reflux', 'val_reflux', 'train_function', 'val_function', 'train_surgery', 'val_surgery', 'train_scan', 'val_scan', 'train_views', 'val_views'])


In [None]:
# This is for visualizing images, 
# MAKE SURE THIS IS COMMENTED OUT WHEN COMMITTING!!
# image_return_order = ['Saggital_Right', 'Transverse_Right', 'Saggital_Left', 'Transverse_Left', 'Bladder']

# for inputs, labs in training_generator:
#     plt.figure(figsize=(20,10)) 
#     first_scan = inputs[0]
#     for i in range(5):
#         im = first_scan[i,:, :, :]

#         im_np = np.asarray(im).squeeze()
#         plt.subplot(2,5,i+ 1)
#         plt.imshow(im_np, cmap='gray')
#         frame1 = plt.gca()
#         frame1.axes.get_xaxis().set_visible(False)
#         frame1.axes.get_yaxis().set_visible(False)

#         plt.title(image_return_order[i])
        
        
#     break