# MA_DL : Codierung 64bit

In [1]:
!pip install facenet-pytorch # due to issues with the google cloud service



In [2]:
import torch
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.models as models
from PIL import Image

from facenet_pytorch import MTCNN, InceptionResnetV1

dtype = torch.cuda.FloatTensor

## Model

In [3]:
class MultiLabel(nn.Sequential):
    def __init__(self, input_dim=512, output_dim=64):
        super(MultiLabel, self).__init__()
        self.l1 = nn.Linear(input_dim, input_dim)
        self.l2 = nn.Linear(input_dim, input_dim)
        self.l5 = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.l5(x)
        return torch.sigmoid(x)

In [4]:
model = MultiLabel()
# model.load_state_dict(torch.load('dl64b_v4_new_414.pt'))
model.cuda()

MultiLabel(
  (l1): Linear(in_features=512, out_features=512, bias=True)
  (l2): Linear(in_features=512, out_features=512, bias=True)
  (l5): Linear(in_features=512, out_features=64, bias=True)
)

## Data

In [20]:
scaler = transforms.Resize((224, 224))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()
to_image = transforms.ToPILImage()
transform = transforms.Compose([
    scaler, 
    to_tensor,
])

In [22]:
train_ds = datasets.ImageFolder('../../../data/small_data_new/train/', transform=transform)
train_ds.idx_to_class = {i:c for c, i in train_ds.class_to_idx.items()}
train_dl = DataLoader(train_ds, batch_size=1, num_workers=4)

valid_ds = datasets.ImageFolder('../../../data/small_data_new/valid/', transform=transform)
valid_ds.idx_to_class = {i:c for c, i in valid_ds.class_to_idx.items()}
valid_dl = DataLoader(valid_ds, batch_size=1, num_workers=4)

## Gesichtserkennung

In [7]:
mtcnn = MTCNN(
    image_size=160,
    thresholds=[0.6, 0.7, 0.7],
#        factor=0.709,
#        prewhiten=True,
    keep_all=True,
    device=torch.device('cuda')
    )

## Embeddings 

In [8]:
# Alternativ resnet
face_resnet = InceptionResnetV1(pretrained='casia-webface').eval()

## Training

In [9]:
def calculate_mistakes(pred, target):
    #pred = pred[0]
#    print('calculating mistakes : ')
#    print('   pred: ' + str(pred.shape))
#    print('   target: ' + str(target.shape))
    
    mistakes = 0
    
    if len(pred) != len(target):
        raise Exception('sizes of both tensors must match')
        
    for x,y in zip(pred, target):
        if round(x.item()) != y.item():
            mistakes = mistakes + 1
        
    return mistakes

In [10]:
loss_func = nn.BCELoss() #MultiLabelSoftMarginLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay=1e-5, momentum=0.9) #weight_decay=1e-4 

In [11]:
import matplotlib.pyplot as plt

In [23]:
epochs = 1000
for e in range(epochs):
    print("======================")
    print("Epoch : " + str(e))
    epoch_loss = 0
    epoch_mistakes = 0
    train_size = len(train_dl)
    
    # activate train mode
    model.train()

    for index, (data, target) in enumerate(train_dl):
        faces, prob = mtcnn(to_image(data[0]), return_prob=True)
                
        target_t = torch.cuda.FloatTensor([int(x) for x in train_ds.idx_to_class[target[0].item()]])

        emb = face_resnet(faces)    
        data_v   = Variable(emb[0], requires_grad=False).type(dtype)
        target_v = Variable(target_t, requires_grad=False).type(dtype)
        
        # forward
#        print('datav: ' + str(data_v.shape))
        pred = model.forward(data_v)
#        print('pred: ' + str(pred.shape))
        # zero grads
        optimizer.zero_grad()
        # calculate loss
        loss = loss_func(pred, target_v.float())
        epoch_loss = epoch_loss + loss
        # back prop
        loss.backward()
        optimizer.step()
        
        epoch_mistakes = epoch_mistakes + calculate_mistakes(pred, target_t)
        
    print("Train loss : " + str(epoch_loss.item()))
    print("Average errors : " + str(epoch_mistakes/train_size))
    
    # activate eval mode
    model.eval()
    
    valid_loss = 0
    valid_mistakes = 0
    valid_size = len(valid_dl)
    for index, (data, target) in enumerate(valid_dl):
        faces, prob = mtcnn(to_image(data[0]), return_prob=True)
        target_t = torch.FloatTensor([int(x) for x in valid_ds.idx_to_class[target[0].item()]])
        
        emb = face_resnet(faces) 
        data_v   = Variable(emb[0], requires_grad=False).type(dtype)
        target_v = Variable(target_t, requires_grad=False).type(dtype)

        pred = model.forward(data_v)
        loss = loss_func(pred, target_v.float())
        valid_loss = valid_loss + loss
        valid_mistakes = valid_mistakes + calculate_mistakes(pred, target_t)
        
        
    print("Valid loss : " + str(valid_loss.item()))
    print("Average valid errors : " + str(valid_mistakes/valid_size))

