In [None]:
from load_data import MultiModalCellDataset
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn as nn
from torchvision.models import resnet18
import torch
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# --- Define transforms ---
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3),
])

In [47]:
# --- Config ---
csv_path = 'multimodal-cancer-classification-challenge-2025/train.csv'
bf_train_dir = 'multimodal-cancer-classification-challenge-2025/BF/train'
fl_train_dir = 'multimodal-cancer-classification-challenge-2025/FL/train'
modality = 'both'  # Choose from: 'BF', 'FL', or 'both'

images_per_patient_train = 100   # Number of images to sample for training per patient
max_patients = 12             # Set to int (e.g., 6) to limit, or None to use all

# --- Load and prepare ---
df = pd.read_csv(csv_path)
df['patient_id'] = df['Name'].apply(lambda x: x.split('_image_')[0])
unique_patients = df['patient_id'].unique()

print(f"Total patients in dataset: {len(unique_patients)}")

# --- Limit number of patients if needed ---
if max_patients is not None and max_patients < len(unique_patients):
    selected_patients = np.random.choice(unique_patients, size=max_patients, replace=False)
else:
    selected_patients = unique_patients

print(f"Number of patients selected: {len(selected_patients)}")

# --- Sample per patient ---
train_rows = []
val_rows = []

for pid in selected_patients:
    patient_df = df[df['patient_id'] == pid]
    
    # Determine how many images are available
    total_images = len(patient_df)
    train_n = min(images_per_patient_train, total_images)
    val_n = max(1, train_n // 10)  # at least 1 for validation
    
    # Shuffle and split
    sampled = patient_df.sample(frac=1, random_state=42).reset_index(drop=True)
    train_sample = sampled.iloc[:train_n]
    val_sample = sampled.iloc[train_n:train_n + val_n]
    
    train_rows.append(train_sample)
    val_rows.append(val_sample)

# --- Concatenate and save ---
train_df = pd.concat(train_rows).reset_index(drop=True)
val_df = pd.concat(val_rows).reset_index(drop=True)

print(f"Final train set: {len(train_df)} images")
print(f"Final val set:   {len(val_df)} images")

train_df.to_csv('sampled_train.csv', index=False)
val_df.to_csv('sampled_val.csv', index=False)

Total patients in dataset: 12
Number of patients selected: 12
Final train set: 1200 images
Final val set:   120 images


In [48]:
train_dataset = MultiModalCellDataset(
    csv_file='sampled_train.csv',
    bf_dir=bf_train_dir,
    fl_dir=fl_train_dir,
    transform=transform,
    mode=modality
)

val_dataset = MultiModalCellDataset(
    csv_file='sampled_val.csv',
    bf_dir=bf_train_dir,
    fl_dir=fl_train_dir,
    transform=transform,
    mode=modality
)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [40]:
# --- Sample test run ---
if __name__ == "__main__":
    for images, labels in train_loader:
        print("Train batch:", images.shape, labels.shape)

Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128]) torch.Size([16])
Train batch: torch.Size([16, 6, 128, 128

In [33]:
def get_modified_resnet18(input_channels=3, pretrained=True):
    model = resnet18(pretrained=pretrained)

    if input_channels != 3:
        # Replace the first conv layer to accept more channels (e.g. 6 for BF+FL)
        model.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

    # Replace final FC for binary classification
    model.fc = nn.Linear(model.fc.in_features, 1)  # Output: raw logits

    return model

In [49]:
def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10):
    model.to(device)

    for epoch in range(num_epochs):
        # --- Training phase ---
        model.train()
        total_train_loss = 0
        train_preds = []
        train_labels = []

        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.float().to(device).unsqueeze(1)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

            probs = torch.sigmoid(outputs).detach().cpu().numpy()
            pred_classes = [1 if p > 0.5 else 0 for p in probs]
            train_preds.extend(pred_classes)
            train_labels.extend(labels.cpu().numpy())

        train_acc = accuracy_score(train_labels, train_preds)
        avg_train_loss = total_train_loss / len(train_loader)

        # --- Validation phase ---
        model.eval()
        total_val_loss = 0
        val_preds = []
        val_probs = []
        val_labels = []

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.float().to(device).unsqueeze(1)

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                total_val_loss += loss.item()

                probs = torch.sigmoid(outputs).cpu().numpy()
                pred_classes = [1 if p > 0.5 else 0 for p in probs]
                val_probs.extend(probs)
                val_preds.extend(pred_classes)
                val_labels.extend(labels.cpu().numpy())

        val_acc = accuracy_score(val_labels, val_preds)
        val_auc = roc_auc_score(val_labels, val_probs)
        avg_val_loss = total_val_loss / len(val_loader)

        # --- Output ---
        print(f"Epoch {epoch+1:03d} | "
              f"Train Loss: {avg_train_loss:.4f}, Acc: {train_acc:.4f} | "
              f"Val Loss: {avg_val_loss:.4f}, Acc: {val_acc:.4f}, AUC: {val_auc:.4f}")


In [50]:
if modality.lower() == 'both':
    input_channels = 6
else:
    input_channels = 3

model = get_modified_resnet18(input_channels=input_channels ,pretrained=True)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)



In [42]:
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    num_epochs=10
)

Epoch 001 | Train Loss: 0.4805, Acc: 0.7742 | Val Loss: 0.4312, Acc: 0.7417, AUC: 0.8760
Epoch 002 | Train Loss: 0.3071, Acc: 0.8742 | Val Loss: 1.6148, Acc: 0.4833, AUC: 0.7329
Epoch 003 | Train Loss: 0.1950, Acc: 0.9200 | Val Loss: 0.2577, Acc: 0.8750, AUC: 0.9583
Epoch 004 | Train Loss: 0.1578, Acc: 0.9308 | Val Loss: 0.2551, Acc: 0.9000, AUC: 0.9557
Epoch 005 | Train Loss: 0.0846, Acc: 0.9650 | Val Loss: 0.2417, Acc: 0.8750, AUC: 0.9663
Epoch 006 | Train Loss: 0.1423, Acc: 0.9483 | Val Loss: 0.3826, Acc: 0.8500, AUC: 0.9429
Epoch 007 | Train Loss: 0.0943, Acc: 0.9650 | Val Loss: 0.1944, Acc: 0.8750, AUC: 0.9714
Epoch 008 | Train Loss: 0.0646, Acc: 0.9775 | Val Loss: 0.4688, Acc: 0.8250, AUC: 0.9549
Epoch 009 | Train Loss: 0.0632, Acc: 0.9775 | Val Loss: 0.8402, Acc: 0.8500, AUC: 0.9266
Epoch 010 | Train Loss: 0.0734, Acc: 0.9767 | Val Loss: 0.3723, Acc: 0.8417, AUC: 0.9346
