# Train a model with Classical Training

Although episodic training has attracted a lot of interest in the early years of Few-Shot Learning research, more recent works suggest that competitive results can be achieved with a simple cross entropy loss across all training classes. Therefore, it is becoming more and more common to use this classical process to train the backbone, that will be common to all methods compared at test time.

This is in fact more representative of real use cases: episodic training assumes that, at training time, you have access to the shape of the few-shot tasks that will be encountered at test time (indeed you choose a specific number of ways for episodic training). You also "force" your inference method into the training of the network. Switching the few-shot learning logic to inference (i.e. no episodic training) allows methods to be agnostic of the backbone.

Nonetheless, if you need to perform episodic training, we also provide [an example notebook](episodic_training.ipynb) for that.

## Getting started
First we're going to do some imports (this is not the interesting part).

In [1]:
from pathlib import Path
import random
from statistics import mean

import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from torchvision import datasets, transforms

Then we're gonna do the most important thing in Machine Learning research: ensuring reproducibility by setting the random seed. We're going to set the seed for all random packages that we could possibly use, plus some other stuff to make CUDA deterministic (see [here](https://pytorch.org/docs/stable/notes/randomness.html)).

I strongly encourage that you do this in **all your scripts**.

In [2]:
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Then we're gonna create our data loader for the training set. You can see that I chose tu use CUB in this notebook, because it’s a small dataset, so we can have good results quite quickly. I set a batch size of 128 but feel free to adapt it to your constraints.

Note that we're not using the `TaskSampler` for the train data loader, because we won't be sampling training data in the shape of tasks as we would have in episodic training. We do it **normally**.

In [3]:
from torch.utils.data import DataLoader
from wrap_few_shot_dataset import WrapFewShotDataset

batch_size = 128
n_workers = 8

# Setup path to data folder
data_path = Path("data")
image_path = data_path / "UCMerced-Fewshot"

# Check if image folder exists
if image_path.is_dir():
    print(f"{image_path} directory exists.")
else:
    print(f"Did not find {image_path} directory")
    exit()

# Setup train and testing paths
train_dir = image_path / "Train"
val_dir = image_path / "Val"

transform = transforms.Compose([
    transforms.RandomResizedCrop(128),
    transforms.ToTensor()
])

train_set = datasets.ImageFolder(
    root=train_dir,
    transform=transform
)

val_set = datasets.ImageFolder(
    root=val_dir,
    transform=transform
)

train_set = WrapFewShotDataset(train_set)
val_set = WrapFewShotDataset(val_set)

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    num_workers=n_workers,
    pin_memory=True,
    shuffle=True,
)

data/UCMerced-Fewshot directory exists.


Scrolling dataset's labels...: 100%|██████████| 1000/1000 [00:00<00:00, 1011.91it/s]
Scrolling dataset's labels...: 100%|██████████| 500/500 [00:00<00:00, 1084.30it/s]


Now, we are going to create the model that we want to train. Here we choose the ResNet12 that is very often used in Few-Shot Learning research. Note that the default setting of these networks in EasyFSL is to not have a last fully connected layer (as it is usual for most Few-Shot Learning methods), but for classical training we need this layer! We also force it to output a vector which size is the number of different classes in the training set.

In [4]:
from modules.predefined_resnet import resnet12

DEVICE = "cpu"

model = resnet12(
    use_fc=True,
    num_classes=len(set(train_set.get_labels())),
).to(DEVICE)

Now, we still need validation ! Since we're training a model to perform few-shot classification, we will validate on few-shot tasks, so now we'll use the `TaskSampler`. We arbitrarily set the shape of the validation tasks. Ideally, you'd like to perform validation on various shapes of tasks, but we didn't implement this yet (feel free to contribute!).

We also need to define the few-shot classification method that we will use during validation of the neural network we're training.
Here we choose Prototypical Networks, because it's simple and efficient, but this is still an arbitrary choice.

In [5]:
from fewshot_sampler import FewShotSampler

n_way = 5
n_shot = 5
n_query = 10
n_validation_tasks = 50

val_sampler = FewShotSampler(
    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)
val_loader = DataLoader(
    val_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)

# Pick one of the following models:

In [None]:
from modules.prototypical import PrototypicalNetworks

few_shot_classifier = PrototypicalNetworks(model).to(DEVICE)

In [None]:
from modules.simple_shot import SimpleShot

few_shot_classifier = SimpleShot(model).to(DEVICE)

In [None]:


few_shot_classifier = 

## Training

Now let's define our training helpers ! I chose to use Stochastic Gradient Descent on 200 epochs with a scheduler that divides the learning rate by 10 after 120 and 160 epochs. The strategy is derived from [this repo](https://github.com/fiveai/on-episodes-fsl).

We're also gonna use a TensorBoard because it's always good to see what your training curves look like.

An other thing: we're doing 200 epochs like in [the episodic training notebook](notebooks/episodic_training.ipynb), but keep in mind that an epoch in classical training means one pass through the 6000 images of the dataset, while in episodic training it's an arbitrary number of episodes. In the episodic training notebook an epoch is 500 episodes of 5-way, 5-shot, 10-query tasks, so 37500 images. TL;DR you may want to monitor your training and increase the number of epochs if necessary.

In [6]:
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter


LOSS_FUNCTION = nn.CrossEntropyLoss()

n_epochs = 100
scheduler_milestones = [150, 180]
scheduler_gamma = 0.1
learning_rate = 1e-01
tb_logs_dir = Path(".")

train_optimizer = SGD(
    model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
)
train_scheduler = MultiStepLR(
    train_optimizer,
    milestones=scheduler_milestones,
    gamma=scheduler_gamma,
)

tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))

