# Brain tumor classification

## Set up

In [None]:
import os
import PIL
import zipfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
import h5py
import cv2
from google.colab.patches import cv2_imshow

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
# Verification of the folder (expected train test val)
!ls '/content/drive/My Drive/dataset_gestion_couleur'

In [None]:
drive_path = "/content/drive/My Drive/dataset_gestion_couleur"
local_path = "/content/dataset_local"

## Resnet50: full transfer learning

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import copy
import os
import shutil
from torchvision.models import resnet50, ResNet50_Weights
from tqdm.notebook import tqdm
import wandb
from collections import Counter

# Copy on local of the dataset
if not os.path.exists(local_path):
    print("Copy the dataset to the local VM to speed up access...")
    shutil.copytree(drive_path, local_path)
    print("Copy finished.")


# Initialisation of wandb
config = {
    "batch_size_physical": 32,
    "lr": 1e-4,
    "epochs": 60,
    "architecture": "ResNet50",
    "strategy": "Transfer Learning",
    "augmentation": "HorizontalFlip + Rotation"
}

wandb.login(key="API_KEY")
wandb.init(project="brain_tumor_classification", name="resnet50", config=config)


# 2. Transformations
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])
}


# 3. Datasets & DataLoaders
batch_size =  config['batch_size_physical']
num_workers = 2

image_datasets = {x: datasets.ImageFolder(os.path.join(local_path, x),
                                          transform=data_transforms[x])
                  for x in ['train', 'val']}

dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size,
                             shuffle=True if x=='train' else False,
                             num_workers=num_workers, pin_memory=True)
               for x in ['train','val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train','val']}
classes = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(f"Classes : {classes}")
print(f"Device : {device}")

# Calculating weights for unbalanced classes
print("Calcul des poids de classes...")
train_targets = image_datasets['train'].targets
count_dict = Counter(train_targets)
class_count = [count_dict[i] for i in range(len(classes))]
total_count = sum(class_count)
weights = [total_count / (len(classes) * c) for c in class_count]
class_weights = torch.FloatTensor(weights).to(device)
print(f"Poids appliquÃ©s aux classes : {weights}")


# Load ResNet50 pretrained
weights = ResNet50_Weights.DEFAULT
model_ft = resnet50(weights=weights)

for param in model_ft.parameters():
    param.requires_grad = False

num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 4)
model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model_ft.fc.parameters(), lr=1e-4)


# Training loop
num_epochs = config['epochs']
patience = 10
epochs_no_improve = 0
best_model_wts = copy.deepcopy(model_ft.state_dict())
best_acc = 0.0

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print('-'*20)

    for phase in ['train','val']:
        if phase=='train':
            model_ft.train()
        else:
            model_ft.eval()

        running_loss = 0.0
        running_corrects = 0

        loop = tqdm(dataloaders[phase], desc=f"{phase} Epoch {epoch+1}", leave=False)

        for inputs, labels in loop:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase=='train'):
                outputs = model_ft(inputs)
                _, preds = torch.max(outputs,1)
                loss = criterion(outputs, labels)

                if phase=='train':
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds==labels.data)

            loop.set_postfix(loss=running_loss/((loop.n+1)*batch_size),
                             acc=(running_corrects.double()/((loop.n+1)*batch_size)).item())

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]

        print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

        # Log metrics on wandb
        wandb.log({f"{phase}_loss": epoch_loss, f"{phase}_acc": epoch_acc, "epoch": epoch+1})

        # Best model
        if phase=='val':
            if epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model_ft.state_dict())
                epochs_no_improve = 0
                torch.save(model_ft.state_dict(), "best_resnet50_transfer_learning.pth")
                wandb.save("best_resnet50_transfer_learning.pth")
            else:
                epochs_no_improve += 1

    # Early stopping
    if epochs_no_improve >= patience:
        print(f"Early stopping triggered at epoch {epoch+1}")
        break

model_ft.load_state_dict(best_model_wts)
torch.save(model_ft.state_dict(), "best_resnet50_transfer_learning.pth")
wandb.save("best_resnet50_transfer_learning.pth")

print("Model saved : best_resnet50_transfer_learning.pth")


### Evaluation

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet50
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix,
    roc_auc_score
)
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import os


# W&B Initialization
wandb.login()
wandb.init(
    project="brain_tumor_classification_test",
    name="resnet50_transfert_learning_evaluation",
    job_type="evaluation"
)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device used:", device)


# Load the trained model
model_path = "best_resnet50_transfer_learning.pth"
assert os.path.exists(model_path), "Model not found!"

model_ft = resnet50(weights=None)
num_ftrs = model_ft.fc.in_features
model_ft.fc =nn.Linear(num_ftrs, 4)

model_ft.load_state_dict(torch.load(model_path, map_location=device))
model_ft = model_ft.to(device)
model_ft.eval()


# Test Dataset & DataLoader
test_dir = "/content/dataset_local/test"

