# 1. Library

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
from PIL import Image
import torch.nn.functional as F
from google.colab import drive

# 2. Basic configuration

In [2]:
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [4]:
# Dataset directory
DATA_DIR = "/content/drive/MyDrive/Brain_Tumor_Detection_Segmentation/brisc2025/classification_task/"

TRAIN_DIR = os.path.join(DATA_DIR, "train")
TEST_DIR = os.path.join(DATA_DIR, "test")

# 3. Data preprocessing

In [5]:
# Data transformation
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.2),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.15, contrast=0.15),
    transforms.GaussianBlur(kernel_size=(3,3)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

In [6]:
train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=transform_train)
test_dataset = datasets.ImageFolder(TEST_DIR, transform=transform_test)

In [7]:
# Splitting training set and validating set (80 train, 20 validation)
val_size = int(0.2 * len(train_dataset))
train_size = len(train_dataset) - val_size
train_ds, val_ds = random_split(train_dataset, [train_size, val_size])

In [8]:
print(f"Train size: {train_size}, Val size: {val_size}, Test size: {len(test_dataset)}")

Train size: 4000, Val size: 1000, Test size: 1000


In [9]:
#DataLoader
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 4. ResNet18

In [10]:
#Model + Dropout in head
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(p=0.5),      # dropout to reduce overfitting
    nn.Linear(num_features, 4)
)
model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 188MB/s]


In [21]:
# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5, patience=2)

In [13]:
# Early stopping params
best_acc = 0.0
patience = 3
patience_counter = 0
EPOCHS = 10

In [18]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    return total_loss / len(loader), correct / total

In [19]:
# evaluation
def evaluate(model, loader, criterion):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return total_loss / len(loader), correct / total

In [20]:
def enable_dropout(model):
    """ Enable dropout layers during inference. """
    for m in model.modules():
        if isinstance(m, nn.Dropout):
            m.train()

def predict_with_mc_dropout_classification(model, img_tensor, n_iter=10):
    """
    img_tensor: (1, C, H, W) preprocessed
    returns: mean_probs (1, num_classes), std_probs (1, num_classes)
    """
    model.eval()
    enable_dropout(model)  # keep dropout active
    probs = []
    with torch.no_grad():
        for _ in range(n_iter):
            out = model(img_tensor)                     # logits
            p = F.softmax(out, dim=1).cpu().numpy()     # (1, num_classes)
            probs.append(p)
    probs = np.concatenate(probs, axis=0)  # (n_iter, 1, num_classes)
    mean_probs = probs.mean(axis=0)        # (1, num_classes)
    std_probs = probs.std(axis=0)          # (1, num_classes)
    return mean_probs, std_probs

def predict_with_mc_dropout_segmentation(model, img_tensor, n_iter=10):
    """
    For segmentation model that outputs logits of shape (1, C, H, W).
    Returns:
      mean_probs: (C, H, W)
      std_probs: (C, H, W)
    """
    model.eval()
    enable_dropout(model)
    probs = []
    with torch.no_grad():
        for _ in range(n_iter):
            out = model(img_tensor)  # assume shape (1, C, H, W)
            p = F.softmax(out, dim=1).cpu().numpy()  # (1, C, H, W)
            probs.append(p)
    probs = np.stack(probs, axis=0)  # (n_iter, 1, C, H, W)
    probs = probs[:, 0, ...]         # (n_iter, C, H, W)
    mean_probs = probs.mean(axis=0)  # (C, H, W)
    std_probs = probs.std(axis=0)    # (C, H, W)
    return mean_probs, std_probs

In [22]:
# train and validation
for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, val_loader, criterion)
    scheduler.step(val_acc)  # reduce LR on plateau of val_acc

    print(f"Epoch {epoch+1}/{EPOCHS}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # Save best model
    if val_acc > best_acc + 1e-6:
        best_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(), "best_resnet18.pth")
        print("Saved best model!")
    else:
        patience_counter += 1
        print(f"No improvement, patience {patience_counter}/{patience}")
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

