In [1]:
import copy
from pathlib import Path
import random
from statistics import mean

import numpy as np
import torch
from torch import nn
from tqdm import tqdm

In [2]:
# random_seed = 0
# np.random.seed(random_seed)
# torch.manual_seed(random_seed)
# random.seed(random_seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
image_size = 32

In [32]:
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision import transforms
from numpy.random import RandomState
from torch.utils.data import Subset

from easyfsl.samplers import TaskSampler


batch_size = 128
n_workers = 1

cifar_data = CIFAR10(root="data", train=True, download=True,  transform= transforms.Compose([
            transforms.RandomResizedCrop(image_size),
            transforms.RandomHorizontalFlip(),
            
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
        ]))

cifar_data_val = CIFAR10(root='.',train=True, transform= transforms.Compose([
            # transforms.RandomResizedCrop(image_size),
            # transforms.RandomHorizontalFlip(),
            
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
        ]), download=True)

n_way = 2
n_shot = 1
n_query = 4
n_tasks_per_epoch = 5
n_validation_tasks = 200

def get_dataloader(dataset,dataset_val, n_workers, seed):
    prng = RandomState(seed)
    random_permute = prng.permutation(np.arange(0, 500))
    classes =  prng.permutation(np.arange(0,10))
    indx_train = np.concatenate([np.where(np.array(dataset.targets) == classe)[0][random_permute[0:25]] for classe in classes[0:2]])
    indx_val = np.concatenate([np.where(np.array(dataset_val.targets) == classe)[0][random_permute[25:225]] for classe in classes[0:2]])
    train_targets = np.array(dataset.targets)[indx_train]
    val_targets = np.array(dataset.targets)[indx_val]
    train_data = Subset(cifar_data, indx_train)
    print(len(train_data.dataset.targets))
    val_data = Subset(cifar_data_val, indx_val)
    train_data.get_labels = lambda: [
        t for t in train_targets
    ] 
    val_data.get_labels = lambda: [
        t for t in val_targets
    ] 
    train_sampler = TaskSampler(
        train_data, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch
    )
    val_sampler = TaskSampler(
        val_data, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
    )

    train_loader = DataLoader(
        train_data,
        batch_sampler=train_sampler,
        num_workers=n_workers,
        pin_memory=True,
        collate_fn=train_sampler.episodic_collate_fn,
    )
    
    val_loader = DataLoader(
        val_data,
        batch_sampler=val_sampler,
        num_workers=n_workers,
        pin_memory=True,
        collate_fn=val_sampler.episodic_collate_fn,
    )
    
    return train_loader, val_loader
    
random_seed = 0


Files already downloaded and verified
Files already downloaded and verified
50000
Train set size: 5
Validation set size: 200


In [42]:
from easyfsl.modules import resnet12, resnet50, 
from torchvision.models import resnet18

DEVICE = "cuda"

# model = resnet18(
#     pretrained=True, 
# )

# model.fc = nn.Sequential(
#     nn.Linear(512, 256), 
#     nn.ReLU(),
#     nn.Linear(256, 256),
#     nn.Linear(256, 10)
# )


model = resnet50()
model = model.to(DEVICE)

In [43]:
from easyfsl.methods import PrototypicalNetworks, LaplacianShot, MatchingNetworks, SimpleShot, TransductiveFinetuning, PTMAP, BDCSPN

few_shot_classifier = LaplacianShot(model).to(DEVICE)



In [50]:
from torch.optim import SGD, Optimizer, Adam
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter


LOSS_FUNCTION = nn.CrossEntropyLoss()

n_epochs = 20
scheduler_milestones = [120, 180]
scheduler_gamma = 0.1
learning_rate = 0.01
tb_logs_dir = Path(".")

# train_optimizer = SGD(
#     model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
# )

train_optimizer = Adam(
    model.parameters(), lr=learning_rate, 
)
train_scheduler = MultiStepLR(
    train_optimizer,
    milestones=scheduler_milestones,
    gamma=scheduler_gamma,
)


# train_scheduler = CosineAnnealingLR(
#     train_optimizer,
#     T_max=100
#     # milestones=scheduler_milestones,
#     # gamma=scheduler_gamma,
# )

tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))

In [51]:
def training_epoch(
    model, data_loader: DataLoader, optimizer: Optimizer
):
    all_loss = []
    model.train()
    with tqdm(
        enumerate(data_loader), total=len(data_loader), desc="Training"
    ) as tqdm_train:
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            _,
        ) in tqdm_train:
            optimizer.zero_grad()
            model.process_support_set(
                support_images.to(DEVICE), support_labels.to(DEVICE)
            )
            classification_scores = model(query_images.to(DEVICE))

            loss = LOSS_FUNCTION(classification_scores, query_labels.to(DEVICE))
            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

