# Token-Based Knowledge Distillation for ISIC 2024 Skin Cancer Detection

This notebook is an implementation of token-based knowledge distillation for the ISIC 2024 Skin Cancer Detection with 3D-TBP challenge. 

The approach leverages a pre-trained Swin Transformer V2 Small model as the teacher and a smaller Vision Transformer (ViT) as the student model to create an efficient yet accurate skin cancer detection system.

## Environment Setup and Dependencies

In [67]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.amp import GradScaler, autocast
from io import BytesIO
import h5py
import timm
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from PIL import Image
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings('ignore')

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


## Dataset Preparation

In [68]:
class ISICDataset(Dataset):
    def __init__(self, image_paths, labels=None, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        if self.labels is not None:
            label = self.labels[idx]
            return image, label
        else:
            return image

# Define data transformations
def get_transforms(mode='train'):
    if mode == 'train':
        return transforms.Compose([
            transforms.RandomResizedCrop(256),  # Changed from 224
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
        ])
    else:
        return transforms.Compose([
            transforms.Resize(256),        # Changed from 256->224
            transforms.CenterCrop(256),    # Changed from 224
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
        ])


# Load the dataset
def prepare_data(data_dir="Data", max_label0_train=40000, val_label0_size=10000, val_label1_ratio=0.2, batch_size=32):
    # Load image paths
    benign_dir = os.path.join(data_dir, "train-image-0_hairRemove")
    malignant_dir = os.path.join(data_dir, "train-image-1_hairRemove")
    
    benign_images = [os.path.join(benign_dir, img) for img in os.listdir(benign_dir) if img.endswith('.jpg')]
    malignant_images = [os.path.join(malignant_dir, img) for img in os.listdir(malignant_dir) if img.endswith('.jpg')]

    # Process label0 (benign)
    np.random.seed(42)
    benign_train = np.random.choice(benign_images, size=max_label0_train, replace=False)
    remaining_benign = list(set(benign_images) - set(benign_train))
    benign_val = np.random.choice(remaining_benign, size=val_label0_size, replace=False)

    # Process label1 (malignant) - keep all in training and split a small portion for validation
    malignant_train, malignant_val = train_test_split(
        malignant_images, 
        test_size=val_label1_ratio, 
        random_state=42
    )

    # Create final datasets
    train_imgs = list(benign_train) + list(malignant_train)
    train_labels = [0]*len(benign_train) + [1]*len(malignant_train)
    
    val_imgs = list(benign_val) + list(malignant_val)
    val_labels = [0]*len(benign_val) + [1]*len(malignant_val)

    # Shuffle training data
    train_combined = list(zip(train_imgs, train_labels))
    np.random.shuffle(train_combined)
    train_imgs, train_labels = zip(*train_combined)

    print(f"Training set: {len(train_imgs)} images (0: {len(benign_train)}, 1: {len(malignant_train)})")
    print(f"Validation set: {len(val_imgs)} images (0: {len(benign_val)}, 1: {len(malignant_val)})")

    # Create datasets
    train_dataset = ISICDataset(train_imgs, train_labels, transform=get_transforms('train'))
    val_dataset = ISICDataset(val_imgs, val_labels, transform=get_transforms('val'))

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    return train_loader, val_loader




## Teacher Model Setup

In [69]:
def load_teacher_model():
    model = timm.create_model('swinv2_small_window8_256', 
                            pretrained=False, 
                            num_classes=2)
    model.load_state_dict(torch.load('clean_swin_weights.pth'))
    return model.to(device).eval()



## Student Model Setup

In [None]:
class DistillableViT(nn.Module):
    def __init__(self, image_size=256, patch_size=16, num_classes=2, dim=384,  # Changed image_size
                 depth=12, heads=6, mlp_dim=1536, dropout=0.1, emb_dropout=0.1):
        super().__init__()
        
        self.vit = timm.create_model(
            'vit_small_patch16_224', 
            pretrained=True,
            num_classes=num_classes,
            img_size=image_size,  # Now 256
            patch_size=patch_size
        )
        # torch.save(self.vit.state_dict(), 'vit_weights.pth')
        
        # Get the feature dimension of the ViT
        self.dim = dim
        
        # Add a distillation token alongside the class token
        self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
        
        # Adjust the position embedding to account for the extra token
        num_patches = (image_size // patch_size) ** 2
        pos_embedding = self.vit.pos_embed
        
        # Create new position embeddings to include the distillation token
        new_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, dim))
        # Copy existing embeddings
        new_pos_embed.data[:, 0:1, :] = pos_embedding.data[:, 0:1, :]  # class token
        new_pos_embed.data[:, 2:, :] = pos_embedding.data[:, 1:, :]    # patch tokens
        
        self.vit.pos_embed = new_pos_embed
        
    def forward(self, x):
        b, c, h, w = x.shape

        # Get the class token and patch embeddings from ViT
        x = self.vit.patch_embed(x)
        cls_token = self.vit.cls_token.expand(b, -1, -1)
        dist_token = self.distillation_token.expand(b, -1, -1)

        # Concatenate tokens
        x = torch.cat((cls_token, dist_token, x), dim=1)
        
        # Add position embeddings
        x = x + self.vit.pos_embed
        x = self.vit.pos_drop(x)
        
        # Apply transformer blocks
        for blk in self.vit.blocks:
            x = blk(x)
        
        x = self.vit.norm(x)
        
        # Get only the class token output for inference
        cls_token_out = x[:, 0]
        
        # Return logits
        return self.vit.head(cls_token_out), cls_token_out


