In [70]:
import torch

import matplotlib.pyplot as plt

from torchvision.datasets import Omniglot
from torchvision.transforms import ToTensor, Resize, Compose

from torch.utils.data import Dataset, DataLoader
import numpy as np

In [71]:
# transforms = Compose([Resize(28), ToTensor()])
# train_data = Omniglot('./datasets/omniglot', background=True, download=True, transform=transforms)

In [72]:
# transform data outside this class
class OmniglotDataset(Dataset):

    def __init__(self, background: bool, device):
        '''
        background: True = use background set, otherwise evaluation set
        '''
        self.device = device
        self.examples_per_char = 20
        self.ds = Omniglot(
            'datasets/omniglot',
            background=background,
            download=True,
            transform=Compose([Resize(28), ToTensor()])
        )

    def __len__(self):
        return int(len(self.ds) / self.examples_per_char)

    # each item is all images of a character (a class): there are 20 images per character and each image is (channel, height, width), so each item is (20, channel, height, width). Since all the images are the same character, the label is an integer, namely the index associated with this item.
    def __getitem__(self, i):
        a = i * self.examples_per_char
        b = a + self.examples_per_char
        index = torch.arange(a, b, 1).tolist()
        x = torch.cat([self.ds[j][0].unsqueeze(0) for j in index])
        return x.to(self.device), i


In [73]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [74]:
train_data = OmniglotDataset(background=True, device=device)

Files already downloaded and verified


In [82]:
print(len(train_data))
print(train_data[963])

964
(tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]]],


        [[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]]],


        [[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]]],


        ...,


        [[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..

In [75]:
train_dataloader = DataLoader(train_data, batch_size=128)

In [76]:
for X, y in train_dataloader:
  print("Shape of X: ", X.shape, X.dtype)
  print("Shape of y: ", y.shape, y.dtype)
  break


Shape of X:  torch.Size([128, 20, 1, 28, 28]) torch.float32
Shape of y:  torch.Size([128]) torch.int64


In [35]:
print(image1)
print(image2)
print(label)
print(image.shape)

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])
tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])
0
torch.Size([1, 105, 105])


In [34]:
print(np.array_equal(image1, image2))

False


In [37]:
print(len(train_data)/20)

964.0


In [38]:
torch.arange(0, 20, 1)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])

In [56]:
i = 0
a = i * 20
b = a + 50
index = torch.arange(a, b, 1).tolist()
x = torch.cat([train_data[j][0].unsqueeze(0) for j in index])
y = [train_data[j][1] for j in index]
print(y)
print(len(y))


[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
50