test_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485,0.456,0.406],
        std =[0.229,0.224,0.225]
    )
])

test_dataset = datasets.ImageFolder(test_dir, transform=test_transforms)
test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2
)

classes = test_dataset.classes
print("Classes:", classes)
print("Number of test images:", len(test_dataset))


# 5. Evaluation loop
all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model_ft(inputs)
        probs = torch.softmax(outputs, dim=1)
        _, preds = torch.max(outputs, 1)

        all_probs.extend(probs.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())


# Global Metrics
accuracy = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(
    all_labels, all_preds, average="macro"
)

# ROC-AUC multiclass
y_true_bin = label_binarize(all_labels, classes=[0,1,2,3])
roc_auc_macro = roc_auc_score(y_true_bin, all_probs, average="macro", multi_class="ovr")

print("\n=== GLOBAL METRICS ===")
print(f"Accuracy        : {accuracy:.4f}")
print(f"Precision macro : {precision:.4f}")
print(f"Recall macro    : {recall:.4f}")
print(f"F1-score macro  : {f1:.4f}")
print(f"ROC-AUC macro   : {roc_auc_macro:.4f}")

# W&B logging
wandb.log({
    "test/accuracy": accuracy,
    "test/precision_macro": precision,
    "test/recall_macro": recall,
    "test/f1_macro": f1,
    "test/roc_auc_macro": roc_auc_macro
})


# Class-wise Metrics
report = classification_report(
    all_labels,
    all_preds,
    target_names=classes,
    output_dict=True
)

print("\n=== CLASS-WISE METRICS ===")
print(classification_report(all_labels, all_preds, target_names=classes))

for cls in classes:
    wandb.log({
        f"test/{cls}/precision": report[cls]["precision"],
        f"test/{cls}/recall": report[cls]["recall"],
        f"test/{cls}/f1": report[cls]["f1-score"]
    })


# Confusion Matrix
wandb.log({
    "test/confusion_matrix": wandb.plot.confusion_matrix(
        y_true=all_labels,
        preds=all_preds,
        class_names=classes
    )
})

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(6,5))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=classes,
    yticklabels=classes
)
plt.xlabel("Predictions")
plt.ylabel("Ground truth")
plt.title("Confusion Matrix - Test set")
plt.tight_layout()
plt.show()

# Save model as Artifact
artifact = wandb.Artifact(
    name="resnet50",
    type="model",
    description="ResNet50 evaluated on test set"
)
artifact.add_file(model_path)
wandb.log_artifact(artifact)

wandb.finish()
print("\nEvaluation completed and logged to W&B")


## Resnet 50

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import copy
import os
import shutil
from torchvision.models import resnet50, ResNet50_Weights
from tqdm.notebook import tqdm
import wandb
from collections import Counter

if not os.path.exists(local_path):
    print("Copy the dataset to the local VM to speed up access...")
    shutil.copytree(drive_path, local_path)
    print("Copy finished.")

# Intialisation of wandb
config = {
    "batch_size_physical": 32,
    "lr": 1e-4,
    "epochs": 60,
    "architecture": "ResNet50",
    "strategy": "Fine-Tuning (Layer 4 + FC)",
    "augmentation": "Rotation"
}

wandb.login(key="API_KEY")
wandb.init(project="brain_tumor_classification", name="resnet50", config=config)


# Transformations
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])
}


# Datasets & DataLoaders
batch_size = config['batch_size_physical']
num_workers = 2

image_datasets = {x: datasets.ImageFolder(os.path.join(local_path, x),
                                          transform=data_transforms[x])
                  for x in ['train', 'val']}

dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size,
                             shuffle=True if x=='train' else False,
                             num_workers=num_workers, pin_memory=True)
               for x in ['train','val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train','val']}
classes = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(f"Classes : {classes}")
print(f"Device : {device}")

# Calculation of weights for unbalanced classes
train_targets = image_datasets['train'].targets
count_dict = Counter(train_targets)
class_count = [count_dict[i] for i in range(len(classes))]
total_count = sum(class_count)
weights = [total_count / (len(classes) * c) for c in class_count]
class_weights = torch.FloatTensor(weights).to(device)
print(f"Weights: {weights}")


# Load pretrained ResNet50
weights = ResNet50_Weights.DEFAULT
model_ft = resnet50(weights=weights)

for param in model_ft.parameters():
    param.requires_grad = False

for param in model_ft.layer4.parameters():
    param.requires_grad = True

num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 4)
model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)
params_to_update = [p for p in model_ft.parameters() if p.requires_grad]

# Adam optimizer with weight decay to improve regularization
optimizer = optim.Adam(params_to_update, lr=config['lr'], weight_decay=1e-4)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=5
)

