# Submission Notebook

The whole idea is to run an inference like we did intraining but with the test test, using all the trained: CyckeGAN for changing the style to the one we trained, the MLP classifier, and the Dino feature extractor.

We start by loading the pretrained CycleGAN generator for stain normalization, and defines a test dataset 
that uses this generator to process each test image before inference.


In [None]:

import os
import torch
import h5py
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from models.multistain_cyclegan_model import networks  # For generator definition
from tqdm import tqdm

# Set the root directory for our project.
root_path = "../"

# Helper function to load the pretrained generator model.
def load_pretrained_generator(ckpt_path):
    # Define the generator using our chosen architecture.
    gen = networks.define_G(3, 3, 64, 'resnet_9blocks', 'instance', True, "normal", 0.02, [0])
    # Load saved weights from the checkpoint.
    gen_state = torch.load(ckpt_path, map_location="cpu")
    gen.load_state_dict(gen_state)
    # Set the generator to eval mode for inference.
    gen.eval()
    return gen

# Build the full path to the generator checkpoint and load it.
gen_ckpt = os.path.join(root_path, "checkps/gen/netG_A_epoch6.pth")
gen_normalizer = load_pretrained_generator(gen_ckpt)

# Define a dataset class for test images using the pretrained generator for stain normalization.
class TestDatasetForSubmission(Dataset):
    def __init__(self, h5_path, transform=None, generator=None):
        self.h5_path = h5_path              # Path to the H5 file with test images.
        self.transform = transform          # Optional transformation (e.g., resizing).
        self.generator = generator          # Pretrained generator to normalize the stain.
        with h5py.File(self.h5_path, 'r') as f:
            self.ids = list(f.keys())       # Retrieve all image IDs.

    def __len__(self):
        return len(self.ids)               # Total number of test images.

    def __getitem__(self, idx):
        key = self.ids[idx]                # Get image key.
        with h5py.File(self.h5_path, 'r') as f:
            img = torch.tensor(f[key]['img'][()]).float()  # Load image data.
        # Convert image to channel-first if needed.
        if img.ndim == 3 and img.shape[-1] == 3:
            img = img.permute(2, 0, 1)
        # Apply transformation if provided.
        if self.transform:
            img = self.transform(img)
        # Scale pixel values to the range [-1, 1].
        img = img * 2.0 - 1.0
        # Pass the image through the generator for normalization.
        if self.generator is not None:
            img = self.generator(img.unsqueeze(0))
            img = (img + 1) / 2.0  # Rescale normalized output to [0, 1].
            img = img.squeeze(0)
        # Return the normalized image and its corresponding key as an integer.
        return img, int(key)


We do the same as we did in classification including the model definition, the feature extractor, in order to load the trained model for evaluation.


In [None]:

submission_transform = transforms.Compose([
    transforms.Resize((98, 98))
])

# Build the test dataset and DataLoader with the pretrained generator for stain normalization.
test_dataset = TestDatasetForSubmission(os.path.join(root_path, "data/test.h5"), 
                                          transform=submission_transform, 
                                          generator=gen_normalizer)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# Define an MLP classifier with three hidden layers.
class MLPClassifier(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(MLPClassifier, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        self.layer3 = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        self.output_layer = nn.Sequential(
            nn.Linear(128, out_dim),
            nn.Sigmoid()
        )
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.output_layer(x)
        return x.view(-1)

# Load the pretrained DINO model for feature extraction and move it to GPU.
dino_net = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to("cuda")
# Freeze all layers except we trained in classification 
for name, param in dino_net.named_parameters():
    if any(sub in name for sub in ["blocks.9", "blocks.10", "blocks.11", "norm", "head"]):
        param.requires_grad = True
    else:
        param.requires_grad = False


We load the pretrained classifier and fine-tuned DINO model from checkpoints, then runs inference on the test dataset and generates a CSV file for submission.

In [None]:

mlp_model = MLPClassifier(in_features=384, out_features=1).to("cuda")
cls_ckpt = os.path.join(root_path, "checkps/classif/best_classifier_paper.pth")
mlp_model.load_state_dict(torch.load(cls_ckpt, map_location="cuda"))
mlp_model.eval()

# Load fine-tuned DINO weights if available and update the DINO model state.
dino_ckpt = os.path.join(root_path, "checkps/classif/best_finetuned_dino_layers_paper.pth")
if os.path.exists(dino_ckpt):
    ft_weights = torch.load(dino_ckpt, map_location="cuda")
    dino_state = dino_net.state_dict()
    dino_state.update(ft_weights)
    dino_net.load_state_dict(dino_state)
dino_net.eval()

# Run inference on the test dataset and collect predictions.
all_predictions = {'ID': [], 'Pred': []}
with torch.no_grad():
    for imgs, ids in tqdm(test_loader, desc="Running Inference"):
        imgs = imgs.to("cuda")
        imgs = submission_transform(imgs)  # Ensure images are resized as needed
        feats = dino_net(imgs)              # Extract features using DINO
        outputs = mlp_model(feats)          # Get classifier outputs
        preds = (outputs.cpu().numpy() > 0.5).astype(int)  # Threshold at 0.5 for binary classification
        
        for idx, image_id in enumerate(ids):
            all_predictions['ID'].append(int(image_id))
            all_predictions['Pred'].append(int(preds[idx]))

# Create a DataFrame from predictions and save it as a CSV file for submission.
submission_df = pd.DataFrame(all_predictions).set_index("ID")
submission_file = "submission.csv"
submission_df.to_csv(submission_file)
print(f"Submission file created at: {submission_file}")
