### https://github.com/AlbertoUAH/Knee-Lesions-Classification-via-Deep-Learning Year-2022


In [None]:
# Libraries
from matplotlib import pyplot as plt
from matplotlib.animation import PillowWriter
from torch.utils import data
from sklearn import metrics
from tqdm import tqdm
import torchvision.models as models
import torchvision.transforms as T
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import matplotlib.animation as animation
import imgaug.augmenters as iaa
import torch.nn as nn
import pandas as pd
import numpy as np
import torch
import cv2
import os
from   google.colab import drive
# CUDA Device setup
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
drive.mount('/content/drive')
# Constants
MRNET_PATH           = '/content/drive/MyDrive/MRNet-v1.0/'
TRAIN_PATH           = '/content/drive/MyDrive/MRNet-v1.0/train/'
VAL_PATH             = '/content/drive/MyDrive/MRNet-v1.0/valid/'
BATCH_SIZE = 1
RANDOM_STATE = 1234
EPOCHS = 50
PATIENT = 10
LOSS_IMPROVE = 1e-04
MAX_PIXEL_VALUE = 255

# Specify seeds for reproducibility
torch.manual_seed(RANDOM_STATE)
torch.cuda.manual_seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True

# Load Stanford MRI Dataset
train_df_abnormal = pd.read_csv(MRNET_PATH + 'train-abnormal.csv', header=None)
train_df_acl = pd.read_csv(MRNET_PATH + 'train-acl.csv', header=None)
train_df_meniscus = pd.read_csv(MRNET_PATH + 'train-meniscus.csv', header=None)

valid_df_abnormal = pd.read_csv(MRNET_PATH + 'valid-abnormal.csv', header=None)
valid_df_acl = pd.read_csv(MRNET_PATH + 'valid-acl.csv', header=None)
valid_df_meniscus = pd.read_csv(MRNET_PATH + 'valid-meniscus.csv', header=None)

train_df = pd.concat([train_df_abnormal, train_df_acl[1], train_df_meniscus[1]], axis=1).drop_duplicates()
valid_df = pd.concat([valid_df_abnormal, valid_df_acl[1], valid_df_meniscus[1]], axis=1).drop_duplicates()
train_df.columns = ['Image', 'Abnormal', 'ACL', 'Meniscus']
valid_df.columns = ['Image', 'Abnormal', 'ACL', 'Meniscus']

pd.concat([train_df, valid_df], axis=0).reset_index(drop=True).to_csv(MRNET_PATH + '/knee_metadata.csv')

# Define MRDataset class
class MRDataset(data.Dataset):
    def __init__(self, transform=False, train=True, train_index_limit=1130):
        super().__init__()
        self.transform = transform
        self.train = train
        self.records = pd.read_csv(MRNET_PATH + '/knee_metadata.csv')
        self.train_index_limit = train_index_limit
        self.planes = ['axial', 'sagittal', 'coronal']
        self.image_path = {}

        if self.train:
            for plane in self.planes:
                self.image_path[plane] = TRAIN_PATH + '/{0}/'.format(plane)
            self.records = self.records.iloc[0:self.train_index_limit, :]
        else:
            for plane in self.planes:
                self.image_path[plane] = VAL_PATH + '/{0}/'.format(plane)
            self.records = self.records.iloc[self.train_index_limit:, :]

        self.records['Image'] = self.records['Image'].map(lambda i: '0' * (4 - len(str(i))) + str(i))
        self.paths = {}
        for plane in self.planes:
            self.paths[plane] = [self.image_path[plane] + filename + '.npy' for filename in self.records['Image'].tolist()]

        self.labels = self.records[['Abnormal', 'ACL', 'Meniscus']].values
        weights_ = []
        for disease in list(range(0, 3)):
            pos = sum(self.labels[:, disease])
            neg = len(self.labels[:, disease]) - pos
            weights_.append(neg / pos)
        self.weights = torch.FloatTensor(weights_)

    def __len__(self):
        return len(self.records)

    def __getitem__(self, index):
        transform = iaa.Sequential([
            iaa.Fliplr(0.5),
            iaa.Affine(
                translate_percent={"x": (-0.11, 0.11), "y": (-0.11, 0.11)},
                scale={"x": (1, 1.2), "y": (1, 1.2)},
                rotate=(-10, 10)
            )
        ])

        img_raw = {}
        for plane in self.planes:
            img_raw[plane] = np.load(self.paths[plane][index])
            img_raw[plane] = (img_raw[plane]) / MAX_PIXEL_VALUE
            if self.transform:
                img_raw_transformed = transform(images=img_raw[plane])
                img_raw[plane] = np.stack((img_raw_transformed,) * 3, axis=1)
            else:
                img_raw[plane] = np.stack((img_raw[plane],) * 3, axis=1)

            img_raw[plane] = torch.FloatTensor(img_raw[plane])

        label = self.labels[index]
        label = torch.FloatTensor(label)
        return [img_raw[plane] for plane in self.planes], label

