In [2]:
import torch

import matplotlib.pyplot as plt

from torchvision.datasets import Omniglot
from torchvision.transforms import ToTensor, Resize, Compose

from torch.utils.data import Dataset, DataLoader
import numpy as np

In [3]:
# transforms = Compose([Resize(28), ToTensor()])
# train_data = Omniglot('./datasets/omniglot', background=True, download=True, transform=transforms)

In [4]:
# transform data outside this class
class OmniglotDataset(Dataset):

    def __init__(self, background: bool, device):
        '''
        background: True = use background set, otherwise evaluation set
        '''
        self.device = device
        self.examples_per_char = 20
        self.ds = Omniglot(
            'datasets/omniglot',
            background=background,
            download=True,
            transform=Compose([Resize(28), ToTensor()])
        )

    def __len__(self):
        return int(len(self.ds) / self.examples_per_char)

    # each item is all images of a character (a class): there are 20 images per character and each image is (channel, height, width), so each item is (20, channel, height, width). Since all the images are the same character, the label is an integer, namely the index associated with this item.
    def __getitem__(self, i):
        a = i * self.examples_per_char
        b = a + self.examples_per_char
        index = torch.arange(a, b, 1).tolist()
        x = torch.cat([self.ds[j][0].unsqueeze(0) for j in index])
        return x.to(self.device), i


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
train_data = OmniglotDataset(background=True, device=device)
test_data = OmniglotDataset(background=False, device=device)

Files already downloaded and verified
Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip to datasets/omniglot/omniglot-py/images_evaluation.zip


6463488it [00:00, 18708771.83it/s]                             


Extracting datasets/omniglot/omniglot-py/images_evaluation.zip to datasets/omniglot/omniglot-py


In [7]:
print(len(train_data))
print(len(test_data))

964
659


In [75]:
train_dataloader = DataLoader(train_data, batch_size=128)

In [76]:
for X, y in train_dataloader:
  print("Shape of X: ", X.shape, X.dtype)
  print("Shape of y: ", y.shape, y.dtype)
  break


Shape of X:  torch.Size([128, 20, 1, 28, 28]) torch.float32
Shape of y:  torch.Size([128]) torch.int64


In [35]:
print(image1)
print(image2)
print(label)
print(image.shape)

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])
tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])
0
torch.Size([1, 105, 105])


In [34]:
print(np.array_equal(image1, image2))

False


In [37]:
print(len(train_data)/20)

964.0


In [38]:
torch.arange(0, 20, 1)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])

In [56]:
i = 0
a = i * 20
b = a + 50
index = torch.arange(a, b, 1).tolist()
x = torch.cat([train_data[j][0].unsqueeze(0) for j in index])
y = [train_data[j][1] for j in index]
print(y)
print(len(y))


