In [1]:
import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F
import torchvision
import transformers

from PIL import Image

import os
import glob

import wandb
from torchinfo import summary

from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns 

import SimpleITK as sitk



from utils import normalize_sitk_image, MRI_Dataset_within_ROI, MRI_Dataset_within_ROI_both_prepost

from torchvision.transforms import v2
import torchio as tio

from torchmetrics import Accuracy, Recall, Precision, F1Score


torch.set_num_threads(12)
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

from sklearn.model_selection import train_test_split

In [2]:
# parameters
SIZE = (100,100,100)
TRAIN_SPLIT = 0.8
SPLIT_SEED = 123456789
BATCH_SIZE = 5

DATA_PATH = '../../../Processed NIFTI Dataset/'
CLASSIFICATION = 'ER'
dataset_path = f'../../Train Test Splits/{CLASSIFICATION}/'

MODEL_SAVE_PATH = f'basic_model_{CLASSIFICATION}.pth'
PROJECT_NAME = 'Breast Cancer Subtype Prediction'

In [3]:
# dataset loading

train_set = pd.read_csv(dataset_path + 'train.csv')
test_set  = pd.read_csv(dataset_path + 'test.csv')

bounding_boxes = pd.read_csv('../../Data/segmentation_annotations_NIFTI.csv').set_index('Patient_ID')

NUM_CLASSES = len(train_set['label'].unique())

In [4]:
# validation set split

train_set, val_set = train_test_split(train_set, train_size = TRAIN_SPLIT, stratify = train_set['label'])

In [5]:
# transforms

transform = v2.Compose([
    tio.ZNormalization()
    # v2.Normalize(mean = [0.5], std = [0.225])
])

upscaler = tio.Resize(SIZE, 'bspline')

In [6]:
affine_args = {
        'scales' : [0.8, 1, 0.8, 1, 0.8, 1],
        'degrees' : 15,
        'translation' : [0.05, 0.05, 0.05],
        'center' : 'image'
}

augment = tio.Compose([
    # tio.transforms.RandomFlip(axes = (0, 1, 2)),
    tio.transforms.RandomAffine(
        **affine_args
    )
])

In [7]:
# # Datasets

# train_dataset = MRI_Dataset_within_ROI(DATA_PATH,
#                                        train_set,
#                                        bounding_boxes, 
#                                        transform,
#                                        upscaler,
#                                     #    augment = augment
#                                        )

# val_dataset = MRI_Dataset_within_ROI(DATA_PATH,
#                                        val_set,
#                                        bounding_boxes, 
#                                        transform,
#                                        upscaler)

# test_dataset = MRI_Dataset_within_ROI(DATA_PATH,
#                                        test_set,
#                                        bounding_boxes, 
#                                        transform,
#                                        upscaler)

In [8]:
class_sample_count = train_set.label.value_counts()  
class_sample_count = 1/class_sample_count
class_sample_count = class_sample_count.tolist()

c_w = train_set.label.apply(lambda x: class_sample_count[x]).to_numpy()
c_w = torch.tensor(c_w, dtype = torch.double)

weighted_sampler = torch.utils.data.WeightedRandomSampler(c_w, len(c_w))

In [9]:
# Datasets

train_dataset = MRI_Dataset_within_ROI_both_prepost(DATA_PATH,
                                       train_set,
                                       bounding_boxes, 
                                       transform,
                                       upscaler,
                                       augment = augment,
                                       )

val_dataset = MRI_Dataset_within_ROI_both_prepost(DATA_PATH,
                                       val_set,
                                       bounding_boxes, 
                                       transform,
                                       upscaler)

test_dataset = MRI_Dataset_within_ROI_both_prepost(DATA_PATH,
                                       test_set,
                                       bounding_boxes, 
                                       transform,
                                       upscaler)

In [10]:
## Dataloaders

train_loader = torch.utils.data.DataLoader(train_dataset,
                                        #    shuffle = True,
                                           batch_size = BATCH_SIZE,
                                           pin_memory = True,
                                           num_workers = 3,
                                           sampler = weighted_sampler
                                           )

val_loader = torch.utils.data.DataLoader(val_dataset,
                                           batch_size = BATCH_SIZE,
                                           pin_memory = True,
                                           num_workers = 2
                                           )

test_loader = torch.utils.data.DataLoader(test_dataset,
                                           batch_size = BATCH_SIZE,
                                           pin_memory = True,
                                           num_workers = 2
                                           )

## Model architecture