# Training loop
num_epochs = config['epochs']
patience = 5
epochs_no_improve = 0
best_model_wts = copy.deepcopy(model_ft.state_dict())
best_acc = 0.0

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print('-'*20)

    for phase in ['train','val']:
        if phase=='train':
            model_ft.train()
        else:
            model_ft.eval()

        running_loss = 0.0
        running_corrects = 0

        loop = tqdm(dataloaders[phase], desc=f"{phase} Epoch {epoch+1}", leave=False)

        for inputs, labels in loop:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase=='train'):
                outputs = model_ft(inputs)
                _, preds = torch.max(outputs,1)
                loss = criterion(outputs, labels)

                if phase=='train':
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds==labels.data)

            loop.set_postfix(loss=running_loss/((loop.n+1)*batch_size),
                             acc=(running_corrects.double()/((loop.n+1)*batch_size)).item())

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]

        print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

        # Log metrics sur wandb
        wandb.log({f"{phase}_loss": epoch_loss, f"{phase}_acc": epoch_acc, "epoch": epoch+1})

        # Best model
        if phase=='val':
            if epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model_ft.state_dict())
                epochs_no_improve = 0
                torch.save(model_ft.state_dict(), "best_resnet50.pth")
                wandb.save("best_resnet50.pth")
            else:
                epochs_no_improve += 1

    # Early stopping
    if epochs_no_improve >= patience:
        print(f"Early stopping triggered at epoch {epoch+1}")
        break

model_ft.load_state_dict(best_model_wts)
torch.save(model_ft.state_dict(), "best_resnet50.pth")
wandb.save("best_restnet50.pth")
print("Model saved : best_restnet50.pth")


### Evaluation

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet50
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix,
    roc_auc_score
)
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import os

# W&B Initialization
wandb.login()
wandb.init(
    project="brain_tumor_classification_test",
    name="resnet50_evaluation",
    job_type="evaluation"
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device used:", device)


# Load the trained model
model_path = "best_resnet50.pth"
assert os.path.exists(model_path), "Model not found!"

model_ft = resnet50(weights=None)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 4)

model_ft.load_state_dict(torch.load(model_path, map_location=device))
model_ft = model_ft.to(device)
model_ft.eval()

# Test Dataset & DataLoader
test_dir = "/content/dataset_local/test"

test_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485,0.456,0.406],
        std =[0.229,0.224,0.225]
    )
])

test_dataset = datasets.ImageFolder(test_dir, transform=test_transforms)
test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2
)

classes = test_dataset.classes
print("Classes:", classes)
print("Number of test images:", len(test_dataset))

# Evaluation loop
all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model_ft(inputs)
        probs = torch.softmax(outputs, dim=1)
        _, preds = torch.max(outputs, 1)

        all_probs.extend(probs.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Global Metrics
accuracy = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(
    all_labels, all_preds, average="macro"
)

# ROC-AUC multiclass
y_true_bin = label_binarize(all_labels, classes=[0,1,2,3])
roc_auc_macro = roc_auc_score(y_true_bin, all_probs, average="macro", multi_class="ovr")

print("\n=== GLOBAL METRICS ===")
print(f"Accuracy        : {accuracy:.4f}")
print(f"Precision macro : {precision:.4f}")
print(f"Recall macro    : {recall:.4f}")
print(f"F1-score macro  : {f1:.4f}")
print(f"ROC-AUC macro   : {roc_auc_macro:.4f}")

# W&B logging
wandb.log({
    "test/accuracy": accuracy,
    "test/precision_macro": precision,
    "test/recall_macro": recall,
    "test/f1_macro": f1,
    "test/roc_auc_macro": roc_auc_macro
})

# Class-wise Metrics
report = classification_report(
    all_labels,
    all_preds,
    target_names=classes,
    output_dict=True
)

print("\n=== CLASS-WISE METRICS ===")
print(classification_report(all_labels, all_preds, target_names=classes))

for cls in classes:
    wandb.log({
        f"test/{cls}/precision": report[cls]["precision"],
        f"test/{cls}/recall": report[cls]["recall"],
        f"test/{cls}/f1": report[cls]["f1-score"]
    })

# Confusion Matrix
wandb.log({
    "test/confusion_matrix": wandb.plot.confusion_matrix(
        y_true=all_labels,
        preds=all_preds,
        class_names=classes
    )
})

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(6,5))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=classes,
    yticklabels=classes
)
plt.xlabel("Predictions")
plt.ylabel("Ground truth")
plt.title("Confusion Matrix - Test set")
plt.tight_layout()
plt.show()

# Save model as Artifact
artifact = wandb.Artifact(
    name="resnet50",
    type="model",
    description="ResNet50 evaluated on test set"
)
artifact.add_file(model_path)
wandb.log_artifact(artifact)

wandb.finish()
print("\nEvaluation completed and logged to W&B")


## VGG16

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import copy
import os
import shutil
import numpy as np
from torchvision.models import vgg16_bn, VGG16_BN_Weights
from tqdm.notebook import tqdm
import wandb
from collections import Counter

if not os.path.exists(local_path):
    print("Copying the dataset to the local VM...")
    shutil.copytree(drive_path, local_path)
    print("Copy finished.")