## Knowledge Distillation

In [71]:
class KnowledgeDistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=3.0, feature_adapter=None, class_weights=None):
        super().__init__()
        self.alpha = alpha
        self.temperature = temperature

        # Add class weights to cross entropy loss
        if class_weights is not None:
            self.ce_loss = nn.CrossEntropyLoss(weight=class_weights)
        else:
            self.ce_loss = nn.CrossEntropyLoss()

        self.kl_loss = nn.KLDivLoss(reduction='batchmean')
        self.feature_adapter = feature_adapter
        
    def forward(self, student_logits, student_features, 
               teacher_logits, teacher_features, labels):
        # Project teacher features if needed
        if self.feature_adapter:
            teacher_features = self.feature_adapter(teacher_features)
            
        # Rest of the loss calculation remains the same
        student_cls_loss = self.ce_loss(student_logits, labels)
        
        soft_targets = nn.functional.softmax(teacher_logits / self.temperature, dim=1)
        soft_predictions = nn.functional.log_softmax(student_logits / self.temperature, dim=1)
        soft_loss = self.kl_loss(soft_predictions, soft_targets) * (self.temperature ** 2)
        
        feature_loss = nn.functional.mse_loss(student_features, teacher_features)
        
        loss = (1 - self.alpha) * student_cls_loss + \
               self.alpha * (0.7 * soft_loss + 0.3 * feature_loss)
        
        return loss

## Training Loop

In [72]:
def calculate_pAUC(val_targets, val_preds, min_tpr=0.80):
    """Calculate partial AUC above specified TPR threshold."""
    v_gt = 1 - np.array(val_targets)  # Invert labels (malignant ↔ benign)
    v_pred = -np.array(val_preds)     # Negate predictions for ranking reversal
    
    max_fpr = 1 - min_tpr  # Convert TPR threshold to FPR limit
    
    fpr, tpr, _ = roc_curve(v_gt, v_pred)
    stop_idx = np.searchsorted(fpr, max_fpr, side='right')
    
    if stop_idx == 0:
        return 0.0
    elif stop_idx == len(fpr):
        return auc(fpr, tpr)
    else:
        # Linear interpolation at boundary
        x_interp = [fpr[stop_idx-1], fpr[stop_idx]]
        y_interp = [tpr[stop_idx-1], tpr[stop_idx]]
        interp_tpr = np.interp(max_fpr, x_interp, y_interp)
        
        partial_fpr = np.append(fpr[:stop_idx], max_fpr)
        partial_tpr = np.append(tpr[:stop_idx], interp_tpr)
        return auc(partial_fpr, partial_tpr)