# Get dataset
train_dataset = MRDataset(transform=True)
val_dataset = MRDataset(train=False)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2)

# Build Conv2D model
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.axial = models.alexnet(pretrained=True, progress=False).features
        self.sagittal = models.alexnet(pretrained=True, progress=False).features
        self.coronal = models.alexnet(pretrained=True, progress=False).features
        self.features_conv_axial = self.axial[:12]
        self.features_conv_sagittal = self.sagittal[:12]
        self.features_conv_coronal = self.coronal[:12]
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
        self.avg_pool_axial = nn.AdaptiveAvgPool2d(1)
        self.avg_pool_sagittal = nn.AdaptiveAvgPool2d(1)
        self.avg_pool_coronal = nn.AdaptiveAvgPool2d(1)
        self.gradients_axial = None
        self.gradients_sagittal = None
        self.gradients_coronal = None
        self.fc = nn.Sequential(nn.Linear(in_features=3 * 256, out_features=3))

    def activations_hook_axial(self, grad):
        self.gradients_axial = grad

    def activations_hook_sagittal(self, grad):
        self.gradients_sagittal = grad

    def activations_hook_coronal(self, grad):
        self.gradients_coronal = grad

    def forward(self, x):
        images = [torch.squeeze(img, dim=0) for img in x]
        image1 = self.features_conv_axial(images[0])
        image2 = self.features_conv_sagittal(images[1])
        image3 = self.features_conv_coronal(images[2])
        h_axial = image1.register_hook(self.activations_hook_axial)
        h_sagittal = image2.register_hook(self.activations_hook_sagittal)
        h_coronal = image3.register_hook(self.activations_hook_coronal)
        image1 = self.max_pool(image1)
        image2 = self.max_pool(image2)
        image3 = self.max_pool(image3)
        image1 = self.avg_pool_axial(image1).view(image1.size(0), -1)
        image2 = self.avg_pool_sagittal(image2).view(image2.size(0), -1)
        image3 = self.avg_pool_coronal(image3).view(image3.size(0), -1)
        image1 = torch.max(image1, dim=0, keepdim=True)[0]
        image2 = torch.max(image2, dim=0, keepdim=True)[0]
        image3 = torch.max(image3, dim=0, keepdim=True)[0]
        output = torch.cat([image1, image2, image3], dim=1)
        output = self.fc(output)
        return output

    def get_activations_gradient(self):
        return [self.gradients_axial, self.gradients_sagittal, self.gradients_coronal]

    def get_activations(self, x):
        images = [torch.squeeze(img, dim=0) for img in x]
        return [self.features_conv_axial(images[0]), self.features_conv_sagittal(images[1]), self.features_conv_coronal(images[2])]

# Initialize model
model = CNNModel()

# Define error criterion and optimize functions
train_criterion = nn.BCEWithLogitsLoss(pos_weight=train_dataset.weights)
val_criterion = nn.BCEWithLogitsLoss(pos_weight=val_dataset.weights)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.3, threshold=1e-4, verbose=True)

# Get Sensitivity-Specificity metrics
def get_sensitivity_specificity(y_true, y_pred):
    abnormal_true = list(map(lambda x: x[0], y_true))
    abnormal_pred = list(map(lambda x: x[0], y_pred))
    ACL_true = list(map(lambda x: x[1], y_true))
    ACL_pred = list(map(lambda x: x[1], y_pred))
    meniscus_true = list(map(lambda x: x[2], y_true))
    meniscus_pred = list(map(lambda x: x[2], y_pred))
    tn_ab, fp_ab, fn_ab, tp_ab = metrics.confusion_matrix(abnormal_true, abnormal_pred).ravel()
    tn_acl, fp_acl, fn_acl, tp_acl = metrics.confusion_matrix(ACL_true, ACL_pred).ravel()
    tn_men, fp_men, fn_men, tp_men = metrics.confusion_matrix(meniscus_true, meniscus_pred).ravel()
    sensitivity = [round(tp_ab / (tp_ab + fn_ab), 4), round(tp_acl / (tp_acl + fn_acl), 4), round(tp_men / (tp_men + fn_men), 4)]
    specificity = [round(tn_ab / (tn_ab + fp_ab), 4), round(tn_acl / (tn_acl + fp_acl), 4), round(tn_men / (tn_men + fp_men), 4)]
    return sensitivity, specificity

