# DenseNet121 Centralized Baseline

This notebook loads the classification manifest, builds train/val/test splits, and fine-tunes a frozen DenseNet121 head for 8 epochs using the centralized training utilities.

In [1]:
from pathlib import Path
import sys

repo_root = Path("/workspace")  # replace with your repo root
if str(repo_root) not in sys.path:
    sys.path.append(str(repo_root))
%load_ext autoreload
%autoreload 2

In [3]:
import os
os.getcwd()

'/workspace/experiments'

In [2]:
import pandas as pd
from data.loader import split_manifest_dataframe, build_datasets_from_manifest
from torch.utils.data import DataLoader

manifest_path = Path("../data/manifests/classification_manifest.csv")
base_dir = Path("../data/brisc2025")

train_df, val_df, test_df = split_manifest_dataframe(
    manifest_path=manifest_path,
    train_split_value="train",
    test_split_value="test",
    val_fraction=0.2,
    stratify_columns=("tumor", "plane"),
    random_state=42,
)

print(f"train rows: {len(train_df)} | val rows: {len(val_df)} | test rows: {len(test_df)}")
display(train_df.head())

ImportError: cannot import name 'split_manifest_dataframe' from 'data.loader' (/workspace/data/loader.py)

In [7]:
train_set, val_set, test_set = build_datasets_from_manifest(
    manifest_path=manifest_path,
    base_dir=base_dir,
    train_split_value="train",
    test_split_value="test",
    val_fraction=0.2,
    stratify_columns=("tumor", "plane"),
    random_state=42,
)

batch_size = 32
num_workers = 0  # set to 0 if running on an environment without worker support

loader_kwargs = dict(batch_size=batch_size, num_workers=num_workers, pin_memory=True)

train_loader = DataLoader(train_set, shuffle=True, drop_last=False, **loader_kwargs)
val_loader = DataLoader(val_set, shuffle=False, drop_last=False, **loader_kwargs)
test_loader = DataLoader(test_set, shuffle=False, drop_last=False, **loader_kwargs)

len(train_set), len(val_set), len(test_set)

(4000, 1000, 1000)

In [8]:
import wandb

wandb.init(project="centralized_baseline", name="densenet121_head", mode="disabled")



In [9]:
import torch
import torch.nn as nn
from torchvision import models

num_classes = 4
weights = models.DenseNet121_Weights.IMAGENET1K_V1
model = models.densenet121(weights=weights)

for param in model.features.parameters():
    param.requires_grad = False

in_features = model.classifier.in_features
model.classifier = nn.Linear(in_features, num_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
device

device(type='cpu')

In [None]:
from training.centralized_training import train_and_validate
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.classifier.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)

checkpoint_dir = Path("checkpoints")
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = checkpoint_dir / "densenet121_centralized_head.pth"

best_val_acc = train_and_validate(
    start_epoch=1,
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    scheduler=scheduler,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    checkpoint_path=str(checkpoint_path),
    num_epochs=8,
    checkpoint_interval=2,
)

best_val_acc

Unnamed: 0,filename,path,split,index,tumor,plane,sequence,label
0,brisc2025_test_00001_gl_ax_t1.jpg,classification_task/test/glioma/brisc2025_test...,test,1,glioma,ax,t1,0
1,brisc2025_test_00002_gl_ax_t1.jpg,classification_task/test/glioma/brisc2025_test...,test,2,glioma,ax,t1,0
2,brisc2025_test_00003_gl_ax_t1.jpg,classification_task/test/glioma/brisc2025_test...,test,3,glioma,ax,t1,0
3,brisc2025_test_00004_gl_ax_t1.jpg,classification_task/test/glioma/brisc2025_test...,test,4,glioma,ax,t1,0
4,brisc2025_test_00005_gl_ax_t1.jpg,classification_task/test/glioma/brisc2025_test...,test,5,glioma,ax,t1,0


In [None]:
from training.centralized_training import test_epoch

test_loss, test_accuracy = test_epoch(model, test_loader, criterion, device)
print(f"Test loss: {test_loss:.4f} | Test accuracy: {test_accuracy:.2f}%")

wandb.finish()

In [None]:
import torch
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

model.eval()
all_preds = []
all_targets = []

with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(labels.cpu().numpy())

cm = confusion_matrix(all_targets, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["glioma", "meningioma", "no_tumor", "pituitary"])
disp.plot(cmap="Blues")
plt.title("Validation Confusion Matrix")
plt.show()
