In [None]:
import pandas as pd
import numpy as np
import os

from tqdm.notebook import tqdm

import nibabel as nib
import torchio as tio

import torch
import torch.nn as nn
from fastai.vision.all import *

import import_ipynb
import Utils as u

In [None]:
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

### Data Preparation

##### Load df

In [None]:
df = pd.read_excel('/df_tr_val_test.xlsx')

##### Split training/validation and test set

In [None]:
df_train = df[df['train_val_test']=='train']

df_valid = df[df['train_val_test']=='valid']
df_valid = df_valid.reset_index(drop=True)

df_test = df[df['train_val_test']=='test']
df_test = df_test.reset_index(drop=True)

### Create Subjects

In [None]:
subjects_train = [tio.Subject(perf1=tio.ScalarImage(df_train['tmax_paths'][i]),
                              target=(df_train['target_bin'][i]),
                              nihss=df_train['NIH on admission'][i],
                              acc=df_train['AccessNumber'][i]) 
                  for i in tqdm(range(len(df_train)))]

In [None]:
subjects_valid = [tio.Subject(perf1=tio.ScalarImage(df_valid['tmax_paths'][i]),
                              target=(df_valid['target_bin'][i]),
                              nihss=df_valid['NIH on admission'][i],
                              acc=df_valid['AccessNumber'][i]) 
                  for i in tqdm(range(len(df_valid)))]

In [None]:
subjects_test = [tio.Subject(perf1=tio.ScalarImage(df_test['tmax_paths'][i]),
                             target=(df_test['target_bin'][i]),
                             nihss=df_test['NIH on admission'][i],
                             acc=df_test['AccessNumber'][i]) 
                 for i in tqdm(range(len(df_test)))]

### Preprocess (perform all static preprocessing steps before training)

##### Define target size:

In [None]:
target_size = (70,84,18)

##### Preprocess images

In [None]:
train_fail=[]
for i in tqdm(range(len(subjects_train)), desc='Preprocess training subjects'):
    try:
        subjects_train[i] = u.preprocess(subjects_train[i], target_size)
    except:
        train_fail.append(i)

if train_fail == []:
    print("All preprocessed successfully")
else:
    print(f"Index of failed preprocessing in training subjects: {train_fail}")

In [None]:
valid_fail=[]
for i in tqdm(range(len(subjects_valid)), desc='Preprocess training subjects'):
    try:
        subjects_valid[i] = u.preprocess(subjects_valid[i], target_size)
    except:
        valid_fail.append(i)

if valid_fail == []:
    print("All preprocessed successfully")
else:
    print(f"Index of failed preprocessing in training subjects: {valid_fail}")

In [None]:
test_fail=[]
for i in tqdm(range(len(subjects_test)), desc='Preprocess test subjects'):
    try:
        subjects_test[i] = u.preprocess(subjects_test[i], target_size)
    except:
        test_fail.append(i)

if test_fail == []:
    print("All preprocessed successfully")
else:
    print(f"Index of failed preprocessing in test subjects: {test_fail}")

##### Visual checks

In [None]:
s = random.randint(0,len(subjects_train))

print(f'Subject: {s}')
perfs = ['perf1']
titles = ['TMAX']

print(f'Shape: {subjects_train[s]['perf1'].data.shape}')

for perf, title in zip(perfs, titles):
    plt.imshow(subjects_train[s][perf][tio.DATA][0,:,:,7])
    plt.title(title)
    plt.show()

### Transforms

##### Define Transforms

In [None]:
x,y,z=subjects_train[0]['perf1'].shape[1], subjects_train[0]['perf1'].shape[2], subjects_train[0]['perf1'].shape[3]

train_tf = tio.Compose([tio.RandomBlur(std=1,p=0.5), 
                        tio.RandomNoise(mean=0, std=(0,0.05),p=0.3), 
                        tio.RandomGhosting(p=0.3), 
                        tio.RandomSwap(patch_size=(round(x/10),round(y/10),round(z/10)),num_iterations=20,p=0.5), 
                       ])

valid_tf = tio.Compose([])
test_tf = tio.Compose([])

##### Visual check

In [None]:
img = subjects_train[random.randint(0, 789)]['perf1'][tio.DATA]
plt.imshow(img[0,:,:,10])
plt.title('Original example')
plt.show()
print(f'Original shape: {img.shape[1]}x{img.shape[2]}x{img.shape[3]}')

img_t = train_tf(img)
plt.imshow(img_t[0,:,:,10])
plt.title('Training example')
plt.show()
print(f'Shape after train_transforms: {img_t.shape[1]}x{img_t.shape[2]}x{img_t.shape[3]}')

img_v = valid_tf(img)
plt.imshow(img_v[0,:,:,10])
plt.title('Validation example')
plt.show()
print(f'Shape after valid_transforms: {img_v.shape[1]}x{img_v.shape[2]}x{img_v.shape[3]}')

### Dataset/Dataloader

In [None]:
train_bn=2
valid_bn=2
test_bn=2