# Configuration
physical_batch_size = 16
accumulation_steps = 2

config = {
    "batch_size_physical": physical_batch_size,
    "batch_size_effective": physical_batch_size * accumulation_steps,
    "lr": 1e-4,
    "epochs": 60,
    "architecture": "VGG16_BN",
    "strategy": "Fine-Tuning (Full Classifier) + Grad Accumulation",
    "augmentation": "HorizontalFlip+Rotation"
}

wandb.login()
wandb.init(project="brain_tumor_classification", name="vgg16", config=config)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device : {device}")

# Transformations
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])
}

# Datasets & Dataloaders
image_datasets = {x: datasets.ImageFolder(os.path.join(local_path, x), transform=data_transforms[x])
                  for x in ['train', 'val']}

dataloaders = {x: DataLoader(image_datasets[x], batch_size=physical_batch_size,
                             shuffle=True if x=='train' else False,
                             num_workers=2, pin_memory=True)
               for x in ['train','val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train','val']}
classes = image_datasets['train'].classes

# Unbalanced weights
train_targets = image_datasets['train'].targets
count_dict = Counter(train_targets)
class_count = [count_dict[i] for i in range(len(classes))]
total_count = sum(class_count)
weights = [total_count / (len(classes) * c) for c in class_count]
class_weights = torch.FloatTensor(weights).to(device)
print(f"Weights : {weights}")


# VGG16 model
weights_model = VGG16_BN_Weights.DEFAULT
model_ft = vgg16_bn(weights=weights_model)

# Freeze Features
for param in model_ft.features.parameters():
    param.requires_grad = False

# Unfreeze Classifier
for param in model_ft.classifier.parameters():
    param.requires_grad = True

# Custom Head
num_ftrs = model_ft.classifier[6].in_features
model_ft.classifier[6] = nn.Linear(num_ftrs, len(classes))
model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model_ft.parameters(), lr=config['lr'], weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)


# Training loop
num_epochs = config['epochs']
patience = 5
epochs_no_improve = 0
best_model_wts = copy.deepcopy(model_ft.state_dict())
best_acc = 0.0

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print('-'*20)

    current_lr = optimizer.param_groups[0]['lr']
    wandb.log({"lr": current_lr, "epoch": epoch+1})

    for phase in ['train','val']:
        if phase=='train':
            model_ft.train()
        else:
            model_ft.eval()

        running_loss = 0.0
        running_corrects = 0

        optimizer.zero_grad()

        loop = tqdm(dataloaders[phase], desc=f"{phase} Ep {epoch+1}", leave=False)

        for i, (inputs, labels) in enumerate(loop):
            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.set_grad_enabled(phase=='train'):
                outputs = model_ft(inputs)
                _, preds = torch.max(outputs,1)
                loss = criterion(outputs, labels)

                if phase=='train':
                    loss = loss / accumulation_steps
                    loss.backward()

                    if (i + 1) % accumulation_steps == 0:
                        optimizer.step()
                        optimizer.zero_grad()

            batch_loss = loss.item() * accumulation_steps if phase == 'train' else loss.item()

            running_loss += batch_loss * inputs.size(0)
            running_corrects += torch.sum(preds==labels.data)

            loop.set_postfix(loss=batch_loss, acc=(torch.sum(preds==labels.data).double()/inputs.size(0)).item())

        if phase == 'train' and len(dataloaders['train']) % accumulation_steps != 0:
            optimizer.step()
            optimizer.zero_grad()

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]

        print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
        wandb.log({f"{phase}_loss": epoch_loss, f"{phase}_acc": epoch_acc, "epoch": epoch+1})

        if phase=='val':
            scheduler.step(epoch_loss)

            if epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model_ft.state_dict())
                epochs_no_improve = 0
                torch.save(model_ft.state_dict(), "best_vgg16.pth")
                wandb.save("best_vgg16.pth")
                print(f"--> New best (Acc: {best_acc:.4f})")
            else:
                epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print("Early stopping triggered")
        break

model_ft.load_state_dict(best_model_wts)
torch.save(model_ft.state_dict(), "final_vgg16.pth")
wandb.save("final_vgg16.pth")
wandb.finish()

### Evaluation

In [None]:
import os
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import vgg16_bn, VGG16_BN_Weights

from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix,
    roc_auc_score
)
from sklearn.preprocessing import label_binarize

import matplotlib.pyplot as plt
import seaborn as sns
import wandb

# W&B Initialization
wandb.login()
wandb.init(
    project="brain_tumor_classification_test",
    name="vgg16_evaluation",
    job_type="evaluation"
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device used:", device)

# Load trained model
model_path = "best_vgg16.pth"
assert os.path.exists(model_path), "Model not found!"

weights_model = VGG16_BN_Weights.DEFAULT
model = vgg16_bn(weights=None)

num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, 4)

model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()

# Test Dataset & DataLoader
test_dir = "/content/dataset_local/test"