In [73]:
def train_with_distillation(teacher_model, student_model, train_loader, val_loader, 
                           num_epochs=10, lr=1e-4, weight_decay=1e-5,
                           alpha=0.5, temperature=3.0, feature_adapter=None,
                           patience=5, delta=0.001):
    # Initialize GradScaler for mixed precision
    scaler = torch.amp.GradScaler()

    # Include feature_adapter in optimizer
    optimizer_params = list(student_model.parameters())
    if feature_adapter:
        optimizer_params += list(feature_adapter.parameters())

    # Define optimizer
    optimizer = optim.AdamW(optimizer_params, lr=lr, weight_decay=weight_decay)
    
    # Define scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    # Define loss function
    class_weights = torch.tensor([0.02, 1.0]).to(device)  # Adjust as needed
    criterion = KnowledgeDistillationLoss(alpha=alpha, temperature=temperature,
                                         feature_adapter=feature_adapter, class_weights=class_weights)
    
    # Training loop
    best_val_auc = 0.0
    best_model_path = 'best_student_model.pth'
    counter = 0  # Track epochs without improvement
    
    for epoch in range(num_epochs):
        # Training phase
        student_model.train()
        teacher_model.eval()
        
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        
        for batch_idx, (inputs, targets) in enumerate(progress_bar):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # # Get teacher predictions
            # with torch.no_grad():
            #     teacher_logits = teacher_model(inputs)
            #     # Extract the last hidden state from the teacher
            #     # Note: This is model-specific and may need to be adjusted
            #     teacher_features = teacher_model.head.global_pool(
            #         teacher_model.forward_features(inputs))
            
            # # Forward pass through student
            # student_logits, student_features = student_model(inputs)
            
            # # Calculate loss
            # loss = criterion(student_logits, student_features, 
            #                 teacher_logits, teacher_features, targets)
            
            # # Backward and optimize
            # optimizer.zero_grad()
            # loss.backward()
            # optimizer.step()

            # Mixed precision forward pass
            with autocast(device_type=device.type):
                # Get teacher predictions
                with torch.no_grad():
                    teacher_logits = teacher_model(inputs)
                    teacher_features = teacher_model.head.global_pool(
                        teacher_model.forward_features(inputs))
                
                # Student forward pass
                student_logits, student_features = student_model(inputs)
                
                # Calculate loss
                loss = criterion(student_logits, student_features, 
                                teacher_logits, teacher_features, targets)

            # Backward and optimize with scaler
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # Calculate training metrics
            _, predicted = student_logits.float().max(1)  # Cast to float32 for accuracy
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()
            
            # Calculate training accuracy
            _, predicted = student_logits.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()
            
            # Update progress bar
            train_loss += loss.item()
            avg_loss = train_loss / (batch_idx + 1)
            acc = 100. * train_correct / train_total
            progress_bar.set_postfix(loss=avg_loss, acc=f'{acc:.2f}%')
        
        # Validation phase
        student_model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        val_preds = []
        val_targets = []
        
        progress_bar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
        
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(progress_bar):
                inputs, targets = inputs.to(device), targets.to(device)
                
                # Forward pass through student (in eval mode, returns only logits)
                # outputs = student_model(inputs)
                outputs = student_model(inputs).float()  # Cast to float32
                
                # Calculate validation accuracy
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
                
                # Store predictions and targets for AUC calculation
                val_preds.extend(torch.softmax(outputs, dim=1)[:, 1].cpu().numpy())
                val_targets.extend(targets.cpu().numpy())
                
                # Update progress bar
                acc = 100. * val_correct / val_total
                progress_bar.set_postfix(acc=f'{acc:.2f}%')
        
        # Calculate validation AUC
        val_auc = roc_auc_score(val_targets, val_preds)
        val_prauc = average_precision_score(val_targets, val_preds)
        val_pauc = calculate_pAUC(val_targets, val_preds, min_tpr=0.80)
        
        # Print epoch summary
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {avg_loss:.4f} | Train Acc: {100.*train_correct/train_total:.2f}%')
        print(f'Val Acc: {100.*val_correct/val_total:.2f}% | Val AUC: {val_auc:.4f} | Val PR-AUC: {val_prauc:.4f} | Val pAUC: {val_pauc:.4f}')
        
        # Early stopping check
        if val_pauc > best_val_auc + delta:  # Improvement threshold
            best_val_auc = val_pauc
            torch.save(student_model.state_dict(), best_model_path)
            print(f'Saved best model with PR-AUC: {val_pauc:.4f}')
            counter = 0  # Reset counter
        else:
            counter += 1
            print(f'No improvement for {counter}/{patience} epochs')
            if counter >= patience:
                print(f'Early stopping at epoch {epoch+1}')
                break  # Exit training loop

        # Update learning rate
        scheduler.step()
    
    # Load best model
    student_model.load_state_dict(torch.load(best_model_path))
    
    return student_model, best_val_auc