In [11]:
class ConvNet_MRI3D(torch.nn.Module):
    def __init__(self, in_channels, num_classes):
        super(ConvNet_MRI3D, self).__init__()
        self.conv1 = torch.nn.Conv3d(in_channels, 16, kernel_size = (3,3,3), stride = (1,1,1), padding = 1)
        self.conv2 = torch.nn.Conv3d(16, 16, kernel_size = (3,3,3), stride = (1,1,1), padding = 1)
        self.conv3 = torch.nn.Conv3d(16, 16, kernel_size = (3,3,3), stride = (1,1,1), padding = 1)
        self.conv4 = torch.nn.Conv3d(16, 32, kernel_size = (3,3,3), stride = (1,1,1), padding = 1)
        
        # self.conv1x1_1 = torch.nn.Conv3d(16, 16, kernel_size=1, stride=1)
        # self.conv1x1_2 = torch.nn.Conv3d(16, 16, kernel_size=1, stride=1)
        self.maxpool1 = torch.nn.MaxPool3d(2)

        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(3456*2, num_classes)
    
    def forward(self, inp):
        out = inp
        out = F.relu(self.conv1(out))
        intermediate = self.maxpool1(out)
        
        
        out = F.relu(self.conv2(intermediate))
        out = out + intermediate      # residual
        intermediate = self.maxpool1(out)
        
        out = F.relu(self.conv3(intermediate))
        out = out + intermediate

        out = self.maxpool1(out)
        out = self.conv4(out)
        out = self.maxpool1(out)
        out = self.flatten(out)
        
        out = self.fc1(out)
        
        return out
        
        

In [12]:
# summary(ConvNet_MRI3D(2, 2), input_size=(20,2,100,100,100), col_names = ['input_size', 'output_size', 'num_params'])

In [13]:
LR = 0.0001
EPOCHS = 10

model = ConvNet_MRI3D(2, NUM_CLASSES).to(DEVICE)
optim = torch.optim.Adam(model.parameters(), lr = LR)