test_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std =[0.229, 0.224, 0.225]
    )
])

test_dataset = datasets.ImageFolder(test_dir, transform=test_transforms)
test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

classes = test_dataset.classes
print("Classes:", classes)
print("Number of test images:", len(test_dataset))


# Evaluation loop
all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        probs = torch.softmax(outputs, dim=1)
        _, preds = torch.max(outputs, 1)

        all_probs.extend(probs.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Global Metrics
accuracy = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(
    all_labels, all_preds, average="macro"
)

# ROC-AUC multi-class
y_true_bin = label_binarize(all_labels, classes=list(range(len(classes))))
roc_auc_macro = roc_auc_score(
    y_true_bin,
    all_probs,
    average="macro",
    multi_class="ovr"
)

print("\n=== GLOBAL METRICS ===")
print(f"Accuracy        : {accuracy:.4f}")
print(f"Precision macro : {precision:.4f}")
print(f"Recall macro    : {recall:.4f}")
print(f"F1-score macro  : {f1:.4f}")
print(f"ROC-AUC macro   : {roc_auc_macro:.4f}")

wandb.log({
    "test/accuracy": accuracy,
    "test/precision_macro": precision,
    "test/recall_macro": recall,
    "test/f1_macro": f1,
    "test/roc_auc_macro": roc_auc_macro
})

# Class-wise Metrics
report = classification_report(
    all_labels,
    all_preds,
    target_names=classes,
    output_dict=True
)

print("\n=== CLASS-WISE METRICS ===")
print(classification_report(all_labels, all_preds, target_names=classes))

for cls in classes:
    wandb.log({
        f"test/{cls}/precision": report[cls]["precision"],
        f"test/{cls}/recall": report[cls]["recall"],
        f"test/{cls}/f1": report[cls]["f1-score"]
    })

# Confusion Matrix
wandb.log({
    "test/confusion_matrix": wandb.plot.confusion_matrix(
        y_true=all_labels,
        preds=all_preds,
        class_names=classes
    )
})

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(6,5))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=classes,
    yticklabels=classes
)
plt.xlabel("Predictions")
plt.ylabel("Ground Truth")
plt.title("Confusion Matrix â Test set")
plt.tight_layout()
plt.show()


# Save model as Artifact
artifact = wandb.Artifact(
    name="vgg16_bn",
    type="model",
    description="VGG16_BN evaluated on test set"
)
artifact.add_file(model_path)
wandb.log_artifact(artifact)

wandb.finish()
print("\nEvaluation completed and logged to W&B")


## DenseNet121

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import copy
import os
import shutil
import numpy as np
from torchvision.models import densenet121, DenseNet121_Weights
from tqdm.notebook import tqdm
import wandb
from collections import Counter


if not os.path.exists(local_path):
    print("Copying the dataset to the local VM...")
    shutil.copytree(drive_path, local_path)
    print("Copy finished.")
-

PHYSICAL_BATCH_SIZE = 32
ACCUMULATION_STEPS = 4  # Batch effectif = 128

config = {
    "batch_size_physical": PHYSICAL_BATCH_SIZE,
    "batch_size_effective": PHYSICAL_BATCH_SIZE * ACCUMULATION_STEPS,
    "lr": 1e-4,
    "epochs": 60,
    "patience": 10,
    "architecture": "DenseNet121",
    "strategy": "Fine-Tuning (DenseBlock4 + Norm5 + Classifier)",
    "augmentation": "HFlip + Rotation(10)"
}

wandb.login()
wandb.init(project="brain_tumor_classification", name="densenet121", config=config)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device : {device}")

# Transformations
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])
}

# Datasets & Dataloaders
image_datasets = {x: datasets.ImageFolder(os.path.join(local_path, x),
                                          transform=data_transforms[x])
                  for x in ['train', 'val']}

dataloaders = {x: DataLoader(image_datasets[x], batch_size=PHYSICAL_BATCH_SIZE,
                             shuffle=True if x=='train' else False,
                             num_workers=2, pin_memory=True)
               for x in ['train','val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train','val']}
classes = image_datasets['train'].classes

# Computation of the weights for each classes
train_targets = image_datasets['train'].targets
count_dict = Counter(train_targets)
class_count = [count_dict[i] for i in range(len(classes))]
total_count = sum(class_count)
weights = [total_count / (len(classes) * c) for c in class_count]
class_weights = torch.FloatTensor(weights).to(device)

# Load pretrained DenseNet121
weights_model = DenseNet121_Weights.DEFAULT
model_ft = densenet121(weights=weights_model)

for param in model_ft.parameters():
    param.requires_grad = False

for param in model_ft.features.denseblock4.parameters():
    param.requires_grad = True

for param in model_ft.features.norm5.parameters():
    param.requires_grad = True

num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Linear(num_ftrs, len(classes))
model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)
params_to_update = [p for p in model_ft.parameters() if p.requires_grad]
optimizer = optim.Adam(params_to_update, lr=config['lr'], weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=4)