In [52]:
from easyfsl.utils import evaluate


best_state = model.state_dict()
best_validation_accuracy = 0.0
validation_frequency = 10

validation_accuracies = []
for seed in range(5):
    
    train_loader, val_loader = get_dataloader(cifar_data,cifar_data_val, n_workers, seed)
    print("Train set size:", len(train_loader))
    print("Validation set size:", len(val_loader))
    
    for epoch in range(n_epochs):
        print(f"Epoch {epoch}")
        average_loss = training_epoch(few_shot_classifier, train_loader, train_optimizer)

        # if epoch % validation_frequency == validation_frequency - 1:

            # model.eval()
            # validation_accuracy = evaluate(
            #     few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
            # )
            # # model.train()
            # if validation_accuracy > best_validation_accuracy:
            #     best_validation_accuracy = validation_accuracy
            #     best_state = copy.deepcopy(few_shot_classifier.state_dict())
            #     print(" found a new best model!")

            # tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

        tb_writer.add_scalar("Train/loss", average_loss, epoch)

        # Warn the scheduler that we did an epoch
        # so it knows when to decrease the learning rate
        train_scheduler.step()
        
    validation_accuracy = evaluate(
        few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
    )
    validation_accuracies.append(validation_accuracy)
    
print(
        f"Accuracy: {100* np.array(validation_accuracies).mean():.3f}% \u00B1 {100*np.array(validation_accuracies).std():.3f}%"
    )

50000
Train set size: 5
Validation set size: 200
Epoch 0


Training: 100%|██████████| 5/5 [00:06<00:00,  1.27s/it, loss=0.788]

Epoch 1



Training: 100%|██████████| 5/5 [00:00<00:00, 15.32it/s, loss=0.863]

Epoch 2



Training: 100%|██████████| 5/5 [00:00<00:00, 14.19it/s, loss=0.688]

Epoch 3



Training: 100%|██████████| 5/5 [00:00<00:00, 15.30it/s, loss=0.688]

Epoch 4



Training: 100%|██████████| 5/5 [00:00<00:00, 13.79it/s, loss=0.738]

Epoch 5



Training: 100%|██████████| 5/5 [00:00<00:00, 14.75it/s, loss=0.813]

Epoch 6



Training: 100%|██████████| 5/5 [00:00<00:00, 15.11it/s, loss=0.788]

Epoch 7



Training: 100%|██████████| 5/5 [00:00<00:00, 14.32it/s, loss=0.888]

Epoch 8



Training: 100%|██████████| 5/5 [00:00<00:00, 14.14it/s, loss=0.838]

Epoch 9



Training: 100%|██████████| 5/5 [00:00<00:00, 12.89it/s, loss=0.788]

Epoch 10



Training: 100%|██████████| 5/5 [00:00<00:00, 14.40it/s, loss=0.838]

Epoch 11



Training: 100%|██████████| 5/5 [00:00<00:00, 15.42it/s, loss=0.813]

Epoch 12



Training: 100%|██████████| 5/5 [00:00<00:00, 14.87it/s, loss=0.763]

Epoch 13



Training: 100%|██████████| 5/5 [00:00<00:00, 14.47it/s, loss=0.888]

Epoch 14



Training: 100%|██████████| 5/5 [00:00<00:00, 15.06it/s, loss=0.813]

Epoch 15



Training: 100%|██████████| 5/5 [00:00<00:00, 15.58it/s, loss=0.738]

Epoch 16



Training: 100%|██████████| 5/5 [00:00<00:00, 14.42it/s, loss=0.713]

Epoch 17



Training: 100%|██████████| 5/5 [00:00<00:00, 14.55it/s, loss=0.838]

Epoch 18



Training: 100%|██████████| 5/5 [00:00<00:00, 15.28it/s, loss=0.813]

Epoch 19



Training: 100%|██████████| 5/5 [00:00<00:00, 15.31it/s, loss=0.763]
Validation: 100%|██████████| 200/200 [00:03<00:00, 51.61it/s, accuracy=0.498]


50000
Train set size: 5
Validation set size: 200
Epoch 0


Training: 100%|██████████| 5/5 [00:00<00:00, 14.66it/s, loss=0.763]

Epoch 1



Training: 100%|██████████| 5/5 [00:00<00:00, 15.00it/s, loss=0.813]

Epoch 2



Training: 100%|██████████| 5/5 [00:00<00:00, 15.08it/s, loss=0.888]

Epoch 3



Training: 100%|██████████| 5/5 [00:00<00:00, 14.66it/s, loss=0.888]

Epoch 4



Training: 100%|██████████| 5/5 [00:00<00:00, 14.88it/s, loss=0.763]

Epoch 5



Training: 100%|██████████| 5/5 [00:00<00:00, 13.04it/s, loss=0.838]

