# Train a model with Classical Training

Although episodic training has attracted a lot of interest in the early years of Few-Shot Learning research, more recent works suggest that competitive results can be achieved with a simple cross entropy loss across all training classes. Therefore, it is becoming more and more common to use this classical process to train the backbone, that will be common to all methods compared at test time.

This is in fact more representative of real use cases: episodic training assumes that, at training time, you have access to the shape of the few-shot tasks that will be encountered at test time (indeed you choose a specific number of ways for episodic training). You also "force" your inference method into the training of the network. Switching the few-shot learning logic to inference (i.e. no episodic training) allows methods to be agnostic of the backbone.

Nonetheless, if you need to perform episodic training, we also provide [an example notebook](episodic_training.ipynb) for that.

## Getting started
First we're going to do some imports (this is not the interesting part).

In [30]:
# Ensure working directory is the project's root

%cd E:\mit\masterarbeit\easy-few-shot-learning-master



from pathlib import Path
import random
from statistics import mean
from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn
from tqdm import tqdm

print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())

E:\mit\masterarbeit\easy-few-shot-learning-master
1.12.1
11.3
8302


Then we're gonna do the most important thing in Machine Learning research: ensuring reproducibility by setting the random seed. We're going to set the seed for all random packages that we could possibly use, plus some other stuff to make CUDA deterministic (see [here](https://pytorch.org/docs/stable/notes/randomness.html)).

I strongly encourage that you do this in **all your scripts**.

In [31]:
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Then we're gonna create our data loader for the training set. You can see that I chose tu use CUB in this notebook, because it’s a small dataset, so we can have good results quite quickly. I set a batch size of 128 but feel free to adapt it to your constraints.

Note that we're not using the `TaskSampler` for the train data loader, because we won't be sampling training data in the shape of tasks as we would have in episodic training. We do it **normally**.

In [32]:
from easyfsl.datasets import easy_set
from torch.utils.data import DataLoader

batch_size = 8
n_workers = 0

train_set = easy_set.EasySet(specs_file = "./data/current spe/train.json",
                             # image_size = 360 * 360,
                             # transform = None,
                             training = False
                             )
train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    num_workers=n_workers,
    pin_memory=True,
    shuffle=True,
)

Now, we are going to create the model that we want to train. Here we choose the ResNet12 that is very often used in Few-Shot Learning research. Note that the default setting of these networks in EasyFSL is to not have a last fully connected layer (as it is usual for most Few-Shot Learning methods), but for classical training we need this layer! We also force it to output a vector which size is the number of different classes in the training set.

In [33]:
from easyfsl.modules import resnet12

DEVICE = "cuda"

model = resnet12(
).to(DEVICE)


Now, we still need validation ! Since we're training a model to perform few-shot classification, we will validate on few-shot tasks, so now we'll use the `TaskSampler`. We arbitrarily set the shape of the validation tasks. Ideally, you'd like to perform validation on various shapes of tasks, but we didn't implement this yet (feel free to contribute!).

We also need to define the few-shot classification method that we will use during validation of the neural network we're training.
Here we choose Prototypical Networks, because it's simple and efficient, but this is still an arbitrary choice.

In [34]:
from easyfsl.methods import Finetune
from easyfsl.samplers import TaskSampler

n_way = 2
n_shot = 20
n_query = 1
n_validation_tasks = 10

val_set = easy_set.EasySet(specs_file = "./data/current spe/val.json",
                           # image_size = 360 * 360,
                           # transform = None,
                           training = False
                           )
val_sampler = TaskSampler(
    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)
val_loader = DataLoader(
    val_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)

few_shot_classifier = Finetune(model).to(DEVICE)

## Training

Now let's define our training helpers ! I chose to use Stochastic Gradient Descent on 200 epochs with a scheduler that divides the learning rate by 10 after 120 and 160 epochs. The strategy is derived from [this repo](https://github.com/fiveai/on-episodes-fsl).

We're also gonna use a TensorBoard because it's always good to see what your training curves look like.

An other thing: we're doing 200 epochs like in [the episodic training notebook](notebooks/episodic_training.ipynb), but keep in mind that an epoch in classical training means one pass through the 6000 images of the dataset, while in episodic training it's an arbitrary number of episodes. In the episodic training notebook an epoch is 500 episodes of 5-way, 5-shot, 10-query tasks, so 37500 images. TL;DR you may want to monitor your training and increase the number of epochs if necessary.

In [35]:
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter


LOSS_FUNCTION = nn.CrossEntropyLoss()

n_epochs = 50
scheduler_milestones = [40, 48]
scheduler_gamma = 0.1
learning_rate = 0.01
tb_logs_dir = Path("./events/")

train_optimizer = SGD(
    model.parameters(), lr=learning_rate,
    momentum=0.9, weight_decay=5e-4
)
train_scheduler = MultiStepLR(
    train_optimizer,
    milestones=scheduler_milestones,
    gamma=scheduler_gamma,
)

tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))

And now let's get to it! Here we define the function that performs a training epoch.

We use tqdm to monitor the training in real time in our logs.