And now let's get to it! Here we define the function that performs a training epoch.

We use tqdm to monitor the training in real time in our logs.

In [7]:
def training_epoch(model_: nn.Module, data_loader: DataLoader, optimizer: Optimizer):
    all_loss = []
    model_.train()
    with tqdm(data_loader, total=len(data_loader), desc="Training") as tqdm_train:
        for images, labels in tqdm_train:
            optimizer.zero_grad()

            loss = LOSS_FUNCTION(model_(images.to(DEVICE)), labels.to(DEVICE))
            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

And we have everything we need! This is now the time to **start training**.

A few notes:

- We only validate every 10 epochs (you may set an even less frequent validation) because a training epoch is much faster than 500 few-shot tasks, and we don't want validation to be the bottleneck of our training process.

- I also added something to log the state of the model that gave the best performance on the validation set.

In [8]:
from fewshot_utils import evaluate


best_state = model.state_dict()
best_validation_accuracy = 0.0
validation_frequency = 10
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(model, train_loader, train_optimizer)

    if epoch % validation_frequency == validation_frequency - 1:

        # We use this very convenient method from EasyFSL's ResNet to specify
        # that the model shouldn't use its last fully connected layer during validation.
        model.set_use_fc(False)
        validation_accuracy = evaluate(
            few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
        )
        model.set_use_fc(True)

        if validation_accuracy > best_validation_accuracy:
            best_validation_accuracy = validation_accuracy
            best_state = model.state_dict()
            print("Ding ding ding! We found a new best model!")

        tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    tb_writer.add_scalar("Train/loss", average_loss, epoch)

    # Warn the scheduler that we did an epoch
    # so it knows when to decrease the learning rate
    train_scheduler.step()

Epoch 0


Training: 100%|██████████| 8/8 [03:57<00:00, 29.68s/it, loss=2.32]


Epoch 1


Training: 100%|██████████| 8/8 [03:57<00:00, 29.70s/it, loss=1.85]


Epoch 2


Training: 100%|██████████| 8/8 [03:59<00:00, 29.92s/it, loss=1.3] 


Epoch 3


Training: 100%|██████████| 8/8 [03:56<00:00, 29.55s/it, loss=1.26]


Epoch 4


Training: 100%|██████████| 8/8 [04:04<00:00, 30.61s/it, loss=1.28]


Epoch 5


Training: 100%|██████████| 8/8 [04:03<00:00, 30.49s/it, loss=1.06]


Epoch 6


Training: 100%|██████████| 8/8 [04:08<00:00, 31.12s/it, loss=0.942]


Epoch 7


Training: 100%|██████████| 8/8 [04:02<00:00, 30.31s/it, loss=0.866]


Epoch 8