Epoch 6



Training: 100%|██████████| 5/5 [00:00<00:00, 14.75it/s, loss=0.688]

Epoch 7



Training: 100%|██████████| 5/5 [00:00<00:00, 15.10it/s, loss=0.788]

Epoch 8



Training: 100%|██████████| 5/5 [00:00<00:00, 14.75it/s, loss=0.713]

Epoch 9



Training: 100%|██████████| 5/5 [00:00<00:00, 13.73it/s, loss=0.863]

Epoch 10



Training: 100%|██████████| 5/5 [00:00<00:00, 14.65it/s, loss=0.888]

Epoch 11



Training: 100%|██████████| 5/5 [00:00<00:00, 13.76it/s, loss=0.788]

Epoch 12



Training: 100%|██████████| 5/5 [00:00<00:00, 14.97it/s, loss=0.863]

Epoch 13



Training: 100%|██████████| 5/5 [00:00<00:00, 14.72it/s, loss=0.788]

Epoch 14



Training: 100%|██████████| 5/5 [00:00<00:00, 14.92it/s, loss=0.838]

Epoch 15



Training: 100%|██████████| 5/5 [00:00<00:00, 14.92it/s, loss=0.688]

Epoch 16



Training: 100%|██████████| 5/5 [00:00<00:00, 14.00it/s, loss=0.838]

Epoch 17



Training: 100%|██████████| 5/5 [00:00<00:00, 14.37it/s, loss=0.863]

Epoch 18



Training: 100%|██████████| 5/5 [00:00<00:00, 15.03it/s, loss=0.738]

Epoch 19



Training: 100%|██████████| 5/5 [00:00<00:00, 14.80it/s, loss=0.763]
Validation: 100%|██████████| 200/200 [00:04<00:00, 49.38it/s, accuracy=0.54] 


50000
Train set size: 5
Validation set size: 200
Epoch 0


Training: 100%|██████████| 5/5 [00:00<00:00, 13.67it/s, loss=0.938]

Epoch 1



Training: 100%|██████████| 5/5 [00:00<00:00, 14.02it/s, loss=0.713]

Epoch 2



Training: 100%|██████████| 5/5 [00:00<00:00, 14.76it/s, loss=0.788]

Epoch 3



Training: 100%|██████████| 5/5 [00:00<00:00, 14.92it/s, loss=0.838]

Epoch 4



Training: 100%|██████████| 5/5 [00:00<00:00, 14.93it/s, loss=0.763]

Epoch 5



Training: 100%|██████████| 5/5 [00:00<00:00, 15.01it/s, loss=0.813]

Epoch 6



Training: 100%|██████████| 5/5 [00:00<00:00, 14.57it/s, loss=0.788]

Epoch 7



Training: 100%|██████████| 5/5 [00:00<00:00, 15.14it/s, loss=0.813]

Epoch 8



Training: 100%|██████████| 5/5 [00:00<00:00, 15.11it/s, loss=0.888]

Epoch 9



Training: 100%|██████████| 5/5 [00:00<00:00, 14.64it/s, loss=0.713]

Epoch 10



Training: 100%|██████████| 5/5 [00:00<00:00, 14.46it/s, loss=0.738]

Epoch 11



Training: 100%|██████████| 5/5 [00:00<00:00, 13.96it/s, loss=0.938]

Epoch 12



Training: 100%|██████████| 5/5 [00:00<00:00, 14.42it/s, loss=0.788]

Epoch 13



Training: 100%|██████████| 5/5 [00:00<00:00, 14.67it/s, loss=0.788]

Epoch 14



Training: 100%|██████████| 5/5 [00:00<00:00, 14.94it/s, loss=0.888]

Epoch 15



Training: 100%|██████████| 5/5 [00:00<00:00, 13.11it/s, loss=0.838]

Epoch 16



Training: 100%|██████████| 5/5 [00:00<00:00, 13.45it/s, loss=0.738]

Epoch 17



Training: 100%|██████████| 5/5 [00:00<00:00, 14.40it/s, loss=0.838]

Epoch 18



Training: 100%|██████████| 5/5 [00:00<00:00, 15.20it/s, loss=0.813]

Epoch 19



Training: 100%|██████████| 5/5 [00:00<00:00, 14.41it/s, loss=0.888]
Validation: 100%|██████████| 200/200 [00:04<00:00, 49.37it/s, accuracy=0.526]


50000
Train set size: 5
Validation set size: 200
Epoch 0


Training: 100%|██████████| 5/5 [00:00<00:00, 14.20it/s, loss=0.788]

Epoch 1



Training: 100%|██████████| 5/5 [00:00<00:00, 14.68it/s, loss=0.788]

Epoch 2



Training: 100%|██████████| 5/5 [00:00<00:00, 12.69it/s, loss=0.738]

