In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
import numpy as np
import h5py
import re
import cv2
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
import scipy.io
from skimage.transform import resize
from sklearn.preprocessing import label_binarize
from tqdm import tqdm
from torchvision import models

In [27]:
###########################################################
# Class for storing and working with the Brain Tumor
# dataset. Uses h5py to work with the .mat files
# 3064 samples in the dataset
class BrainTumorDataset(Dataset):
    def __init__(self, data_path):
        
        self.images = []
        self.labels = []
        
        for i in range(1,3065):
            data = self.load_mat(data_path + str(i) + ".mat")
            
            try:
                label = data[1]
                mask = data[2]
                
                image = data[0]
                
                # applying tumor mask
#                 image = image * mask
                
                self.images.append(image)
                self.labels.append(label)
            except Exception as e:
                print(f"Error during processing file {i}: {e}")
    
    
    def load_mat(self, file_path):
        data=h5py.File(file_path, 'r')
        image = data.get('cjdata/image/')
        label = data.get('cjdata/label/')
        mask = data.get('cjdata/tumorMask/')
        image = np.array(image, dtype="float32")
        label = np.array(label, dtype="int")
        mask = np.array(mask, dtype="int")
        return image, label, mask
        return data

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        
        image = torch.tensor(self.images[idx], dtype=torch.float32)
        image = image.expand(3, -1, -1)
        image = F.interpolate(image.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=False).squeeze(0)
        
        label = torch.tensor(self.labels[idx][0], dtype=torch.long)
        
        return image, label


In [28]:
###########################################################
# Function for extracting features using a model.
# Flattens the features and then stacks them before
# returning.
def extract_features(model, dataloader):
    
    features = []
    return_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc='Extracting Features'):
            outputs = model(inputs)
            flattened_features = outputs.view(outputs.size(0), -1)
            features.append(flattened_features.numpy())
            return_labels.append(labels.numpy())

    return_labels = np.concatenate(return_labels)
    features = np.vstack(features)
    return features, return_labels


In [29]:
###########################################################
# Function for evaluating classifiers and printing
# the required metrics
def evaluate_classifier(y_true, y_pred, model_name):
    
    # Accuracy
    accuracy = accuracy_score(y_true, y_pred)
    
    # Precision
    precision = precision_score(y_true, y_pred, average='weighted')
    
    # Recall
    recall = recall_score(y_true, y_pred, average='weighted')
    
    # F1 Score
    f1 = f1_score(y_true, y_pred, average='weighted')

    y_true_one_hot_encoding = label_binarize(y_true, classes=np.unique(y_true))

    if len(y_pred.shape) == 1:
        y_pred_one_hot_encoding = label_binarize(y_pred, classes=np.unique(y_true))
    else:
        y_pred_one_hot_encoding = y_pred

    # AUC-ROC
    roc_auc = roc_auc_score(y_true_one_hot_encoding, y_pred_one_hot_encoding, average='weighted')

    # Confusion Matrix
    cm = confusion_matrix(y_true, y_pred)
    
    print(f"------------------------------------------------------")
    print(f"Evaluation results for {model_name}:\n")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"ROC AUC Score: {roc_auc:.4f}")
    print("\nConfusion Matrix:\n", cm)
    print(f"------------------------------------------------------")
    
    return


In [30]:
# Load the data and split into train/validation/test sets

dataset = BrainTumorDataset('/home/rit/temp/dataset/')

train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

print("Train Dataset size: ", train_size)
print("Validation Dataset size: ", val_size)
print("Test Dataset size: ", test_size)

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Creating dataloaders
trainloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valloader = DataLoader(val_dataset, batch_size=32, shuffle=False)
testloader = DataLoader(test_dataset, batch_size=32, shuffle=False)


Train Dataset size:  2451
Validation Dataset size:  306
Test Dataset size:  307



## Feature Extraction using ResNet-18


In [31]:

# Use a pre-trained ResNet18 for feature extraction

resnet18 = models.resnet18(pretrained=True)

resnet18 = nn.Sequential(*list(resnet18.children())[:-1])

for param in resnet18.parameters():
    param.requires_grad = False

# Extacting the features
train_features, train_labels = extract_features(resnet18, trainloader)
test_features, test_labels = extract_features(resnet18, testloader)
val_features, val_labels = extract_features(resnet18, valloader)

Extracting Features: 100%|██████████| 77/77 [00:49<00:00,  1.55it/s]
Extracting Features: 100%|██████████| 10/10 [00:06<00:00,  1.61it/s]
Extracting Features: 100%|██████████| 10/10 [00:06<00:00,  1.62it/s]


--------------------------------------------------------------------------------------------------------------------------------


# Bayesian Classifier


In [32]:

# Declaring the Bayesian Classifier
bayesian_classifier = GaussianNB()

bayesian_classifier.fit(train_features, train_labels)

val_preds_bayesian = bayesian_classifier.predict(val_features)

evaluate_classifier(val_labels, val_preds_bayesian, 'Bayesian')

------------------------------------------------------
Evaluation results for Bayesian:

Accuracy: 0.8137
Precision: 0.8239
Recall: 0.8137
F1 Score: 0.8160
ROC AUC Score: 0.8620

Confusion Matrix:
 [[ 55   3   9]
 [ 22 115   8]
 [  4  11  79]]
------------------------------------------------------


  y = column_or_1d(y, warn=True)


--------------------------------------------------------------------------------------------------------------------------------


# Decision Tree Classifier


In [33]:

# Declaring the Decision Tree Classifier
decision_tree_classifier = DecisionTreeClassifier()

decision_tree_classifier.fit(train_features, train_labels)

val_preds_tree = decision_tree_classifier.predict(val_features)

evaluate_classifier(val_labels, val_preds_tree, 'Decision Tree')

------------------------------------------------------
Evaluation results for Decision Tree:

Accuracy: 0.7320
Precision: 0.7329
Recall: 0.7320
F1 Score: 0.7324
ROC AUC Score: 0.7932

Confusion Matrix:
 [[ 37  19  11]
 [ 19 115  11]
 [ 12  10  72]]
------------------------------------------------------


--------------------------------------------------------------------------------------------------------------------------------


# SVM (Support Vector Machines) Classifier


In [34]:

# Declaring the SVM Classifier
svm_classifier = SVC(probability=True)

svm_classifier.fit(train_features, train_labels)

val_preds_svm = svm_classifier.predict(val_features)

evaluate_classifier(val_labels, val_preds_svm, 'Support Vector Machine (SVM)')

  y = column_or_1d(y, warn=True)


------------------------------------------------------
Evaluation results for Support Vector Machine (SVM):

Accuracy: 0.9314
Precision: 0.9350
Recall: 0.9314
F1 Score: 0.9324
ROC AUC Score: 0.9532

Confusion Matrix:
 [[ 59   1   7]
 [ 10 135   0]
 [  3   0  91]]
------------------------------------------------------