Training: 100%|██████████| 8/8 [04:02<00:00, 30.26s/it, loss=0.808]


Epoch 9


Training: 100%|██████████| 8/8 [04:02<00:00, 30.26s/it, loss=0.832]
Validation: 100%|██████████| 50/50 [04:46<00:00,  5.73s/it, accuracy=0.38] 


Ding ding ding! We found a new best model!
Epoch 10


Training: 100%|██████████| 8/8 [04:09<00:00, 31.23s/it, loss=0.806]


Epoch 11


Training: 100%|██████████| 8/8 [04:01<00:00, 30.24s/it, loss=0.648]


Epoch 12


Training: 100%|██████████| 8/8 [04:01<00:00, 30.22s/it, loss=0.63] 


Epoch 13


Training: 100%|██████████| 8/8 [04:00<00:00, 30.02s/it, loss=0.651]


Epoch 14


Training: 100%|██████████| 8/8 [04:01<00:00, 30.21s/it, loss=0.654]


Epoch 15


Training: 100%|██████████| 8/8 [03:58<00:00, 29.81s/it, loss=0.624]


Epoch 16


Training: 100%|██████████| 8/8 [04:00<00:00, 30.11s/it, loss=0.576]


Epoch 17


Training: 100%|██████████| 8/8 [04:07<00:00, 30.89s/it, loss=0.524]


Epoch 18


Training: 100%|██████████| 8/8 [04:00<00:00, 30.01s/it, loss=0.592]


Epoch 19


Training: 100%|██████████| 8/8 [04:07<00:00, 30.92s/it, loss=0.528]
Validation: 100%|██████████| 50/50 [04:41<00:00,  5.62s/it, accuracy=0.45] 


Ding ding ding! We found a new best model!
Epoch 20


Training: 100%|██████████| 8/8 [03:56<00:00, 29.56s/it, loss=0.48] 


Epoch 21


Training: 100%|██████████| 8/8 [04:10<00:00, 31.34s/it, loss=0.507]


Epoch 22


Training: 100%|██████████| 8/8 [04:01<00:00, 30.15s/it, loss=0.561]


Epoch 23


Training: 100%|██████████| 8/8 [04:03<00:00, 30.40s/it, loss=0.485]


Epoch 24


Training: 100%|██████████| 8/8 [04:02<00:00, 30.36s/it, loss=0.441]


Epoch 25


Training: 100%|██████████| 8/8 [04:00<00:00, 30.03s/it, loss=0.427]


Epoch 26


Training: 100%|██████████| 8/8 [04:00<00:00, 30.08s/it, loss=0.443]


Epoch 27


Training: 100%|██████████| 8/8 [04:08<00:00, 31.07s/it, loss=0.472]


Epoch 28


Training: 100%|██████████| 8/8 [04:05<00:00, 30.65s/it, loss=0.427]


Epoch 29


Training: 100%|██████████| 8/8 [04:03<00:00, 30.46s/it, loss=0.42] 
Validation: 100%|██████████| 50/50 [04:46<00:00,  5.72s/it, accuracy=0.45] 


Epoch 30


Training: 100%|██████████| 8/8 [04:00<00:00, 30.08s/it, loss=0.349]


Epoch 31


Training: 100%|██████████| 8/8 [04:03<00:00, 30.48s/it, loss=0.362]


Epoch 32


Training: 100%|██████████| 8/8 [04:05<00:00, 30.64s/it, loss=0.425]


Epoch 33


Training: 100%|██████████| 8/8 [04:11<00:00, 31.47s/it, loss=0.345]


Epoch 34


Training: 100%|██████████| 8/8 [04:04<00:00, 30.57s/it, loss=0.393]


Epoch 35


Training: 100%|██████████| 8/8 [04:08<00:00, 31.04s/it, loss=0.342]


Epoch 36


Training: 100%|██████████| 8/8 [04:16<00:00, 32.10s/it, loss=0.357]


Epoch 37


Training: 100%|██████████| 8/8 [04:14<00:00, 31.81s/it, loss=0.317]


Epoch 38


Training: 100%|██████████| 8/8 [04:04<00:00, 30.52s/it, loss=0.313]


Epoch 39