Epoch 3



Training: 100%|██████████| 5/5 [00:00<00:00, 14.56it/s, loss=0.813]

Epoch 4



Training: 100%|██████████| 5/5 [00:00<00:00, 13.34it/s, loss=0.838]

Epoch 5



Training: 100%|██████████| 5/5 [00:00<00:00, 13.90it/s, loss=0.663]

Epoch 6



Training: 100%|██████████| 5/5 [00:00<00:00, 14.55it/s, loss=0.713]

Epoch 7



Training: 100%|██████████| 5/5 [00:00<00:00, 12.87it/s, loss=0.813]

Epoch 8



Training: 100%|██████████| 5/5 [00:00<00:00, 14.51it/s, loss=0.763]

Epoch 9



Training: 100%|██████████| 5/5 [00:00<00:00, 12.78it/s, loss=0.738]

Epoch 10



Training: 100%|██████████| 5/5 [00:00<00:00, 13.78it/s, loss=0.838]

Epoch 11



Training: 100%|██████████| 5/5 [00:00<00:00, 14.46it/s, loss=0.863]

Epoch 12



Training: 100%|██████████| 5/5 [00:00<00:00, 13.42it/s, loss=0.738]

Epoch 13



Training: 100%|██████████| 5/5 [00:00<00:00, 14.55it/s, loss=0.838]

Epoch 14



Training: 100%|██████████| 5/5 [00:00<00:00, 12.78it/s, loss=0.838]

Epoch 15



Training: 100%|██████████| 5/5 [00:00<00:00, 12.60it/s, loss=0.838]

Epoch 16



Training: 100%|██████████| 5/5 [00:00<00:00, 11.66it/s, loss=0.788]


Epoch 17


Training: 100%|██████████| 5/5 [00:00<00:00, 13.83it/s, loss=0.713]

Epoch 18



Training: 100%|██████████| 5/5 [00:00<00:00, 12.96it/s, loss=0.838]

Epoch 19



Training: 100%|██████████| 5/5 [00:00<00:00, 12.48it/s, loss=0.738]
Validation: 100%|██████████| 200/200 [00:04<00:00, 46.50it/s, accuracy=0.503]

50000
Train set size: 5
Validation set size: 200
Epoch 0



Training: 100%|██████████| 5/5 [00:00<00:00, 12.83it/s, loss=0.738]

Epoch 1



Training: 100%|██████████| 5/5 [00:00<00:00, 12.15it/s, loss=0.788]


Epoch 2


Training: 100%|██████████| 5/5 [00:00<00:00, 12.67it/s, loss=0.738]

Epoch 3



Training: 100%|██████████| 5/5 [00:00<00:00, 14.20it/s, loss=0.838]

Epoch 4



Training: 100%|██████████| 5/5 [00:00<00:00, 11.51it/s, loss=0.888]

Epoch 5



Training: 100%|██████████| 5/5 [00:00<00:00, 14.66it/s, loss=0.788]

Epoch 6



Training: 100%|██████████| 5/5 [00:00<00:00, 13.78it/s, loss=0.738]


Epoch 7


Training: 100%|██████████| 5/5 [00:00<00:00, 12.55it/s, loss=0.863]

Epoch 8



Training: 100%|██████████| 5/5 [00:00<00:00, 12.59it/s, loss=0.788]

Epoch 9



Training: 100%|██████████| 5/5 [00:00<00:00, 13.95it/s, loss=0.838]

Epoch 10



Training: 100%|██████████| 5/5 [00:00<00:00, 14.18it/s, loss=0.888]

Epoch 11



Training: 100%|██████████| 5/5 [00:00<00:00, 13.60it/s, loss=0.713]

Epoch 12



Training: 100%|██████████| 5/5 [00:00<00:00, 13.07it/s, loss=0.688]

Epoch 13



Training: 100%|██████████| 5/5 [00:00<00:00, 12.78it/s, loss=0.838]

Epoch 14



Training: 100%|██████████| 5/5 [00:00<00:00, 14.50it/s, loss=0.738]

Epoch 15



Training: 100%|██████████| 5/5 [00:00<00:00, 14.38it/s, loss=0.738]

Epoch 16



Training: 100%|██████████| 5/5 [00:00<00:00, 14.12it/s, loss=0.738]

Epoch 17



Training: 100%|██████████| 5/5 [00:00<00:00, 14.69it/s, loss=0.738]

Epoch 18



Training: 100%|██████████| 5/5 [00:00<00:00, 14.14it/s, loss=0.738]

Epoch 19



Training: 100%|██████████| 5/5 [00:00<00:00, 13.68it/s, loss=0.788]
Validation: 100%|██████████| 200/200 [00:04<00:00, 48.56it/s, accuracy=0.534]


Accuracy: 52.037% ± 1.678%