# Training loop
num_epochs = config['epochs']
patience = config['patience']
epochs_no_improve = 0
best_model_wts = copy.deepcopy(model_ft.state_dict())
best_acc = 0.0

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print('-'*20)

    current_lr = optimizer.param_groups[0]['lr']
    wandb.log({"lr": current_lr, "epoch": epoch+1})

    for phase in ['train','val']:
        if phase=='train':
            model_ft.train()
        else:
            model_ft.eval()

        running_loss = 0.0
        running_corrects = 0

        optimizer.zero_grad()

        loop = tqdm(dataloaders[phase], desc=f"{phase} Epoch {epoch+1}", leave=False)

        for i, (inputs, labels) in enumerate(loop):
            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.set_grad_enabled(phase=='train'):
                outputs = model_ft(inputs)
                _, preds = torch.max(outputs,1)
                loss = criterion(outputs, labels)

                if phase=='train':
                    loss = loss / ACCUMULATION_STEPS
                    loss.backward()

                    if (i + 1) % ACCUMULATION_STEPS == 0:
                        optimizer.step()
                        optimizer.zero_grad()

            # Logging correct loss value
            current_loss_val = loss.item() * ACCUMULATION_STEPS if phase == 'train' else loss.item()

            running_loss += current_loss_val * inputs.size(0)
            running_corrects += torch.sum(preds==labels.data)

            loop.set_postfix(loss=current_loss_val, acc=(running_corrects.double()/((loop.n+1)*PHYSICAL_BATCH_SIZE)).item())

        if phase == 'train' and len(dataloaders['train']) % ACCUMULATION_STEPS != 0:
            optimizer.step()
            optimizer.zero_grad()

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]

        print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
        wandb.log({f"{phase}_loss": epoch_loss, f"{phase}_acc": epoch_acc, "epoch": epoch+1})

        if phase=='val':
            scheduler.step(epoch_loss)

            if epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model_ft.state_dict())
                epochs_no_improve = 0
                torch.save(model_ft.state_dict(), "best_densenet121.pth")
                wandb.save("best_densenet121.pth")
                print(f"--> New best ! (Acc: {best_acc:.4f})")
            else:
                epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print(f"Early stopping triggered at epoch {epoch+1}")
        break

model_ft.load_state_dict(best_model_wts)
torch.save(model_ft.state_dict(), "final_densenet121.pth")
wandb.save("final_densenet121.pth")
wandb.finish()

### Evaluation

In [None]:
import os
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import densenet121, DenseNet121_Weights

from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix,
    roc_auc_score
)
from sklearn.preprocessing import label_binarize

import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import numpy as np
import shutil


# W&B Initialization
wandb.login()
wandb.init(
    project="brain_tumor_classification_test",
    name="densenet121_evaluation",
    job_type="evaluation"
)


# Device & Dataset Setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device used:", device)


# Load Trained Model (DenseNet121)
model_path = "best_densenet121.pth"

weights_model = DenseNet121_Weights.DEFAULT
model = densenet121(weights=weights_model)

num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, 4)

try:
    model.load_state_dict(torch.load(model_path, map_location=device))
except RuntimeError as e:
    print("Error in loading")
    raise e

model = model.to(device)
model.eval()

# Test Dataset & DataLoader
test_dir = os.path.join(local_path, "test")

test_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std =[0.229, 0.224, 0.225]
    )
])

test_dataset = datasets.ImageFolder(test_dir, transform=test_transforms)
test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

classes = test_dataset.classes
print(f"Classes found : {classes}")
print(f"Nombre test images : {len(test_dataset)}")


# Evaluation Loop
all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        probs = torch.softmax(outputs, dim=1)
        _, preds = torch.max(outputs, 1)

        all_probs.extend(probs.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())


# Global Metrics
accuracy = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(
    all_labels, all_preds, average="macro"
)

# ROC-AUC multi-class
y_true_bin = label_binarize(all_labels, classes=list(range(len(classes))))
try:
    roc_auc_macro = roc_auc_score(
        y_true_bin,
        all_probs,
        average="macro",
        multi_class="ovr"
    )
except ValueError:
    roc_auc_macro = 0.0

print("\n" + "="*30)
print("=== GLOBAL METRICS ===")
print(f"Accuracy        : {accuracy:.4f}")
print(f"Precision macro : {precision:.4f}")
print(f"Recall macro    : {recall:.4f}")
print(f"F1-score macro  : {f1:.4f}")
print(f"ROC-AUC macro   : {roc_auc_macro:.4f}")
print("="*30)

wandb.log({
    "test/accuracy": accuracy,
    "test/precision_macro": precision,
    "test/recall_macro": recall,
    "test/f1_macro": f1,
    "test/roc_auc_macro": roc_auc_macro
})

# Class-wise Metrics
report = classification_report(
    all_labels,
    all_preds,
    target_names=classes,
    output_dict=True
)