Training: 100%|██████████| 8/8 [04:01<00:00, 30.22s/it, loss=0.342]
Validation: 100%|██████████| 50/50 [04:46<00:00,  5.73s/it, accuracy=0.509]


Ding ding ding! We found a new best model!
Epoch 40


Training: 100%|██████████| 8/8 [04:01<00:00, 30.18s/it, loss=0.286]


Epoch 41


Training: 100%|██████████| 8/8 [04:01<00:00, 30.23s/it, loss=0.334]


Epoch 42


Training: 100%|██████████| 8/8 [04:00<00:00, 30.09s/it, loss=0.269]


Epoch 43


Training: 100%|██████████| 8/8 [04:10<00:00, 31.25s/it, loss=0.313]


Epoch 44


Training: 100%|██████████| 8/8 [04:07<00:00, 30.97s/it, loss=0.302]


Epoch 45


Training: 100%|██████████| 8/8 [03:58<00:00, 29.79s/it, loss=0.269]


Epoch 46


Training: 100%|██████████| 8/8 [04:02<00:00, 30.33s/it, loss=0.283]


Epoch 47


Training: 100%|██████████| 8/8 [03:59<00:00, 29.96s/it, loss=0.204]


Epoch 48


Training: 100%|██████████| 8/8 [03:59<00:00, 29.88s/it, loss=0.26] 


Epoch 49


Training: 100%|██████████| 8/8 [03:59<00:00, 30.00s/it, loss=0.251]
Validation: 100%|██████████| 50/50 [04:45<00:00,  5.71s/it, accuracy=0.542]


Ding ding ding! We found a new best model!
Epoch 50


Training: 100%|██████████| 8/8 [04:06<00:00, 30.77s/it, loss=0.231]


Epoch 51


Training: 100%|██████████| 8/8 [04:02<00:00, 30.33s/it, loss=0.289]


Epoch 52


Training: 100%|██████████| 8/8 [04:03<00:00, 30.48s/it, loss=0.26] 


Epoch 53


Training: 100%|██████████| 8/8 [04:02<00:00, 30.35s/it, loss=0.253]


Epoch 54


Training: 100%|██████████| 8/8 [04:04<00:00, 30.53s/it, loss=0.261]


Epoch 55


Training: 100%|██████████| 8/8 [04:02<00:00, 30.31s/it, loss=0.259]


Epoch 56


Training: 100%|██████████| 8/8 [03:58<00:00, 29.84s/it, loss=0.201]


Epoch 57


Training: 100%|██████████| 8/8 [04:00<00:00, 30.00s/it, loss=0.208]


Epoch 58


Training: 100%|██████████| 8/8 [04:00<00:00, 30.03s/it, loss=0.209]


Epoch 59


Training: 100%|██████████| 8/8 [04:01<00:00, 30.21s/it, loss=0.21] 
Validation: 100%|██████████| 50/50 [04:43<00:00,  5.67s/it, accuracy=0.528]


Epoch 60


Training: 100%|██████████| 8/8 [04:04<00:00, 30.51s/it, loss=0.231]


Epoch 61


Training: 100%|██████████| 8/8 [04:01<00:00, 30.20s/it, loss=0.224]


Epoch 62


Training: 100%|██████████| 8/8 [04:01<00:00, 30.22s/it, loss=0.239]


Epoch 63


Training: 100%|██████████| 8/8 [04:03<00:00, 30.42s/it, loss=0.201]


Epoch 64


Training: 100%|██████████| 8/8 [04:03<00:00, 30.46s/it, loss=0.203]


Epoch 65


Training: 100%|██████████| 8/8 [04:03<00:00, 30.42s/it, loss=0.15]  


Epoch 66


Training: 100%|██████████| 8/8 [04:04<00:00, 30.54s/it, loss=0.196]


Epoch 67


Training: 100%|██████████| 8/8 [04:06<00:00, 30.80s/it, loss=0.178]


Epoch 68


Training: 100%|██████████| 8/8 [04:02<00:00, 30.27s/it, loss=0.152]


Epoch 69


Training: 100%|██████████| 8/8 [04:06<00:00, 30.87s/it, loss=0.182] 
Validation: 100%|██████████| 50/50 [04:39<00:00,  5.59s/it, accuracy=0.455]


Epoch 70


