# 0 Initialise

### Import Packages

In [8]:
# Import packages
import torch
import warnings
import pandas as pd

from fastai.vision import *         # script requires fastai version 1.0.61
from fastai.callbacks import *
from image_tabular.core import *
from image_tabular.metric import *

from src.cnn_functions import SplitData
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

In [3]:
# suppress warning
warnings.filterwarnings('ignore', category=UserWarning, module='torch.nn.functional')

# change device to use GPU if avalible
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### Paths

In [3]:
# Path for data location
data_path = '../Cardiomegaly Classification/MIMIC_features/MIMIC_features.pkl'

# Path for image location
image_path = '../MIMIC/'

# Path for model storage
model_folder = '../Cardiomegaly Classification/models/cnn'
model_storage = 'Image_CNN.pkl'

### Model Parameters

In [15]:
# Data parameters
DataSplits = [0.8, 0.1, 0.1]        # Spits of data for train, validation, and test sets
norm_pixel = ([0.4712,],[0.3030,])  # Data normalisation: normalised pixel values in image
size = 244                          # Data normalisation: normalised image size
max_rot = 10                        # Data augmentaiton: maximum rotation
Vflip = True                        # Data augmentation: vertical flips (True/False)   
Hflip = True                        # Data augmentation: horizontal flips (Ture/False)

# CNN parameters
bs = 64                             # batch size
epochs = 15                         # epochs of training
lr = 1e-2                           # learning rate

# 1 Data Preprocessing

In [16]:
# Read 
data = pd.read_pickle(data_path)

# Change name of column to indicate class deoaration
data.rename(columns={'Cardiomegaly':'class'}, inplace=True)

# Split into 5 folds
[train_df, val_df, test_df] = SplitData(data, DataSplits)

In [8]:
# combine into dataframes for integration into fastai learner
train_val_df = pd.concat([train_df,val_df]).reset_index(drop=True)
train_test_df = pd.concat([train_df,test_df]).reset_index(drop=True)

val_idx = val_df.index.to_numpy() + len(train_df)
test_idx = test_df.index.to_numpy() + len(train_df)

In [7]:
# Define transforms applied on images for data augmentation
tfms = get_transforms(max_rotate = max_rot, do_flip=Hflip, flip_vert=Vflip)    

In [None]:
# load image data using train_df and prepare fastai LabelLists
train_image_data = (ImageList.from_df(train_val_df, path=image_path, cols='path')
                        .split_by_idx(val_idx)
                        .label_from_df(cols='class')
                        .transform(tfms, size=size)
                        .databunch(bs=bs)
                        .normalize(norm_pixel))

# same for test data (but without transforms)
test_image_data = (ImageList.from_df(train_test_df, path=image_path, cols='path')
                            .split_by_idx(test_idx)
                            .label_from_df(cols='class')
                            .transform(size=size)
                            .databunch(bs=bs)
                            .normalize(norm_pixel))



# 2 Model Training

### Model Definition

In [10]:
# adjust loss function weight because the dataset is extremely unbalanced
weights = [1/(1-train_df['class'].mean()), 1/train_df['class'].mean()]

loss_func = CrossEntropyFlat(weight=torch.FloatTensor(weights).to(device))

In [None]:
# package everything in fastai learner, use accuracy and auc roc score as metrics
learn = cnn_learner(train_image_data, 
                    models.resnet50, 
                    lin_ftrs=[512, 256, 32], 
                    ps=0.2, 
                    metrics=[accuracy, ROCAUC()], 
                    loss_func=loss_func,
                    path = model_storage)

### Model Training

In [None]:
# Generate figure to check if learning rate is appropriate
learn.lr_find()
x_unfrozenplot = learn.recorder.plot(return_fig=True)
x_unfrozenplot.savefig(model_folder + 'learning_rate_fig.jpg')

In [None]:
# Train model and save version with lowest validation loss
learn.fit_one_cycle(epochs, lr, callbacks=[SaveModelCallback(learn, monitor='valid_loss', mode='min')])
learn.export(model_storage)

# 3 Testing

In [None]:
# Change data avaliable to model 
learn.data = test_image_data

# Get predictions and make binary
learn.validate()
preds, targets = learn.get_preds()
class_preds = np.argmax(preds, 1) 

In [None]:
#Accuracy
Accuracy = accuracy_score(targets, class_preds) 
print('Accuracy = ' + str(Accuracy))

#F1 Score
f1Score = f1_score(targets, class_preds)
print('F1 Score = ' + str(f1Score))

#Confusion Matrix
CF = confusion_matrix(targets, class_preds)
print(CF)