In [36]:
def training_epoch(model_: nn.Module, data_loader: DataLoader, optimizer: Optimizer):
    all_loss = []
    model_.train()
    with tqdm(data_loader, total=len(data_loader), desc="Training") as tqdm_train:
        for images, labels in tqdm_train:
            optimizer.zero_grad()

            loss = LOSS_FUNCTION(model_(images.to(DEVICE)), labels.to(DEVICE))
            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

And we have everything we need! This is now the time to **start training**.

A few notes:

- We only validate every 10 epochs (you may set an even less frequent validation) because a training epoch is much faster than 500 few-shot tasks, and we don't want validation to be the bottleneck of our training process.

- I also added something to log the state of the model that gave the best performance on the validation set.

In [37]:
from easyfsl.methods.utils import evaluate, evaluate2


best_state = model.state_dict()
best_validation_accuracy = 0.0
validation_frequency = 5
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(model, train_loader, train_optimizer)

    if epoch % validation_frequency == validation_frequency - 1:

        # We use this very convenient method from EasyFSL's ResNet to specify
        # that the model shouldn't use its last fully connected layer during validation.
        model.set_use_fc(False)
        validation_accuracy = evaluate(
            few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
        )
        model.set_use_fc(True)

        if validation_accuracy > best_validation_accuracy:
            best_validation_accuracy = validation_accuracy
            best_state = model.state_dict()
            print("Ding ding ding! We found a new best model!")

        tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    tb_writer.add_scalar("Train/loss", average_loss, epoch)

# plt.plot(n_epochs, average_loss, 'b', label='Training Loss')
# Warn the scheduler that we did an epoch
# so it knows when to decrease the learning rate
    train_scheduler.step()



Epoch 0


Training: 100%|██████████| 6/6 [00:02<00:00,  2.99it/s, loss=6.21]


Epoch 1


Training: 100%|██████████| 6/6 [00:01<00:00,  3.78it/s, loss=5.9] 


Epoch 2


Training: 100%|██████████| 6/6 [00:01<00:00,  3.76it/s, loss=5.63]


Epoch 3


Training: 100%|██████████| 6/6 [00:01<00:00,  3.77it/s, loss=5.35]


Epoch 4


Training: 100%|██████████| 6/6 [00:01<00:00,  3.73it/s, loss=5.03]
Validation: 100%|██████████| 10/10 [00:06<00:00,  1.64it/s, accuracy=0.75]


Ding ding ding! We found a new best model!
Epoch 5


Training: 100%|██████████| 6/6 [00:01<00:00,  3.76it/s, loss=0.768]


Epoch 6


Training: 100%|██████████| 6/6 [00:01<00:00,  3.77it/s, loss=0.726]


Epoch 7


Training: 100%|██████████| 6/6 [00:01<00:00,  3.73it/s, loss=0.636]


Epoch 8


Training: 100%|██████████| 6/6 [00:01<00:00,  3.74it/s, loss=0.665]


Epoch 9


Training: 100%|██████████| 6/6 [00:01<00:00,  3.71it/s, loss=0.549]
Validation: 100%|██████████| 10/10 [00:06<00:00,  1.59it/s, accuracy=0.7] 


Epoch 10


Training: 100%|██████████| 6/6 [00:01<00:00,  3.74it/s, loss=0.501]


Epoch 11


Training: 100%|██████████| 6/6 [00:01<00:00,  3.72it/s, loss=0.61] 


Epoch 12


Training: 100%|██████████| 6/6 [00:01<00:00,  3.78it/s, loss=0.51] 


Epoch 13


Training: 100%|██████████| 6/6 [00:01<00:00,  3.73it/s, loss=0.513]


Epoch 14


Training: 100%|██████████| 6/6 [00:01<00:00,  3.72it/s, loss=0.472]
Validation: 100%|██████████| 10/10 [00:06<00:00,  1.61it/s, accuracy=0.85]


Ding ding ding! We found a new best model!
Epoch 15


Training: 100%|██████████| 6/6 [00:01<00:00,  3.75it/s, loss=0.374]


Epoch 16


Training: 100%|██████████| 6/6 [00:01<00:00,  3.78it/s, loss=0.301]


Epoch 17


Training: 100%|██████████| 6/6 [00:01<00:00,  3.76it/s, loss=0.424]


Epoch 18


Training: 100%|██████████| 6/6 [00:01<00:00,  3.80it/s, loss=0.725]


Epoch 19


Training: 100%|██████████| 6/6 [00:01<00:00,  3.79it/s, loss=0.911]
Validation: 100%|██████████| 10/10 [00:06<00:00,  1.65it/s, accuracy=0.7] 


Epoch 20


Training: 100%|██████████| 6/6 [00:01<00:00,  3.63it/s, loss=0.549]


Epoch 21


Training: 100%|██████████| 6/6 [00:01<00:00,  3.50it/s, loss=0.281]


Epoch 22


Training: 100%|██████████| 6/6 [00:01<00:00,  3.77it/s, loss=0.198]


Epoch 23


Training: 100%|██████████| 6/6 [00:01<00:00,  3.79it/s, loss=0.228]


Epoch 24


