In [1]:
!pip install easyfsl
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torchvision.models import resnet18
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score

from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average

# Parameters
image_size = 32  # CIFAR-100 image size
N_WAY = 5  # Number of classes in a task
N_SHOT = 5  # Number of images per class in the support set
N_QUERY = 10  # Number of images per class in the query set
N_EVALUATION_TASKS = 100
N_TRAINING_EPISODES = 40000

# CIFAR-100 Dataset
train_set = CIFAR100(
    root="./data",
    train=True,
    transform=transforms.Compose(
        [
            transforms.RandomCrop(image_size, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    ),
    download=True,
)

test_set = CIFAR100(
    root="./data",
    train=False,
    transform=transforms.Compose(
        [
            transforms.Resize([image_size, image_size]),
            transforms.ToTensor(),
        ]
    ),
    download=True,
)

# Prototypical Networks
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(self, support_images: torch.Tensor, support_labels: torch.Tensor, query_images: torch.Tensor) -> torch.Tensor:
        z_support = self.backbone(support_images)
        z_query = self.backbone(query_images)

        n_way = len(torch.unique(support_labels))
        z_proto = torch.cat([z_support[support_labels == label].mean(0).unsqueeze(0) for label in range(n_way)])

        dists = torch.cdist(z_query, z_proto)
        scores = -dists
        return scores

convolutional_network = resnet18(pretrained=True)
convolutional_network.fc = nn.Flatten()

model = PrototypicalNetworks(convolutional_network).cuda()

# Test DataLoader
test_set.get_labels = lambda: [instance[1] for instance in test_set]
test_sampler = TaskSampler(test_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS)
test_loader = DataLoader(test_set, batch_sampler=test_sampler, num_workers=12, pin_memory=True, collate_fn=test_sampler.episodic_collate_fn)

# Training DataLoader
train_set.get_labels = lambda: [instance[1] for instance in train_set]
train_sampler = TaskSampler(train_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_TRAINING_EPISODES)
train_loader = DataLoader(train_set, batch_sampler=train_sampler, num_workers=12, pin_memory=True, collate_fn=train_sampler.episodic_collate_fn)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training function
def fit(support_images: torch.Tensor, support_labels: torch.Tensor, query_images: torch.Tensor, query_labels: torch.Tensor) -> float:
    optimizer.zero_grad()
    classification_scores = model(support_images.cuda(), support_labels.cuda(), query_images.cuda())
    loss = criterion(classification_scores, query_labels.cuda())
    loss.backward()
    optimizer.step()
    return loss.item()

# Train the model
log_update_frequency = 10
all_loss = []
model.train()
with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train:
    for episode_index, (support_images, support_labels, query_images, query_labels, _) in tqdm_train:
        loss_value = fit(support_images, support_labels, query_images, query_labels)
        all_loss.append(loss_value)

        if episode_index % log_update_frequency == 0:
            tqdm_train.set_postfix(loss=sliding_average(all_loss, log_update_frequency))

# Evaluation function
def evaluate_on_one_task(support_images: torch.Tensor, support_labels: torch.Tensor, query_images: torch.Tensor, query_labels: torch.Tensor):
    predicted_labels = torch.max(model(support_images.cuda(), support_labels.cuda(), query_images.cuda()).detach().data, 1)[1]
    correct_predictions = (predicted_labels == query_labels.cuda()).sum().item()
    return correct_predictions, len(query_labels), predicted_labels

def evaluate(data_loader: DataLoader):
    total_predictions = 0
    correct_predictions = 0
    all_pred_labels = []
    all_true_labels = []

    model.eval()
    with torch.no_grad():
        for episode_index, (support_images, support_labels, query_images, query_labels, class_ids) in tqdm(enumerate(data_loader), total=len(data_loader)):
            correct, total, predicted_labels = evaluate_on_one_task(support_images, support_labels, query_images, query_labels)
            total_predictions += total
            correct_predictions += correct
            all_pred_labels.extend(predicted_labels.cpu().numpy())
            all_true_labels.extend(query_labels.cpu().numpy())

    accuracy = (100 * correct_predictions / total_predictions)
    precision = precision_score(all_true_labels, all_pred_labels, average='weighted')
    recall = recall_score(all_true_labels, all_pred_labels, average='weighted')
    f1 = f1_score(all_true_labels, all_pred_labels, average='weighted')

    print(f"Model tested on {len(data_loader)} tasks. Accuracy: {accuracy:.2f}%")
    print(f"Precision: {precision:.2f}")
    print(f"Recall: {recall:.2f}")
    print(f"F1 Score: {f1:.2f}")

# Evaluate the model on the test set
evaluate(test_loader)


Files already downloaded and verified
Files already downloaded and verified


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 173MB/s]
100%|██████████| 40000/40000 [35:09<00:00, 18.96it/s, loss=0.727]
100%|██████████| 100/100 [00:03<00:00, 28.20it/s]

Model tested on 100 tasks. Accuracy: 69.68%
Precision: 0.70
Recall: 0.70
F1 Score: 0.70