dls = u.make_dls(subjects_train, subjects_valid, train_tf, valid_tf, train_bn=train_bn, valid_bn=valid_bn)
print(f'Training set: n={len(dls.train.dataset)} ({len(dls.train.dataset)/len(df):.1%})')
print(f'Validation set: n={len(dls.valid.dataset)} ({len(dls.valid.dataset)/len(df):.1%})')

dls_test = u.make_dls(subjects_test, subjects_test, test_tf, valid_tf, train_bn=valid_bn, valid_bn=valid_bn)
print(f'Test set: n={len(dls_test.valid.dataset)} ({len(dls_test.valid.dataset)/len(df):.1%})')

### Load parameters

In [None]:
in_channels = dls.train.dataset[0]['image'].shape[0]
num_classes = 2

model_resnet = u.ResNet3D18(num_classes=num_classes,
                            in_channels=in_channels)

loss_func = CrossEntropyLossFlat()
lr = 5e-6
opt_func = Adam

onecyc = u.OneCycle(lr)
cbs = [onecyc]

### Set up Learners and fit

#### Fit Learner: Image only

In [None]:
learner_img = u.My_Learner_img(model_resnet, dls, loss_func=loss_func, lr=lr, cbs=cbs)

In [None]:
learner_img.fit(10)

In [None]:
learner_img.validate()

In [None]:
# torch.save(learner_img.model.state_dict(), '/media/user/Elements/BENIGN_Results_MRP_nih_dich/MRP_25_04_24_kai/weights_resnet18.pth')

### Interprete learners

In [None]:
class_names = [f"NIHSS 0-{cutoff}", f"NIHSS >{cutoff}"]
class_names

In [None]:
# learner_img.model.load_state_dict(torch.load('/media/user/Elements/BENIGN_Results_MRP_nih_dich/MRP_25_04_24_kai/weights_resnet18.pth'))

In [None]:
val_plot, val_cm_disp, val_y_true, val_y_pred_proba = u.Interp_from_learner(learner_img, c=in_channels, class_names=class_names, 
                                                              use_tabular=False,
                                                              download=False, 
                                                              download_path='/media/user/Elements/combined_plot_val.png',
                                                              dpi=300,
                                                              title='Validation set'
                                                             )                               

### TEST

In [None]:
learner_img.dls = dls_test

In [None]:
tst_plot, tst_cm_disp, tst_y_true, tst_y_pred_proba = u.Interp_from_learner(learner_img, c=in_channels, class_names=class_names, 
                                                              use_tabular=False,
                                                              download=False, 
                                                              download_path='/media/user/Elements/BENIGN_Results_MRP_nih_dich/MRP_25_04_24_kai/plot_test_voc_only.tiff',
                                                              dpi=900,
                                                              title='Test set'
                                                             )     

In [None]:
#Combine plots 3x2
from matplotlib import gridspec
import matplotlib.lines as mlines

fig = plt.figure(figsize=(20, 15))
gs = gridspec.GridSpec(1, 3, figure=fig)
plt.axis('off')

# Add line at 1/3 of figure width
line = mlines.Line2D([0.325, 0.325], [0.1, 0.91], color='grey', linestyle=':', linewidth=2, transform=fig.transFigure)
fig.add_artist(line)
line = mlines.Line2D([0.66, 0.66], [0.1, 0.91], color='grey', linestyle=':', linewidth=2, transform=fig.transFigure)
fig.add_artist(line)

subfig1 = fig.add_subfigure(gs[0, 0])
subfig2 = fig.add_subfigure(gs[0, 1])
subfig3 = fig.add_subfigure(gs[0, 2])

u.create_combined_plot_roc_prc(val_cm_disp, val_y_true, val_y_pred_proba, class_names, n_classes=2, download=False, 
                     download_path='combined_plot_test.tiff', dpi=900,
                     title='Validation set', fig=subfig1)
u.create_combined_plot_roc_prc(tst_cm_disp, tst_y_true, tst_y_pred_proba, class_names, n_classes=2, download=False, 
                     download_path='combined_plot_test.tiff', dpi=900, 
                     title='Test set', fig=subfig2)
u.create_combined_plot_roc_prc(tst_cm_disp_voc, tst_y_true_voc, tst_y_pred_proba_voc, class_names, n_classes=2, download=False, 
                     download_path='combined_plot_test.tiff', dpi=900, 
                     title='Test set (only vessel occlusions)', fig=subfig3)

plt.tight_layout()
plt.subplots_adjust(hspace=0.2)
# plt.savefig('combined_all_plot_test.tiff', dpi=900)
# plt.savefig('combined_all_plot_test.png', dpi=300)
plt.show()

#### Analysis of wrongly classified instances

In [None]:
u.Wrong_instances(learner_img, c=in_channels, use_tabular=False)

##### GradCAM

In [None]:
model = learner_img.model
model.eval()

target_layer = model.resnet.layer4[-1]  # Last layer of ResNet3D18

input_image = dls_test.valid.dataset[i]['image'].cuda()

import matplotlib.colors as mcolors
norm = mcolors.Normalize(vmin=3.5, vmax=7, clip=True)

grad_cam = u.GradCAM(model, target_layer)
grad_cam.plot_cam(input_image, tabular_data=None, norm=norm, 
                  download=True, 
                  target_class=1, 
                  sl=9,
                  alpha = 0.6,
                 )