print("\n=== CLASS-WISE METRICS ===")
print(classification_report(all_labels, all_preds, target_names=classes))

for cls in classes:
    wandb.log({
        f"test/{cls}/precision": report[cls]["precision"],
        f"test/{cls}/recall": report[cls]["recall"],
        f"test/{cls}/f1": report[cls]["f1-score"]
    })


# Confusion Matrix
wandb.log({
    "test/confusion_matrix": wandb.plot.confusion_matrix(
        y_true=all_labels,
        preds=all_preds,
        class_names=classes
    )
})

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=classes,
    yticklabels=classes
)
plt.xlabel("Prediction")
plt.ylabel("Ground Truth")
plt.title(f"Confusion Matrix - Test set")
plt.tight_layout()
plt.savefig("densenet121_confusion_matrix.png", dpi=300)
plt.show()


# Save Model as Artifact
artifact = wandb.Artifact(
    name="densenet121",
    type="model",
    description="DenseNet121 evaluated on test set"
)
artifact.add_file(model_path)
wandb.log_artifact(artifact)

wandb.finish()

## Inception

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import copy
import os
import shutil
from torchvision.models import inception_v3, Inception_V3_Weights
from tqdm.notebook import tqdm
import wandb
from collections import Counter

if not os.path.exists(local_path):
    print(f"Copying dataset from {drive_path} to {local_path}...")
    try:
        shutil.copytree(drive_path, local_path)
        print("Copy finished.")
    except Exception as e:
        print(f"Error copying dataset: {e}")
else:
    print("Dataset already exists locally.")

# Configuration
config = {
    "batch_size_physical": 32,
    "lr": 1e-4,
    "epochs": 60,
    "architecture": "InceptionV3",
    "strategy": "Fine-Tuning (Mixed_7c + AuxLogits + FC)",
    "augmentation": "Rotation + Resize 299 + ImageNet Norm"
}

wandb.login(key="API_KEY")
wandb.init(project="brain_tumor_classification", name="inception_v3_optimized", config=config)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Transformations
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(norm_mean, norm_std)
    ]),
    'val': transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize(norm_mean, norm_std)
    ])
}

# Datasets & DataLoaders
batch_size = config['batch_size_physical']

image_datasets = {
    x: datasets.ImageFolder(os.path.join(local_path, x), transform=data_transforms[x])
    for x in ['train', 'val']
}

dataloaders = {
    x: DataLoader(image_datasets[x], batch_size=batch_size,
                  shuffle=(x == 'train'), num_workers=2, pin_memory=True)
    for x in ['train', 'val']
}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
classes = image_datasets['train'].classes
num_classes = len(classes)

print(f"Classes: {classes}")
print(f"Dataset sizes: {dataset_sizes}")

# Class Weights Calculation
train_targets = image_datasets['train'].targets
count_dict = Counter(train_targets)
class_count = [count_dict[i] for i in range(num_classes)]
weights = [sum(class_count) / (num_classes * c) for c in class_count]
class_weights = torch.FloatTensor(weights).to(device)
print(f"Class Weights: {weights}")

# Loading Inception V3
weights_inception = Inception_V3_Weights.DEFAULT
model_ft = inception_v3(weights=weights_inception)

for param in model_ft.parameters():
    param.requires_grad = False

for param in model_ft.Mixed_7c.parameters():
    param.requires_grad = True

for param in model_ft.AuxLogits.parameters():
    param.requires_grad = True

for param in model_ft.fc.parameters():
    param.requires_grad = True

num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, num_classes)

num_aux_ftrs = model_ft.AuxLogits.fc.in_features
model_ft.AuxLogits.fc = nn.Linear(num_aux_ftrs, num_classes)
model_ft = model_ft.to(device)


criterion = nn.CrossEntropyLoss(weight=class_weights)
params_to_update = [p for p in model_ft.parameters() if p.requires_grad]
optimizer = optim.Adam(params_to_update, lr=config['lr'], weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=5
)


# Training Loop
num_epochs = config["epochs"]
patience = 7
epochs_no_improve = 0
best_model_wts = copy.deepcopy(model_ft.state_dict())
best_acc = 0.0

print("\nStarting Training...")

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print('-'*20)

    for phase in ['train', 'val']:
        if phase == 'train':
            model_ft.train()
        else:
            model_ft.eval()

        running_loss = 0.0
        running_corrects = 0

        loop = tqdm(dataloaders[phase], desc=f"{phase} Epoch {epoch+1}", leave=False)

        for inputs, labels in loop:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                if phase == 'train':
                    outputs, aux_outputs = model_ft(inputs)
                    loss1 = criterion(outputs, labels)
                    loss2 = criterion(aux_outputs, labels)
                    loss = loss1 + 0.4 * loss2
                else:
                    outputs = model_ft(inputs)
                    loss = criterion(outputs, labels)

                _, preds = torch.max(outputs, 1)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

            loop.set_postfix(
                loss=running_loss/((loop.n+1)*batch_size),
                acc=(running_corrects.double()/((loop.n+1)*batch_size)).item()
            )

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]

        print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
        wandb.log({f"{phase}_loss": epoch_loss, f"{phase}_acc": epoch_acc, "epoch": epoch+1})

        if phase == 'val':
            scheduler.step(epoch_loss)

            if epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model_ft.state_dict())
                epochs_no_improve = 0

                # Save checkpoint
                torch.save(model_ft.state_dict(), "best_inception_v3.pth")
                wandb.save("best_inception_v3.pth")
                print(f"--> New best Acc: {best_acc:.4f} (Saved)")
            else:
                epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print(f"\nEarly stopping triggered at epoch {epoch+1}")
        break


