In [1]:
import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F
import torchvision
import transformers

from PIL import Image

import os
import glob

import wandb
from torchinfo import summary

from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns 

import SimpleITK as sitk



from utils import normalize_sitk_image, MRI_Dataset_within_ROI, MRI_Dataset_within_ROI_both_prepost

from torchvision.transforms import v2
import torchio as tio

from torchmetrics import Accuracy, Recall, Precision, F1Score


torch.set_num_threads(12)
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

from sklearn.model_selection import train_test_split

In [2]:
# parameters
SIZE = (100,100,100)
TRAIN_SPLIT = 0.75
SPLIT_SEED = 123456789
BATCH_SIZE = 5

DATA_PATH = '../../../Processed NIFTI Dataset/'
CLASSIFICATION = 'ER'
dataset_path = f'../../Train Test Splits/{CLASSIFICATION}/'

MODEL_SAVE_PATH = f'basic_model_{CLASSIFICATION}.pth'
PROJECT_NAME = 'Breast Cancer Subtype Prediction'

In [3]:
# dataset loading

train_set = pd.read_csv(dataset_path + 'train.csv')
test_set  = pd.read_csv(dataset_path + 'test.csv')

bounding_boxes = pd.read_csv('../../Data/segmentation_annotations_NIFTI.csv').set_index('Patient_ID')

NUM_CLASSES = len(train_set['label'].unique())

In [4]:
# validation set split

train_set, val_set = train_test_split(train_set, train_size = TRAIN_SPLIT, stratify = train_set['label'])

In [5]:
# transforms

transform = v2.Compose([
    tio.ZNormalization()
    # v2.Normalize(mean = [0.5], std = [0.225])
])

upscaler = tio.Resize(SIZE, 'bspline')

In [6]:
affine_args = {
        'scales' : [0.8, 1, 0.8, 1, 0.8, 1],
        'degrees' : 15,
        'translation' : [0.05, 0.05, 0.05],
        'center' : 'image'
}

augment = tio.Compose([
    tio.transforms.RandomFlip(axes = (0, 1, 2)),
    tio.transforms.RandomAffine(
        **affine_args
    )
])

In [7]:
# # Datasets

# train_dataset = MRI_Dataset_within_ROI(DATA_PATH,
#                                        train_set,
#                                        bounding_boxes, 
#                                        transform,
#                                        upscaler,
#                                     #    augment = augment
#                                        )

# val_dataset = MRI_Dataset_within_ROI(DATA_PATH,
#                                        val_set,
#                                        bounding_boxes, 
#                                        transform,
#                                        upscaler)

# test_dataset = MRI_Dataset_within_ROI(DATA_PATH,
#                                        test_set,
#                                        bounding_boxes, 
#                                        transform,
#                                        upscaler)

In [8]:
# Datasets

train_dataset = MRI_Dataset_within_ROI_both_prepost(DATA_PATH,
                                       train_set,
                                       bounding_boxes, 
                                       transform,
                                       upscaler,
                                    #    augment = augment
                                       )

val_dataset = MRI_Dataset_within_ROI_both_prepost(DATA_PATH,
                                       val_set,
                                       bounding_boxes, 
                                       transform,
                                       upscaler)

test_dataset = MRI_Dataset_within_ROI_both_prepost(DATA_PATH,
                                       test_set,
                                       bounding_boxes, 
                                       transform,
                                       upscaler)

In [9]:
## Dataloaders

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           shuffle = True,
                                           batch_size = BATCH_SIZE,
                                           pin_memory = True,
                                           num_workers = 2
                                           )

val_loader = torch.utils.data.DataLoader(val_dataset,
                                           batch_size = BATCH_SIZE,
                                           pin_memory = True,
                                           num_workers = 2
                                           )

test_loader = torch.utils.data.DataLoader(test_dataset,
                                           batch_size = BATCH_SIZE,
                                           pin_memory = True,
                                           num_workers = 2
                                           )

