<a href="https://colab.research.google.com/github/mtzig/LIDC_GDRO/blob/main/notebooks/lidc_cnn_ERM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#ERM CNN Model for Malignancy

#First We setup the repo

In [282]:
# Only run if on Colab
#%cd .. #run this on local machine


!git clone https://github.com/mtzig/LIDC_GDRO.git
%cd /content/LIDC_GDRO

fatal: destination path 'LIDC_GDRO' already exists and is not an empty directory.
/content/LIDC_GDRO


In [283]:
# !git pull

In [284]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [285]:
import os
import torch
import pandas as pd
import numpy as np
from dataloaders import InfiniteDataLoader
from datasets import NoduleDataset
from models import VGGNet, ResNet18
from loss import ERMLoss
from train import train

In [286]:
if torch.cuda.is_available():
    print("Good to go!")
    DEVICE = torch.device("cuda")
else:
    print("Using cpu")
    DEVICE = torch.device("cpu")

Good to go!


#Next We get our data

## First some functions to retrive the data

In [287]:
def getNormed(this_array, this_min = 0, this_max = 255, set_to_int = True):
    
    rat = (this_max - this_min)/(this_array.max() - this_array.min())
    this_array = this_array * rat
    this_array -= this_array.min()
    this_array += this_min
    if set_to_int:
        return this_array.to(dtype= torch.int)
    return this_array

In [288]:
def getImages(image_folder):
    '''
        Input:
        image_folder: directory of the image files

        Output:
        m1: list of the labels encountered (1,2,4,5)
        m2: list of binary labels encountered (benign, malignant)
        diff: list of any nodes with discrepency to CSV labels

    '''
    
    train_img = []
    train_label = []
    train_spic_label = []
    marked_benign = []
    unmarked_benign = []
    
    marked_malignant = []
    unmarked_malignant = []

    

    lidc = pd.read_csv('./data/lidc_spic_subgrouped.csv')
    train_test = pd.read_csv('./data/lidc_train_test_split_stratified.csv')
    for dir1 in os.listdir(image_folder):
  
        if dir1 == 'Malignancy_3':
            continue

        for file in os.listdir(os.path.join(image_folder, dir1)):


            temp_nodule_ID = file.split('.')[0]
            subtype = lidc[lidc['noduleID']==int(temp_nodule_ID)]['subgroup'].iloc[0]
            malignancy = lidc[lidc['noduleID']==int(temp_nodule_ID)]['malignancy'].iloc[0]
            spiculation = lidc[lidc['noduleID']==int(temp_nodule_ID)]['malignancy'].iloc[0]
            
            train_type = train_test[train_test['noduleID'] ==int(temp_nodule_ID)]['dataset'].iloc[0]
            
            
            image = np.loadtxt(os.path.join(image_folder, dir1,file))
            image = torch.from_numpy(image).to(DEVICE)
            rgb_image = torch.stack((image,image,image), dim = 0)
            rgb_image = getNormed(rgb_image)
            rgb_image = rgb_image / 255 

            
            
            if train_type == 'train':
                train_img.append(rgb_image)
                train_label.append(torch.tensor(1).to(DEVICE).to(torch.float32) if malignancy > 3 else torch.tensor(0).to(DEVICE).to(torch.float32))
                train_spic_label.append(torch.tensor(1).to(DEVICE).to(torch.float32) if spiculation > 1 else torch.tensor(0).to(DEVICE).to(torch.float32))
                
                continue
            
            if subtype == 'marked_benign':
                image_array = marked_benign
            elif subtype == 'unmarked_benign':
                image_array = unmarked_benign
            elif subtype == 'marked_malignant':
                image_array = marked_malignant
            else:
                image_array = unmarked_malignant
            
            image_array.append(rgb_image)
 


    return train_img, train_label, train_spic_label, marked_benign, unmarked_benign, marked_malignant, unmarked_malignant

## Now we get the data

In [289]:
train_img, train_label, train_spic_label, marked_benign, unmarked_benign, marked_malignant, unmarked_malignant = getImages('./LIDC(MaxSlices)_Nodules(fixed)')

In [290]:
train_dataset = NoduleDataset(train_img, train_label)

