Import libraries

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Omniglot
from torchvision.models import resnet18
from tqdm import tqdm
from fewshot_sampler import FewshotSampler

Load training and test sets, then split training set into training and validation sets

In [11]:
image_size = 28

# NB: background=True selects the train set, background=False selects the test set
# It's the nomenclature from the original paper, we just have to deal with it

train_set = Omniglot(
    root="./data",
    background=True,
    transform=transforms.Compose(
        [
            transforms.Grayscale(num_output_channels=3),
            transforms.RandomResizedCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    ),
    download=True,
)
test_set = Omniglot(
    root="./data",
    background=False,
    transform=transforms.Compose(
        [
            # Omniglot images have 1 channel, but our model will expect 3-channel images
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize([int(image_size * 1.15), int(image_size * 1.15)]),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ]
    ),
    download=True,
)

train_set_percent = 0.8
train_set_size = int(0.8*len(train_set))
val_set_size = len(train_set) - train_set_size
train_set, val_set = torch.utils.data.random_split(train_set,(train_set_size,val_set_size))

Files already downloaded and verified
Files already downloaded and verified


Look at image and set sizes

In [14]:
img, label = train_set[0]

print("Image size: ", img.shape)
print("Label: ", label)

print("Training set: ", len(train_set))
print("Testing set: ", len(test_set))
print("Validation set: ", len(val_set))

Image size:  torch.Size([3, 28, 28])
Label:  339
Training set:  15424
Testing set:  13180
Validation set:  3856