[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
50


In [27]:
arr1 = torch.rand(3, 5, 1, 3, 3)
print(arr1)
arr1 = arr1.view(-1, 1, 3, 3)
print(arr1)

tensor([[[[[8.0997e-02, 4.2193e-01, 8.1334e-01],
           [1.2969e-01, 4.2988e-01, 1.2039e-01],
           [4.6172e-01, 7.4664e-01, 1.8254e-02]]],


         [[[5.8016e-01, 9.5926e-01, 9.3822e-01],
           [1.0973e-01, 1.3439e-01, 4.8388e-02],
           [2.8419e-01, 7.5943e-01, 9.1863e-01]]],


         [[[6.4715e-01, 7.2249e-01, 8.6136e-01],
           [9.3812e-01, 6.0957e-01, 8.1645e-01],
           [7.4185e-01, 8.6372e-01, 6.5094e-01]]],


         [[[4.8270e-01, 9.3351e-02, 2.8065e-02],
           [7.4571e-01, 2.9880e-01, 6.0708e-01],
           [8.4472e-01, 2.2107e-01, 6.6672e-01]]],


         [[[6.5210e-01, 8.3995e-01, 3.4862e-01],
           [5.9103e-01, 3.1147e-01, 9.2754e-01],
           [2.8013e-01, 5.2364e-01, 3.7105e-01]]]],



        [[[[5.1782e-01, 7.9067e-01, 6.5458e-01],
           [9.0733e-01, 9.8703e-01, 3.1219e-01],
           [1.0280e-01, 4.7866e-01, 7.8595e-01]]],


         [[[9.1416e-01, 5.3393e-02, 7.4905e-01],
           [4.3032e-02, 1.1632e-01, 2.1146e

In [24]:
print(arr1.shape)

torch.Size([3, 5, 1, 3, 3])


In [26]:
arr1 = arr1.view(-1, 1, 3, 3)
print(arr1)
print(arr1.shape)

tensor([[[[0.8254, 0.2062, 0.4414],
          [0.1693, 0.1156, 0.1034],
          [0.3556, 0.7355, 0.5341]]],


        [[[0.1149, 0.9736, 0.4648],
          [0.5248, 0.6283, 0.2747],
          [0.2308, 0.8994, 0.9189]]],


        [[[0.6436, 0.6326, 0.8499],
          [0.5793, 0.1320, 0.5639],
          [0.1219, 0.1503, 0.1626]]],


        [[[0.4525, 0.4057, 0.9199],
          [0.6694, 0.0708, 0.6379],
          [0.5464, 0.3954, 0.6731]]],


        [[[0.3775, 0.9449, 0.1084],
          [0.0610, 0.1772, 0.7075],
          [0.9018, 0.3233, 0.3586]]],


        [[[0.4487, 0.1838, 0.1892],
          [0.0532, 0.5683, 0.5177],
          [0.1140, 0.5543, 0.2802]]],


        [[[0.0739, 0.5391, 0.3034],
          [0.6709, 0.8047, 0.9850],
          [0.9480, 0.0222, 0.2088]]],


        [[[0.6417, 0.3245, 0.5681],
          [0.8841, 0.1992, 0.2314],
          [0.8945, 0.5146, 0.5665]]],


        [[[0.5854, 0.7760, 0.8065],
          [0.4896, 0.7638, 0.5972],
          [0.9444, 0.9582, 0.098

In [28]:
arr2 = torch.rand(2, 3, 2, 2)

In [36]:
print(arr2)
print()
print(arr2.repeat(4, 1, 1, 1, 1))

tensor([[[[0.1456, 0.3708],
          [0.7828, 0.9570]],

         [[0.3080, 0.5535],
          [0.8589, 0.3495]],

         [[0.9205, 0.7702],
          [0.9663, 0.1048]]],


        [[[0.4162, 0.1656],
          [0.7065, 0.3462]],

         [[0.4843, 0.5078],
          [0.1471, 0.3288]],

         [[0.9974, 0.5821],
          [0.5965, 0.6928]]]])

tensor([[[[[0.1456, 0.3708],
           [0.7828, 0.9570]],

          [[0.3080, 0.5535],
           [0.8589, 0.3495]],

          [[0.9205, 0.7702],
           [0.9663, 0.1048]]],


         [[[0.4162, 0.1656],
           [0.7065, 0.3462]],

          [[0.4843, 0.5078],
           [0.1471, 0.3288]],

          [[0.9974, 0.5821],
           [0.5965, 0.6928]]]],



        [[[[0.1456, 0.3708],
           [0.7828, 0.9570]],

          [[0.3080, 0.5535],
           [0.8589, 0.3495]],

          [[0.9205, 0.7702],
           [0.9663, 0.1048]]],


         [[[0.4162, 0.1656],
           [0.7065, 0.3462]],

          [[0.4843, 0.5078],
           

In [31]:
print(arr2.shape)

torch.Size([2, 3, 2, 2])


In [38]:
k = 5
n = 4
y = torch.eye(k).repeat_interleave(n, dim=0)
print(y)

tensor([[1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.]])


In [39]:
target_labels = y.argmax(dim=1)
print(target_labels)

tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4])


In [40]:
print(len(target_labels))

20


In [41]:
print(y.shape)

torch.Size([20, 5])


In [7]:
class_folders = [314, 32, 562, 171, 5]
labels = np.array(range(5))
print(labels)
print()
labels = dict(zip(class_folders, labels))
print(labels)

[0 1 2 3 4]

{314: 0, 32: 1, 562: 2, 171: 3, 5: 4}


In [16]:
torch.eye(5).repeat_interleave(4, dim=0)

tensor([[1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.]])

In [10]:
arr1 = torch.rand(3, 1, 2, 3, 3)

In [11]:
print(arr1)

tensor([[[[[0.8332, 0.3807, 0.6383],
           [0.4811, 0.3465, 0.3542],
           [0.5666, 0.5240, 0.7081]],

          [[0.4694, 0.3827, 0.0848],
           [0.6445, 0.7771, 0.5237],
           [0.3178, 0.1932, 0.4129]]]],



        [[[[0.3874, 0.4644, 0.1771],
           [0.9633, 0.7331, 0.5700],
           [0.3572, 0.6973, 0.8108]],

          [[0.5166, 0.0056, 0.2524],
           [0.6644, 0.5498, 0.4049],
           [0.8729, 0.2930, 0.5789]]]],



        [[[[0.4650, 0.3470, 0.3959],
           [0.9097, 0.3673, 0.6685],
           [0.4904, 0.6084, 0.5545]],

          [[0.7985, 0.8934, 0.4959],
           [0.1574, 0.9263, 0.3108],
           [0.7100, 0.8997, 0.4894]]]]])


In [12]:
arr1 = torch.sum(arr1, 1)
print(arr1)

tensor([[[[0.8332, 0.3807, 0.6383],
          [0.4811, 0.3465, 0.3542],
          [0.5666, 0.5240, 0.7081]],

         [[0.4694, 0.3827, 0.0848],
          [0.6445, 0.7771, 0.5237],
          [0.3178, 0.1932, 0.4129]]],


        [[[0.3874, 0.4644, 0.1771],
          [0.9633, 0.7331, 0.5700],
          [0.3572, 0.6973, 0.8108]],

         [[0.5166, 0.0056, 0.2524],
          [0.6644, 0.5498, 0.4049],
          [0.8729, 0.2930, 0.5789]]],


        [[[0.4650, 0.3470, 0.3959],
          [0.9097, 0.3673, 0.6685],
          [0.4904, 0.6084, 0.5545]],

         [[0.7985, 0.8934, 0.4959],
          [0.1574, 0.9263, 0.3108],
          [0.7100, 0.8997, 0.4894]]]])
