<a href="https://colab.research.google.com/github/mtzig/LIDC_GDRO/blob/main/notebooks/lidc_cnn_ERM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#ERM CNN Model for Malignancy

#First We setup the repo

In [1]:
# Only run if on Colab
#%cd .. #run this on local machine


!git clone https://github.com/mtzig/LIDC_GDRO.git
%cd /content/LIDC_GDRO

Cloning into 'LIDC_GDRO'...
remote: Enumerating objects: 3061, done.[K
remote: Counting objects: 100% (136/136), done.[K
remote: Compressing objects: 100% (100/100), done.[K
remote: Total 3061 (delta 75), reused 85 (delta 36), pack-reused 2925[K
Receiving objects: 100% (3061/3061), 39.79 MiB | 17.73 MiB/s, done.
Resolving deltas: 100% (2887/2887), done.
Checking out files: 100% (5386/5386), done.
/content/LIDC_GDRO


In [None]:
#!git pull

In [2]:
%load_ext autoreload
%autoreload 2

In [26]:
import os
import torch
import pandas as pd
import numpy as np
from dataloaders import InfiniteDataLoader
from datasets import NoduleDataset
from models import VGGNet
from loss import ERMLoss
from train import train

In [4]:
if torch.cuda.is_available():
    print("Good to go!")
    DEVICE = torch.device("cuda")
else:
    print("Using cpu")
    DEVICE = torch.device("cpu")

Good to go!


#Next We get our data

## First some functions to retrive the data

In [5]:
def getNormed(this_array, this_min = 0, this_max = 255, set_to_int = True):
    
    rat = (this_max - this_min)/(this_array.max() - this_array.min())
    this_array = this_array * rat
    this_array -= this_array.min()
    this_array += this_min
    if set_to_int:
        return this_array.to(dtype= torch.int)
    return this_array

In [6]:
def getImages(image_folder):
    '''
        Input:
        image_folder: directory of the image files

        Output:
        m1: list of the labels encountered (1,2,4,5)
        m2: list of binary labels encountered (benign, malignant)
        diff: list of any nodes with discrepency to CSV labels

    '''
    
    train_img = []
    train_label = []
    train_spic_label = []
    marked_benign = []
    unmarked_benign = []
    
    marked_malignant = []
    unmarked_malignant = []

    

    lidc = pd.read_csv('./data/lidc_spic_subgrouped.csv')
    train_test = pd.read_csv('./data/lidc_train_test_split_stratified.csv')
    for dir1 in os.listdir(image_folder):
  
        if dir1 == 'Malignancy_3':
            continue

        for file in os.listdir(os.path.join(image_folder, dir1)):


            temp_nodule_ID = file.split('.')[0]
            subtype = lidc[lidc['noduleID']==int(temp_nodule_ID)]['subgroup'].iloc[0]
            malignancy = lidc[lidc['noduleID']==int(temp_nodule_ID)]['malignancy'].iloc[0]
            spiculation = lidc[lidc['noduleID']==int(temp_nodule_ID)]['malignancy'].iloc[0]
            
            train_type = train_test[train_test['noduleID'] ==int(temp_nodule_ID)]['dataset'].iloc[0]
            
            
            image = np.loadtxt(os.path.join(image_folder, dir1,file))
            image = torch.from_numpy(image).to(DEVICE)
            rgb_image = torch.stack((image,image,image), dim = 0)
            rgb_image = getNormed(rgb_image)
            rgb_image = rgb_image / 255 

            
            
            if train_type == 'train':
                train_img.append(rgb_image)
                train_label.append(torch.tensor(1).to(DEVICE).to(torch.float32) if malignancy > 3 else torch.tensor(0).to(DEVICE).to(torch.float32))
                train_spic_label.append(torch.tensor(1).to(DEVICE).to(torch.float32) if spiculation > 1 else torch.tensor(0).to(DEVICE).to(torch.float32))
                
                continue
            
            if subtype == 'marked_benign':
                image_array = marked_benign
            elif subtype == 'unmarked_benign':
                image_array = unmarked_benign
            elif subtype == 'marked_malignant':
                image_array = marked_malignant
            else:
                image_array = unmarked_malignant
            
            image_array.append(rgb_image)
 


    return train_img, train_label, train_spic_label, marked_benign, unmarked_benign, marked_malignant, unmarked_malignant

## Now we get the data

In [7]:
train_img, train_label, train_spic_label, marked_benign, unmarked_benign, marked_malignant, unmarked_malignant = getImages('./LIDC(MaxSlices)_Nodules(fixed)')

In [10]:
train_dataset = NoduleDataset(train_img, train_label)

In [11]:
len(train_dataset)

1210

In [13]:
train_set, val_set = torch.utils.data.random_split(train_dataset, [1000, 210])

train_loader = InfiniteDataLoader(train_set, 1000)
val_loader = InfiniteDataLoader(val_set, len(val_set))

#Now we create the model and setup training

First we make our model

In [16]:
model = VGGNet(device=DEVICE)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

In [20]:
loss_fn = ERMLoss(model,torch.nn.functional.binary_cross_entropy_with_logits,{})
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=0.005)


##Now we train the model

In [27]:
epochs = 40

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    train(train_loader, model, loss_fn, optimizer, verbose=True)

    model.eval()
    with torch.no_grad():
      X,y = next(val_loader)
      results = torch.sigmoid(model(X))
      accuracy = torch.sum(torch.round(results) == y)/X.shape[0]
    model.train()
    print(f'cv accuracy {accuracy}')

Epoch 1/40
Average training loss: 0.12029918283224106
cv accuracy 0.6428571939468384
Epoch 2/40
Average training loss: 0.11049578338861465
cv accuracy 0.6523810029029846
Epoch 3/40


KeyboardInterrupt: ignored

#Lastly We evaluate model performance

We first create a simple function to get sensitivities

In [35]:
def get_sensitivity(model, imgs, label, label_tensor = False):
  '''
  Inputs:
  model: the model to use
  img: list of imgs in the class
  label: either 0 or 1 depending on the ground truth of subclass
  label_tensor: if True, then label is tensor of ground truth

  Output:
  accuracy: accuracy for this subgroup

  '''
  results = torch.sigmoid(model(torch.stack(imgs).to(DEVICE)))
  if label_tensor:
    truth = label
  elif label == 1:
    truth = torch.ones(len(imgs), device=DEVICE)
  else:
    truth = torch.zeros(len(imgs), device=DEVICE)
  
  accuracy = torch.sum(torch.round(results) == truth)/len(imgs)

  return accuracy

##Model Performance on Test Set

In [36]:
#spaghetti code-esque way to get imgs and labels for entire test set
all_test_imgs = marked_benign+unmarked_benign+marked_malignant+unmarked_malignant
all_labels = torch.tensor([0 for _ in marked_benign+unmarked_benign]+[1 for _ in marked_malignant+unmarked_malignant], device=DEVICE)


print(f'spiculated benign accuracy: {get_sensitivity(model, marked_benign, 0):.3f}')
print(f'unspiculated benign accuracy: {get_sensitivity(model, unmarked_benign, 0):.3f}')
print(f'spiculated malignant accuracy: {get_sensitivity(model, marked_malignant, 1):.3f}')
print(f'unspiculated malignant accuracy: {get_sensitivity(model, unmarked_malignant, 1):.3f}')

print(f'Total accuracy: {get_sensitivity(model, all_test_imgs, all_labels, label_tensor=True):.3f}')

spiculated benign accuracy: 0.762
unspiculated benign accuracy: 0.834
spiculated malignant accuracy: 0.720
unspiculated malignant accuracy: 0.561
Total accuracy: 0.741
