In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, recall_score, roc_curve, auc

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, ConcatDataset, Subset, SubsetRandomSampler
from torchvision import transforms, datasets
import torchvision.models as models
from torchvision import models
import timm
import csv
import seaborn as sns
from sklearn.metrics import confusion_matrix
import copy
import random
from torch.optim.lr_scheduler import StepLR, MultiStepLR, LambdaLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau
from torchsummary import summary

In [None]:
# Set random seeds for reproducibility
random.seed(2024)
np.random.seed(2024)
torch.manual_seed(2024)
torch.cuda.manual_seed(2024)

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load ResNet model
model = models.resnet50(pretrained=False)

In [None]:
num_ftrs = model.fc.in_features 
model.fc = nn.Sequential(nn.Linear(num_ftrs,6), nn.LogSoftmax(dim=1))

In [None]:
import torch
from timm.models import create_model

num_params = sum(p.numel() for p in model.parameters()) # Calculate the number of model parameters
print(f"Number of parameters in resnet50 model: {num_params}")

In [None]:
# Data preprocessing and augmentation
transform = transforms.Compose([
    transforms.Resize((224, 224)), # Resize the image
    transforms.ToTensor(), # Convert image to PyTorch tensor
])

In [None]:
data_dir = "/root/autodl-tmp/project/MedSAM-0.1/data/MULTI_TUMOR_split" # Root directory for data storage
# Define train, validation datasets
train_data = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform)
val_data = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=transform)
batch_size=64
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=64)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=64)