Epoch 1/10: Train Loss: 0.4234, Train Acc: 0.8357 | Val Loss: 0.2676, Val Acc: 0.9050
Saved best model!
Epoch 2/10: Train Loss: 0.1640, Train Acc: 0.9385 | Val Loss: 0.1716, Val Acc: 0.9450
Saved best model!
Epoch 3/10: Train Loss: 0.1225, Train Acc: 0.9560 | Val Loss: 0.1169, Val Acc: 0.9590
Saved best model!
Epoch 4/10: Train Loss: 0.0771, Train Acc: 0.9738 | Val Loss: 0.1296, Val Acc: 0.9570
No improvement, patience 1/3
Epoch 5/10: Train Loss: 0.0782, Train Acc: 0.9742 | Val Loss: 0.1360, Val Acc: 0.9530
No improvement, patience 2/3
Epoch 6/10: Train Loss: 0.0622, Train Acc: 0.9775 | Val Loss: 0.0863, Val Acc: 0.9710
Saved best model!
Epoch 7/10: Train Loss: 0.0524, Train Acc: 0.9832 | Val Loss: 0.1283, Val Acc: 0.9670
No improvement, patience 1/3
Epoch 8/10: Train Loss: 0.0571, Train Acc: 0.9815 | Val Loss: 0.1012, Val Acc: 0.9690
No improvement, patience 2/3
Epoch 9/10: Train Loss: 0.0437, Train Acc: 0.9848 | Val Loss: 0.0922, Val Acc: 0.9760
Saved best model!
Epoch 10/10: Train L

In [23]:
model.load_state_dict(torch.load("best_resnet18.pth"))
model.eval()

y_true, y_pred = [], []

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        outputs = model(imgs)
        _, predicted = torch.max(outputs.data, 1)
        y_true.extend(labels.numpy())
        y_pred.extend(predicted.cpu().numpy())

print("Classification Report:")
print(classification_report(y_true, y_pred, target_names=train_dataset.classes))

Classification Report:
              precision    recall  f1-score   support

      glioma       0.99      0.98      0.99       254
  meningioma       0.98      0.98      0.98       306
    no_tumor       0.99      1.00      0.99       140
   pituitary       0.99      0.99      0.99       300

    accuracy                           0.99      1000
   macro avg       0.99      0.99      0.99      1000
weighted avg       0.99      0.99      0.99      1000



In [24]:
# prediction function
def predict_image(image_path, model, class_names):
    model.eval()
    img = Image.open(image_path).convert("RGB")

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

    img_tensor = transform(img).unsqueeze(0).to(device)  # add more batch dimension
    with torch.no_grad():
        outputs = model(img_tensor)
        probs = F.softmax(outputs, dim=1)
        conf, pred_class = torch.max(probs, 1)

    label = class_names[pred_class.item()]
    print(f"Prediction: {label} (confidence: {conf.item():.4f})")

    for i, cls in enumerate(class_names):
        print(f"  - {cls}: {probs[0][i].item():.4f}")

    return label, conf.item()

In [36]:
def predict_image_mc(image_path, model, class_names, n_iter=10):
    model.to(device)
    img = Image.open(image_path).convert("RGB")
    img_tensor = transform_test(img).unsqueeze(0).to(device)

    mean_probs, std_probs = predict_with_mc_dropout_classification(model, img_tensor, n_iter=n_iter)

    pred_class = int(np.argmax(mean_probs))
    conf = float(mean_probs[pred_class])
    conf_std = float(std_probs[pred_class])

    print(f"Prediction: {class_names[pred_class]} (mean conf: {conf:.4f} ± {conf_std:.4f})")
    for i, cls in enumerate(class_names):
        print(f"  - {cls}: mean={mean_probs[i]:.4f}, std={std_probs[i]:.4f}")

    return pred_class, mean_probs, std_probs

In [35]:
sample_path = "/content/drive/MyDrive/Brain_Tumor_Detection_Segmentation/brisc2025/classification_task/test/meningioma/brisc2025_test_00255_me_ax_t1.jpg"

model.load_state_dict(torch.load("best_resnet18.pth"))
model.to(device)
predict_image_mc(sample_path, model, train_dataset.classes, n_iter=10)


Prediction: glioma (mean conf: 0.4716 ± 0.1257)
  - glioma: mean=0.4716, std=0.1257
  - meningioma: mean=0.1376, std=0.0447
  - no_tumor: mean=0.2882, std=0.0859
  - pituitary: mean=0.1027, std=0.0565


(0,
 array([0.47156852, 0.13756481, 0.28820685, 0.1026598 ], dtype=float32),
 array([0.12565494, 0.04474884, 0.08594912, 0.05651054], dtype=float32))