## Feature Adaptation Layer

When distilling from a Swin Transformer to a ViT, we need to handle the difference in feature dimensions. Add a feature adaptation layer to align the teacher's features with the student's.

In [74]:
class FeatureAdapter(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.adapter = nn.Sequential(
            nn.Linear(in_features, out_features),
            nn.ReLU(),
            nn.Linear(out_features, out_features)
        )
    
    def forward(self, x):
        return self.adapter(x)


In [75]:
def interpolate_pos_embed(pos_embed, new_num_tokens):
    # pos_embed: [1, 258, 384] (student model WITH distillation token)
    # new_num_tokens: 197 (inference model needs 197 tokens)
    
    # Separate tokens
    cls_token = pos_embed[:, :1, :]    # [1, 1, 384] class token
    dist_token = pos_embed[:, 1:2, :]  # [1, 1, 384] distillation token (to discard)
    patch_embed = pos_embed[:, 2:, :]  # [1, 256, 384] original patches
    
    # Reshape patches to 2D grid (16x16 for 256x256 input)
    old_size = int(patch_embed.shape[1] ** 0.5)  # 16
    patch_embed = patch_embed.reshape(1, old_size, old_size, -1).permute(0, 3, 1, 2)
    
    # Calculate new grid size (14x14 for 224x224 input)
    new_size = int((new_num_tokens - 1) ** 0.5)  # 14
    
    # Interpolate using bicubic
    patch_embed = torch.nn.functional.interpolate(
        patch_embed,
        size=(new_size, new_size),
        mode='bicubic',
        align_corners=False
    )
    
    # Re-flatten patches
    patch_embed = patch_embed.permute(0, 2, 3, 1).reshape(1, -1, patch_embed.shape[1])
    
    # Combine with class token (discard distillation token)
    new_pos_embed = torch.cat([cls_token, patch_embed], dim=1)  # [1, 197, 384]
    
    return new_pos_embed


## Complete Pipeline

In [76]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Hyperparameters
BATCH_SIZE = 32
NUM_EPOCHS = 30
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
ALPHA = 0.5  # Weight for distillation loss
TEMPERATURE = 3.0  # Temperature for soft targets

In [None]:
# Load data
train_loader, val_loader = prepare_data(batch_size=BATCH_SIZE)

# Load teacher model
teacher_model = load_teacher_model()

# Create student model
student_model = DistillableViT(image_size=256, patch_size=16, num_classes=2)
student_model = student_model.to(device)

# Create feature adapter if dimensions don't match
feature_adapter = None
teacher_feature_dim = teacher_model.num_features  # Swinv2 small: 768
student_feature_dim = student_model.dim  # ViT small: 384

if teacher_feature_dim != student_feature_dim:
    feature_adapter = FeatureAdapter(teacher_feature_dim, student_feature_dim)
    feature_adapter = feature_adapter.to(device)

# Modify the training function to use feature_adapter if needed
def modified_train_loop():
    # Training logic
    # This includes the feature adapter in the forward pass if needed
    student_model.train()
    teacher_model.eval()
    
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Get teacher predictions
        with torch.no_grad():
            teacher_logits = teacher_model(inputs)
            teacher_features = teacher_model.head.global_pool(
                teacher_model.forward_features(inputs))
            
            # Adapt teacher features if needed
            if feature_adapter:
                teacher_features = feature_adapter(teacher_features)
        
        # Forward pass through student
        student_logits, student_features = student_model(inputs)
        
        # Rest of the training loop
        # ...

# Pass feature_adapter to training function
trained_student, best_auc = train_with_distillation(
    teacher_model, student_model, train_loader, val_loader,
    num_epochs=NUM_EPOCHS, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY,
    alpha=ALPHA, temperature=TEMPERATURE, feature_adapter=feature_adapter,
    patience=10, delta=0.001
)   

print(f"Training completed. Best validation AUC: {best_auc:.4f}")

# Save the final model
torch.save({
    'model_state_dict': trained_student.state_dict(),
    'best_auc': best_auc
}, 'final_distilled_vit_small.pth')

In [None]:
# Create and load trained student model
trained_student = DistillableViT(image_size=256, patch_size=16, num_classes=2)
trained_student.load_state_dict(torch.load('./best_student_model.pth'))
trained_student.eval()

# Get student state dict (excluding distillation token)
student_state_dict = trained_student.vit.state_dict()

# Create inference model
inference_model = timm.create_model('vit_small_patch16_224', num_classes=2)
inference_state_dict = inference_model.state_dict()

# 1. Copy compatible weights
for name, param in student_state_dict.items():
    if name in inference_state_dict and name != 'pos_embed':
        inference_state_dict[name] = param

# 2. Handle positional embeddings
# Student pos_embed shape: [1, 258, 384] (class + dist + 256 patches)
# Need to convert to [1, 197, 384] (class + 196 patches)
new_pos_embed = interpolate_pos_embed(
    student_state_dict['pos_embed'],
    inference_state_dict['pos_embed'].shape[1]
)
inference_state_dict['pos_embed'] = new_pos_embed

# 3. Handle cls_token and distillation token
# Copy class token weights if needed
if 'cls_token' in student_state_dict:
    inference_state_dict['cls_token'] = student_state_dict['cls_token']

# 4. Load adjusted weights
inference_model.load_state_dict(inference_state_dict, strict=False)
inference_model = inference_model.to(device)



## Evaluation and Making Kaggle Submissions

In [None]:
def create_submission(model, metadata_path, hdf5_path, output_file='submission.csv'):
    # Load test metadata
    test_meta = pd.read_csv(metadata_path)
    
    # Corrected transforms with proper tensor handling
    val_transform = A.Compose([
        A.Resize(256, 256),
        A.CenterCrop(224, 224),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0
        ),
        ToTensorV2()  # Handles conversion to tensor and CHW format
    ])

    class ISICTestDataset(Dataset):
        def __init__(self, df, h5_file_path, transform=None):
            self.df = df
            self.h5_file_path = h5_file_path
            self.isic_ids = df['isic_id'].values
            self.transform = transform
            self.h5_file = None

        def __len__(self):
            return len(self.isic_ids)

        def __getitem__(self, idx):
            if self.h5_file is None:
                self.h5_file = h5py.File(self.h5_file_path, 'r')
            
            isic_id = self.isic_ids[idx]
            image_data = np.array(Image.open(BytesIO(self.h5_file[isic_id][()])))
            
            if self.transform:
                transformed = self.transform(image=image_data)
                image = transformed["image"]  # Already a tensor from ToTensorV2()
            else:
                # If no transform, convert to tensor here
                image = torch.from_numpy(image_data).permute(2, 0, 1).float() / 255.0
                
            return {'image': image, 'isic_id': isic_id}

    test_dataset = ISICTestDataset(test_meta, h5_file_path=hdf5_path, transform=val_transform)
    test_loader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )

    model.eval()
    all_isic_ids = []
    all_probs = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Generating Predictions"):
            images = batch['image'].to(device)
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
            all_isic_ids.extend(batch['isic_id'])
            all_probs.extend(probs)

    submission_df = pd.DataFrame({
        'isic_id': all_isic_ids,
        'malignancy': all_probs
    }).sort_values('isic_id')

    submission_df.to_csv(output_file, index=False)
    print(f"Submission file created: {output_file}")
    submission_df.head()