## Model architecture

In [10]:
class ConvNet_MRI3D(torch.nn.Module):
    def __init__(self, in_channels, num_classes):
        super(ConvNet_MRI3D, self).__init__()
        self.conv1 = torch.nn.Conv3d(in_channels, 16, kernel_size = (3,3,3), stride = (1,1,1), padding = 1)
        self.conv2 = torch.nn.Conv3d(16, 16, kernel_size = (3,3,3), stride = (1,1,1), padding = 1)
        self.conv3 = torch.nn.Conv3d(16, 16, kernel_size = (3,3,3), stride = (1,1,1), padding = 1)
        self.conv4 = torch.nn.Conv3d(16, 16, kernel_size = (3,3,3), stride = (1,1,1), padding = 1)
        
        self.conv1x1_1 = torch.nn.Conv3d(16, 16, kernel_size=1, stride=1)
        self.conv1x1_2 = torch.nn.Conv3d(16, 16, kernel_size=1, stride=1)
        self.maxpool1 = torch.nn.MaxPool3d(2)

        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(3456, num_classes)
    
    def forward(self, inp):
        out = inp
        out = F.relu(self.conv1(out))
        intermediate = self.maxpool1(out)
        
        
        out = F.relu(self.conv2(intermediate))
        out = out + self.conv1x1_1(intermediate)      # residual
        intermediate = self.maxpool1(out)
        
        out = F.relu(self.conv3(intermediate))
        out = out + self.conv1x1_2(intermediate)

        out = self.maxpool1(out)
        out = self.conv4(out)
        out = self.maxpool1(out)
        out = self.flatten(out)
        
        out = self.fc1(out)
        
        return out
        
        

In [11]:
# summary(ConvNet_MRI3D(2, 2), input_size=(20,2,100,100,100), col_names = ['input_size', 'output_size', 'num_params'])

In [19]:
LR = 0.0001
EPOCHS = 10

model = ConvNet_MRI3D(2, NUM_CLASSES).to(DEVICE)
optim = torch.optim.Adam(model.parameters(), lr = LR)

# wandb.login()
# wandb.init(
#     project = PROJECT_NAME,
#     name = CLASSIFICATION,

#     config = {
#         'learning_rate': LR,
#         'architecture': "3D ConvNet",
#         'epochs' : EPOCHS,
#         'batch_size': BATCH_SIZE
#     }
# )


In [20]:
weights = torch.tensor((train_set['label'].value_counts()/len(train_set)).to_numpy())
weights = 1/weights
weights /= weights.min()
weights = weights.float().to(DEVICE)
weights = weights/weights.mean()
weights

tensor([0.5135, 1.4865], device='cuda:0')

In [21]:
# training loop

train_metrics = []

scaler = torch.cuda.amp.grad_scaler.GradScaler()
CrossEntropyLoss = torch.nn.CrossEntropyLoss(weight = weights).to(DEVICE)

acc_train = Accuracy(task = 'multiclass', num_classes = NUM_CLASSES, top_k = 1).to(DEVICE)
acc_val = Accuracy(task = 'multiclass', num_classes = NUM_CLASSES, top_k = 1).to(DEVICE)

recall_train = Recall(task = 'multiclass', num_classes = NUM_CLASSES).to(DEVICE)
recall_val = Recall(task = 'multiclass', num_classes = NUM_CLASSES).to(DEVICE)

precision_train = Precision(task = 'multiclass', num_classes = NUM_CLASSES).to(DEVICE)
precision_val = Precision(task = 'multiclass', num_classes = NUM_CLASSES).to(DEVICE)

f1_train = F1Score(task = 'multiclass', num_classes = NUM_CLASSES).to(DEVICE)
f1_val = F1Score(task = 'multiclass', num_classes = NUM_CLASSES).to(DEVICE)