In [291]:
len(train_dataset)

1210

In [292]:
train_set, val_set = torch.utils.data.random_split(train_dataset, [1000, 210])

train_loader = InfiniteDataLoader(train_set, 128) #200 epochs
val_loader = InfiniteDataLoader(val_set, len(val_set))

#Now we create the model and setup training

First we make our model

In [293]:
model = ResNet18(device=DEVICE, pretrained=True, freeze=False)

In [294]:
loss_fn = ERMLoss(model,torch.nn.functional.binary_cross_entropy_with_logits,{})
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.005)


We also make learning rate scheduler

In [295]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=2, verbose=True)


##Now we train the model

In [296]:
epochs = 40

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    train(train_loader, model, loss_fn, optimizer, verbose=True)

    model.eval()
    with torch.no_grad():
      X,y = next(val_loader)
      results = torch.sigmoid(model(X))
      accuracy = torch.sum(torch.round(results) == y)/X.shape[0]
    print(f'cv accuracy {accuracy}')
    scheduler.step(accuracy)

Epoch 1/40
Average training loss: 0.5866453945636749
cv accuracy 0.6761904954910278
Epoch 2/40
Average training loss: 0.4466914108821324
cv accuracy 0.6761904954910278
Epoch 3/40
Average training loss: 0.38839320625577656
cv accuracy 0.7000000476837158
Epoch 4/40
Average training loss: 0.242648788860866
cv accuracy 0.7523809671401978
Epoch 5/40
Average training loss: 0.18248883528368814
cv accuracy 0.7809523940086365
Epoch 6/40
Average training loss: 0.17846629342862538
cv accuracy 0.6571428775787354
Epoch 7/40
Average training loss: 0.16402730558599746
cv accuracy 0.6285714507102966
Epoch 8/40
Average training loss: 0.14910744343485152
cv accuracy 0.747619092464447
Epoch 00008: reducing learning rate of group 0 to 2.0000e-04.
Epoch 9/40
Average training loss: 0.09558290083493505
cv accuracy 0.7714285850524902
Epoch 10/40
Average training loss: 0.045237461903265545
cv accuracy 0.785714328289032
Epoch 11/40
Average training loss: 0.02549218146928719
cv accuracy 0.761904776096344
Epoch 1

#Lastly We evaluate model performance

We first create a simple function to get sensitivities

In [297]:
def get_sensitivity(model, imgs, label, label_tensor = False):
  '''
  Inputs:
  model: the model to use
  img: list of imgs in the class
  label: either 0 or 1 depending on the ground truth of subclass
  label_tensor: if True, then label is tensor of ground truth

  Output:
  accuracy: accuracy for this subgroup

  '''
  results = torch.sigmoid(model(torch.stack(imgs).to(DEVICE)))
  if label_tensor:
    truth = label
  elif label == 1:
    truth = torch.ones(len(imgs), device=DEVICE)
  else:
    truth = torch.zeros(len(imgs), device=DEVICE)
  
  accuracy = torch.sum(torch.round(results) == truth)/len(imgs)

  return accuracy

##Model Performance on Test Set

In [298]:
#spaghetti code-esque way to get imgs and labels for entire test set
all_test_imgs = marked_benign+unmarked_benign+marked_malignant+unmarked_malignant
all_labels = torch.tensor([0 for _ in marked_benign+unmarked_benign]+[1 for _ in marked_malignant+unmarked_malignant], device=DEVICE)


print(f'spiculated benign accuracy: {get_sensitivity(model, marked_benign, 0):.3f}')
print(f'unspiculated benign accuracy: {get_sensitivity(model, unmarked_benign, 0):.3f}')
print(f'spiculated malignant accuracy: {get_sensitivity(model, marked_malignant, 1):.3f}')
print(f'unspiculated malignant accuracy: {get_sensitivity(model, unmarked_malignant, 1):.3f}')

print(f'Total accuracy: {get_sensitivity(model, all_test_imgs, all_labels, label_tensor=True):.3f}')

spiculated benign accuracy: 0.714
unspiculated benign accuracy: 0.855
spiculated malignant accuracy: 0.829
unspiculated malignant accuracy: 0.649
Total accuracy: 0.800