Training: 100%|██████████| 8/8 [03:57<00:00, 29.72s/it, loss=0.158]


Epoch 71


Training: 100%|██████████| 8/8 [04:07<00:00, 30.88s/it, loss=0.15] 


Epoch 72


Training: 100%|██████████| 8/8 [04:07<00:00, 30.89s/it, loss=0.197]


Epoch 73


Training: 100%|██████████| 8/8 [04:01<00:00, 30.21s/it, loss=0.238]


Epoch 74


Training: 100%|██████████| 8/8 [04:01<00:00, 30.16s/it, loss=0.215]


Epoch 75


Training: 100%|██████████| 8/8 [04:00<00:00, 30.01s/it, loss=0.18] 


Epoch 76


Training: 100%|██████████| 8/8 [03:57<00:00, 29.66s/it, loss=0.186]


Epoch 77


Training: 100%|██████████| 8/8 [03:59<00:00, 29.90s/it, loss=0.166]


Epoch 78


Training: 100%|██████████| 8/8 [03:58<00:00, 29.77s/it, loss=0.197]


Epoch 79


Training: 100%|██████████| 8/8 [03:59<00:00, 29.91s/it, loss=0.204]
Validation: 100%|██████████| 50/50 [04:44<00:00,  5.70s/it, accuracy=0.613]


Ding ding ding! We found a new best model!
Epoch 80


Training: 100%|██████████| 8/8 [04:06<00:00, 30.78s/it, loss=0.128] 


Epoch 81


Training: 100%|██████████| 8/8 [04:01<00:00, 30.15s/it, loss=0.129]


Epoch 82


Training: 100%|██████████| 8/8 [04:02<00:00, 30.31s/it, loss=0.171]


Epoch 83


Training: 100%|██████████| 8/8 [04:03<00:00, 30.47s/it, loss=0.188]


Epoch 84


Training: 100%|██████████| 8/8 [04:05<00:00, 30.66s/it, loss=0.184]


Epoch 85


Training: 100%|██████████| 8/8 [04:00<00:00, 30.08s/it, loss=0.15] 


Epoch 86


Training: 100%|██████████| 8/8 [04:04<00:00, 30.51s/it, loss=0.158]


Epoch 87


Training: 100%|██████████| 8/8 [04:04<00:00, 30.57s/it, loss=0.181]


Epoch 88


Training: 100%|██████████| 8/8 [04:00<00:00, 30.11s/it, loss=0.139] 


Epoch 89


Training: 100%|██████████| 8/8 [04:03<00:00, 30.44s/it, loss=0.155]
Validation: 100%|██████████| 50/50 [04:49<00:00,  5.78s/it, accuracy=0.627]


Ding ding ding! We found a new best model!
Epoch 90


Training: 100%|██████████| 8/8 [03:57<00:00, 29.70s/it, loss=0.168]


Epoch 91


Training: 100%|██████████| 8/8 [03:59<00:00, 29.97s/it, loss=0.135]


Epoch 92


Training: 100%|██████████| 8/8 [04:05<00:00, 30.73s/it, loss=0.143] 


Epoch 93


Training: 100%|██████████| 8/8 [03:58<00:00, 29.82s/it, loss=0.144]


Epoch 94


Training: 100%|██████████| 8/8 [04:03<00:00, 30.43s/it, loss=0.156]


Epoch 95


Training: 100%|██████████| 8/8 [04:01<00:00, 30.22s/it, loss=0.163] 


Epoch 96


Training: 100%|██████████| 8/8 [04:02<00:00, 30.32s/it, loss=0.15]  


Epoch 97


Training: 100%|██████████| 8/8 [04:03<00:00, 30.42s/it, loss=0.168]


Epoch 98


Training: 100%|██████████| 8/8 [04:03<00:00, 30.47s/it, loss=0.139] 


Epoch 99


Training: 100%|██████████| 8/8 [04:00<00:00, 30.00s/it, loss=0.128]
Validation: 100%|██████████| 50/50 [04:47<00:00,  5.75s/it, accuracy=0.499]


Yay we successfully performed Classical Training! Now if you want to you can retrieve the best model's state.

In [9]:
model.load_state_dict(best_state)

<All keys matched successfully>

In [23]:
torch.save(few_shot_classifier.state_dict(),"models/fewshot_merced_proto_scratch.pth")