for epoch in range(1, EPOCHS+1):
    epoch_loss = 0
    samples = 0
    #set model to train
    model.train()
    for data in (pbar:= tqdm(train_loader)):
        #zero optim grad
        optim.zero_grad(set_to_none=True)
        img, labels = data

        # append to sampels
        n_batch = len(img)
        samples+= n_batch

        # to device
        img = img.to(DEVICE)
        labels = labels.to(DEVICE)

        #forward step
        with torch.autocast(device_type = 'cuda', dtype = torch.float16):
            outputs = model(img)
            loss = CrossEntropyLoss(outputs, labels)
        

        # backward step
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()

        epoch_loss+= (loss.item() * n_batch)
        pbar.set_description(f"CE Loss: {epoch_loss/samples}")

        # accuracy and metrics
        acc_train(outputs, labels)
        precision_train(outputs, labels)
        recall_train(outputs,labels)
        f1_train(outputs, labels)

    # validation
    model.eval()

    val_loss = 0
    val_samples = 0
    for data in (val_loader):
        val_img, val_labels = data

        # append to sampels
        n_batch_val = len(img)
        val_samples+= n_batch_val

        # to device
        val_img = val_img.to(DEVICE)
        val_labels = val_labels.to(DEVICE)


        with torch.no_grad():  
            val_outputs = model(val_img)
            v_loss = CrossEntropyLoss(val_outputs, val_labels)
            val_loss+= (v_loss.item() * n_batch_val)

        # metrics
        acc_val(val_outputs, val_labels)
        precision_val(val_outputs, val_labels)
        recall_val(val_outputs, val_labels)
        f1_val(val_outputs, val_labels)

    logging_dict = {
        'epoch': epoch,
        'train_loss': epoch_loss/samples,
        'val_loss': val_loss/val_samples,
        
        
        'train_acc': acc_train.compute(),
        'train_rec': recall_train.compute(),
        'train_prec': precision_train.compute(),
        'train_f1': f1_train.compute(),
        
        'val_acc_top1': acc_val.compute(),
        'val_rec': recall_val.compute(),
        'val_prec': precision_val.compute(),
        'val_f1' : f1_val.compute()
    }
    
    print('Training Accuracy: {}, Validation Accuracy: {}'.format(logging_dict['train_acc'].item(), logging_dict['val_acc_top1'].item()))
    
    train_metrics.append(logging_dict)
    
    acc_train.reset()
    acc_val.reset()
    recall_train.reset()
    recall_val.reset()
    f1_train.reset()
    f1_val.reset()
    precision_train.reset()
    precision_val.reset()
    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    # wandb.save(MODEL_SAVE_PATH)



CE Loss: 0.40904389260101515: 100%|██████████| 97/97 [02:24<00:00,  1.49s/it]


Training Accuracy: 0.7370600700378418, Validation Accuracy: 0.7469135522842407


CE Loss: 0.3807996062573439: 100%|██████████| 97/97 [02:05<00:00,  1.29s/it] 


Training Accuracy: 0.7432712316513062, Validation Accuracy: 0.7469135522842407


CE Loss: 0.3648752240969281: 100%|██████████| 97/97 [02:05<00:00,  1.30s/it] 


Training Accuracy: 0.7432712316513062, Validation Accuracy: 0.7469135522842407


CE Loss: 0.3709760159562587: 100%|██████████| 97/97 [02:06<00:00,  1.30s/it] 


Training Accuracy: 0.7432712316513062, Validation Accuracy: 0.7469135522842407


CE Loss: 0.3717754186433788: 100%|██████████| 97/97 [02:06<00:00,  1.30s/it] 


Training Accuracy: 0.7432712316513062, Validation Accuracy: 0.7469135522842407


CE Loss: 0.3543183697121484: 100%|██████████| 97/97 [02:05<00:00,  1.29s/it] 


Training Accuracy: 0.7432712316513062, Validation Accuracy: 0.7469135522842407


CE Loss: 0.3567116712345346: 100%|██████████| 97/97 [02:05<00:00,  1.29s/it] 