Training: 100%|██████████| 6/6 [00:01<00:00,  3.75it/s, loss=0.115]
Validation: 100%|██████████| 10/10 [00:06<00:00,  1.65it/s, accuracy=1]


Ding ding ding! We found a new best model!
Epoch 25


Training: 100%|██████████| 6/6 [00:01<00:00,  3.80it/s, loss=0.307]


Epoch 26


Training: 100%|██████████| 6/6 [00:01<00:00,  3.80it/s, loss=0.0673]


Epoch 27


Training: 100%|██████████| 6/6 [00:01<00:00,  3.80it/s, loss=0.13]  


Epoch 28


Training: 100%|██████████| 6/6 [00:01<00:00,  3.82it/s, loss=0.0566]


Epoch 29


Training: 100%|██████████| 6/6 [00:01<00:00,  3.77it/s, loss=0.0647]
Validation: 100%|██████████| 10/10 [00:06<00:00,  1.64it/s, accuracy=1]


Epoch 30


Training: 100%|██████████| 6/6 [00:01<00:00,  3.80it/s, loss=0.0496]


Epoch 31


Training: 100%|██████████| 6/6 [00:01<00:00,  3.78it/s, loss=0.0431] 


Epoch 32


Training: 100%|██████████| 6/6 [00:01<00:00,  3.80it/s, loss=0.144] 


Epoch 33


Training: 100%|██████████| 6/6 [00:01<00:00,  3.78it/s, loss=0.45] 


Epoch 34


Training: 100%|██████████| 6/6 [00:01<00:00,  3.77it/s, loss=0.621]
Validation: 100%|██████████| 10/10 [00:06<00:00,  1.58it/s, accuracy=1]


Epoch 35


Training: 100%|██████████| 6/6 [00:01<00:00,  3.78it/s, loss=0.477]


Epoch 36


Training: 100%|██████████| 6/6 [00:01<00:00,  3.80it/s, loss=0.528] 


Epoch 37


Training: 100%|██████████| 6/6 [00:01<00:00,  3.79it/s, loss=0.372]


Epoch 38


Training: 100%|██████████| 6/6 [00:01<00:00,  3.78it/s, loss=0.126] 


Epoch 39


Training: 100%|██████████| 6/6 [00:01<00:00,  3.80it/s, loss=0.382]
Validation: 100%|██████████| 10/10 [00:06<00:00,  1.60it/s, accuracy=1]


Epoch 40


Training: 100%|██████████| 6/6 [00:01<00:00,  3.68it/s, loss=0.271] 


Epoch 41


Training: 100%|██████████| 6/6 [00:01<00:00,  3.71it/s, loss=0.0671]


Epoch 42


Training: 100%|██████████| 6/6 [00:01<00:00,  3.72it/s, loss=0.0397]


Epoch 43


Training: 100%|██████████| 6/6 [00:01<00:00,  3.62it/s, loss=0.0393]


Epoch 44


Training: 100%|██████████| 6/6 [00:01<00:00,  3.73it/s, loss=0.107] 
Validation: 100%|██████████| 10/10 [00:06<00:00,  1.60it/s, accuracy=1]


Epoch 45


Training: 100%|██████████| 6/6 [00:01<00:00,  3.74it/s, loss=0.0762] 


Epoch 46


Training: 100%|██████████| 6/6 [00:01<00:00,  3.74it/s, loss=0.0147]


Epoch 47


Training: 100%|██████████| 6/6 [00:01<00:00,  3.77it/s, loss=0.0232]


Epoch 48


Training: 100%|██████████| 6/6 [00:01<00:00,  3.73it/s, loss=0.0548]


Epoch 49


Training: 100%|██████████| 6/6 [00:01<00:00,  3.78it/s, loss=0.0236] 
Validation: 100%|██████████| 10/10 [00:06<00:00,  1.66it/s, accuracy=1]


Yay we successfully performed Classical Training! Now if you want to you can retrieve the best model's state.

In [38]:
model.load_state_dict(best_state)

<All keys matched successfully>

## Evaluation

Now that our model is trained, we want to test it.

First step: we fetch the test data. Note that we'll evaluate on the same shape of tasks as in validation. This is malicious practice, because it means that we used *a priori* information about the evaluation tasks during training. This is still less malicious than episodic training, though.

In [39]:
n_test_tasks = 15

test_set = easy_set.EasySet(specs_file = "./data/current spe/test.json",
                            # image_size = 360 * 360,
                            # transform = None,
                            training = False
                            )
test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

Second step: we instantiate a few-shot classifier using our trained ResNet as backbone, and run it on the test data. We keep using Prototypical Networks for consistence, but at this point you could basically use any few-shot classifier that takes no additional trainable parameters.

Like we did during validation, we need to tell our ResNet to not use its last fully connected layer.

In [40]:
import seaborn as sns
model.set_use_fc(False)
TP, total = evaluate2(few_shot_classifier, test_loader, device=DEVICE)


# mythreshold = 0.4
# y_pred = (model.predict(few_shot_classifier) >= mythreshold).astype(int)



100%|██████████| 15/15 [00:09<00:00,  1.63it/s, accuracy=0.633]