wandb.login()
wandb.init(
    project = f"{PROJECT_NAME} - {CLASSIFICATION}",
    name = '3D Convnet with random sampling and class weights',

    config = {
        'learning_rate': LR,
        'architecture': "3D ConvNet",
        'epochs' : EPOCHS,
        'batch_size': BATCH_SIZE
    }
)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmadhava20217[0m ([33mmadkri[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888889479, max=1.0)…

In [14]:
weights = torch.tensor((train_set['label'].value_counts()/len(train_set)).to_numpy())
weights = 1 - weights
weights = weights.float()
print(weights)

tensor([0.2558, 0.7442])


In [15]:
# training loop

# train_metrics = []

scaler = torch.cuda.amp.grad_scaler.GradScaler()
CrossEntropyLoss = torch.nn.CrossEntropyLoss(weight = weights).to(DEVICE)           # unweighted right now

acc_train = Accuracy(task = 'multiclass', num_classes = NUM_CLASSES, average = 'weighted', top_k = 1).to(DEVICE)
acc_val = Accuracy(task = 'multiclass', num_classes = NUM_CLASSES, average = 'weighted', top_k = 1).to(DEVICE)

recall_train = Recall(task = 'multiclass', average = 'weighted', num_classes = NUM_CLASSES).to(DEVICE)
recall_val = Recall(task = 'multiclass', average = 'weighted', num_classes = NUM_CLASSES).to(DEVICE)

precision_train = Precision(task = 'multiclass', average = 'weighted', num_classes = NUM_CLASSES).to(DEVICE)
precision_val = Precision(task = 'multiclass', average = 'weighted', num_classes = NUM_CLASSES).to(DEVICE)

f1_train = F1Score(task = 'multiclass', average = 'weighted', num_classes = NUM_CLASSES).to(DEVICE)
f1_val = F1Score(task = 'multiclass', average = 'weighted', num_classes = NUM_CLASSES).to(DEVICE)

for epoch in range(1, EPOCHS+1):
    epoch_loss = 0
    samples = 0
    #set model to train
    model.train()
    for data in (pbar:= tqdm(train_loader)):
        #zero optim grad
        optim.zero_grad(set_to_none=True)
        img, labels = data

        # append to sampels
        n_batch = len(img)
        samples+= n_batch

        # to device
        img = img.to(DEVICE)
        labels = labels.to(DEVICE)

        #forward step
        with torch.autocast(device_type = 'cuda', dtype = torch.float16):
            outputs = model(img)
            loss = CrossEntropyLoss(outputs, labels)
        

        # backward step
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()

        epoch_loss+= (loss.item() * n_batch)
        pbar.set_description(f"CE Loss: {epoch_loss/samples}")

        # accuracy and metrics
        acc_train(outputs, labels)
        precision_train(outputs, labels)
        recall_train(outputs,labels)
        f1_train(outputs, labels)

    # validation
    model.eval()

    val_loss = 0
    val_samples = 0
    for data in (val_loader):
        val_img, val_labels = data

        # append to sampels
        n_batch_val = len(img)
        val_samples+= n_batch_val

        # to device
        val_img = val_img.to(DEVICE)
        val_labels = val_labels.to(DEVICE)


        with torch.no_grad():  
            val_outputs = model(val_img)
            v_loss = CrossEntropyLoss(val_outputs, val_labels)
            val_loss+= (v_loss.item() * n_batch_val)

        # metrics
        acc_val(val_outputs, val_labels)
        precision_val(val_outputs, val_labels)
        recall_val(val_outputs, val_labels)
        f1_val(val_outputs, val_labels)

    logging_dict = {
        'epoch': epoch,
        'train_loss': epoch_loss/samples,
        'val_loss': val_loss/val_samples,
        
        
        'train_acc': acc_train.compute(),
        'train_rec': recall_train.compute(),
        'train_prec': precision_train.compute(),
        'train_f1': f1_train.compute(),
        
        'val_acc_top1': acc_val.compute(),
        'val_rec': recall_val.compute(),
        'val_prec': precision_val.compute(),
        'val_f1' : f1_val.compute()
    }
    
    # print('Training Accuracy: {}, Validation Accuracy: {}'.format(logging_dict['train_acc'].item(), logging_dict['val_acc_top1'].item()))
    wandb.log(logging_dict)
    
    # train_metrics.append(logging_dict)
    
    acc_train.reset()
    acc_val.reset()
    recall_train.reset()
    recall_val.reset()
    f1_train.reset()
    f1_val.reset()
    precision_train.reset()
    precision_val.reset()
    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    wandb.save(MODEL_SAVE_PATH)



CE Loss: 0.20225836607706985: 100%|██████████| 104/104 [01:49<00:00,  1.05s/it]
CE Loss: 0.1723579501827038: 100%|██████████| 104/104 [01:46<00:00,  1.02s/it] 
CE Loss: 0.19735226340211978: 100%|██████████| 104/104 [01:45<00:00,  1.01s/it]
CE Loss: 0.19610532026109828: 100%|██████████| 104/104 [01:43<00:00,  1.00it/s]
CE Loss: 0.1830793091514718: 100%|██████████| 104/104 [01:46<00:00,  1.02s/it] 
CE Loss: 0.1815568495040052: 100%|██████████| 104/104 [01:44<00:00,  1.00s/it] 
CE Loss: 0.19360402520644918: 100%|██████████| 104/104 [01:47<00:00,  1.04s/it]
CE Loss: 0.17160414746969707: 100%|██████████| 104/104 [01:44<00:00,  1.00s/it]
CE Loss: 0.16000574607719747: 100%|██████████| 104/104 [01:44<00:00,  1.00s/it]
CE Loss: 0.16030015654107482: 100%|██████████| 104/104 [01:44<00:00,  1.00s/it]


In [16]:
# import pickle as pkl

# with open("training_logs.pkl", 'wb') as file:
#     pkl.dump(train_metrics, file)

In [17]:
# testing
model.eval()

labs = []

acc_test = Accuracy(task = 'multiclass', average = 'weighted', num_classes = NUM_CLASSES, top_k = 1).to(DEVICE)
recall_test = Recall(task = 'multiclass', average = 'weighted', num_classes = NUM_CLASSES).to(DEVICE)
precision_test = Precision(task = 'multiclass', average = 'weighted', num_classes = NUM_CLASSES).to(DEVICE)
f1_test = F1Score(task = 'multiclass', average = 'weighted', num_classes = NUM_CLASSES).to(DEVICE)

test_loss = 0
test_samples = 0
for data in (pbar:= tqdm(test_loader)):
    test_img, test_labels = data

    # append to sampels
    n_batch_test = len(img)
    test_samples+= n_batch_test

    # to device
    test_img = test_img.to(DEVICE)
    test_labels = test_labels.to(DEVICE)


    with torch.no_grad():  
        test_outputs = model(test_img)
        t_loss = CrossEntropyLoss(test_outputs, test_labels)
        test_loss+= (t_loss.item() * n_batch_test)
        labs.append(test_outputs)

    # metrics
    acc_test(test_outputs, test_labels)
    precision_test(test_outputs, test_labels)
    recall_test(test_outputs, test_labels)
    f1_test(test_outputs, test_labels)
test_metrics = {
    'test_acc': acc_test.compute(),
    'test_rec': recall_test.compute(),
    'test_prec': precision_test.compute(),
    'test_f1': f1_test.compute(),
}
wandb.log(test_metrics)

100%|██████████| 56/56 [01:12<00:00,  1.30s/it]


In [18]:
out = []

for i in labs:
    out.extend(torch.argmax(i, 1).detach().cpu().tolist())

In [19]:
np.unique(out, return_counts=True)

(array([1]), array([277], dtype=int64))

In [20]:
# import pickle as pkl

# with open("testing_logs.pkl", 'wb') as file:
#     pkl.dump(test_metrics, file)

In [21]:
wandb.finish()

VBox(children=(Label(value='0.167 MB of 0.167 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▂▃▃▄▅▆▆▇█
test_acc,▁
test_f1,▁
test_prec,▁
test_rec,▁
train_acc,▅▇▅▂▃▃▁▆▇█
train_f1,▄▆▄▂▃▃▁▆▇█
train_loss,█▃▇▇▅▅▇▃▁▁
train_prec,▂▃▂▁▇▁▇▂▃█
train_rec,▅▇▅▂▃▃▁▆▇█

0,1
epoch,10.0
test_acc,0.74368
test_f1,0.63436
test_prec,0.55306
test_rec,0.74368
train_acc,0.90698
train_f1,0.8646
train_loss,0.1603
train_prec,0.91565
train_rec,0.90698