# Define train function
def train(train_data, model, criterion):
    print('Training...')
    model.train()
    counter = 0
    correct = 0
    train_running_loss = 0.0
    total = 0.0
    prediction_list = []
    label_list = []
    for input_data, label in tqdm(train_data):
        if torch.cuda.is_available():
            input_data, label = [data.cuda() for data in input_data], label.cuda()
        counter += 1
        optimizer.zero_grad()
        outputs = model(input_data)
        outputs_sig = torch.sigmoid(outputs)
        predicted = torch.round(outputs_sig)
        prediction_list.append(list(predicted.cpu().detach().numpy())[0])
        label_list.append(list(label.cpu().detach().numpy())[0])
        total += label.size(1)
        correct += (np.array(predicted.cpu().detach().numpy())[0] == np.array(label.cpu().detach().numpy())[0]).sum().item()
        loss = criterion(outputs.cpu(), label.cpu())
        train_running_loss += loss.item()
        loss.backward()
        optimizer.step()
    train_accuracy = correct / total
    train_loss = train_running_loss / counter
    train_auc = metrics.roc_auc_score(label_list, prediction_list, average='macro', multi_class='ovr')
    return train_loss, train_accuracy, train_auc, prediction_list

# Define val function
def val(val_data, model, criterion):
    print('Validating...')
    model.eval()
    counter = 0
    correct = 0
    val_running_loss = 0.0
    total = 0.0
    prediction_list = []
    label_list = []
    for input_data, label in tqdm(val_data):
        if torch.cuda.is_available():
            input_data, label = [data.cuda() for data in input_data], label.cuda()
        counter += 1
        outputs = model(input_data)
        outputs_sig = torch.sigmoid(outputs)
        predicted = torch.round(outputs_sig)
        prediction_list.append(list(predicted.cpu().detach().numpy())[0])
        label_list.append(list(label.cpu().detach().numpy())[0])
        total += label.size(1)
        correct += (np.array(predicted.cpu().detach().numpy())[0] == np.array(label.cpu().detach().numpy())[0]).sum().item()
        loss = criterion(outputs.cpu(), label.cpu())
        val_running_loss += loss.item()
    val_accuracy = correct / total
    val_loss = val_running_loss / counter
    val_auc = metrics.roc_auc_score(label_list, prediction_list, average='macro', multi_class='ovr')
    sensitivity, specificity = get_sensitivity_specificity(label_list, prediction_list)
    return val_loss, val_accuracy, val_auc, prediction_list, sensitivity, specificity

# Start the training and validation
train_loss = []
train_accuracy = []
train_auc = []
valid_loss = []
valid_accuracy = []
valid_auc = []
total_train_predictions = []
total_val_predictions = []

best_val_loss = float('inf')
best_val_auc = float(0)
patient_counter = 0

if torch.cuda.is_available():
    model = model.cuda()

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1} of {EPOCHS}")
    train_epoch_loss, train_epoch_accuracy, train_epoch_auc, train_predictions = train(train_loader, model, train_criterion)
    val_epoch_loss, val_epoch_accuracy, val_epoch_auc, val_predictions, val_sensitivity, val_specificity = val(val_loader, model, val_criterion)
    scheduler.step(val_epoch_loss)

    if best_val_loss - val_epoch_loss >= LOSS_IMPROVE:
        print("Val loss has improved. From {} to {}. Saving model...".format(best_val_loss, val_epoch_loss))
        best_val_loss = val_epoch_loss
        patient_counter = 0
        torch.save(model, f'{MRNET_PATH}/models/mrnet_three_pretrained_models_non_frozen_weights_standarized_img_aug_gradcam_2022_01_28.pth')
    else:
        print("Val loss did not improve")
        patient_counter += 1
        if patient_counter == PATIENT:
            break

    train_loss.append(train_epoch_loss)
    train_accuracy.append(train_epoch_accuracy)
    train_auc.append(train_epoch_auc)
    valid_loss.append(val_epoch_loss)
    valid_accuracy.append(val_epoch_accuracy)
    valid_auc.append(val_epoch_auc)
    total_train_predictions.append(train_predictions)
    total_val_predictions.append(val_predictions)
    print(f"Train Accuracy: {train_epoch_accuracy:.4f}")
    print(f'Val Accuracy: {val_epoch_accuracy:.4f}')
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f'Val Loss: {val_epoch_loss:.4f}')
    print(f"Train AUC: {train_epoch_auc:.4f}")
    print(f'Val AUC: {val_epoch_auc:.4f}')
    print("Val-Sensitivity. Abnormal : {}, ACL: {}, Meniscus: {}".format(val_sensitivity[0], val_sensitivity[1], val_sensitivity[2]))
    print("Val-Specifity.   Abnormal : {}, ACL: {}, Meniscus: {}".format(val_specificity[0], val_specificity[1], val_specificity[2]))
    print("-" * 80)



  ### Training Results Summary
   | Metric               | Value     |
|----------------------|-----------|
| Train Accuracy       | 0.5708    |
| Val Accuracy         | 0.6722    |
| Train Loss           | 0.7467    |
| Val Loss             | 0.5517    |
| Train AUC            | 0.5637    |
| Val AUC              | 0.7098    |
| Val-Sensitivity (Abnormal) | 0.5263 |
| Val-Sensitivity (ACL)     | 0.537   |
| Val-Sensitivity (Meniscus) | 0.8077 |
| Val-Specificity (Abnormal) | 0.92 |
| Val-Specificity (ACL)     | 0.8939 |
| Val-Specificity (Meniscus) | 0.5735 |