In [None]:
create_submission(
    inference_model,
    metadata_path='./isic-2024-challenge/test-metadata.csv',
    hdf5_path='./isic-2024-challenge/test-image.hdf5',
    output_file='submission.csv'
)

In [None]:
# test_transform = A.Compose([
#     A.Resize(256, 256),
#     A.Normalize(
#         mean=[0.4815, 0.4578, 0.4082], 
#         std=[0.2686, 0.2613, 0.2758], 
#         max_pixel_value=255.0,
#         p=1.0
#     ),
#     ToTensorV2(),
# ])

# # Prepare test dataset (no labels)
# class ISICTestDataset(Dataset):
#     def __init__(self, df, h5_file_path, transform=None):
#         self.df = df
#         self.h5_file_path = h5_file_path
#         self.isic_ids = df['isic_id'].values
#         self.transform = transform
#         self.h5_file = None

#     def __len__(self):
#         return len(self.isic_ids)

#     def __getitem__(self, idx):
#         if self.h5_file is None:
#             self.h5_file = h5py.File(self.h5_file_path, mode="r")
#         isic_id = self.isic_ids[idx]
#         img = np.array(Image.open(BytesIO(self.h5_file[isic_id][()])))
#         if self.transform:
#             img = self.transform(image=img)["image"]
#         return {'image': img, 'isic_id': isic_id}