In [None]:
# Function to plot training/validation metrics and ROC curve
def plot_train_curve(modal ,train_acc_history, val_acc_history, val_auc_history, val_f1_history, val_r_history, val_fpr, val_tpr):
    # Parameters: train_acc_history, val_acc_history, val_auc_history, val_f1_history, val_r_history, val_fpr, val_tpr
    
     # --- Plot Accuracy curve ---
    plt.figure()
    plt.plot(train_acc_history, label="Train Acc")
    plt.plot(val_acc_history, label="Val Acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.savefig('./inde_train_curve_images/' + 'Acc_curve_of_' + modal + '.png') # Save Accuracy curve
    plt.close()

    # --- Plot AUC curve ---
    plt.figure()
    plt.plot(val_auc_history, label="Val AUC")
    plt.xlabel("Epoch")
    plt.ylabel("AUC")
    plt.legend()
    plt.savefig('./inde_train_curve_images/' + 'Auc_curve_of_' + modal + '.png') # Save AUC curve
    plt.close()

    # --- Plot F1 curve ---
    plt.figure()
    plt.plot(val_f1_history, label="Val F1")
    plt.xlabel("Epoch")
    plt.ylabel("F1 Score")
    plt.legend()
    plt.savefig('./inde_train_curve_images/' + 'F1_curve_of_' + modal + '.png') # Save F1 curve
    plt.close()

    # --- Plot Recall curve ---
    plt.figure()
    plt.plot(val_r_history, label="Val Recall")
    plt.xlabel("Epoch")
    plt.ylabel("Recall")
    plt.legend()
    plt.savefig('./inde_train_curve_images/' + 'Recall_curve_of_' + modal + '.png') # Save Recall curve
    plt.close()
    
    # --- Plot ROC curve ---
    plt.figure()
    plt.plot(val_fpr, val_tpr, label=f'ROC curve (area = {val_auc:.2f})')
    plt.plot([0, 1], [0, 1], linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic')
    plt.legend(loc="lower right")
    plt.savefig('./inde_train_curve_images/' + 'ROC_curve_of_' + modal + '.png') # Save ROC curve
    plt.close()

In [None]:
# Settings for transfer learning layers
flair_model = copy.deepcopy(model)
flair_model.to(device)
flair_params_to_update = [] # Parameters to update
all_layer_names = [name for name, _ in model.named_parameters()] # All parameter names in the model
free_layer_names = all_layer_names[-20:] # The last 20 parameter names of the model, you can change the number of the layers you want to finetune
# Only update layers listed in free_layer_names
for name, param in flair_model.named_parameters():
    if name not in free_layer_names:
        param.requires_grad = False 
    else:
        param.requires_grad = True
        flair_params_to_update.append(param)

161

In [None]:
criterion = nn.CrossEntropyLoss()
alpha = torch.tensor([3.0, 1.0])
optimizer = optim.Adam(flair_params_to_update, lr=5e-3)

In [None]:
# Training and validation
num_epochs = 500
train_acc_history = []
val_acc_history = []
val_auc_history = []
val_f1_history = []
val_r_history = []

min_loss = float('inf') # Initialize min_loss to infinity
best_acc = 0.0
best_auc = 0.0
patience = 30
early_stop = patience
change_rate = 15 # Change learning rate after this many epochs if performance doesn't improve

for epoch in range(num_epochs):
    # Training
    flair_model.train()
    train_correct = 0
    train_total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = flair_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        predicted = torch.argmax(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
        
    lr = optimizer.param_groups[0]['lr'] # Learning rate for the current epoch
    # Decay learning rate if early stop counter is below patience - change_rate
    if early_stop < (patience - change_rate):
        lr -= lr / (change_rate + early_stop)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        print ('Decay learning rate to lr: {}.'.format(lr))
    
    train_acc = train_correct / train_total
    train_acc_history.append(train_acc)
    
    # Validation
    flair_model.eval()
    val_correct = 0
    val_total = 0
    val_outputs_list = []
    val_labels_list = []
    val_loss = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = flair_model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            predicted = torch.argmax(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
            val_outputs_list.append(outputs.cpu().numpy())
            val_labels_list.append(labels.cpu().numpy())

        val_acc = val_correct / val_total
        val_acc_history.append(val_acc)

        val_outputs = np.concatenate(val_outputs_list, axis=0)
        val_labels = np.concatenate(val_labels_list, axis=0)

        val_fpr, val_tpr, _ = roc_curve(val_labels, val_outputs[:, 1], pos_label=1) # Calculate AUC
        val_auc = auc(val_fpr, val_tpr)
        val_auc_history.append(val_auc)

        val_f1 = f1_score(val_labels, np.argmax(val_outputs, axis=1),average='weighted')
        val_f1_history.append(val_f1)

        val_r = recall_score(val_labels, np.argmax(val_outputs, axis=1),average='weighted')
        val_r_history.append(val_r)

        # Calculate average validation loss
        avg_val_loss = val_loss / len(val_loader)
    
        # Early stopping logic
        if best_acc < val_acc: # Save best model based on highest accuracy
            best_acc = val_acc
            early_stop = patience
            model_name = 'flair_model_fml.pth' # Set model save filename
            torch.save(flair_model.state_dict(), model_name)
        else: 
            early_stop -= 1
        # Stop training when early_stop counter reaches 0
        if early_stop == 0:
            break

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Val AUC: {val_auc:.4f}, Val F1: {val_f1:.4f}, Val Recall: {val_r:.4f}, Val Loss: {avg_val_loss:.4f}, Early Stop: {early_stop:.0f}")

plot_train_curve('flair', train_acc_history, val_acc_history, val_auc_history, val_f1_history, val_r_history, val_fpr, val_tpr)

In [None]:
flair_model.load_state_dict(torch.load("flair_model_fml.pth")) # Load the best model saved during the training phase

In [49]:
from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score, f1_score, recall_score, roc_curve, auc

In [None]:
# Create a data loader for the test images
import torch.nn.functional as F
data_dir = "/root/autodl-tmp/project/MedSAM-0.1/data/MULTI_TUMOR_split/test"
image_data = datasets.ImageFolder(data_dir, transform=transform)
data_loader = DataLoader(image_data, batch_size=1, shuffle=False)

# Predict and evaluate
y_true = []
y_pred = []
y_scores = []

with torch.no_grad():
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = flair_model(images)
        probabilities = F.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs, 1)
        y_true.append(labels.item())
        y_pred.append(predicted.item())
        y_scores.append(probabilities.cpu().numpy())

y_true = np.array(y_true)
y_pred = np.array(y_pred)
y_scores = np.concatenate(y_scores, axis=0)
acc = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred,average='weighted')
recall = recall_score(y_true, y_pred,average='weighted')
print(f"Accuracy: {acc:.4f}, F1 Score: {f1:.4f}, Recall: {recall:.4f}")

In [None]:
from sklearn.metrics import classification_report
print(classification_report(y_true,y_pred))

In [None]:
import seaborn as sns
from sklearn.metrics import confusion_matrix
# Define class labels
labels = ['Glioma','Meningioma', 'Neurocitoma', 'NORMAL', 'Outros', 'Schwannoma']  # Replace with your actual class labels...

# Plot confusion matrix
conf_mat = confusion_matrix(y_true, y_pred)
plt.figure()
sns.heatmap(conf_mat, annot=True, fmt=".0f", cmap='Blues', xticklabels=labels, yticklabels=labels)
plt.xlabel("Predicted labels")
plt.ylabel("True labels")
plt.title("Confusion Matrix")
plt.show()