Training Accuracy: 0.7432712316513062, Validation Accuracy: 0.7469135522842407


CE Loss: 0.35058640602274216: 100%|██████████| 97/97 [02:06<00:00,  1.30s/it]


Training Accuracy: 0.7432712316513062, Validation Accuracy: 0.7469135522842407


CE Loss: 0.3465940893683621: 100%|██████████| 97/97 [02:04<00:00,  1.29s/it] 


Training Accuracy: 0.7432712316513062, Validation Accuracy: 0.7469135522842407


CE Loss: 0.3350946472853607: 100%|██████████| 97/97 [02:04<00:00,  1.29s/it] 


Training Accuracy: 0.7453415989875793, Validation Accuracy: 0.7469135522842407


In [22]:
import pickle as pkl

with open("training_logs.pkl", 'wb') as file:
    pkl.dump(train_metrics, file)

In [23]:
# testing
model.eval()

acc_test = Accuracy(task = 'multiclass', num_classes = NUM_CLASSES, top_k = 1).to(DEVICE)
recall_test = Recall(task = 'multiclass', num_classes = NUM_CLASSES).to(DEVICE)
precision_test = Precision(task = 'multiclass', num_classes = NUM_CLASSES).to(DEVICE)
f1_test = F1Score(task = 'multiclass', num_classes = NUM_CLASSES).to(DEVICE)

test_loss = 0
test_samples = 0
for data in (pbar:= tqdm(test_loader)):
    test_img, test_labels = data

    # append to sampels
    n_batch_test = len(img)
    test_samples+= n_batch_test

    # to device
    test_img = test_img.to(DEVICE)
    test_labels = test_labels.to(DEVICE)


    with torch.no_grad():  
        test_outputs = model(test_img)
        t_loss = CrossEntropyLoss(test_outputs, test_labels)
        test_loss+= (t_loss.item() * n_batch_test)

    # metrics
    acc_test(test_outputs, test_labels)
    precision_test(test_outputs, test_labels)
    recall_test(test_outputs, test_labels)
    f1_test(test_outputs, test_labels)
test_metrics = {
    'test_acc': acc_test.compute(),
    'test_rec': recall_test.compute(),
    'test_prec': precision_test.compute(),
    'test_f1': f1_test.compute(),
}
# wandb.log()

100%|██████████| 56/56 [01:15<00:00,  1.35s/it]


In [24]:
import pickle as pkl

with open("testing_logs.pkl", 'wb') as file:
    pkl.dump(test_metrics, file)

## Plots

In [25]:
train_d = pd.DataFrame.from_dict(train_metrics)
train_d

Unnamed: 0,epoch,train_loss,val_loss,train_acc,train_rec,train_prec,train_f1,val_acc_top1,val_rec,val_prec,val_f1
0,1,0.409044,0.383698,"tensor(0.7371, device='cuda:0')","tensor(0.7371, device='cuda:0')","tensor(0.7371, device='cuda:0')","tensor(0.7371, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')"
1,2,0.3808,0.379476,"tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')"
2,3,0.364875,0.413584,"tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')"
3,4,0.370976,0.406642,"tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')"
4,5,0.371775,0.398275,"tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')"
5,6,0.354318,0.385302,"tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')"
6,7,0.356712,0.380451,"tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')"
7,8,0.350586,0.373581,"tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')"
8,9,0.346594,0.40912,"tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7433, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')"
9,10,0.335095,0.373481,"tensor(0.7453, device='cuda:0')","tensor(0.7453, device='cuda:0')","tensor(0.7453, device='cuda:0')","tensor(0.7453, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')","tensor(0.7469, device='cuda:0')"


In [26]:
test_metrics

{'test_acc': tensor(0.7437, device='cuda:0'),
 'test_rec': tensor(0.7437, device='cuda:0'),
 'test_prec': tensor(0.7437, device='cuda:0'),
 'test_f1': tensor(0.7437, device='cuda:0')}

In [27]:
wandb.finish()