Epoch : 0
tensor([[[[0.0118, 0.0118, 0.0118,  ..., 0.0078, 0.0118, 0.0118],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.1922, 0.2078, 0.2392,  ..., 0.2392, 0.2431, 0.2471],
          [0.2353, 0.2275, 0.2392,  ..., 0.2431, 0.2471, 0.2510],
          [0.2588, 0.2431, 0.2392,  ..., 0.2431, 0.2471, 0.2510]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0039, 0.0039, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          [0.0039, 0.0039, 0.0039,  ..., 0.0000, 0.0039, 0.0039],
          ...,
          [0.1765, 0.1804, 0.1882,  ..., 0.1843, 0.1882, 0.1922],
          [0.2078, 0.1843, 0.1843,  ..., 0.1882, 0.1922, 0.1961],
          [0.2157, 0.1843, 0.1804,  ..., 0.1882, 0.1922, 0.1961]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0039, 0.0039, 0.0039,  ..., 0.0078, 0.0078, 0.0078],
          [0.0118, 0.0118, 0.011

tensor([[[[0.4118, 0.4000, 0.3765,  ..., 0.3922, 0.4078, 0.3961],
          [0.4078, 0.3961, 0.3765,  ..., 0.3961, 0.4157, 0.4039],
          [0.4118, 0.4039, 0.3922,  ..., 0.3922, 0.4078, 0.4039],
          ...,
          [0.0588, 0.0863, 0.1569,  ..., 0.0667, 0.0588, 0.0471],
          [0.0588, 0.0980, 0.1412,  ..., 0.0431, 0.0157, 0.0157],
          [0.0431, 0.0588, 0.0588,  ..., 0.0314, 0.0000, 0.0118]],

         [[0.3216, 0.3098, 0.2863,  ..., 0.2392, 0.2549, 0.2431],
          [0.3176, 0.3059, 0.2863,  ..., 0.2431, 0.2627, 0.2549],
          [0.3216, 0.3137, 0.2980,  ..., 0.2353, 0.2549, 0.2549],
          ...,
          [0.2902, 0.3020, 0.3412,  ..., 0.1569, 0.1569, 0.1412],
          [0.2863, 0.3098, 0.3255,  ..., 0.1373, 0.1059, 0.1098],
          [0.2706, 0.2706, 0.2392,  ..., 0.1176, 0.0941, 0.1059]],

         [[0.1882, 0.1843, 0.1725,  ..., 0.0863, 0.1020, 0.0902],
          [0.1843, 0.1804, 0.1725,  ..., 0.0902, 0.1098, 0.1020],
          [0.1843, 0.1843, 0.1804,  ..., 0

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.1412, 0.1373, 0.1176,  ..., 0.5020, 0.5490, 0.5843],
          [0.1490, 0.1686, 0.1255,  ..., 0.5608, 0.5882, 0.5961],
          [0.1333, 0.1294, 0.1020,  ..., 0.6275, 0.6157, 0.6078]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0392, 0.0392, 0.0392,  ..., 0.5961, 0.6353, 0.6627],
          [0.0471, 0.0706, 0.0471,  ..., 0.6510, 0.6745, 0.6706],
          [0.0314, 0.0314, 0.0196,  ..., 0.7176, 0.6980, 0.6824]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0

tensor([[[[0.1255, 0.1176, 0.1020,  ..., 0.0000, 0.0078, 0.0078],
          [0.1608, 0.1529, 0.1412,  ..., 0.0000, 0.0078, 0.0078],
          [0.2275, 0.2118, 0.2000,  ..., 0.0000, 0.0078, 0.0078],
          ...,
          [0.0941, 0.0902, 0.0863,  ..., 0.0784, 0.0745, 0.0745],
          [0.0902, 0.0863, 0.0824,  ..., 0.0784, 0.0784, 0.0745],
          [0.0902, 0.0863, 0.0824,  ..., 0.0784, 0.0784, 0.0745]],

         [[0.0627, 0.0588, 0.0510,  ..., 0.0118, 0.0157, 0.0157],
          [0.0549, 0.0549, 0.0510,  ..., 0.0118, 0.0157, 0.0157],
          [0.0431, 0.0431, 0.0392,  ..., 0.0118, 0.0157, 0.0157],
          ...,
          [0.1137, 0.1098, 0.1059,  ..., 0.0980, 0.0941, 0.0941],
          [0.1098, 0.1059, 0.1020,  ..., 0.0980, 0.0980, 0.0941],
          [0.1098, 0.1059, 0.1020,  ..., 0.0980, 0.0980, 0.0941]],

         [[0.0745, 0.0706, 0.0627,  ..., 0.0078, 0.0118, 0.0118],
          [0.0824, 0.0784, 0.0706,  ..., 0.0078, 0.0118, 0.0118],
          [0.0824, 0.0784, 0.0706,  ..., 0

tensor([[[[0.0039, 0.0000, 0.0000,  ..., 0.0000, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0078, 0.0078],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0863, 0.0902, 0.0980,  ..., 0.0980, 0.1529, 0.2431],
          [0.0902, 0.0980, 0.1020,  ..., 0.1529, 0.2196, 0.3216],
          [0.0941, 0.0980, 0.1020,  ..., 0.2039, 0.2706, 0.3569]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0784, 0.0824, 0.0902,  ..., 0.1176, 0.1725, 0.2627],
          [0.0824, 0.0902, 0.0941,  ..., 0.1725, 0.2471, 0.3490],
          [0.0863, 0.0902, 0.0941,  ..., 0.2235, 0.2941, 0.3843]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0

tensor([[[[-0.6211, -0.6211, -0.6055,  ...,  0.4102,  0.4102,  0.4102],
          [-0.4414, -0.4414, -0.4258,  ...,  0.4727,  0.4727,  0.4727],
          [-0.2148, -0.2148, -0.1992,  ...,  0.5273,  0.5273,  0.5273],
          ...,
          [ 0.3477,  0.3477,  0.3398,  ..., -0.6289, -0.6211, -0.6211],
          [ 0.3555,  0.3555,  0.3477,  ..., -0.3867, -0.3711, -0.3711],
          [ 0.3633,  0.3633,  0.3555,  ..., -0.1602, -0.1445, -0.1445]],

         [[-0.7695, -0.7695, -0.7461,  ..., -0.0117, -0.0117, -0.0117],
          [-0.6367, -0.6367, -0.6211,  ...,  0.0508,  0.0508,  0.0508],
          [-0.4648, -0.4648, -0.4570,  ...,  0.1055,  0.1055,  0.1055],
          ...,
          [ 0.4414,  0.4414,  0.4414,  ..., -0.7383, -0.7227, -0.7227],
          [ 0.4570,  0.4570,  0.4570,  ..., -0.5039, -0.4883, -0.4883],
          [ 0.4648,  0.4648,  0.4648,  ..., -0.2852, -0.2695, -0.2695]],

         [[-0.8242, -0.8242, -0.8008,  ..., -0.2773, -0.2773, -0.2773],
          [-0.6914, -0.6914, -

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.5843, 0.5686, 0.5490,  ..., 0.6980, 0.7647, 0.8118],
          [0.5922, 0.5686, 0.5490,  ..., 0.7725, 0.8275, 0.8549],
          [0.6000, 0.5686, 0.5529,  ..., 0.8196, 0.8588, 0.8745]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.5843, 0.5686, 0.5490,  ..., 0.6980, 0.7647, 0.8118],
          [0.5922, 0.5686, 0.5490,  ..., 0.7725, 0.8275, 0.8549],
          [0.6000, 0.5686, 0.5529,  ..., 0.8196, 0.8588, 0.8745]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0

tensor([[[[0.7686, 0.7490, 0.7098,  ..., 0.5804, 0.5373, 0.5216],
          [0.7137, 0.7020, 0.6824,  ..., 0.5961, 0.5529, 0.5373],
          [0.6980, 0.6980, 0.6941,  ..., 0.6078, 0.5686, 0.5529],
          ...,
          [0.6588, 0.6588, 0.6588,  ..., 0.8353, 0.8314, 0.8235],
          [0.6588, 0.6627, 0.6588,  ..., 0.8353, 0.8392, 0.8314],
          [0.6627, 0.6627, 0.6627,  ..., 0.8392, 0.8471, 0.8392]],

         [[0.7882, 0.7686, 0.7294,  ..., 0.4706, 0.4314, 0.4157],
          [0.7333, 0.7216, 0.7020,  ..., 0.4863, 0.4471, 0.4314],
          [0.7176, 0.7176, 0.7137,  ..., 0.4980, 0.4627, 0.4471],
          ...,
          [0.5216, 0.5216, 0.5176,  ..., 0.8392, 0.8353, 0.8314],
          [0.5176, 0.5216, 0.5176,  ..., 0.8510, 0.8510, 0.8431],
          [0.5216, 0.5216, 0.5176,  ..., 0.8549, 0.8627, 0.8549]],

         [[0.7608, 0.7412, 0.7020,  ..., 0.3882, 0.3529, 0.3412],
          [0.7059, 0.6941, 0.6745,  ..., 0.4039, 0.3686, 0.3569],
          [0.6902, 0.6902, 0.6863,  ..., 0

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0078, 0.0078, 0.0078],
          ...,
          [0.2235, 0.2235, 0.2157,  ..., 0.3765, 0.3529, 0.3333],
          [0.2000, 0.2078, 0.2078,  ..., 0.3765, 0.3490, 0.3333],
          [0.1922, 0.2000, 0.2039,  ..., 0.3647, 0.3490, 0.3373]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0039, 0.0039, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.2275, 0.2275, 0.2157,  ..., 0.3255, 0.3059, 0.2863],
          [0.2039, 0.2118, 0.2039,  ..., 0.3255, 0.2980, 0.2863],
          [0.1961, 0.2039, 0.2039,  ..., 0.3176, 0.2980, 0.2863]],

         [[0.0078, 0.0078, 0.0078,  ..., 0.0078, 0.0078, 0.0078],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0039, 0.0039, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.9961, 0.9961, 0.9882,  ..., 0.1490, 0.1412, 0.1412],
          [1.0000, 0.9961, 0.9804,  ..., 0.1412, 0.1333, 0.1608],
          [1.0000, 0.9882, 0.9686,  ..., 0.1373, 0.1373, 0.1765]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.9961, 0.9961, 0.9882,  ..., 0.4627, 0.4588, 0.4588],
          [1.0000, 0.9961, 0.9804,  ..., 0.4549, 0.4510, 0.4784],
          [1.0000, 0.9882, 0.9725,  ..., 0.4549, 0.4588, 0.4941]],

         [[0.0314, 0.0314, 0.0314,  ..., 0.0000, 0.0000, 0.0000],
          [0.0118, 0.0118, 0.0118,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0

tensor([[[[0.0510, 0.0510, 0.0471,  ..., 0.0667, 0.0667, 0.0706],
          [0.0471, 0.0471, 0.0510,  ..., 0.0745, 0.0706, 0.0745],
          [0.0510, 0.0510, 0.0510,  ..., 0.0706, 0.0667, 0.0706],
          ...,
          [0.1176, 0.1216, 0.1294,  ..., 0.1490, 0.1529, 0.1569],
          [0.1255, 0.1255, 0.1294,  ..., 0.1216, 0.1373, 0.1451],
          [0.1294, 0.1294, 0.1294,  ..., 0.1137, 0.1294, 0.1412]],

         [[0.0510, 0.0510, 0.0471,  ..., 0.0667, 0.0667, 0.0706],
          [0.0471, 0.0471, 0.0510,  ..., 0.0745, 0.0706, 0.0745],
          [0.0510, 0.0510, 0.0510,  ..., 0.0706, 0.0667, 0.0706],
          ...,
          [0.1216, 0.1255, 0.1333,  ..., 0.1725, 0.1804, 0.1843],
          [0.1294, 0.1294, 0.1333,  ..., 0.1451, 0.1647, 0.1725],
          [0.1333, 0.1333, 0.1333,  ..., 0.1294, 0.1569, 0.1686]],

         [[0.0588, 0.0588, 0.0549,  ..., 0.0745, 0.0745, 0.0784],
          [0.0549, 0.0549, 0.0588,  ..., 0.0824, 0.0784, 0.0824],
          [0.0588, 0.0588, 0.0588,  ..., 0

tensor([[[[0.7647, 0.7490, 0.7333,  ..., 0.3451, 0.3765, 0.4157],
          [0.7412, 0.7294, 0.7176,  ..., 0.3529, 0.3373, 0.3569],
          [0.7255, 0.7137, 0.7059,  ..., 0.3490, 0.3176, 0.3137],
          ...,
          [0.0471, 0.0510, 0.0510,  ..., 0.0549, 0.0549, 0.0510],
          [0.0471, 0.0471, 0.0510,  ..., 0.0549, 0.0510, 0.0510],
          [0.0471, 0.0510, 0.0510,  ..., 0.0549, 0.0510, 0.0510]],

         [[0.6235, 0.6078, 0.5843,  ..., 0.2196, 0.2471, 0.2941],
          [0.5961, 0.5804, 0.5686,  ..., 0.2235, 0.2078, 0.2275],
          [0.5725, 0.5608, 0.5490,  ..., 0.2196, 0.1882, 0.1843],
          ...,
          [0.0431, 0.0471, 0.0471,  ..., 0.0510, 0.0510, 0.0471],
          [0.0431, 0.0431, 0.0471,  ..., 0.0510, 0.0471, 0.0471],
          [0.0431, 0.0471, 0.0471,  ..., 0.0510, 0.0471, 0.0471]],

         [[0.3176, 0.3059, 0.2941,  ..., 0.0745, 0.1137, 0.1608],
          [0.3059, 0.2941, 0.2863,  ..., 0.0824, 0.0784, 0.0980],
          [0.2980, 0.2902, 0.2824,  ..., 0

tensor([[[[ 0.3867,  0.4180,  0.4648,  ..., -0.6289, -0.6133, -0.6055],
          [ 0.5664,  0.5742,  0.5898,  ..., -0.6289, -0.6211, -0.6133],
          [ 0.6367,  0.6367,  0.6289,  ..., -0.6289, -0.6211, -0.6133],
          ...,
          [ 0.7773,  0.7773,  0.7695,  ..., -0.7383, -0.7461, -0.7461],
          [ 0.7773,  0.7773,  0.7695,  ..., -0.7305, -0.7461, -0.7539],
          [ 0.7695,  0.7695,  0.7773,  ..., -0.8086, -0.8008, -0.7930]],

         [[ 0.2695,  0.2930,  0.3320,  ..., -0.6680, -0.6523, -0.6445],
          [ 0.4414,  0.4414,  0.4414,  ..., -0.6680, -0.6602, -0.6523],
          [ 0.4883,  0.4805,  0.4570,  ..., -0.6680, -0.6602, -0.6523],
          ...,
          [ 0.7695,  0.7695,  0.7617,  ..., -0.7461, -0.7539, -0.7539],
          [ 0.7695,  0.7695,  0.7617,  ..., -0.7461, -0.7539, -0.7617],
          [ 0.7617,  0.7617,  0.7695,  ..., -0.8242, -0.8164, -0.8086]],

         [[ 0.0117,  0.0273,  0.0508,  ..., -0.6992, -0.6836, -0.6758],
          [ 0.1680,  0.1602,  

tensor([[[[ 0.1289,  0.1211,  0.1133,  ..., -0.7070, -0.7148, -0.7227],
          [ 0.1133,  0.1133,  0.1133,  ..., -0.7148, -0.7070, -0.7148],
          [ 0.0898,  0.0977,  0.1133,  ..., -0.7148, -0.7070, -0.6992],
          ...,
          [ 0.8320,  0.8320,  0.8164,  ...,  0.1133,  0.1133,  0.1133],
          [ 0.8398,  0.8320,  0.8086,  ...,  0.1133,  0.1133,  0.1133],
          [ 0.8398,  0.8320,  0.8164,  ...,  0.1133,  0.1055,  0.1055]],

         [[ 0.1367,  0.1445,  0.1523,  ..., -0.6992, -0.7070, -0.7148],
          [ 0.1367,  0.1445,  0.1523,  ..., -0.7070, -0.7070, -0.7070],
          [ 0.1445,  0.1445,  0.1523,  ..., -0.7227, -0.7070, -0.7070],
          ...,
          [ 0.8477,  0.8398,  0.8164,  ..., -0.1992, -0.2070, -0.2070],
          [ 0.8398,  0.8320,  0.8086,  ..., -0.2148, -0.2227, -0.2227],
          [ 0.8398,  0.8320,  0.8164,  ..., -0.2227, -0.2305, -0.2305]],

         [[ 0.0977,  0.0977,  0.1055,  ..., -0.7461, -0.7539, -0.7539],
          [ 0.0898,  0.0898,  

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.9255, 0.9333, 0.9451,  ..., 0.8588, 0.7529, 0.6353],
          [0.9373, 0.9412, 0.9451,  ..., 0.8824, 0.7725, 0.6196],
          [0.9451, 0.9490, 0.9490,  ..., 0.9020, 0.7451, 0.5686]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.9412, 0.9490, 0.9608,  ..., 0.8824, 0.7765, 0.6588],
          [0.9608, 0.9686, 0.9725,  ..., 0.9059, 0.7961, 0.6431],
          [0.9725, 0.9765, 0.9765,  ..., 0.9137, 0.7647, 0.5922]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0

tensor([[[[0.0039, 0.0039, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0039, 0.0000],
          [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
          ...,
          [0.4314, 0.4353, 0.4431,  ..., 0.0157, 0.0078, 0.0118],
          [0.4431, 0.4588, 0.4745,  ..., 0.0627, 0.0431, 0.0431],
          [0.4588, 0.4745, 0.4941,  ..., 0.0824, 0.0863, 0.0941]],

         [[0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0078, 0.0078],
          [0.0039, 0.0039, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.4000, 0.4000, 0.4078,  ..., 0.0235, 0.0118, 0.0078],
          [0.4118, 0.4235, 0.4392,  ..., 0.0627, 0.0471, 0.0431],
          [0.4275, 0.4392, 0.4588,  ..., 0.0824, 0.0863, 0.0941]],

         [[0.0039, 0.0039, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0078, 0.0039],
          [0.0039, 0.0039, 0.0039,  ..., 0

tensor([[[[0.6039, 0.6000, 0.5961,  ..., 0.2510, 0.2510, 0.2510],
          [0.5176, 0.5137, 0.5137,  ..., 0.2471, 0.2471, 0.2471],
          [0.4235, 0.4196, 0.4196,  ..., 0.2471, 0.2471, 0.2471],
          ...,
          [0.2000, 0.2039, 0.2039,  ..., 0.4863, 0.4902, 0.4902],
          [0.1961, 0.2078, 0.2078,  ..., 0.4941, 0.4941, 0.4941],
          [0.1961, 0.2000, 0.2000,  ..., 0.4902, 0.4941, 0.4941]],

         [[0.6314, 0.6275, 0.6235,  ..., 0.3216, 0.3216, 0.3216],
          [0.5529, 0.5490, 0.5451,  ..., 0.3176, 0.3176, 0.3176],
          [0.4706, 0.4667, 0.4667,  ..., 0.3176, 0.3176, 0.3176],
          ...,
          [0.2784, 0.2824, 0.2824,  ..., 0.4549, 0.4510, 0.4431],
          [0.2863, 0.2863, 0.2863,  ..., 0.4588, 0.4588, 0.4588],
          [0.2824, 0.2863, 0.2863,  ..., 0.4510, 0.4588, 0.4588]],

         [[0.6784, 0.6745, 0.6784,  ..., 0.4784, 0.4784, 0.4784],
          [0.6078, 0.6039, 0.6118,  ..., 0.4745, 0.4745, 0.4745],
          [0.5412, 0.5373, 0.5451,  ..., 0

tensor([[[[0.7059, 0.7098, 0.7137,  ..., 0.7020, 0.7059, 0.7098],
          [0.7098, 0.7098, 0.7137,  ..., 0.7020, 0.7098, 0.7098],
          [0.7098, 0.7137, 0.7137,  ..., 0.6980, 0.7059, 0.7098],
          ...,
          [0.0039, 0.0039, 0.0039,  ..., 0.5569, 0.5490, 0.5412],
          [0.0039, 0.0039, 0.0039,  ..., 0.5569, 0.5451, 0.5412],
          [0.0039, 0.0039, 0.0039,  ..., 0.5569, 0.5451, 0.5412]],

         [[0.6627, 0.6627, 0.6667,  ..., 0.6353, 0.6431, 0.6471],
          [0.6627, 0.6627, 0.6667,  ..., 0.6353, 0.6431, 0.6471],
          [0.6588, 0.6627, 0.6627,  ..., 0.6275, 0.6392, 0.6431],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.6392, 0.6314, 0.6235],
          [0.0000, 0.0000, 0.0000,  ..., 0.6392, 0.6275, 0.6235],
          [0.0000, 0.0000, 0.0000,  ..., 0.6392, 0.6275, 0.6235]],

         [[0.5765, 0.5765, 0.5804,  ..., 0.5333, 0.5412, 0.5451],
          [0.5765, 0.5765, 0.5804,  ..., 0.5333, 0.5412, 0.5451],
          [0.5686, 0.5765, 0.5725,  ..., 0

tensor([[[[0.8784, 0.8784, 0.8784,  ..., 0.8980, 0.8980, 0.8980],
          [0.8784, 0.8784, 0.8784,  ..., 0.8980, 0.8980, 0.8980],
          [0.8784, 0.8784, 0.8784,  ..., 0.8980, 0.8980, 0.8980],
          ...,
          [0.5451, 0.5529, 0.5647,  ..., 0.0667, 0.0667, 0.0667],
          [0.5608, 0.5725, 0.5843,  ..., 0.0667, 0.0667, 0.0667],
          [0.5686, 0.5804, 0.5882,  ..., 0.0667, 0.0667, 0.0667]],

         [[0.8980, 0.8980, 0.8980,  ..., 0.9020, 0.9020, 0.9020],
          [0.8980, 0.8980, 0.8980,  ..., 0.9020, 0.9020, 0.9020],
          [0.8980, 0.8980, 0.8980,  ..., 0.9020, 0.9020, 0.9020],
          ...,
          [0.5294, 0.5373, 0.5490,  ..., 0.0902, 0.0902, 0.0902],
          [0.5412, 0.5569, 0.5647,  ..., 0.0902, 0.0902, 0.0902],
          [0.5490, 0.5647, 0.5725,  ..., 0.0902, 0.0902, 0.0902]],

         [[0.8745, 0.8745, 0.8745,  ..., 0.8824, 0.8824, 0.8824],
          [0.8745, 0.8745, 0.8745,  ..., 0.8824, 0.8824, 0.8824],
          [0.8745, 0.8745, 0.8745,  ..., 0

tensor([[[[0.1412, 0.1765, 0.2118,  ..., 0.2353, 0.2314, 0.2314],
          [0.1804, 0.2275, 0.3176,  ..., 0.2392, 0.2314, 0.2314],
          [0.2157, 0.3098, 0.4667,  ..., 0.2392, 0.2314, 0.2314],
          ...,
          [0.1725, 0.1725, 0.1725,  ..., 0.8549, 0.8588, 0.8588],
          [0.1686, 0.1686, 0.1725,  ..., 0.8549, 0.8627, 0.8627],
          [0.1686, 0.1686, 0.1725,  ..., 0.8549, 0.8667, 0.8667]],

         [[0.2196, 0.2549, 0.2902,  ..., 0.3451, 0.3412, 0.3412],
          [0.2588, 0.3020, 0.3922,  ..., 0.3490, 0.3412, 0.3412],
          [0.2863, 0.3804, 0.5373,  ..., 0.3490, 0.3412, 0.3412],
          ...,
          [0.2667, 0.2667, 0.2667,  ..., 0.8745, 0.8745, 0.8745],
          [0.2627, 0.2627, 0.2667,  ..., 0.8745, 0.8824, 0.8824],
          [0.2627, 0.2627, 0.2667,  ..., 0.8745, 0.8863, 0.8863]],

         [[0.1176, 0.1569, 0.1961,  ..., 0.1451, 0.1412, 0.1412],
          [0.1608, 0.2078, 0.3020,  ..., 0.1490, 0.1412, 0.1412],
          [0.2000, 0.2980, 0.4549,  ..., 0

tensor([[[[0.0039, 0.0039, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.1137, 0.0941, 0.0510,  ..., 0.0353, 0.0745, 0.0902],
          [0.0980, 0.0784, 0.0431,  ..., 0.0510, 0.0863, 0.1255],
          [0.0941, 0.0745, 0.0392,  ..., 0.0588, 0.0980, 0.1569]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0039, 0.0039, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.1373, 0.1176, 0.0745,  ..., 0.0510, 0.0784, 0.0941],
          [0.1216, 0.1020, 0.0667,  ..., 0.0588, 0.0902, 0.1255],
          [0.1176, 0.0980, 0.0627,  ..., 0.0627, 0.1020, 0.1451]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0039, 0.0039, 0.0039,  ..., 0

tensor([[[[0.4784, 0.4784, 0.4784,  ..., 0.4510, 0.4510, 0.4510],
          [0.4784, 0.4784, 0.4784,  ..., 0.4549, 0.4510, 0.4510],
          [0.4745, 0.4745, 0.4784,  ..., 0.4667, 0.4627, 0.4627],
          ...,
          [0.6902, 0.6941, 0.7020,  ..., 0.9725, 0.9686, 0.9686],
          [0.7059, 0.7020, 0.6980,  ..., 0.9686, 0.9686, 0.9686],
          [0.7098, 0.7020, 0.6980,  ..., 0.9686, 0.9686, 0.9686]],

         [[0.5020, 0.5020, 0.5020,  ..., 0.5098, 0.5098, 0.5098],
          [0.5020, 0.5020, 0.5020,  ..., 0.5020, 0.5020, 0.5020],
          [0.4980, 0.4980, 0.5020,  ..., 0.4980, 0.4980, 0.4980],
          ...,
          [0.6196, 0.6235, 0.6314,  ..., 0.9804, 0.9765, 0.9765],
          [0.6353, 0.6314, 0.6275,  ..., 0.9765, 0.9765, 0.9765],
          [0.6392, 0.6314, 0.6275,  ..., 0.9765, 0.9765, 0.9765]],

         [[0.5020, 0.5020, 0.5020,  ..., 0.5294, 0.5294, 0.5294],
          [0.5020, 0.5020, 0.5020,  ..., 0.5255, 0.5255, 0.5255],
          [0.4980, 0.4980, 0.5020,  ..., 0

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
          [0.0235, 0.0235, 0.0235,  ..., 0.0235, 0.0196, 0.0196],
          ...,
          [0.1176, 0.1176, 0.1137,  ..., 0.7490, 0.8000, 0.7961],
          [0.1176, 0.1137, 0.1137,  ..., 0.7686, 0.8039, 0.7922],
          [0.1176, 0.1137, 0.1137,  ..., 0.7922, 0.8078, 0.7922]],

         [[0.0078, 0.0078, 0.0078,  ..., 0.0078, 0.0078, 0.0078],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0902, 0.0902, 0.0863,  ..., 0.6745, 0.7216, 0.7137],
          [0.0941, 0.0902, 0.0902,  ..., 0.6941, 0.7255, 0.7098],
          [0.0941, 0.0902, 0.0902,  ..., 0.7176, 0.7294, 0.7098]],

         [[0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
          [0.0196, 0.0196, 0.0196,  ..., 0.0196, 0.0157, 0.0157],
          [0.0196, 0.0196, 0.0196,  ..., 0

tensor([[[[0.7216, 0.7373, 0.7373,  ..., 0.2706, 0.2706, 0.2706],
          [0.7098, 0.7373, 0.7412,  ..., 0.2706, 0.2706, 0.2706],
          [0.7020, 0.7333, 0.7412,  ..., 0.2706, 0.2706, 0.2706],
          ...,
          [0.1255, 0.1255, 0.1255,  ..., 0.1059, 0.1059, 0.1059],
          [0.1255, 0.1333, 0.1333,  ..., 0.0863, 0.0863, 0.0863],
          [0.1255, 0.1294, 0.1294,  ..., 0.0745, 0.0745, 0.0745]],

         [[0.7725, 0.7804, 0.7725,  ..., 0.0941, 0.0941, 0.0941],
          [0.7647, 0.7804, 0.7725,  ..., 0.0941, 0.0941, 0.0941],
          [0.7529, 0.7765, 0.7765,  ..., 0.0941, 0.0941, 0.0941],
          ...,
          [0.1922, 0.1922, 0.1922,  ..., 0.0980, 0.0980, 0.0980],
          [0.2000, 0.2000, 0.2000,  ..., 0.0902, 0.0902, 0.0902],
          [0.2000, 0.2039, 0.2000,  ..., 0.0784, 0.0784, 0.0784]],

         [[0.7765, 0.7882, 0.7804,  ..., 0.0118, 0.0118, 0.0118],
          [0.7686, 0.7882, 0.7843,  ..., 0.0118, 0.0118, 0.0118],
          [0.7569, 0.7843, 0.7843,  ..., 0

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0706, 0.0627, 0.0588,  ..., 0.0784, 0.0824, 0.0824],
          [0.0706, 0.0667, 0.0667,  ..., 0.0824, 0.0784, 0.0745],
          [0.0706, 0.0667, 0.0667,  ..., 0.0863, 0.0824, 0.0784]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          ...,
          [0.1059, 0.0980, 0.0941,  ..., 0.1137, 0.1176, 0.1176],
          [0.1059, 0.1020, 0.1020,  ..., 0.1176, 0.1137, 0.1098],
          [0.1059, 0.1020, 0.1020,  ..., 0.1216, 0.1176, 0.1137]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0

tensor([[[[0.3490, 0.3490, 0.3490,  ..., 0.3216, 0.3216, 0.3216],
          [0.3647, 0.3608, 0.3569,  ..., 0.3255, 0.3255, 0.3255],
          [0.3804, 0.3765, 0.3686,  ..., 0.3333, 0.3294, 0.3294],
          ...,
          [0.2000, 0.2000, 0.1961,  ..., 0.1647, 0.1647, 0.1647],
          [0.1961, 0.1961, 0.1922,  ..., 0.1569, 0.1529, 0.1490],
          [0.1961, 0.1961, 0.1922,  ..., 0.1569, 0.1451, 0.1412]],

         [[0.2627, 0.2627, 0.2627,  ..., 0.2549, 0.2549, 0.2549],
          [0.2784, 0.2745, 0.2706,  ..., 0.2588, 0.2588, 0.2588],
          [0.2941, 0.2902, 0.2824,  ..., 0.2667, 0.2627, 0.2627],
          ...,
          [0.1765, 0.1765, 0.1725,  ..., 0.1451, 0.1451, 0.1451],
          [0.1725, 0.1725, 0.1686,  ..., 0.1373, 0.1333, 0.1294],
          [0.1725, 0.1725, 0.1686,  ..., 0.1373, 0.1255, 0.1216]],

         [[0.1725, 0.1725, 0.1725,  ..., 0.2157, 0.2157, 0.2157],
          [0.1882, 0.1843, 0.1843,  ..., 0.2196, 0.2196, 0.2196],
          [0.2118, 0.2078, 0.2000,  ..., 0

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0039, 0.0039, 0.0039,  ..., 0.0941, 0.0706, 0.0667],
          [0.0039, 0.0039, 0.0039,  ..., 0.0902, 0.0745, 0.0706],
          [0.0039, 0.0039, 0.0039,  ..., 0.0902, 0.0863, 0.0824]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0039, 0.0039, 0.0039,  ..., 0.1451, 0.1294, 0.1255],
          [0.0039, 0.0039, 0.0039,  ..., 0.1412, 0.1294, 0.1294],
          [0.0039, 0.0039, 0.0039,  ..., 0.1412, 0.1412, 0.1412]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0

tensor([[[[0.4000, 0.2353, 0.1137,  ..., 0.0039, 0.0118, 0.0431],
          [0.3804, 0.2039, 0.1137,  ..., 0.0000, 0.0118, 0.0471],
          [0.3216, 0.1569, 0.1216,  ..., 0.0000, 0.0157, 0.0510],
          ...,
          [0.8235, 0.8196, 0.8078,  ..., 0.1294, 0.1216, 0.1216],
          [0.8196, 0.8196, 0.8196,  ..., 0.1529, 0.1294, 0.1098],
          [0.8157, 0.8235, 0.8235,  ..., 0.1490, 0.1216, 0.0941]],

         [[0.4157, 0.2549, 0.1294,  ..., 0.0118, 0.0196, 0.0510],
          [0.3922, 0.2235, 0.1294,  ..., 0.0078, 0.0196, 0.0549],
          [0.3333, 0.1725, 0.1373,  ..., 0.0078, 0.0235, 0.0588],
          ...,
          [0.3294, 0.3294, 0.3176,  ..., 0.1059, 0.0941, 0.0863],
          [0.3255, 0.3255, 0.3255,  ..., 0.1255, 0.0980, 0.0706],
          [0.3216, 0.3294, 0.3294,  ..., 0.1176, 0.0863, 0.0510]],

         [[0.2941, 0.1412, 0.0392,  ..., 0.0000, 0.0078, 0.0392],
          [0.2706, 0.1176, 0.0471,  ..., 0.0000, 0.0078, 0.0431],
          [0.2235, 0.0784, 0.0627,  ..., 0

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.1961, 0.1843, 0.1804,  ..., 0.6431, 0.5882, 0.6549],
          [0.1961, 0.1882, 0.1843,  ..., 0.7059, 0.6275, 0.6745],
          [0.1961, 0.1882, 0.1843,  ..., 0.7098, 0.6471, 0.7020]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.2353, 0.2235, 0.2275,  ..., 0.4902, 0.4275, 0.4824],
          [0.2353, 0.2275, 0.2314,  ..., 0.5569, 0.4667, 0.5137],
          [0.2353, 0.2275, 0.2314,  ..., 0.5765, 0.4902, 0.5373]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0

tensor([[[[0.0039, 0.0039, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          [0.0078, 0.0078, 0.0078,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0863, 0.1098, 0.1333,  ..., 0.1725, 0.1882, 0.1882],
          [0.0941, 0.1098, 0.1294,  ..., 0.1765, 0.1804, 0.2039],
          [0.0980, 0.1098, 0.1255,  ..., 0.1725, 0.1765, 0.2039]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0039, 0.0039, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0863, 0.1059, 0.1294,  ..., 0.1608, 0.1765, 0.1765],
          [0.0941, 0.1098, 0.1255,  ..., 0.1647, 0.1686, 0.1922],
          [0.0980, 0.1098, 0.1255,  ..., 0.1608, 0.1647, 0.1922]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0157, 0.0157, 0.0196],
          [0.0039, 0.0039, 0.0039,  ..., 0

tensor([[[[ 0.0508,  0.0508,  0.0430,  ..., -0.1914, -0.1992, -0.1992],
          [ 0.0352,  0.0273,  0.0195,  ..., -0.2383, -0.2383, -0.2383],
          [ 0.0273,  0.0195,  0.0117,  ..., -0.2773, -0.2773, -0.2773],
          ...,
          [-0.6836, -0.6836, -0.6836,  ..., -0.0664, -0.0898, -0.0977],
          [-0.6914, -0.6914, -0.6914,  ..., -0.0430, -0.0664, -0.0742],
          [-0.6992, -0.6992, -0.6992,  ..., -0.0195, -0.0352, -0.0430]],

         [[-0.1602, -0.1680, -0.1758,  ..., -0.2930, -0.3008, -0.3086],
          [-0.1758, -0.1836, -0.1992,  ..., -0.3320, -0.3398, -0.3477],
          [-0.1836, -0.1914, -0.2070,  ..., -0.3711, -0.3789, -0.3789],
          ...,
          [-0.6367, -0.6367, -0.6367,  ..., -0.3867, -0.3945, -0.4023],
          [-0.6445, -0.6445, -0.6367,  ..., -0.3789, -0.3867, -0.3945],
          [-0.6523, -0.6523, -0.6445,  ..., -0.3711, -0.3789, -0.3789]],

         [[-0.3633, -0.3633, -0.3711,  ..., -0.4961, -0.5117, -0.5195],
          [-0.3867, -0.3867, -

tensor([[[[0.0078, 0.0078, 0.0078,  ..., 0.0039, 0.0000, 0.0000],
          [0.0039, 0.0039, 0.0039,  ..., 0.0000, 0.0039, 0.0039],
          [0.0039, 0.0039, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0275, 0.1725,  ..., 0.3333, 0.2941, 0.2706],
          [0.0118, 0.0941, 0.2863,  ..., 0.3255, 0.2863, 0.2510],
          [0.0784, 0.2118, 0.3882,  ..., 0.3137, 0.2471, 0.2078]],

         [[0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0039, 0.0039],
          [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
          ...,
          [0.0706, 0.1176, 0.2863,  ..., 0.6235, 0.5843, 0.5647],
          [0.1059, 0.2000, 0.4118,  ..., 0.6157, 0.5765, 0.5451],
          [0.1843, 0.3294, 0.5255,  ..., 0.6039, 0.5373, 0.5020]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0118, 0.0039, 0.0039],
          [0.0039, 0.0039, 0.0039,  ..., 0

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          ...,
          [0.0039, 0.0000, 0.0000,  ..., 0.6510, 0.6706, 0.6588],
          [0.0000, 0.0000, 0.0000,  ..., 0.6510, 0.6627, 0.6667],
          [0.0039, 0.0000, 0.0000,  ..., 0.6706, 0.6745, 0.6667]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          ...,
          [0.0039, 0.0039, 0.0078,  ..., 0.5216, 0.5412, 0.5294],
          [0.0039, 0.0039, 0.0078,  ..., 0.5216, 0.5333, 0.5373],
          [0.0039, 0.0039, 0.0078,  ..., 0.5333, 0.5373, 0.5373]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0

tensor([[[[0.4902, 0.4588, 0.4314,  ..., 0.2627, 0.2902, 0.2980],
          [0.5020, 0.4745, 0.4471,  ..., 0.2549, 0.2784, 0.2902],
          [0.5216, 0.5098, 0.4863,  ..., 0.2471, 0.2627, 0.2745],
          ...,
          [0.4549, 0.4667, 0.4745,  ..., 0.7412, 0.7373, 0.7373],
          [0.4510, 0.4627, 0.4706,  ..., 0.7529, 0.7451, 0.7451],
          [0.4471, 0.4588, 0.4706,  ..., 0.7608, 0.7529, 0.7490]],

         [[0.3608, 0.3333, 0.3137,  ..., 0.2039, 0.2314, 0.2392],
          [0.3725, 0.3490, 0.3294,  ..., 0.1961, 0.2196, 0.2275],
          [0.3922, 0.3804, 0.3686,  ..., 0.1882, 0.2039, 0.2039],
          ...,
          [0.1255, 0.1373, 0.1451,  ..., 0.5922, 0.5804, 0.5765],
          [0.1216, 0.1333, 0.1412,  ..., 0.6000, 0.5882, 0.5804],
          [0.1255, 0.1333, 0.1412,  ..., 0.6000, 0.5882, 0.5843]],

         [[0.3255, 0.2980, 0.2745,  ..., 0.1922, 0.2196, 0.2196],
          [0.3373, 0.3137, 0.2902,  ..., 0.1843, 0.2078, 0.2078],
          [0.3569, 0.3451, 0.3294,  ..., 0

tensor([[[[0.4314, 0.4314, 0.4314,  ..., 0.8431, 0.8353, 0.8314],
          [0.4314, 0.4314, 0.4314,  ..., 0.8275, 0.8196, 0.8196],
          [0.4314, 0.4314, 0.4314,  ..., 0.8118, 0.8039, 0.7961],
          ...,
          [0.4980, 0.4745, 0.4667,  ..., 0.5294, 0.5451, 0.5529],
          [0.4902, 0.4667, 0.4588,  ..., 0.5176, 0.5333, 0.5412],
          [0.4824, 0.4588, 0.4510,  ..., 0.5137, 0.5255, 0.5333]],

         [[0.3098, 0.3098, 0.3098,  ..., 0.7529, 0.7412, 0.7373],
          [0.3098, 0.3098, 0.3098,  ..., 0.7373, 0.7255, 0.7255],
          [0.3059, 0.3059, 0.3059,  ..., 0.7137, 0.7059, 0.7020],
          ...,
          [0.3373, 0.3137, 0.3098,  ..., 0.3176, 0.3294, 0.3333],
          [0.3294, 0.3059, 0.3020,  ..., 0.3098, 0.3176, 0.3216],
          [0.3216, 0.2980, 0.2941,  ..., 0.3020, 0.3098, 0.3137]],

         [[0.0980, 0.0980, 0.0980,  ..., 0.4667, 0.4588, 0.4549],
          [0.0980, 0.0980, 0.0980,  ..., 0.4549, 0.4471, 0.4431],
          [0.1059, 0.1059, 0.1059,  ..., 0

tensor([[[[1.0000, 0.9961, 0.9922,  ..., 0.9843, 0.9843, 0.9843],
          [1.0000, 1.0000, 0.9922,  ..., 0.9843, 0.9804, 0.9843],
          [1.0000, 1.0000, 0.9922,  ..., 0.9804, 0.9804, 0.9804],
          ...,
          [0.5961, 0.6000, 0.6000,  ..., 0.4667, 0.6078, 0.6980],
          [0.5804, 0.5882, 0.5882,  ..., 0.5961, 0.7176, 0.7804],
          [0.5804, 0.5765, 0.5843,  ..., 0.6784, 0.7765, 0.8235]],

         [[1.0000, 0.9961, 0.9922,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 0.9922,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 0.9922,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [0.5294, 0.5373, 0.5451,  ..., 0.2824, 0.4157, 0.5059],
          [0.5176, 0.5255, 0.5333,  ..., 0.4000, 0.5098, 0.5686],
          [0.5176, 0.5216, 0.5333,  ..., 0.4667, 0.5647, 0.6039]],

         [[0.9922, 0.9882, 0.9843,  ..., 0.9961, 1.0000, 1.0000],
          [0.9922, 0.9922, 0.9843,  ..., 1.0000, 1.0000, 1.0000],
          [0.9922, 0.9922, 0.9843,  ..., 0

tensor([[[[0.0039, 0.0039, 0.0078,  ..., 0.0000, 0.0000, 0.0000],
          [0.0039, 0.0039, 0.0039,  ..., 0.0118, 0.0118, 0.0118],
          [0.0039, 0.0039, 0.0039,  ..., 0.0235, 0.0235, 0.0235],
          ...,
          [0.1686, 0.1882, 0.2039,  ..., 0.1922, 0.1922, 0.1922],
          [0.1725, 0.1843, 0.1961,  ..., 0.1922, 0.1922, 0.1922],
          [0.1412, 0.1490, 0.1608,  ..., 0.1922, 0.1922, 0.1922]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0118, 0.0118, 0.0118],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.1961, 0.2118, 0.2275,  ..., 0.2078, 0.2078, 0.2078],
          [0.2000, 0.2118, 0.2235,  ..., 0.2078, 0.2078, 0.2078],
          [0.1765, 0.1843, 0.1961,  ..., 0.2039, 0.2039, 0.2039]],

         [[0.0000, 0.0000, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0039, 0.0039, 0.0039,  ..., 0

tensor([[[[0.0039, 0.0039, 0.0039,  ..., 0.0078, 0.0039, 0.0118],
          [0.1569, 0.1569, 0.1569,  ..., 0.1529, 0.1569, 0.1569],
          [0.7529, 0.7529, 0.7529,  ..., 0.7373, 0.7373, 0.7373],
          ...,
          [0.3333, 0.3294, 0.3176,  ..., 0.6431, 0.5765, 0.5765],
          [0.3412, 0.3294, 0.3137,  ..., 0.6745, 0.5451, 0.4431],
          [0.3529, 0.3412, 0.3255,  ..., 0.6471, 0.5333, 0.4039]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0000, 0.0000],
          [0.1490, 0.1490, 0.1490,  ..., 0.1451, 0.1451, 0.1451],
          [0.7294, 0.7294, 0.7294,  ..., 0.7216, 0.7216, 0.7216],
          ...,
          [0.3098, 0.3059, 0.2941,  ..., 0.6078, 0.5412, 0.5412],
          [0.3176, 0.3059, 0.2902,  ..., 0.6431, 0.5137, 0.4078],
          [0.3294, 0.3176, 0.3020,  ..., 0.6157, 0.5059, 0.3725]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.1412, 0.1412, 0.1412,  ..., 0.1373, 0.1373, 0.1373],
          [0.7020, 0.7020, 0.7020,  ..., 0

tensor([[[[0.9216, 0.9216, 0.9216,  ..., 0.6157, 0.6196, 0.6118],
          [0.9216, 0.9216, 0.9216,  ..., 0.6235, 0.6118, 0.5961],
          [0.9216, 0.9216, 0.9216,  ..., 0.6314, 0.6078, 0.5882],
          ...,
          [0.8510, 0.8510, 0.8510,  ..., 0.5020, 0.4980, 0.4980],
          [0.8510, 0.8510, 0.8510,  ..., 0.5020, 0.5020, 0.5020],
          [0.8510, 0.8510, 0.8510,  ..., 0.5020, 0.5020, 0.5020]],

         [[0.9529, 0.9529, 0.9529,  ..., 0.6275, 0.6353, 0.6275],
          [0.9529, 0.9529, 0.9529,  ..., 0.6353, 0.6275, 0.6118],
          [0.9529, 0.9529, 0.9529,  ..., 0.6392, 0.6196, 0.6000],
          ...,
          [0.8745, 0.8745, 0.8745,  ..., 0.4980, 0.4941, 0.4941],
          [0.8745, 0.8745, 0.8745,  ..., 0.4980, 0.4980, 0.4980],
          [0.8745, 0.8745, 0.8745,  ..., 0.4980, 0.4980, 0.4980]],

         [[0.9647, 0.9647, 0.9647,  ..., 0.5333, 0.5373, 0.5294],
          [0.9647, 0.9647, 0.9647,  ..., 0.5412, 0.5294, 0.5137],
          [0.9647, 0.9647, 0.9647,  ..., 0

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0039, 0.0118,  ..., 0.7098, 0.7686, 0.7922],
          [0.0000, 0.0000, 0.0118,  ..., 0.5255, 0.5843, 0.5961],
          [0.0000, 0.0000, 0.0118,  ..., 0.4667, 0.4784, 0.4667]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0039, 0.0118,  ..., 0.5922, 0.6510, 0.6745],
          [0.0000, 0.0000, 0.0118,  ..., 0.3922, 0.4471, 0.4510],
          [0.0000, 0.0000, 0.0118,  ..., 0.3333, 0.3255, 0.3020]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0

tensor([[[[ 0.6914,  0.6836,  0.6758,  ..., -0.7461, -0.7539, -0.7617],
          [ 0.6602,  0.6523,  0.6523,  ..., -0.7148, -0.7305, -0.7383],
          [ 0.6289,  0.6289,  0.6367,  ..., -0.6680, -0.6914, -0.7070],
          ...,
          [ 0.5898,  0.5820,  0.5742,  ...,  0.6914,  0.8242,  0.8945],
          [ 0.5820,  0.5820,  0.5820,  ...,  0.7383,  0.8477,  0.9180],
          [ 0.6133,  0.6133,  0.6133,  ...,  0.7695,  0.8555,  0.9102]],

         [[ 0.6680,  0.6602,  0.6523,  ..., -0.6211, -0.6289, -0.6367],
          [ 0.6367,  0.6289,  0.6289,  ..., -0.6055, -0.6211, -0.6289],
          [ 0.6133,  0.6133,  0.6211,  ..., -0.5820, -0.6055, -0.6133],
          ...,
          [-0.1836, -0.2148, -0.2695,  ..., -0.6211, -0.5977, -0.5820],
          [-0.3242, -0.3398, -0.3789,  ..., -0.5977, -0.5742, -0.5664],
          [-0.3867, -0.4023, -0.4414,  ..., -0.5742, -0.5664, -0.5586]],

         [[ 0.3477,  0.3398,  0.3320,  ..., -0.8242, -0.8320, -0.8398],
          [ 0.3164,  0.3164,  

tensor([[[[0.1137, 0.1216, 0.1294,  ..., 0.2392, 0.1804, 0.1490],
          [0.0941, 0.0980, 0.1020,  ..., 0.2118, 0.1569, 0.1255],
          [0.0706, 0.0784, 0.0863,  ..., 0.1804, 0.1216, 0.0863],
          ...,
          [0.1255, 0.1255, 0.1255,  ..., 0.1216, 0.1255, 0.1255],
          [0.1373, 0.1333, 0.1333,  ..., 0.1294, 0.1333, 0.1333],
          [0.1412, 0.1373, 0.1333,  ..., 0.1294, 0.1451, 0.1490]],

         [[0.1216, 0.1333, 0.1451,  ..., 0.2510, 0.1961, 0.1725],
          [0.1176, 0.1176, 0.1255,  ..., 0.2353, 0.1882, 0.1647],
          [0.1098, 0.1176, 0.1255,  ..., 0.2196, 0.1725, 0.1412],
          ...,
          [0.1529, 0.1529, 0.1529,  ..., 0.2039, 0.2078, 0.2039],
          [0.1647, 0.1608, 0.1608,  ..., 0.2039, 0.2118, 0.2078],
          [0.1686, 0.1647, 0.1608,  ..., 0.2000, 0.2196, 0.2196]],

         [[0.0667, 0.0667, 0.0745,  ..., 0.1686, 0.1216, 0.0941],
          [0.0549, 0.0471, 0.0510,  ..., 0.1529, 0.1098, 0.0824],
          [0.0353, 0.0392, 0.0431,  ..., 0

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), 'dl64b_v4_new_800.pt')

## Testing 

In [62]:
def tensor_to_binary(tensor):
    if tensor.dim() != 2:
        raise ValueException('tensor must have two dimensions')

    for i, x in enumerate(tensor):
        for j, y in enumerate(x) :
            if y.item() < 0.5:
                tensor[i][j] = 0
            elif y.item() >= 0.5:
                tensor[i][j] = 1

    return tensor

In [65]:
for i, t in train_ds:
    print(i.shape)
    faces, prob = mtcnn(to_image(i), return_prob=True)
    target_t = torch.FloatTensor([int(x) for x in valid_ds.idx_to_class[target[0].item()]])
        
    emb = face_resnet(faces) 
    data_v   = Variable(emb[0], requires_grad=False).type(dtype)
    target_v = Variable(target_t, requires_grad=False).type(dtype)

    pred = model.forward(data_v)
    print(pred.shape)
    print(tensor_to_binary(pred.unsqueeze(0)))

torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 1., 1., 1.,
         1., 0., 1., 0., 0., 1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 1., 1., 1.,
         0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 0., 1., 0.,
         0., 0., 0., 1., 1., 1., 0., 0., 0., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 1., 1., 1.,
         1., 0., 1., 0., 0., 1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 1., 1., 1.,
         0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 0., 1., 0.,
         0., 0., 0., 1., 1., 1., 0., 0., 0., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 1., 1., 1.,
         1., 0., 1., 0., 0., 1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 1., 1., 1.,
         0., 1.

torch.Size([64])
tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 1., 0.,
         0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 0., 1., 0., 1., 1., 1., 0.,
         0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 1., 1., 1., 0., 0., 1., 1.,
         1., 1., 1., 1., 1., 0., 1., 0., 0., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 1., 0.,
         0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 0., 1., 0., 1., 1., 1., 0.,
         0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 1., 1., 1., 0., 0., 1., 1.,
         1., 1., 1., 1., 1., 0., 1., 0., 0., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 1., 0.,
         0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 0., 1., 0., 1., 1., 1., 0.,
         0., 1., 0., 0., 0., 0., 1., 0., 

torch.Size([64])
tensor([[0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1.,
         1., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 1., 1., 1., 0., 1., 0., 1.,
         1., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0.,
         0., 1., 0., 0., 1., 0., 1., 1., 0., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 0., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 1., 0.,
         0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 1., 0., 1., 0., 0., 1., 1., 1.,
         0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 1., 1., 1., 0., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 0., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 1., 0.,
         0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 1., 0., 1., 0., 0., 1., 1., 1.,
         0., 0., 1., 1., 1., 1., 1., 1., 

torch.Size([64])
tensor([[0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 1., 0., 1.,
         0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0., 1., 1., 0., 0., 0., 1., 0.,
         0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0.,
         0., 1., 0., 1., 1., 0., 0., 1., 0., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0.,
         0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 1., 1., 1., 1.,
         1., 1., 1., 1., 0., 1., 0., 0., 0., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0.,
         0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 1., 

torch.Size([64])
tensor([[0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1.,
         0., 0., 0., 0., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 0., 0.,
         1., 1., 1., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 1., 1., 0., 1.,
         1., 0., 0., 1., 0., 1., 0., 1., 0., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 1.,
         1., 1., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0.,
         0., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1.,
         1., 1., 0., 1., 0., 1., 0., 0., 1., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 1.,
         1., 1., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0.,
         0., 1., 1., 0., 0., 1., 1., 1., 

torch.Size([64])
tensor([[0., 0., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1.,
         1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 0., 0.,
         0., 1., 1., 0., 1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0., 1.,
         0., 0., 1., 0., 0., 0., 1., 1., 0., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 0., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1.,
         1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 0., 0.,
         0., 1., 1., 0., 1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0., 1.,
         0., 0., 1., 0., 0., 0., 1., 1., 0., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 0., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1.,
         1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 0., 0.,
         0., 1., 1., 0., 1., 1., 1., 1., 

torch.Size([64])
tensor([[0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0.,
         1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 1., 0., 0., 1., 1.,
         1., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
         1., 0., 0., 0., 0., 1., 0., 1., 0., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0.,
         1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0.,
         1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0.,
         1., 1., 0., 1., 1., 1., 0., 1., 0., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0.,
         1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0.,
         1., 0., 0., 0., 1., 0., 0., 0., 

torch.Size([64])
tensor([[0., 1., 0., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 1., 1.,
         1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1., 1.,
         1., 0., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1.,
         0., 1., 0., 1., 0., 0., 0., 1., 0., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 1., 0., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 1., 1.,
         1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1., 1.,
         1., 0., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1.,
         0., 1., 0., 1., 0., 0., 0., 1., 0., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 1., 0., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 1., 1.,
         1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1., 1.,
         1., 0., 0., 0., 1., 1., 1., 0., 

torch.Size([64])
tensor([[0., 1., 0., 1., 1., 1., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 1., 1.,
         0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1.,
         1., 1., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1.,
         1., 1., 1., 0., 0., 1., 0., 1., 1., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 1., 0., 1., 1., 1., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 1., 1.,
         0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1.,
         1., 1., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1.,
         1., 1., 1., 0., 0., 1., 0., 1., 1., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 1., 0., 1., 1., 1., 0., 1., 1., 0., 1., 1., 0., 0., 0., 1., 0., 0.,
         1., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0.,
         1., 1., 0., 1., 1., 1., 1., 1., 

torch.Size([64])
tensor([[0., 1., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0., 1., 1., 1., 1., 0.,
         1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0., 1., 0.,
         0., 1., 0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 1., 0., 0., 1.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 1., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0., 1., 1., 1., 1., 0.,
         1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0., 1., 0.,
         0., 1., 0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 1., 0., 0., 1.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[0., 1., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0.,
         0., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 0., 0.,
         0., 1., 1., 0., 1., 0., 0., 0., 

torch.Size([64])
tensor([[1., 0., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1., 1., 0.,
         1., 0., 1., 0., 1., 1., 0., 1., 1., 1., 1., 0., 0., 1., 0., 1., 1., 0.,
         0., 0., 1., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0., 1.,
         0., 1., 1., 1., 1., 1., 1., 0., 0., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[1., 0., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1., 1., 0.,
         1., 0., 1., 0., 1., 1., 0., 1., 1., 1., 1., 0., 0., 1., 0., 1., 1., 0.,
         0., 0., 1., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0., 1.,
         0., 1., 1., 1., 1., 1., 1., 0., 0., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[1., 0., 0., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 0., 1., 0., 1., 0.,
         1., 1., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0.,
         1., 1., 1., 0., 1., 0., 0., 0., 

torch.Size([64])
tensor([[1., 0., 1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1., 1., 1.,
         1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 0.,
         1., 1., 1., 0., 0., 1., 1., 0., 1., 1., 0., 1., 0., 0., 1., 0., 1., 1.,
         1., 1., 0., 0., 1., 0., 1., 1., 1., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[1., 0., 1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1., 1., 1.,
         1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 0.,
         1., 1., 1., 0., 0., 1., 1., 0., 1., 1., 0., 1., 0., 0., 1., 0., 1., 1.,
         1., 1., 0., 0., 1., 0., 1., 1., 1., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[1., 0., 1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1., 1., 1.,
         1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 0.,
         1., 1., 1., 0., 0., 1., 1., 0., 

torch.Size([64])
tensor([[1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1., 0., 1., 0., 1., 1., 1.,
         1., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 1.,
         1., 1., 0., 0., 1., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 0., 0., 1.,
         0., 1., 0., 0., 1., 0., 1., 0., 0., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1., 0., 1., 0., 1., 1., 1.,
         1., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 1.,
         1., 1., 0., 0., 1., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 0., 0., 1.,
         0., 1., 0., 0., 1., 0., 1., 0., 0., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1., 0., 1., 0., 1., 1., 1.,
         1., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 1.,
         1., 1., 0., 0., 1., 0., 0., 0., 

torch.Size([64])
tensor([[1., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1.,
         0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.,
         0., 1., 0., 1., 0., 1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0.,
         1., 1., 1., 0., 1., 0., 0., 0., 1., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[1., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1.,
         0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.,
         0., 1., 0., 1., 0., 1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0.,
         1., 1., 1., 0., 1., 0., 0., 0., 1., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[1., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0.,
         0., 0., 1., 1., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0., 1., 1.,
         1., 0., 0., 1., 0., 0., 0., 0., 

torch.Size([64])
tensor([[1., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 1.,
         0., 1., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1.,
         0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 0., 0., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[1., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 1.,
         0., 1., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1.,
         0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 0., 0., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[1., 1., 0., 0., 1., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0.,
         1., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1.,
         1., 1., 1., 0., 0., 0., 0., 0., 

torch.Size([64])
tensor([[1., 1., 1., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0.,
         0., 0., 0., 1., 1., 1., 0., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1.,
         0., 1., 0., 0., 1., 0., 0., 1., 1., 1., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 1., 0., 1., 0., 0., 0., 0., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[1., 1., 1., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0.,
         0., 0., 0., 1., 1., 1., 0., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1.,
         0., 1., 0., 0., 1., 0., 0., 1., 1., 1., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 1., 0., 1., 0., 0., 0., 0., 0.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[1., 1., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 1., 0., 0.,
         0., 1., 1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1.,
         0., 0., 0., 1., 0., 1., 1., 1., 

torch.Size([64])
tensor([[1., 1., 1., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 1., 1., 1.,
         0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0.,
         1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 1., 1., 1., 0., 1., 1.,
         0., 1., 1., 0., 0., 0., 0., 0., 1., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[1., 1., 1., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 1., 1., 1.,
         0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0.,
         1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 1., 1., 1., 0., 1., 1.,
         0., 1., 1., 0., 0., 0., 0., 0., 1., 1.]], device='cuda:0',
       grad_fn=<AsStridedBackward>)
torch.Size([3, 224, 224])
torch.Size([64])
tensor([[1., 1., 1., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 1., 1., 1.,
         0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0.,
         1., 0., 0., 1., 0., 0., 1., 0., 

In [47]:
!ls ../../data/small_data/train/1111100010001101110100100001010010101001001010001110110110000011

Tom_Hanks_0001.jpg  Tom_Hanks_0004.jpg	Tom_Hanks_0007.jpg  Tom_Hanks_0010.jpg
Tom_Hanks_0002.jpg  Tom_Hanks_0005.jpg	Tom_Hanks_0008.jpg
Tom_Hanks_0003.jpg  Tom_Hanks_0006.jpg	Tom_Hanks_0009.jpg


In [None]:
path = '../../data/small_data/train/1111100010001101110100100001010010101001001010001110110110000011/'
for i in os.listdir(path):
        Image.open(path + i)
Image.open('../../data/small_data/train/1111100010001101110100100001010010101001001010001110110110000011/T')

In [None]:
model.forward()