# # Load test metadata
# test_metadata = pd.read_csv('./isic-2024-challenge/test-metadata.csv')

# # Use the same transforms as validation
# test_dataset = ISICTestDataset(
#     test_metadata,
#     h5_file_path='./isic-2024-challenge/test-image.hdf5',
#     transform=test_transform
# )

# test_loader = DataLoader(
#     test_dataset,
#     batch_size=32,
#     shuffle=False,
#     num_workers=0
# )

In [None]:
# # Instantiate the student model with the same config as training
# student_model = DistillableViT(image_size=256, patch_size=16, num_classes=2)
# student_model.load_state_dict(torch.load('best_student_model.pth'))
# student_model.eval()

# all_isic_ids = []
# all_probs = []


# with torch.no_grad():
#     for batch in tqdm(test_loader, desc="Inference"):
#         images = batch['image']
#         isic_ids = batch['isic_id']
#         outputs, _= student_model(images)
#         probs = torch.sigmoid(outputs).cpu().numpy()
#         all_isic_ids.extend(isic_ids)
#         all_probs.extend(probs)

# submission = pd.DataFrame({
#     'isic_id': all_isic_ids,
#     'target': [prob[1] for prob in all_probs]  # Extract second probability value
# })

# submission.to_csv('submission.csv', index=False)
# print("Submission file saved as submission.csv")
# submission.head()

Inference: 100%|██████████| 1/1 [00:00<00:00,  8.29it/s]

Submission file saved as submission.csv





Unnamed: 0,isic_id,target
0,ISIC_0015657,0.331845
1,ISIC_0015729,0.295968
2,ISIC_0015740,0.32907