## Evaluation

Now that our model is trained, we want to test it.

First step: we fetch the test data. Note that we'll evaluate on the same shape of tasks as in validation. This is malicious practice, because it means that we used *a priori* information about the evaluation tasks during training. This is still less malicious than episodic training, though.

In [12]:
n_test_tasks = 1000

test_dir = image_path / "Test"

# transform = transforms.Compose([
#     transforms.RandomResizedCrop(128),
#     transforms.ToTensor()
# ])

test_set = datasets.ImageFolder(
    root=test_dir,
    transform=transform
)

test_set = WrapFewShotDataset(test_set)

test_sampler = FewShotSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

Scrolling dataset's labels...: 100%|██████████| 900/900 [00:01<00:00, 694.93it/s]


Second step: we instantiate a few-shot classifier using our trained ResNet as backbone, and run it on the test data. We keep using Prototypical Networks for consistence, but at this point you could basically use any few-shot classifier that takes no additional trainable parameters.

Like we did during validation, we need to tell our ResNet to not use its last fully connected layer.

In [13]:
model.set_use_fc(False)

accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.2f} %")

 57%|█████▋    | 573/1000 [53:19<39:44,  5.58s/it, accuracy=0.486]  


KeyboardInterrupt: 

In [27]:
from sklearn.metrics import confusion_matrix
import pandas as pd
import  numpy as np
import matplotlib.pyplot as plt
import seaborn as sn

def eval_one_task(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
    class_ids: torch.Tensor
):
    """
    Returns the number of correct predictions of query labels, and the total number of predictions.
    """

    few_shot_classifier.process_support_set(support_images, support_labels)

    print("hi")
    print(type(few_shot_classifier))
    predictions = few_shot_classifier(query_images).detach().data
    print(type(predictions),predictions.shape)
    
    pred_labels = []
    class_labels = []

    for i, label in enumerate(query_labels):
        pred_labels.append(class_ids[predictions[i]])
        class_labels.append(class_ids[label])

    return pred_labels, class_labels


def get_data(data_loader: DataLoader):
    # We'll count everything and compute the ratio at the end
    pred = []
    true = []

    device = "cpu"

    # eval mode affects the behaviour of some layers (such as batch normalization or dropout)
    # no_grad() tells torch not to keep in memory the whole computational graph (it's more lightweight this way)
    model.eval()

    with torch.no_grad():
        with tqdm(
            enumerate(data_loader),
            total=len(data_loader),
            desc="getting data",
        ) as tqdm_eval:
            for _, (
                support_images,
                support_labels,
                query_images,
                query_labels,
                _,
            ) in tqdm_eval:
                correct, total = eval_one_task(
                    model,
                    support_images.to(device),
                    support_labels.to(device),
                    query_images.to(device),
                    query_labels.to(device),
                )

        # for (
        #     support_images,
        #     support_labels,
        #     query_images,
        #     query_labels,
        #     class_ids,
        # ) in tqdm(data_loader, total=len(data_loader)):

        #     correct, total = get_labels(
        #         query_images, query_labels, class_ids
        #     )

            pred.extend(correct)
            true.extend(total)

    return pred, true

# generate confusion matrix
classes = ('buildings', 'chaparral', 'denseresidential', 'intersection', 'mediumresidential',
    'mobilehomepark', 'sparseresidential', 'storagetanks', 'tenniscourt')

pred, true = get_data(test_loader)

cf_matrix = confusion_matrix(true,pred)
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index = [i for i in classes],
                    columns = [i for i in classes])

plt.figure(figsize = (12,7))
sn.heatmap(df_cm, annot=True)
plt.savefig('proto.png')

<class 'modules.resnet.ResNet'>


getting data:   0%|          | 0/1000 [00:00<?, ?it/s]
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f1dd787be20>
Traceback (most recent call last):
  File "/home/eileen/miniconda3/envs/tiles/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/home/eileen/miniconda3/envs/tiles/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1442, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/eileen/miniconda3/envs/tiles/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/home/eileen/miniconda3/envs/tiles/lib/python3.10/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/home/eileen/miniconda3/envs/tiles/lib/python3.10/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/home/eileen/miniconda3/envs/ti