In [None]:
from pybbbc import BBBC021
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
from tqdm import tqdm


import sys
sys.path.append("..")

from models.load_model import MODEL_NAMES
from training.wsdino_resnet_train import (
    BBBC021WeakLabelDataset,
    get_resnet50,
    dino_loss,
    update_teacher
)

In [None]:
# Hyperparameters
num_classes = 12
batch_size = 64
lr = 0.001
epochs = 10  # reduce for notebook
momentum = 0.996
temperature = 0.07
save_path = "./resnet_wsdino.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load the preprocessed BBBC021 dataset using pybbbc,
# pointing to the processed HDF5 dataset on disk
bbbc = BBBC021(root_path="/scratch/cv-course2025/group8/processed")

# Define a torchvision transform: resize all images to 224x224 and convert to tensor
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Wrap the BBBC021 dataset into a PyTorch-compatible Dataset,
# filtering out samples with 'null' MoA and applying transforms
dataset = BBBC021WeakLabelDataset(bbbc, transform)

# Create a DataLoader for batching and shuffling the dataset during training
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Initialize the student ResNet-50 model with a classification head
# The model type is BASE_RESNET (ImageNet-pretrained)
student = get_resnet50(num_classes, MODEL_NAMES.BASE_RESNET)

# Initialize the teacher model with the same architecture and weights as the student
teacher = get_resnet50(num_classes, MODEL_NAMES.BASE_RESNET)
teacher.load_state_dict(student.state_dict())  # synchronize weights

# Freeze all parameters in the teacher so it won't be updated via gradient descent
for p in teacher.parameters():
    p.requires_grad = False

# Set up the optimizer (Adam) to update only the student model parameters
optimizer = torch.optim.Adam(student.parameters(), lr=lr)

In [None]:
# Training loop
for epoch in range(epochs):
    student.train()  # Set the student model to training mode
    total_loss = 0.0  # Accumulator for total loss in this epoch

    # Iterate over each batch in the dataloader with a progress bar
    for imgs, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        # Move input images and labels to GPU or CPU depending on `device`
        imgs, labels = imgs.to(device), labels.to(device)

        # Forward pass through the student model
        student_out = student(imgs)

        # Forward pass through the teacher model (no gradients needed)
        with torch.no_grad():
            teacher_out = teacher(imgs)

        # Compute DINO distillation loss between student and teacher outputs
        loss = dino_loss(student_out, teacher_out)

        # Backpropagation step
        optimizer.zero_grad()     # Clear existing gradients
        loss.backward()           # Compute gradients
        optimizer.step()          # Update student model weights

        # Update teacher weights using exponential moving average of student weights
        update_teacher(student, teacher, momentum)

        # Accumulate the batch loss
        total_loss += loss.item()

    # Compute and print the average loss for the epoch
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}: Avg Loss = {avg_loss:.4f}")


In [None]:
torch.save(student.state_dict(), save_path)
print(f"Model saved to {save_path}")