In [None]:
!pip install torchvision torch scikit-learn matplotlib opendatasets --quiet

In [None]:
import opendatasets as od
import os

dataset_url = 'https://www.kaggle.com/datasets/kmader/skin-cancer-mnist-ham10000'
od.download(dataset_url)

data_dir = './skin-cancer-mnist-ham10000'

In [None]:
import pandas as pd
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset
from PIL import Image
import os
import random
from collections import defaultdict

Data Prep

In [None]:

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet stats
                         std=[0.229, 0.224, 0.225])
])

class ISICEpisodicDataset(Dataset):
    def __init__(self, image_folder, n_way, k_shot, q_queries, episodes_per_epoch):
        self.image_folder = image_folder
        self.classes = os.listdir(image_folder)
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_queries = q_queries
        self.episodes_per_epoch = episodes_per_epoch

    def __len__(self):
        return self.episodes_per_epoch

    def __getitem__(self, index):
        episode_classes = random.sample(self.classes, self.n_way)
        support_images, support_labels = [], []
        query_images, query_labels = [], []

        label_map = {cls: idx for idx, cls in enumerate(episode_classes)}

        for cls in episode_classes:
            class_path = os.path.join(self.image_folder, cls)
            images = os.listdir(class_path)
            selected = random.sample(images, self.k_shot + self.q_queries)
            support = selected[:self.k_shot]
            query = selected[self.k_shot:]

            for img_name in support:
                img = Image.open(os.path.join(class_path, img_name)).convert("RGB")
                support_images.append(transform(img))
                support_labels.append(label_map[cls])

            for img_name in query:
                img = Image.open(os.path.join(class_path, img_name)).convert("RGB")
                query_images.append(transform(img))
                query_labels.append(label_map[cls])

        return (torch.stack(support_images), torch.tensor(support_labels),
                torch.stack(query_images), torch.tensor(query_labels))


Prototypical Network with ResNet-18

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

In [None]:

class ProtoNet(nn.Module):
    def __init__(self):
        super(ProtoNet, self).__init__()
        self.encoder = models.resnet18(pretrained=True)
        self.encoder.fc = nn.Identity()  # remove final classification layer

    def forward(self, x):
        return self.encoder(x)

    def compute_prototypes(self, support, support_labels, n_way, k_shot):
        embeddings = self.forward(support)
        prototypes = []
        for i in range(n_way):
            class_embeddings = embeddings[support_labels == i]
            prototype = class_embeddings.mean(0)
            prototypes.append(prototype)
        return torch.stack(prototypes)

    def classify(self, query, prototypes):
        query_embeddings = self.forward(query)
        dists = torch.cdist(query_embeddings, prototypes)
        return -dists  # negative distances for softmax


Training Loop

In [None]:
import torch.optim as optim
from sklearn.metrics import accuracy_score, f1_score

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ProtoNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def train_protonet(model, dataloader, n_way, k_shot, q_queries, epochs):
    model.train()
    for epoch in range(epochs):
        all_acc = []
        for support_x, support_y, query_x, query_y in dataloader:
            support_x, support_y = support_x.to(device), support_y.to(device)
            query_x, query_y = query_x.to(device), query_y.to(device)

            prototypes = model.compute_prototypes(support_x, support_y, n_way, k_shot)
            logits = model.classify(query_x, prototypes)

            loss = F.cross_entropy(logits, query_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            preds = torch.argmax(logits, dim=1)
            acc = accuracy_score(query_y.cpu(), preds.cpu())
            all_acc.append(acc)

        print(f"[Epoch {epoch+1}] Avg Accuracy: {sum(all_acc)/len(all_acc):.4f}")


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 77.4MB/s]


Evaluation Function

In [None]:
def evaluate_protonet(model, dataloader, n_way, k_shot):
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for support_x, support_y, query_x, query_y in dataloader:
            support_x, support_y = support_x.to(device), support_y.to(device)
            query_x, query_y = query_x.to(device), query_y.to(device)

            prototypes = model.compute_prototypes(support_x, support_y, n_way, k_shot)
            logits = model.classify(query_x, prototypes)

            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(query_y.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    print(f"[Evaluation] Accuracy: {acc:.4f}, F1-score: {f1:.4f}")


Run Training and Evaluation

In [None]:
# Params
n_way = 5
k_shot = 1  # or 5 for 5-shot
q_queries = 5
episodes_per_epoch = 50
epochs = 10

train_dataset = ISICEpisodicDataset(f"{data_dir}/images", n_way, k_shot, q_queries, episodes_per_epoch)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)

# Train
train_protonet(model, train_loader, n_way, k_shot, q_queries, epochs)

evaluate_protonet(model, train_loader, n_way, k_shot)


NotADirectoryError: [Errno 20] Not a directory: '/content/skin-cancer-mnist-ham10000/HAM10000_metadata.csv'

In [None]:
/content/skin-cancer-mnist-ham10000