model_ft.load_state_dict(best_model_wts)
torch.save(model_ft.state_dict(), "final_inception_v3.pth")
wandb.save("final_inception_v3.pth")
wandb.finish()

### Evaluation

In [None]:
import os
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import inception_v3, Inception_V3_Weights

from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix,
    roc_auc_score
)
from sklearn.preprocessing import label_binarize

import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import numpy as np
import shutil

# W&B Initialization
wandb.login()
wandb.init(
    project="brain_tumor_classification_test",
    name="inception_v3_evaluation",
    job_type="evaluation"
)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device used:", device)

if not os.path.exists(local_path):
    if os.path.exists(drive_path):
        print("Dataset not found locally. Copying from Drive...")
        shutil.copytree(drive_path, local_path)
        print("Copy completed.")
    else:
        print("Warning: Dataset not found on Drive either.")


# 3. Load Trained Model (Inception V3)
model_path = "best_inception_v3.pth"
weights_model = Inception_V3_Weights.DEFAULT
model = inception_v3(weights=None)

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 4)

num_aux_ftrs = model.AuxLogits.fc.in_features
model.AuxLogits.fc = nn.Linear(num_aux_ftrs, 4)

try:
    model.load_state_dict(torch.load(model_path, map_location=device))
    print("Weights loaded successfully.")
except RuntimeError as e:
    print("Dimension error. Ensure you redefined AuxLogits correctly.")
    raise e

model = model.to(device)
model.eval()

# Test Dataset & DataLoader
test_dir = os.path.join(local_path, "test")
test_transforms = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std =[0.5, 0.5, 0.5]
    )
])

test_dataset = datasets.ImageFolder(test_dir, transform=test_transforms)
test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

classes = test_dataset.classes
print(f"Classes found: {classes}")
print(f"Number of test images: {len(test_dataset)}")

# Evaluation loop
all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        probs = torch.softmax(outputs, dim=1)
        _, preds = torch.max(outputs, 1)

        all_probs.extend(probs.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Global Metrics
accuracy = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(
    all_labels, all_preds, average="macro"
)

y_true_bin = label_binarize(all_labels, classes=list(range(len(classes))))
try:
    roc_auc_macro = roc_auc_score(
        y_true_bin,
        all_probs,
        average="macro",
        multi_class="ovr"
    )
except ValueError:
    roc_auc_macro = 0.0

print("\n" + "="*30)
print("=== GLOBAL METRICS ===")
print(f"Accuracy        : {accuracy:.4f}")
print(f"Precision macro : {precision:.4f}")
print(f"Recall macro    : {recall:.4f}")
print(f"F1-score macro  : {f1:.4f}")
print(f"ROC-AUC macro   : {roc_auc_macro:.4f}")
print("="*30)

wandb.log({
    "test/accuracy": accuracy,
    "test/precision_macro": precision,
    "test/recall_macro": recall,
    "test/f1_macro": f1,
    "test/roc_auc_macro": roc_auc_macro
})

# Class-wise Metrics
report = classification_report(
    all_labels,
    all_preds,
    target_names=classes,
    output_dict=True
)

print("\n=== CLASS-WISE METRICS ===")
print(classification_report(all_labels, all_preds, target_names=classes))

for cls in classes:
    wandb.log({
        f"test/{cls}/precision": report[cls]["precision"],
        f"test/{cls}/recall": report[cls]["recall"],
        f"test/{cls}/f1": report[cls]["f1-score"]
    })

# Confusion Matrix
wandb.log({
    "test/confusion_matrix": wandb.plot.confusion_matrix(
        y_true=all_labels,
        preds=all_preds,
        class_names=classes
    )
})

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=classes,
    yticklabels=classes
)
plt.xlabel("Predicted")
plt.ylabel("Ground Truth")
plt.title(f"Confusion Matrix - Test set")
plt.tight_layout()
plt.savefig("inception_v3_confusion_matrix.png", dpi=300)
plt.show()

# Save Artifact
artifact = wandb.Artifact(
    name="inception_v3",
    type="model",
    description="InceptionV3 evaluated on test set"
)
artifact.add_file(model_path)
wandb.log_artifact(artifact)

wandb.finish()