In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(777)

if device == 'cuda':
    torch.cuda.manual_seed_all(777)

In [7]:
learning_rate = 0.001
training_epochs = 20
batch_size = 2571
img_size = 28, 28

transform = transforms.Compose(
    [transforms.Resize(img_size),transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.ImageFolder(root='./xray_data', transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.ImageFolder(root='./xray_data', transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

In [8]:
for x, y in trainloader:
    print(x, y)
    break

tensor([[[[-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
          [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000,  ..., -0.9765, -0.9922, -1.0000],
          ...,
          [-0.9922, -0.6941, -0.1373,  ...,  0.3569,  0.0039, -0.5686],
          [-0.9922, -0.6784, -0.0902,  ...,  0.3412,  0.0980, -0.4745],
          [-0.9765, -0.6235, -0.0353,  ...,  0.3647,  0.1686, -0.3020]],

         [[-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
          [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000,  ..., -0.9765, -0.9922, -1.0000],
          ...,
          [-0.9922, -0.6941, -0.1373,  ...,  0.3569,  0.0039, -0.5686],
          [-0.9922, -0.6784, -0.0902,  ...,  0.3412,  0.0980, -0.4745],
          [-0.9765, -0.6235, -0.0353,  ...,  0.3647,  0.1686, -0.3020]],

         [[-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
          [-1.0000, -1.0000, -

In [62]:
x.size()

torch.Size([2569, 3, 28, 28])

In [63]:
x.dim()

4

In [64]:
x.shape[0]
y.shape[0]

2569

In [65]:
len(trainloader)

1

In [66]:
occ = torch.eye(int(y.max()+1),int(y.max()+1))[y].sum(axis=0)
print(occ)

tensor([1753.,  816.])


In [67]:
dominant_class = torch.argmax(occ)
print(dominant_class)

tensor(0)


In [68]:
n_occ = int(occ[dominant_class].item())
print(n_occ)

1753


In [69]:
for i in range(len(occ)):
    print(i)

0
1


In [70]:
i = 1
N = (n_occ - occ[i]) * 100 / occ[i]
print(N)

tensor(114.8284)


In [71]:
candidates = x[y == i]
print(candidates)

tensor([[[[-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
          [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000,  ..., -0.9765, -0.9922, -1.0000],
          ...,
          [-0.9922, -0.6941, -0.1373,  ...,  0.3569,  0.0039, -0.5686],
          [-0.9922, -0.6784, -0.0902,  ...,  0.3412,  0.0980, -0.4745],
          [-0.9765, -0.6235, -0.0353,  ...,  0.3647,  0.1686, -0.3020]],

         [[-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
          [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000,  ..., -0.9765, -0.9922, -1.0000],
          ...,
          [-0.9922, -0.6941, -0.1373,  ...,  0.3569,  0.0039, -0.5686],
          [-0.9922, -0.6784, -0.0902,  ...,  0.3412,  0.0980, -0.4745],
          [-0.9765, -0.6235, -0.0353,  ...,  0.3647,  0.1686, -0.3020]],

         [[-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
          [-1.0000, -1.0000, -

In [72]:
T = candidates.shape[0]
print(T)

816


In [73]:
synthetic_arr = torch.zeros(int(N/100)*T,512)
print(synthetic_arr)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [74]:
synthetic_arr.size()

torch.Size([816, 512])

In [75]:
N = int(N/100)
print(N)

1


In [78]:
euclid_distance = torch.zeros((candidates.shape[0],candidates.shape[2],candidates.shape[3]), dtype = torch.float32)
print(euclid_distance)

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0., 

In [79]:
euclid_distance.size()

torch.Size([816, 28, 28])

In [80]:
X = candidates
print(len(X))
for i in range(len(X)):
    dif = (X - X[i])**2
    dist = torch.sqrt(dif.sum(axis=1))
    print(dist)

816
tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.1087, 0.1087, 0.0951,  ..., 0.1223, 0.1223, 0.1223],
         [0.0815, 0.0815, 0.0951,  ..., 0.3124, 0.1494, 0.0951],
         [0.2174, 0.3668, 0.4890,  ..., 0.3668, 0.2581, 0.1358],
         ...,
         [1.4807, 1.5351, 0.8694,  ..., 0.0815, 0.3124, 0.8015],
         [1.3449, 1.2090, 0.4211,  ..., 0.0815, 0.0272, 0.4483],
         [1.6438, 1.4128, 0.6385,  ..., 0.0815, 0.1358, 0.0272]],

        [[0.1902, 0.2853, 0.6249,  ..., 0.2038, 0.1902, 0.1902],
         [0.1902, 0.4619, 0.8694,  ..., 0.3940, 0.1630, 0.0272],
         [0.5026, 0.7743, 1.1139,  ..., 1.0596, 0.8151

tensor([[[0.0815, 0.0408, 0.0272,  ..., 0.0815, 0.0815, 0.1087],
         [0.0951, 0.2853, 0.6249,  ..., 0.6385, 0.3396, 0.1087],
         [0.5298, 1.0596, 1.4536,  ..., 1.4264, 1.1275, 0.6521],
         ...,
         [0.1494, 0.4347, 1.3721,  ..., 2.2415, 1.5758, 0.5570],
         [0.1494, 0.4619, 1.4536,  ..., 2.2143, 1.7524, 0.7336],
         [0.1766, 0.5026, 1.4807,  ..., 2.1600, 1.7660, 0.9238]],

        [[0.0272, 0.0679, 0.0679,  ..., 0.0408, 0.0408, 0.0136],
         [0.0136, 0.2038, 0.5298,  ..., 0.3260, 0.1902, 0.0136],
         [0.3124, 0.6928, 0.9645,  ..., 1.0596, 0.8694, 0.5162],
         ...,
         [1.3313, 1.9698, 2.2415,  ..., 2.1600, 1.8883, 1.3585],
         [1.1955, 1.6709, 1.8747,  ..., 2.1328, 1.7796, 1.1819],
         [1.4671, 1.9154, 2.1192,  ..., 2.0785, 1.6302, 0.9509]],

        [[0.1087, 0.2445, 0.5977,  ..., 0.1223, 0.1087, 0.0815],
         [0.0951, 0.1766, 0.2445,  ..., 0.2445, 0.1766, 0.0815],
         [0.0272, 0.2853, 0.3396,  ..., 0.3668, 0.3124, 0.

tensor([[[0.0679, 0.0679, 0.0543,  ..., 0.3124, 0.2989, 0.1902],
         [0.0000, 0.0000, 0.0136,  ..., 0.2038, 0.1766, 0.0679],
         [0.0272, 0.1494, 0.3396,  ..., 0.0136, 0.0136, 0.0272],
         ...,
         [0.0000, 0.5162, 1.0868,  ..., 1.8068, 1.7253, 0.7200],
         [0.0000, 0.5298, 1.0053,  ..., 1.5487, 1.8204, 0.8830],
         [0.0408, 0.6521, 1.4128,  ..., 1.3041, 1.8339, 1.1955]],

        [[0.0408, 0.0408, 0.0408,  ..., 0.1902, 0.1766, 0.0679],
         [0.0815, 0.0815, 0.0815,  ..., 0.1087, 0.0272, 0.0272],
         [0.1902, 0.2174, 0.1494,  ..., 0.3804, 0.2717, 0.1087],
         ...,
         [1.4807, 2.0513, 1.9562,  ..., 1.7253, 2.0377, 1.5215],
         [1.3449, 1.7388, 1.4264,  ..., 1.4671, 1.8475, 1.3313],
         [1.6845, 2.0649, 2.0513,  ..., 1.2226, 1.6981, 1.2226]],

        [[0.1223, 0.2174, 0.5706,  ..., 0.1087, 0.1087, 0.0000],
         [0.1902, 0.4619, 0.8558,  ..., 0.1902, 0.0136, 0.0408],
         [0.4755, 0.6249, 0.7743,  ..., 1.0732, 0.8287, 0.

tensor([[[0.2174, 0.2038, 0.1766,  ..., 0.3668, 0.2853, 0.2309],
         [0.2309, 0.2717, 0.2989,  ..., 1.0868, 1.0189, 0.6385],
         [0.8287, 1.1275, 1.1819,  ..., 1.9290, 1.9019, 1.3721],
         ...,
         [0.2445, 0.3396, 1.2498,  ..., 2.1056, 1.4264, 0.3532],
         [0.2445, 0.3668, 1.3992,  ..., 2.0649, 1.5622, 0.5434],
         [0.2445, 0.4483, 1.5079,  ..., 2.0785, 1.6573, 0.8151]],

        [[0.1087, 0.0951, 0.0815,  ..., 0.2445, 0.1630, 0.1087],
         [0.1494, 0.1902, 0.2038,  ..., 0.7743, 0.8694, 0.5434],
         [0.6113, 0.7607, 0.6928,  ..., 1.5622, 1.6438, 1.2362],
         ...,
         [1.2362, 1.8747, 2.1192,  ..., 2.0241, 1.7388, 1.1547],
         [1.1004, 1.5758, 1.8204,  ..., 1.9834, 1.5894, 0.9917],
         [1.3992, 1.8611, 2.1464,  ..., 1.9970, 1.5215, 0.8423]],

        [[0.0272, 0.0815, 0.4483,  ..., 0.1630, 0.0951, 0.0408],
         [0.0408, 0.1902, 0.5706,  ..., 0.6928, 0.8558, 0.6113],
         [0.3260, 0.3532, 0.0679,  ..., 0.8694, 1.0868, 1.

tensor([[[0.3396, 0.0136, 0.0136,  ..., 0.1494, 0.1087, 0.1494],
         [0.3668, 0.0815, 0.1630,  ..., 0.8151, 0.6792, 0.6657],
         [1.0732, 0.5841, 0.7336,  ..., 1.2634, 1.1275, 1.2362],
         ...,
         [0.0815, 0.5026, 1.0596,  ..., 0.4890, 0.1494, 0.3668],
         [0.0679, 0.5434, 1.2362,  ..., 0.6792, 0.3668, 0.2853],
         [0.1087, 0.5841, 1.1547,  ..., 0.3396, 0.1630, 0.4347]],

        [[0.2309, 0.1223, 0.1087,  ..., 0.0272, 0.0136, 0.0272],
         [0.2853, 0.0000, 0.0679,  ..., 0.5026, 0.5298, 0.5706],
         [0.8558, 0.2174, 0.2445,  ..., 0.8966, 0.8694, 1.1004],
         ...,
         [1.3992, 2.0377, 1.9290,  ..., 0.4075, 0.4619, 0.4347],
         [1.2770, 1.7524, 1.6573,  ..., 0.5977, 0.3940, 0.1630],
         [1.5351, 1.9970, 1.7932,  ..., 0.2581, 0.0272, 0.4075]],

        [[0.1494, 0.2989, 0.6385,  ..., 0.0543, 0.0815, 0.0408],
         [0.1766, 0.3804, 0.7064,  ..., 0.4211, 0.5162, 0.6385],
         [0.5706, 0.1902, 0.3804,  ..., 0.2038, 0.3124, 0.

tensor([[[2.3230, 2.5132, 2.4860,  ..., 2.5403, 2.5947, 2.4860],
         [0.9238, 0.9917, 0.9645,  ..., 1.1139, 1.1547, 1.3856],
         [0.2853, 0.2989, 0.2989,  ..., 0.4619, 0.4347, 0.7064],
         ...,
         [0.2581, 0.2853, 1.2634,  ..., 1.2905, 1.3585, 0.1223],
         [0.2581, 0.3260, 1.3585,  ..., 0.6385, 0.8830, 0.1494],
         [0.2038, 0.4347, 1.4807,  ..., 0.3396, 0.5162, 0.2445]],

        [[2.2143, 2.4045, 2.3909,  ..., 2.4181, 2.4724, 2.3637],
         [0.8423, 0.9102, 0.8694,  ..., 0.8015, 1.0053, 1.2905],
         [0.0679, 0.0679, 0.1902,  ..., 0.0951, 0.1766, 0.5706],
         ...,
         [1.2226, 1.8204, 2.1328,  ..., 1.2090, 1.6709, 0.9238],
         [1.0868, 1.5351, 1.7796,  ..., 0.5570, 0.9102, 0.5977],
         [1.4400, 1.8475, 2.1192,  ..., 0.2581, 0.3804, 0.2717]],

        [[2.1328, 2.2279, 1.8611,  ..., 2.3366, 2.4045, 2.2958],
         [0.7336, 0.5298, 0.0951,  ..., 0.7200, 0.9917, 1.3585],
         [0.2174, 0.4755, 0.8151,  ..., 0.5977, 0.3804, 0.

tensor([[[0.0136, 0.0136, 0.0136,  ..., 0.1766, 0.1223, 0.0815],
         [0.3260, 0.4619, 0.5298,  ..., 0.8015, 0.8694, 0.8015],
         [1.2226, 1.3449, 1.4400,  ..., 1.4943, 1.5351, 1.4400],
         ...,
         [0.5434, 0.9373, 0.6113,  ..., 0.0543, 0.0679, 0.1223],
         [0.4483, 0.8423, 0.5026,  ..., 0.0543, 0.1766, 0.1630],
         [0.4483, 0.7879, 0.4483,  ..., 0.1494, 0.3396, 0.5026]],

        [[0.0951, 0.0951, 0.0815,  ..., 0.0543, 0.0000, 0.0408],
         [0.2445, 0.3804, 0.4347,  ..., 0.4890, 0.7200, 0.7064],
         [1.0053, 0.9781, 0.9509,  ..., 1.1275, 1.2770, 1.3041],
         ...,
         [0.9373, 0.5977, 0.2581,  ..., 0.0272, 0.2445, 0.6792],
         [0.8966, 0.3668, 0.0815,  ..., 0.0272, 0.2038, 0.6113],
         [1.1955, 0.6249, 0.1902,  ..., 0.0679, 0.2038, 0.5298]],

        [[0.1766, 0.2717, 0.6113,  ..., 0.0272, 0.0679, 0.1087],
         [0.1358, 0.0000, 0.3396,  ..., 0.4075, 0.7064, 0.7743],
         [0.7200, 0.5706, 0.3260,  ..., 0.4347, 0.7200, 1.

tensor([[[0.1630, 0.0000, 0.0000,  ..., 0.0000, 0.0136, 0.0000],
         [0.3804, 0.3940, 0.6113,  ..., 0.3940, 0.1902, 0.0815],
         [1.1275, 1.1955, 1.3992,  ..., 1.1275, 0.9373, 0.6928],
         ...,
         [0.7336, 0.8694, 0.4483,  ..., 0.8151, 1.2226, 0.7200],
         [0.6928, 0.7336, 0.2581,  ..., 0.8966, 1.4671, 0.8966],
         [0.6928, 0.6249, 0.1358,  ..., 0.9781, 1.6166, 1.1819]],

        [[0.0543, 0.1087, 0.0951,  ..., 0.1223, 0.1358, 0.1223],
         [0.2989, 0.3124, 0.5162,  ..., 0.0815, 0.0408, 0.0136],
         [0.9102, 0.8287, 0.9102,  ..., 0.7607, 0.6792, 0.5570],
         ...,
         [0.7472, 0.6657, 0.4211,  ..., 0.7336, 1.5351, 1.5215],
         [0.6521, 0.4755, 0.1630,  ..., 0.8151, 1.4943, 1.3449],
         [0.9509, 0.7879, 0.5026,  ..., 0.8966, 1.4807, 1.2090]],

        [[0.0272, 0.2853, 0.6249,  ..., 0.2038, 0.2038, 0.1902],
         [0.1902, 0.0679, 0.2581,  ..., 0.0000, 0.0272, 0.0543],
         [0.6249, 0.4211, 0.2853,  ..., 0.0679, 0.1223, 0.

tensor([[[0.0136, 0.0136, 0.0136,  ..., 0.0136, 0.0272, 0.0272],
         [0.1087, 0.2038, 0.2038,  ..., 0.0136, 0.0136, 0.0272],
         [0.9509, 1.1004, 1.0732,  ..., 0.1358, 0.0679, 0.0543],
         ...,
         [0.0136, 0.4483, 0.5434,  ..., 0.8015, 1.3449, 0.6928],
         [0.0136, 0.5026, 0.7472,  ..., 0.9238, 1.6030, 0.8558],
         [0.0000, 0.6113, 0.9917,  ..., 1.1411, 1.8204, 0.9373]],

        [[0.1223, 0.1223, 0.1087,  ..., 0.1087, 0.0951, 0.0951],
         [0.0272, 0.1223, 0.1087,  ..., 0.2989, 0.1358, 0.0679],
         [0.7336, 0.7336, 0.5841,  ..., 0.2309, 0.1902, 0.0815],
         ...,
         [1.4671, 1.9834, 1.4128,  ..., 0.7200, 1.6573, 1.4943],
         [1.3313, 1.7117, 1.1683,  ..., 0.8423, 1.6302, 1.3041],
         [1.6438, 2.0241, 1.6302,  ..., 1.0596, 1.6845, 0.9645]],

        [[0.2038, 0.2989, 0.6385,  ..., 0.1902, 0.1630, 0.1630],
         [0.0815, 0.2581, 0.6657,  ..., 0.3804, 0.1494, 0.0000],
         [0.4483, 0.3260, 0.0408,  ..., 0.9238, 0.7472, 0.

tensor([[[0.0543, 0.0272, 0.0136,  ..., 0.0136, 0.0000, 0.0000],
         [0.0408, 0.0272, 0.0408,  ..., 0.5026, 0.3396, 0.0543],
         [0.2174, 0.5026, 0.7743,  ..., 1.3041, 1.2362, 0.8015],
         ...,
         [0.0000, 0.4890, 0.7472,  ..., 2.1192, 1.7253, 0.7200],
         [0.0000, 0.5298, 0.9509,  ..., 2.2007, 1.8883, 0.8830],
         [0.1223, 0.5026, 0.9373,  ..., 2.2007, 1.9154, 1.0868]],

        [[0.0543, 0.0815, 0.0815,  ..., 0.1087, 0.1223, 0.1223],
         [0.0408, 0.0543, 0.0543,  ..., 0.1902, 0.1902, 0.0408],
         [0.0000, 0.1358, 0.2853,  ..., 0.9373, 0.9781, 0.6657],
         ...,
         [1.4807, 2.0241, 1.6166,  ..., 2.0377, 2.0377, 1.5215],
         [1.3449, 1.7388, 1.3721,  ..., 2.1192, 1.9154, 1.3313],
         [1.5215, 1.9154, 1.5758,  ..., 2.1192, 1.7796, 1.1139]],

        [[0.1358, 0.2581, 0.6113,  ..., 0.1902, 0.1902, 0.1902],
         [0.1494, 0.4347, 0.8287,  ..., 0.1087, 0.1766, 0.0272],
         [0.2853, 0.2717, 0.3396,  ..., 0.2445, 0.4211, 0.

tensor([[[0.4755, 0.0543, 0.0543,  ..., 0.0679, 0.0679, 0.2581],
         [0.4211, 0.0408, 0.0408,  ..., 0.0272, 0.0408, 0.2309],
         [0.4211, 0.0408, 0.0543,  ..., 0.0000, 0.0272, 0.2309],
         ...,
         [0.5026, 0.3260, 0.1494,  ..., 2.2822, 1.6438, 0.4483],
         [0.4347, 0.0272, 0.2174,  ..., 2.2551, 1.7932, 0.6113],
         [0.3804, 0.3396, 0.7200,  ..., 2.2958, 1.9154, 0.7200]],

        [[0.3668, 0.0543, 0.0408,  ..., 0.0543, 0.0543, 0.1358],
         [0.3396, 0.0408, 0.0543,  ..., 0.2853, 0.1087, 0.1358],
         [0.2038, 0.3260, 0.4347,  ..., 0.3668, 0.2309, 0.0951],
         ...,
         [0.9781, 1.2090, 0.7200,  ..., 2.2007, 1.9562, 1.2498],
         [0.9102, 1.1819, 0.6385,  ..., 2.1736, 1.8204, 1.0596],
         [1.2634, 1.7524, 1.3585,  ..., 2.2143, 1.7796, 0.7472]],

        [[0.2853, 0.2309, 0.5706,  ..., 0.1358, 0.1223, 0.0679],
         [0.2309, 0.4211, 0.8287,  ..., 0.3668, 0.1223, 0.2038],
         [0.0815, 0.7336, 1.0596,  ..., 1.0596, 0.7879, 0.

tensor([[[0.3532, 0.3940, 0.3532,  ..., 0.1087, 0.1087, 0.1223],
         [1.0189, 1.2634, 1.3585,  ..., 0.2581, 0.2853, 0.2581],
         [1.6573, 1.9019, 2.0241,  ..., 1.0596, 1.2634, 1.2226],
         ...,
         [1.2362, 1.5894, 1.0460,  ..., 0.0679, 0.0679, 0.0272],
         [1.1275, 1.4807, 0.9238,  ..., 0.1223, 0.1494, 0.3124],
         [0.9373, 1.2634, 0.7336,  ..., 0.2174, 0.4483, 0.7200]],

        [[0.2445, 0.2853, 0.2581,  ..., 0.0136, 0.0136, 0.0000],
         [0.9373, 1.1819, 1.2634,  ..., 0.0543, 0.1358, 0.1630],
         [1.4400, 1.5351, 1.5351,  ..., 0.6928, 1.0053, 1.0868],
         ...,
         [0.2445, 0.0543, 0.1766,  ..., 0.0136, 0.2445, 0.8287],
         [0.2174, 0.2717, 0.5026,  ..., 0.0408, 0.1766, 0.7607],
         [0.7064, 0.1494, 0.0951,  ..., 0.1358, 0.3124, 0.7472]],

        [[0.1630, 0.1087, 0.2717,  ..., 0.0951, 0.0815, 0.0679],
         [0.8287, 0.8015, 0.4890,  ..., 0.1358, 0.1223, 0.2309],
         [1.1547, 1.1275, 0.9102,  ..., 0.0000, 0.4483, 0.

tensor([[[0.0136, 0.0136, 0.0815,  ..., 0.3124, 0.1358, 0.0272],
         [0.2717, 0.5706, 0.8423,  ..., 1.1004, 0.9373, 0.4890],
         [0.8287, 1.2905, 1.5215,  ..., 1.6438, 1.5894, 1.1139],
         ...,
         [0.0000, 0.4619, 0.6249,  ..., 1.1683, 1.5622, 0.7472],
         [0.0136, 0.5026, 0.7743,  ..., 1.1683, 1.7524, 0.9102],
         [0.0408, 0.6113, 0.9238,  ..., 1.2090, 1.8747, 1.1955]],

        [[0.0951, 0.0951, 0.0136,  ..., 0.1902, 0.0136, 0.0951],
         [0.1902, 0.4890, 0.7472,  ..., 0.7879, 0.7879, 0.3940],
         [0.6113, 0.9238, 1.0324,  ..., 1.2770, 1.3313, 0.9781],
         ...,
         [1.4807, 1.9970, 1.4943,  ..., 1.0868, 1.8747, 1.5487],
         [1.3585, 1.7117, 1.1955,  ..., 1.0868, 1.7796, 1.3585],
         [1.6845, 2.0241, 1.5622,  ..., 1.1275, 1.7388, 1.2226]],

        [[0.1766, 0.2717, 0.5434,  ..., 0.1087, 0.0543, 0.1630],
         [0.0815, 0.1087, 0.0272,  ..., 0.7064, 0.7743, 0.4619],
         [0.3260, 0.5162, 0.4075,  ..., 0.5841, 0.7743, 0.

tensor([[[0.0136, 0.0136, 0.0136,  ..., 0.0000, 0.0000, 0.0136],
         [0.0000, 0.0136, 0.0272,  ..., 0.1358, 0.0000, 0.0000],
         [0.0951, 0.2445, 0.3940,  ..., 0.1902, 0.1087, 0.0408],
         ...,
         [0.0136, 0.4755, 0.8966,  ..., 0.9102, 0.9781, 0.6385],
         [0.0136, 0.4483, 0.8694,  ..., 0.8015, 1.0868, 0.7743],
         [0.0408, 0.4211, 0.7472,  ..., 0.7064, 1.0596, 0.7743]],

        [[0.1223, 0.1223, 0.1087,  ..., 0.1223, 0.1223, 0.1358],
         [0.0815, 0.0679, 0.0679,  ..., 0.1766, 0.1494, 0.0951],
         [0.1223, 0.1223, 0.0951,  ..., 0.1766, 0.1494, 0.0951],
         ...,
         [1.4943, 2.0105, 1.7660,  ..., 0.8287, 1.2905, 1.4400],
         [1.3585, 1.6573, 1.2905,  ..., 0.7200, 1.1139, 1.2226],
         [1.6845, 1.8339, 1.3856,  ..., 0.6249, 0.9238, 0.8015]],

        [[0.2038, 0.2989, 0.6385,  ..., 0.2038, 0.1902, 0.2038],
         [0.1902, 0.4483, 0.8423,  ..., 0.2581, 0.1630, 0.0272],
         [0.4075, 0.5298, 0.7200,  ..., 0.8694, 0.7064, 0.

tensor([[[0.1766, 0.1494, 0.1358,  ..., 0.2445, 0.2445, 0.2717],
         [0.0136, 0.0000, 0.0000,  ..., 0.2174, 0.0272, 0.0000],
         [0.0136, 0.0000, 0.0000,  ..., 0.2717, 0.1494, 0.0679],
         ...,
         [0.0136, 0.3532, 0.3396,  ..., 0.8966, 1.4400, 0.6928],
         [0.0136, 0.4211, 0.4347,  ..., 0.9509, 1.6573, 0.8423],
         [0.0408, 0.5434, 0.5977,  ..., 1.1819, 1.8611, 1.1411]],

        [[0.0679, 0.0408, 0.0408,  ..., 0.1223, 0.1223, 0.1494],
         [0.0679, 0.0815, 0.0951,  ..., 0.0951, 0.1223, 0.0951],
         [0.2038, 0.3668, 0.4890,  ..., 0.0951, 0.1087, 0.0679],
         ...,
         [1.4943, 1.8883, 1.2090,  ..., 0.8151, 1.7524, 1.4943],
         [1.3585, 1.6302, 0.8558,  ..., 0.8694, 1.6845, 1.2905],
         [1.6845, 1.9562, 1.2362,  ..., 1.1004, 1.7253, 1.1683]],

        [[0.0136, 0.1358, 0.4890,  ..., 0.0408, 0.0543, 0.0815],
         [0.1766, 0.4619, 0.8694,  ..., 0.1766, 0.1358, 0.0272],
         [0.4890, 0.7743, 1.1139,  ..., 0.7879, 0.6657, 0.

tensor([[[0.2309, 0.1358, 0.1223,  ..., 0.0408, 0.0543, 0.1223],
         [0.2174, 0.1223, 0.1087,  ..., 0.0543, 0.0408, 0.1087],
         [0.1902, 0.2853, 0.6113,  ..., 0.5706, 0.4483, 0.2717],
         ...,
         [0.1630, 0.3940, 1.3992,  ..., 1.4128, 1.6302, 0.6249],
         [0.1223, 0.4755, 1.5079,  ..., 1.5351, 1.8068, 0.8151],
         [0.4075, 0.2445, 1.3041,  ..., 1.4128, 1.6573, 0.8830]],

        [[0.1223, 0.0272, 0.0272,  ..., 0.0815, 0.0679, 0.0000],
         [0.1358, 0.0408, 0.0136,  ..., 0.2581, 0.1087, 0.0136],
         [0.0272, 0.0815, 0.1223,  ..., 0.2038, 0.1902, 0.1358],
         ...,
         [1.3177, 1.9290, 2.2686,  ..., 1.3313, 1.9426, 1.4264],
         [1.2226, 1.6845, 1.9290,  ..., 1.4536, 1.8339, 1.2634],
         [1.2362, 1.6573, 1.9426,  ..., 1.3313, 1.5215, 0.9102]],

        [[0.0408, 0.1494, 0.5026,  ..., 0.1630, 0.1358, 0.0679],
         [0.0272, 0.3396, 0.7607,  ..., 0.3396, 0.1223, 0.0815],
         [0.3124, 0.4890, 0.5026,  ..., 0.4890, 0.3668, 0.

tensor([[[0.0136, 0.0136, 0.0136,  ..., 0.0136, 0.0136, 0.0136],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.2309, 0.2309, 0.2853,  ..., 0.2309, 0.2038, 0.1223],
         ...,
         [0.0272, 0.1494, 0.0951,  ..., 0.3532, 0.2445, 0.1358],
         [0.0136, 0.1766, 0.1902,  ..., 0.4211, 0.3940, 0.2445],
         [0.0679, 0.3124, 0.2581,  ..., 0.2174, 0.2038, 0.0408]],

        [[0.1223, 0.1223, 0.1087,  ..., 0.1358, 0.1358, 0.1358],
         [0.0815, 0.0815, 0.0951,  ..., 0.3124, 0.1494, 0.0951],
         [0.0136, 0.1358, 0.2038,  ..., 0.1358, 0.0543, 0.0136],
         ...,
         [1.4536, 1.3856, 0.7743,  ..., 0.2717, 0.5570, 0.9373],
         [1.3585, 1.3856, 0.6113,  ..., 0.3396, 0.4211, 0.6928],
         [1.5758, 1.7253, 0.8966,  ..., 0.1358, 0.0679, 0.0136]],

        [[0.2038, 0.2989, 0.6385,  ..., 0.2174, 0.2038, 0.2038],
         [0.1902, 0.4619, 0.8694,  ..., 0.3940, 0.1630, 0.0272],
         [0.2717, 0.5434, 0.8287,  ..., 0.8287, 0.6113, 0.

tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0136],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0408, 0.1494,  ..., 0.2853, 0.0951, 0.0136],
         ...,
         [0.0679, 0.3396, 0.2174,  ..., 1.1819, 1.4671, 0.7472],
         [0.2174, 0.6113, 0.2989,  ..., 0.9509, 1.4536, 0.9102],
         [0.4619, 0.7472, 0.3260,  ..., 0.8287, 1.3449, 0.9238]],

        [[0.1087, 0.1087, 0.0951,  ..., 0.1223, 0.1223, 0.1087],
         [0.0815, 0.0815, 0.0951,  ..., 0.3124, 0.1494, 0.0951],
         [0.2174, 0.3260, 0.3396,  ..., 0.0815, 0.1630, 0.1223],
         ...,
         [1.4128, 1.1955, 0.6521,  ..., 1.1004, 1.7796, 1.5487],
         [1.1275, 0.5977, 0.1223,  ..., 0.8694, 1.4807, 1.3585],
         [1.1819, 0.6657, 0.3124,  ..., 0.7472, 1.2090, 0.9509]],

        [[0.1902, 0.2853, 0.6249,  ..., 0.2038, 0.1902, 0.1766],
         [0.1902, 0.4619, 0.8694,  ..., 0.3940, 0.1630, 0.0272],
         [0.5026, 0.7336, 0.9645,  ..., 0.7743, 0.7200, 0.

tensor([[[0.0136, 0.0136, 0.0136,  ..., 0.6792, 0.6113, 0.4211],
         [0.0136, 0.0815, 0.2445,  ..., 1.9019, 1.8204, 1.7117],
         [0.8015, 1.1411, 1.3856,  ..., 2.1192, 2.1464, 2.1464],
         ...,
         [0.0000, 0.1494, 0.3804,  ..., 0.0679, 0.0815, 0.4483],
         [0.0136, 0.0272, 0.2717,  ..., 0.0136, 0.2174, 0.5570],
         [0.0408, 0.0543, 0.2038,  ..., 0.0543, 0.2038, 0.6792]],

        [[0.1223, 0.1223, 0.1087,  ..., 0.5570, 0.4890, 0.2989],
         [0.0679, 0.0000, 0.1494,  ..., 1.5894, 1.6709, 1.6166],
         [0.5841, 0.7743, 0.8966,  ..., 1.7524, 1.8883, 2.0105],
         ...,
         [1.4807, 1.3856, 0.4890,  ..., 0.0136, 0.3940, 1.2498],
         [1.3585, 1.1819, 0.1494,  ..., 0.0679, 0.2445, 1.0053],
         [1.6845, 1.4671, 0.4347,  ..., 0.1358, 0.0679, 0.7064]],

        [[0.2038, 0.2989, 0.6385,  ..., 0.4755, 0.4211, 0.2309],
         [0.1766, 0.3804, 0.6249,  ..., 1.5079, 1.6573, 1.6845],
         [0.2989, 0.3668, 0.2717,  ..., 1.0596, 1.3313, 1.

tensor([[[0.0272, 0.0408, 0.1902,  ..., 0.4211, 0.2174, 0.0951],
         [0.4211, 0.8151, 1.1955,  ..., 1.1139, 1.1411, 0.9781],
         [1.2090, 1.5758, 1.8883,  ..., 1.8747, 1.9698, 1.8339],
         ...,
         [0.0000, 0.3396, 0.2174,  ..., 0.2038, 0.0543, 0.0815],
         [0.0136, 0.4619, 0.6113,  ..., 0.2717, 0.4483, 0.4755],
         [0.0136, 0.5841, 0.8558,  ..., 0.4483, 0.8423, 0.9645]],

        [[0.0815, 0.0679, 0.0951,  ..., 0.2989, 0.0951, 0.0272],
         [0.3396, 0.7336, 1.1004,  ..., 0.8015, 0.9917, 0.8830],
         [0.9917, 1.2090, 1.3992,  ..., 1.5079, 1.7117, 1.6981],
         ...,
         [1.4807, 1.8747, 1.0868,  ..., 0.1223, 0.3668, 0.8830],
         [1.3585, 1.6709, 1.0324,  ..., 0.1902, 0.4755, 0.9238],
         [1.6573, 1.9970, 1.4943,  ..., 0.3668, 0.7064, 0.9917]],

        [[0.1630, 0.2445, 0.4347,  ..., 0.2174, 0.0272, 0.0951],
         [0.2309, 0.3532, 0.3260,  ..., 0.7200, 0.9781, 0.9509],
         [0.7064, 0.8015, 0.7743,  ..., 0.8151, 1.1547, 1.

tensor([[[0.2853, 0.1087, 0.1223,  ..., 0.2581, 0.2038, 0.2038],
         [0.3124, 0.1087, 0.1223,  ..., 0.3396, 0.1358, 0.1087],
         [0.7200, 0.7064, 0.8966,  ..., 0.7879, 0.6521, 0.5026],
         ...,
         [0.4619, 0.1766, 0.3396,  ..., 2.2007, 1.5622, 0.5434],
         [0.5434, 0.2309, 0.0543,  ..., 1.9562, 1.7524, 0.7336],
         [0.8830, 0.6657, 0.2309,  ..., 1.3992, 1.8204, 1.0460]],

        [[0.1766, 0.0000, 0.0272,  ..., 0.1358, 0.0815, 0.0815],
         [0.2309, 0.0272, 0.0272,  ..., 0.0272, 0.0136, 0.0136],
         [0.5026, 0.3396, 0.4075,  ..., 0.4211, 0.3940, 0.3668],
         ...,
         [1.0189, 1.7117, 1.2090,  ..., 2.1192, 1.8747, 1.3449],
         [0.8015, 0.9781, 0.4755,  ..., 1.8747, 1.7796, 1.1819],
         [0.7607, 0.7472, 0.4075,  ..., 1.3177, 1.6845, 1.0732]],

        [[0.0951, 0.1766, 0.5026,  ..., 0.0543, 0.0136, 0.0136],
         [0.1223, 0.3532, 0.7472,  ..., 0.0543, 0.0272, 0.0815],
         [0.2174, 0.0679, 0.2174,  ..., 0.2717, 0.1630, 0.

tensor([[[0.2309, 0.0272, 0.0408,  ..., 0.6113, 0.4483, 0.5026],
         [0.2445, 0.1630, 0.4347,  ..., 1.3585, 1.2226, 1.1683],
         [0.6385, 0.7336, 1.2362,  ..., 1.7117, 1.4400, 1.4400],
         ...,
         [0.2038, 0.5298, 1.2362,  ..., 0.6113, 0.7743, 0.3940],
         [0.1902, 0.5570, 1.3313,  ..., 0.5570, 0.8830, 0.5298],
         [0.1630, 0.6521, 1.3992,  ..., 0.5570, 0.9509, 0.8015]],

        [[0.1223, 0.0815, 0.0543,  ..., 0.4890, 0.3260, 0.3804],
         [0.1630, 0.0815, 0.3396,  ..., 1.0460, 1.0732, 1.0732],
         [0.4211, 0.3668, 0.7472,  ..., 1.3449, 1.1819, 1.3041],
         ...,
         [1.2770, 2.0649, 2.1056,  ..., 0.5298, 1.0868, 1.1955],
         [1.1547, 1.7660, 1.7524,  ..., 0.4755, 0.9102, 0.9781],
         [1.4807, 2.0649, 2.0377,  ..., 0.4755, 0.8151, 0.8287]],

        [[0.0408, 0.2581, 0.5841,  ..., 0.4075, 0.2581, 0.3124],
         [0.0543, 0.2989, 0.4347,  ..., 0.9645, 1.0596, 1.1411],
         [0.1358, 0.0408, 0.1223,  ..., 0.6521, 0.6249, 1.

tensor([[[0.1766, 0.1630, 0.1494,  ..., 0.2174, 0.2581, 0.2445],
         [0.2989, 0.2853, 0.2581,  ..., 0.8830, 1.0189, 0.9645],
         [1.1547, 1.1547, 1.1004,  ..., 1.5079, 1.5894, 1.6709],
         ...,
         [0.2581, 0.3396, 0.8966,  ..., 2.1056, 1.4536, 0.4211],
         [0.2717, 0.3532, 1.1411,  ..., 2.0785, 1.6166, 0.5841],
         [0.2853, 0.4075, 1.3177,  ..., 2.0785, 1.7117, 0.8558]],

        [[0.0679, 0.0543, 0.0543,  ..., 0.0951, 0.1358, 0.1223],
         [0.2174, 0.2038, 0.1630,  ..., 0.5706, 0.8694, 0.8694],
         [0.9373, 0.7879, 0.6113,  ..., 1.1411, 1.3313, 1.5351],
         ...,
         [1.2226, 1.8747, 1.7660,  ..., 2.0241, 1.7660, 1.2226],
         [1.0732, 1.5622, 1.5622,  ..., 1.9970, 1.6438, 1.0324],
         [1.3585, 1.8204, 1.9562,  ..., 1.9970, 1.5758, 0.8830]],

        [[0.0136, 0.1223, 0.4755,  ..., 0.0136, 0.0679, 0.0543],
         [0.1087, 0.1766, 0.6113,  ..., 0.4890, 0.8558, 0.9373],
         [0.6521, 0.3804, 0.0136,  ..., 0.4483, 0.7743, 1.

tensor([[[0.0815, 0.0272, 0.0272,  ..., 0.2853, 0.2038, 0.1223],
         [0.0679, 0.0000, 0.0000,  ..., 0.2174, 0.1630, 0.1087],
         [0.1766, 0.1358, 0.2038,  ..., 0.2717, 0.2038, 0.2581],
         ...,
         [0.1902, 0.2174, 0.1358,  ..., 0.3804, 0.6385, 0.2989],
         [0.1902, 0.4483, 0.5977,  ..., 0.5434, 1.0732, 0.5162],
         [0.1494, 0.6249, 1.0324,  ..., 0.7200, 1.3449, 0.8423]],

        [[0.0272, 0.0815, 0.0679,  ..., 0.1630, 0.0815, 0.0000],
         [0.0136, 0.0815, 0.0951,  ..., 0.0951, 0.0136, 0.0136],
         [0.0408, 0.2309, 0.2853,  ..., 0.0951, 0.0543, 0.1223],
         ...,
         [1.2905, 1.7524, 1.0053,  ..., 0.2989, 0.9509, 1.1004],
         [1.1547, 1.6573, 1.0189,  ..., 0.4619, 1.1004, 0.9645],
         [1.4943, 2.0377, 1.6709,  ..., 0.6385, 1.2090, 0.8694]],

        [[0.1087, 0.2581, 0.5977,  ..., 0.0815, 0.0136, 0.0679],
         [0.1223, 0.4619, 0.8694,  ..., 0.1766, 0.0000, 0.0815],
         [0.3260, 0.6385, 0.9102,  ..., 0.7879, 0.6113, 0.

In [81]:
dist.size()

torch.Size([816, 28, 28])

In [86]:
euclid_distance[0].size()
dist.size()
euclid_distance[0] = dist[814]
print(euclid_distance[0])

tensor([[0.1630, 0.2309, 0.2853, 0.3396, 0.4211, 0.5298, 0.6385, 0.6928, 0.6792,
         0.5570, 0.8830, 0.6521, 0.4075, 0.2309, 0.3532, 0.4483, 0.6928, 0.5298,
         0.7064, 0.7607, 0.7472, 0.7472, 0.6928, 0.6657, 0.6521, 0.6657, 0.6792,
         0.7064],
        [0.8423, 1.1275, 1.1955, 1.2362, 1.2905, 1.2090, 0.9509, 0.7607, 0.6249,
         0.4347, 0.7336, 0.7607, 0.4483, 0.2717, 0.3940, 0.6385, 0.7743, 0.6521,
         0.8694, 0.8558, 0.9509, 1.2905, 1.5758, 1.4671, 1.1955, 1.0868, 1.2362,
         1.3721],
        [1.5622, 1.7253, 1.7117, 1.6709, 1.5894, 1.4807, 1.2770, 0.9373, 0.8151,
         0.7336, 0.6113, 0.5706, 0.1902, 0.0408, 0.1087, 0.5162, 0.5706, 0.7879,
         1.1139, 1.2090, 1.4807, 1.8068, 1.8339, 1.6573, 1.6438, 1.6845, 1.6845,
         1.5894],
        [1.4264, 1.2498, 1.1275, 1.0868, 0.8423, 0.7472, 1.0324, 1.2634, 1.2090,
         0.7064, 0.1087, 0.1358, 0.1223, 0.0543, 0.0815, 0.4075, 0.1630, 0.2853,
         0.8287, 1.0596, 1.1275, 1.2226, 1.1955, 1.4264

In [52]:
euclid_distance = torch.zeros(816, 28, 28)
print(euclid_distance)
euclid_distance.size()

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0., 

torch.Size([816, 28, 28])