The following is nowhere near the extent of combinations of things I've tried in order to get this network to train. 
I have overwritten dozens of versions of for defining tranforms and datasets, defining model architecture, and defining the training/validation loops. 

I reset the kernel before each attempt to clear out existing variables (but I wonder if a problem may exist in that perhaps resetting the kernel might not be clearing the GPU memory?)

Results have been mixed, but never good. 

Sometimes I get a few training epochs that look hopeful, before the loss suddenly explodes to very high numbers. 
Sometimes I get what look like multiple epochs with tiny improvements in the loss, before it then begins bouncing around with no progress over time, always stuck in the range of 4.8...

I have found in the latter cirumstance, when I include code to print the top class probability, that the network is making the same prediction over and over for every photo in multiple batches (anywhere from 8 - 40 batches at a time with the same single class prediction for every image), before switching to a new class being predicted for the next 8 to 40 batches, with this staying consistent through the training epochs.

Test batches, then, of course, come out the same way - a single class predicted for all images in a batch.

**While my training loops here are limited to 10 epochs, this is only so I can get this notebook created and uploaded tonight. I have run all these and many many more versions of this code with 30-100+ epochs. If it's making no progress by 10 epochs, it makes no progress by 30 either.**

It has always been clear by the 10th epoch whether or not the training is progressing properly.
No version has ever achieved above 1% accuracy in testing (which is consistent with the behavior of making the same prediction for image after image)

# Latest Model Implementing Mentor Suggestions

### Transforms, datasets and dataloaders
**Training Data Transforms:** Random Horizontal Flip, RandomRotation(30), Resize images so smallest dimension is 250, then RandomCrop to 224x224. Converted to Tensor, then Normalized with the means and stds considered optimal.

**Validation and Testing Data Transforms:** Resized to smallest dimension 250, then Center Cropped to 224x224, converted to tensor, then normalized with optimal mean/std values

batch_size - set to 8

all datasets shuffled into DataLoaders

In [1]:
import os
import numpy as np
from torchvision import datasets, transforms
import torch
import torch.nn as nn
import torch.nn.functional as F

## Define transforms, Datasets, batch_size, and Dataloaders

#
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.RandomRotation(30),
                                      transforms.Resize(250),
                                      transforms.RandomCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                           std=[0.229, 0.224, 0.225])
                                     ])

test_transform = transforms.Compose([transforms.Resize(250),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], 
                                                          [0.229, 0.224, 0.225])
                                    ])


data_dir = 'dogImages/'
train_dir = os.path.join(data_dir, 'train/')
valid_dir = os.path.join(data_dir, 'valid/')
test_dir = os.path.join(data_dir, 'test/')

train_data = datasets.ImageFolder(train_dir, transform=train_transform)
valid_data = datasets.ImageFolder(valid_dir, transform=test_transform)
test_data = datasets.ImageFolder(test_dir, transform=test_transform)

print('Num Training imgs: ', len(train_data))
print('Num Validation imgs: ', len(valid_data))
print('Num Test imgs: ', len(test_data))

batch_size = 8
num_workers = 0

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)


loaders_scratch = {'train': train_loader, 
                   'valid': valid_loader, 
                   'test': test_loader}

classes = train_data.classes

Num Training imgs:  6680
Num Validation imgs:  835
Num Test imgs:  836


## Model Architecture
* 3 Convolutional Layers (all with kernel_size=3, padding=1), 3 Max Pooling Layers (kernel=2, stride=2), 2 Fully Connected Layers
* relu activation applied to each convolutional layer output, followed by batch normalization.
* 30% dropout probability between 1st and 2nd convolutional layers
* MaxPool layer of kernel size=2, stride=2 after each convolutional layer
* ReLU non-linearity applied to output of 1st fully connected layer. 
* No dropout in FC layers.

* Cross-Entropy Loss Function
* Adam optimizer with learning rate = 0.002

In [2]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.batchnorm2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.batchnorm3 = nn.BatchNorm2d(128)
#         self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
#         self.batchnorm4 = nn.BatchNorm2d(256)
#         self.conv5 = nn.Conv2d(256, 512, 3, padding=1)
#         self.batchnorm5 = nn.BatchNorm2d(512)
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.fc1 = nn.Linear(128*28*28, 1000, bias=True)
        self.fc2 = nn.Linear(1000, 133, bias=True)
#         self.fc3 = nn.Linear(500, 133, bias=True)
        self.dropout = nn.Dropout(p=0.3)
        
    def forward(self, x):
        x = self.batchnorm1(F.relu(self.conv1(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = self.batchnorm2(F.relu(self.conv2(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = self.batchnorm3(F.relu(self.conv3(x)))
        x = self.pool(x)
#         x = self.batchnorm4(F.relu(self.conv4(x)))
#         x = self.dropout(x)
#         x = self.pool(x)
#         x = self.batchnorm5(F.relu(self.conv5(x)))
#         x = self.pool(x)
        
        x = x.view(-1, 128*28*28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x

#-#-# You do NOT have to modify the code below this line. #-#-#

# instantiate the CNN
model_scratch = Net()

# move tensors to GPU if CUDA is available
use_cuda = torch.cuda.is_available()
if use_cuda:
    model_scratch.cuda()

In [3]:
import torch.optim as optim

### TODO: select loss function
criterion_scratch = nn.CrossEntropyLoss()

### TODO: select optimizer
optimizer_scratch = optim.Adam(model_scratch.parameters(), lr=0.0025)

### Training and Validation

all very much standard according to the way the models in the program to this point have been written.

validation loop inside `with torch.no_grad():`


In [4]:
# the following import is required for training to be robust to truncated images
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


def train(n_epochs, loaders, model, optimizer, criterion, use_cuda, save_path):
    """returns trained model"""
    # initialize tracker for minimum validation loss
    valid_loss_min = np.Inf 
    
    for epoch in range(1, n_epochs+1):

        # initialize variables to monitor training and validation loss
        train_loss = 0.0
        valid_loss = 0.0
        ###################
        # train the model #
        ###################
        model.train()
        for batch_idx, (data, target) in enumerate(loaders['train']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            ## find the loss and update the model parameters accordingly
            optimizer.zero_grad()
            output = model(data) #get predictions
            loss = criterion(output, target) #calulate loss
            loss.backward()  # calculate the gradients
            optimizer.step() # perform optimization step
            train_loss += loss.item()

    
        ######################    
        # validate the model #
        ######################
        model.eval()
        for batch_idx, (data, target) in enumerate(loaders['valid']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            with torch.no_grad():
                output = model(data)
                loss = criterion(output, target)
                valid_loss += loss.item()
            
        train_loss = train_loss / len(train_loader)
        valid_loss = valid_loss / len(valid_loader)
        
            
        # print training/validation statistics 
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            epoch, 
            train_loss,
            valid_loss
            ))
        
        
        ## TODO: save the model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print('Validation Loss Decreased. Saving model')
            torch.save(model.state_dict(), save_path)
            valid_loss_min = valid_loss
    # return trained model
    return model

# train the model
model_scratch = train(25, loaders_scratch, model_scratch, optimizer_scratch, 
                      criterion_scratch, use_cuda, 'model_scratch_newest.pt')

Epoch: 1 	Training Loss: 10.118235 	Validation Loss: 4.906532
Validation Loss Decreased. Saving model
Epoch: 2 	Training Loss: 4.910345 	Validation Loss: 4.881038
Validation Loss Decreased. Saving model
Epoch: 3 	Training Loss: 4.885892 	Validation Loss: 4.869480
Validation Loss Decreased. Saving model
Epoch: 4 	Training Loss: 4.933945 	Validation Loss: 4.868146
Validation Loss Decreased. Saving model
Epoch: 5 	Training Loss: 4.866847 	Validation Loss: 4.867838
Validation Loss Decreased. Saving model
Epoch: 6 	Training Loss: 4.866725 	Validation Loss: 4.868748
Epoch: 7 	Training Loss: 4.866579 	Validation Loss: 4.870161
Epoch: 8 	Training Loss: 4.866401 	Validation Loss: 4.868974
Epoch: 9 	Training Loss: 4.866455 	Validation Loss: 4.869438
Epoch: 10 	Training Loss: 4.866441 	Validation Loss: 4.871444
Epoch: 11 	Training Loss: 4.866513 	Validation Loss: 4.869052
Epoch: 12 	Training Loss: 4.866430 	Validation Loss: 4.869843
Epoch: 13 	Training Loss: 4.866408 	Validation Loss: 4.870361
Ep

### Testing
using the code supplied by Udacity, but adding a couple print statements to view the predictions

In [5]:
def test(loaders, model, criterion, use_cuda):

    # monitor test loss and accuracy
    test_loss = 0.0
    correct = 0.0
    total = 0.0
    model.eval()
    
    for batch_idx, (data, target) in enumerate(loaders['test']):
        # move to GPU
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        with torch.no_grad():
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            print('Data: ', data)
            print('targets: ', target)
            # calculate the loss
            loss = criterion(output, target)
            # update average test loss 
            test_loss = test_loss + ((1 / (batch_idx + 1)) * (loss.data - test_loss))
            # convert output probabilities to predicted class
            pred = output.data.max(1, keepdim=True)[1]
            print('Predictions: ', pred)
            # compare predictions to true label
            correct += np.sum(np.squeeze(pred.eq(target.data.view_as(pred))).cpu().numpy())
            total += data.size(0)
            
    print('Test Loss: {:.6f}\n'.format(test_loss))

    print('\nTest Accuracy: %2d%% (%2d/%2d)' % (
        100. * correct / total, correct, total))

# call test function    
test(loaders_scratch, model_scratch, criterion_scratch, use_cuda)

Data:  tensor([[[[-0.1999, -0.3369, -0.3883,  ..., -0.1657, -0.1486, -0.1486],
          [-0.2171, -0.3369, -0.3712,  ..., -0.1999, -0.2684, -0.3541],
          [-0.2513, -0.3369, -0.3712,  ..., -0.1657, -0.2856, -0.4397],
          ...,
          [ 0.7591,  0.7762,  0.7933,  ...,  0.9817,  0.9474,  1.0502],
          [ 0.8618,  0.9132,  0.8789,  ...,  0.7591,  0.9303,  0.8961],
          [ 0.7762,  0.7762,  0.7419,  ...,  0.8104,  0.5878,  0.6392]],

         [[-0.9328, -1.0553, -1.0553,  ..., -0.8803, -0.8452, -0.8452],
          [-0.9678, -1.0553, -1.0553,  ..., -0.9153, -0.9678, -1.0553],
          [-0.9853, -1.0553, -1.0553,  ..., -0.8803, -0.9853, -1.1429],
          ...,
          [ 0.2227,  0.2577,  0.2577,  ...,  1.0980,  1.1155,  1.1681],
          [ 0.3102,  0.3803,  0.3452,  ...,  0.8880,  1.0980,  0.9930],
          [ 0.1001,  0.1176,  0.1176,  ...,  0.9405,  0.6954,  0.7129]],

         [[-1.4559, -1.4384, -1.3513,  ..., -1.4384, -1.3861, -1.3861],
          [-1.4907, -1.

Data:  tensor([[[[-0.4054,  0.3823,  0.5022,  ...,  0.7077,  1.2899,  0.3994],
          [-0.5424,  0.4508,  0.5022,  ...,  1.3070,  0.6392,  0.3652],
          [-0.4054,  0.2453,  0.3138,  ...,  0.9132,  0.4337,  0.3309],
          ...,
          [ 0.4337,  0.5707,  0.6563,  ...,  0.0398,  0.0741,  0.1597],
          [ 0.4679,  0.5878,  0.6049,  ...,  0.0741,  0.0741,  0.1426],
          [ 0.6049,  0.5707,  0.4337,  ...,  0.1939,  0.0912,  0.0912]],

         [[-0.2850,  0.5203,  0.6429,  ...,  0.5903,  1.1856,  0.2577],
          [-0.4251,  0.5903,  0.6429,  ...,  1.2206,  0.5203,  0.2227],
          [-0.2850,  0.3803,  0.4503,  ...,  0.7829,  0.2927,  0.1877],
          ...,
          [ 0.2577,  0.3978,  0.5203,  ...,  0.0126,  0.0476,  0.1352],
          [ 0.2927,  0.4153,  0.4503,  ...,  0.0476,  0.0476,  0.1176],
          [ 0.4328,  0.3978,  0.2927,  ...,  0.1702,  0.0651,  0.0651]],

         [[-0.0615,  0.7402,  0.8622,  ...,  0.5485,  1.1411,  0.2173],
          [-0.2010,  0.

Data:  tensor([[[[-0.5596, -0.3712, -0.1999,  ..., -1.4158, -1.4329, -1.4672],
          [-0.5767, -0.3883, -0.2513,  ..., -1.4329, -1.4329, -1.4500],
          [-0.5767, -0.4397, -0.3369,  ..., -1.3815, -1.4158, -1.4158],
          ...,
          [ 0.1768,  0.1939,  0.1597,  ...,  0.6392,  0.5878,  0.7419],
          [ 0.1768,  0.2796,  0.2111,  ...,  0.5364,  0.3481,  0.2453],
          [ 0.2624,  0.2624,  0.2111,  ...,  0.8789,  0.6392,  0.3481]],

         [[-0.2675, -0.0749,  0.0651,  ..., -1.1078, -1.1604, -1.1954],
          [-0.2500, -0.0924,  0.0126,  ..., -1.1779, -1.1954, -1.2654],
          [-0.2850, -0.1450, -0.0924,  ..., -1.1954, -1.2479, -1.2829],
          ...,
          [-0.1450, -0.1450, -0.1450,  ...,  0.8704,  0.8354,  0.9055],
          [-0.0574, -0.0049, -0.1099,  ...,  0.7654,  0.6254,  0.5378],
          [ 0.0476,  0.0301, -0.0924,  ...,  1.1856,  0.9055,  0.7304]],

         [[-0.6193, -0.3404, -0.2184,  ..., -1.2467, -1.2467, -1.2641],
          [-0.6890, -0.

Data:  tensor([[[[ 0.8789,  0.8961,  0.9646,  ...,  0.7248,  0.6734,  0.6049],
          [ 0.9132,  0.9303,  0.9132,  ...,  0.7419,  0.6392,  0.5536],
          [ 0.8961,  0.9132,  0.9303,  ...,  0.6221,  0.5878,  0.5022],
          ...,
          [ 1.1015,  1.1015,  1.1358,  ...,  1.6324,  1.7180,  1.9407],
          [ 1.1187,  1.1700,  1.1529,  ...,  1.6324,  1.7180,  1.7865],
          [ 1.1700,  1.1872,  1.1872,  ...,  1.5468,  1.4440,  1.5810]],

         [[ 1.0105,  1.0105,  1.0630,  ...,  0.8704,  0.8529,  0.7829],
          [ 1.0280,  1.0455,  1.0280,  ...,  0.8704,  0.7654,  0.6954],
          [ 1.0105,  1.0105,  1.0105,  ...,  0.7304,  0.6779,  0.6078],
          ...,
          [ 1.1155,  1.1331,  1.1681,  ...,  1.6758,  1.7808,  2.0434],
          [ 1.1681,  1.2206,  1.2031,  ...,  1.6758,  1.7633,  1.8508],
          [ 1.2556,  1.2731,  1.2731,  ...,  1.5707,  1.5007,  1.6583]],

         [[ 0.9494,  0.9668,  1.0365,  ...,  0.7576,  0.7402,  0.6356],
          [ 0.9668,  1.

Data:  tensor([[[[ 1.2728,  1.2557,  1.2385,  ..., -0.9705, -1.1075, -1.7412],
          [ 1.2728,  1.2728,  1.2728,  ..., -0.9534, -1.0733, -1.7069],
          [ 1.2899,  1.2899,  1.2728,  ..., -0.9363, -1.0562, -1.6042],
          ...,
          [ 0.4679,  0.4166,  0.3652,  ...,  0.3138,  0.3481,  0.3823],
          [ 0.4337,  0.4166,  0.3652,  ...,  0.3138,  0.3481,  0.3652],
          [ 0.4166,  0.3994,  0.3652,  ...,  0.2967,  0.3481,  0.3652]],

         [[ 1.2731,  1.2556,  1.2381,  ..., -0.9503, -1.0728, -1.6856],
          [ 1.2731,  1.2731,  1.2731,  ..., -0.8978, -1.0203, -1.6506],
          [ 1.2906,  1.2906,  1.2731,  ..., -0.8978, -1.0028, -1.5280],
          ...,
          [ 0.3452,  0.3277,  0.2927,  ...,  0.3978,  0.4328,  0.4678],
          [ 0.3452,  0.3277,  0.3102,  ...,  0.4153,  0.4503,  0.4678],
          [ 0.3277,  0.3102,  0.2927,  ...,  0.3978,  0.4503,  0.4678]],

         [[ 1.3851,  1.3502,  1.3328,  ..., -0.7936, -0.8807, -1.4733],
          [ 1.4025,  1.

Data:  tensor([[[[ 0.6906,  0.6906,  0.6734,  ...,  0.5536,  0.5364,  0.5022],
          [ 0.6906,  0.7077,  0.7077,  ...,  0.5364,  0.5364,  0.5364],
          [ 0.7248,  0.7591,  0.7419,  ...,  0.5364,  0.5022,  0.4851],
          ...,
          [-0.1486, -0.1314, -0.1657,  ..., -0.5082, -0.5253, -0.5253],
          [-0.1657, -0.1657, -0.1657,  ..., -0.5082, -0.4911, -0.4739],
          [-0.1657, -0.1657, -0.1486,  ..., -0.5082, -0.4739, -0.4568]],

         [[ 0.6078,  0.6078,  0.5903,  ...,  0.4503,  0.4503,  0.4153],
          [ 0.6078,  0.6254,  0.6254,  ...,  0.4328,  0.4503,  0.4503],
          [ 0.6429,  0.6779,  0.6604,  ...,  0.4328,  0.4153,  0.3978],
          ...,
          [-0.2675, -0.2675, -0.3025,  ..., -0.5651, -0.5826, -0.5826],
          [-0.2675, -0.3025, -0.3025,  ..., -0.5651, -0.5476, -0.5301],
          [-0.2675, -0.2850, -0.2675,  ..., -0.5476, -0.5301, -0.4951]],

         [[ 0.6879,  0.6705,  0.6531,  ...,  0.5136,  0.5136,  0.4788],
          [ 0.6879,  0.

Data:  tensor([[[[ 0.9474,  0.7591,  0.5536,  ...,  1.4269,  1.4098,  1.3413],
          [ 0.8961,  0.7419,  0.5878,  ...,  1.4783,  1.4440,  1.3927],
          [ 0.8961,  0.8104,  0.7248,  ...,  1.4612,  1.4440,  1.4098],
          ...,
          [ 0.2967,  0.2967,  0.2967,  ...,  0.4166,  0.0569, -0.8678],
          [ 0.2624,  0.2624,  0.2796,  ...,  0.2624,  0.2796, -0.7479],
          [ 0.2282,  0.2453,  0.2624,  ...,  0.1254,  0.4166, -0.0458]],

         [[ 1.0980,  0.9055,  0.6954,  ...,  1.5882,  1.5707,  1.5007],
          [ 1.0455,  0.8880,  0.7304,  ...,  1.6408,  1.6057,  1.5532],
          [ 1.0455,  0.9580,  0.8704,  ...,  1.6232,  1.6057,  1.5707],
          ...,
          [ 0.4328,  0.4328,  0.4328,  ...,  0.2227, -0.2500, -1.0728],
          [ 0.3978,  0.3978,  0.4153,  ..., -0.1099,  0.0126, -0.9853],
          [ 0.3627,  0.3803,  0.3978,  ..., -0.3200,  0.1001, -0.3375]],

         [[ 1.3154,  1.1237,  0.9145,  ...,  1.8034,  1.7860,  1.7163],
          [ 1.2631,  1.

       device='cuda:0')
targets:  tensor([ 63, 103,  55,  59,  19,  60,  10,  20], device='cuda:0')
Predictions:  tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[ 1.9749,  1.9920,  1.9749,  ...,  1.1187,  1.1015,  1.1015],
          [ 1.9920,  1.9920,  1.9749,  ...,  1.1187,  1.1015,  1.1015],
          [ 1.9749,  1.9920,  1.9920,  ...,  1.1187,  1.1015,  1.1015],
          ...,
          [ 1.8550,  1.8208,  1.7694,  ...,  1.8893,  1.9578,  1.8208],
          [ 1.9235,  1.7865,  1.8208,  ...,  1.6495,  1.8208,  1.8208],
          [ 1.9578,  1.8550,  1.9064,  ...,  1.6838,  1.7865,  1.7523]],

         [[ 2.1835,  2.2010,  2.1835,  ...,  1.2206,  1.2031,  1.2031],
          [ 2.2010,  2.2010,  2.1835,  ...,  1.2206,  1.2031,  1.2031],
          [ 2.1835,  2.2010,  2.2010,  ...,  1.2206,  1.2031,  1.2031],
          ...,
          [ 1.9209,  1.9034,  1.8508,  ...,  1.9384,  2.0084,  1.8508],
     

targets:  tensor([120,  31,   0,   3,  47,  20,   3,  75], device='cuda:0')
Predictions:  tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[ 1.2728,  1.2043,  1.2043,  ...,  1.8722,  1.8379,  1.8037],
          [ 1.2214,  1.1872,  1.1700,  ...,  1.8722,  1.8208,  1.7865],
          [ 1.1872,  1.1700,  1.1700,  ...,  1.8722,  1.8208,  1.7865],
          ...,
          [ 0.6049,  0.4851,  0.4166,  ...,  1.0331,  1.0673,  1.0673],
          [ 0.8276,  0.6049,  0.4851,  ...,  1.0502,  1.1358,  1.1187],
          [ 0.5536,  0.5364,  0.8447,  ...,  1.1015,  1.1358,  1.1700]],

         [[ 1.1856,  1.1506,  1.1681,  ...,  1.7108,  1.6583,  1.6232],
          [ 1.2031,  1.1506,  1.1331,  ...,  1.7458,  1.6933,  1.6232],
          [ 1.1681,  1.1331,  1.1331,  ...,  1.7458,  1.6933,  1.6232],
          ...,
          [-0.0224, -0.2150, -0.3200,  ...,  0.3277,  0.3627,  0.3627],
          [ 0.1001, -0.0574, 

targets:  tensor([116,  12,  53,  45, 114,  88,  67,   2], device='cuda:0')
Predictions:  tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[-0.8335, -0.9020, -0.9534,  ..., -0.1143, -0.1143, -0.1486],
          [-0.8335, -0.9020, -0.9534,  ..., -0.1314, -0.1314, -0.1486],
          [-0.8335, -0.9020, -0.9534,  ..., -0.1486, -0.1486, -0.1486],
          ...,
          [ 0.3994,  0.3994,  0.3481,  ...,  0.1254,  0.1254,  0.0398],
          [ 0.4337,  0.3994,  0.3994,  ...,  0.0569,  0.0227,  0.0227],
          [ 0.5193,  0.5878,  0.6392,  ...,  0.1254,  0.1597,  0.1083]],

         [[-1.0728, -1.1253, -1.1779,  ..., -0.1275, -0.1450, -0.1625],
          [-1.0728, -1.1253, -1.1779,  ..., -0.1450, -0.1450, -0.1800],
          [-1.0553, -1.1078, -1.1604,  ..., -0.1450, -0.1625, -0.1625],
          ...,
          [ 0.3627,  0.3277,  0.2577,  ...,  0.1527,  0.1352,  0.0651],
          [ 0.3277,  0.3102, 

Data:  tensor([[[[-0.3198, -0.3541, -0.3883,  ..., -0.2856, -0.3027, -0.3369],
          [-0.1999, -0.2513, -0.2856,  ..., -0.2513, -0.2684, -0.3198],
          [-0.1486, -0.1999, -0.2342,  ..., -0.1828, -0.2513, -0.3027],
          ...,
          [ 0.4508,  0.3823,  0.3823,  ..., -0.5253, -0.5938, -0.6794],
          [ 0.4166,  0.3994,  0.4166,  ..., -0.5253, -0.6109, -0.6623],
          [ 0.3994,  0.4508,  0.4679,  ..., -0.4911, -0.5596, -0.6281]],

         [[-0.3725, -0.4076, -0.4426,  ..., -0.1975, -0.2150, -0.2675],
          [-0.3025, -0.3200, -0.3375,  ..., -0.1625, -0.1800, -0.2325],
          [-0.2675, -0.2850, -0.3025,  ..., -0.1450, -0.1800, -0.2150],
          ...,
          [-0.6352, -0.6702, -0.6877,  ..., -0.3375, -0.3725, -0.4601],
          [-0.6702, -0.6527, -0.6702,  ..., -0.3375, -0.3901, -0.4426],
          [-0.6877, -0.6176, -0.6176,  ..., -0.3025, -0.3550, -0.4251]],

         [[-1.3687, -1.3861, -1.3861,  ..., -1.4559, -1.4559, -1.4733],
          [-1.3164, -1.

Data:  tensor([[[[-0.6452, -0.5082,  0.0056,  ...,  0.9303,  0.5707,  0.2796],
          [-0.5938, -0.4397, -0.0801,  ...,  0.9646,  0.6392,  0.3823],
          [-0.4054, -0.3369, -0.0287,  ...,  0.9646,  0.6734,  0.3652],
          ...,
          [ 1.0844,  1.2728,  1.2899,  ...,  1.7180,  1.4954,  1.4098],
          [ 1.1015,  1.2385,  1.2557,  ...,  1.8722,  1.6495,  1.5639],
          [ 1.2728,  1.3755,  1.3927,  ...,  2.0092,  1.8037,  1.6838]],

         [[-0.5651, -0.3375,  0.1527,  ...,  0.9580,  0.6078,  0.3277],
          [-0.4776, -0.2500,  0.1001,  ...,  1.0455,  0.6954,  0.4503],
          [-0.2675, -0.0749,  0.1877,  ...,  1.0805,  0.7829,  0.5028],
          ...,
          [ 0.8880,  1.0805,  1.0980,  ...,  1.2381,  1.0105,  0.9230],
          [ 0.9055,  1.1155,  1.1331,  ...,  1.2556,  1.0280,  0.9055],
          [ 1.1155,  1.3256,  1.3256,  ...,  1.2731,  1.0280,  0.8880]],

         [[-0.3927, -0.2010,  0.3045,  ...,  1.1062,  0.8274,  0.6182],
          [-0.2532, -0.

Data:  tensor([[[[-0.0116,  0.6049,  0.8276,  ...,  0.1768,  0.1083,  0.1768],
          [-0.1999, -0.0116,  0.4166,  ...,  0.4851,  0.4508,  0.4337],
          [-0.4739, -0.6281,  0.0398,  ...,  0.8276,  0.7762,  0.5878],
          ...,
          [ 0.6049,  0.5536,  0.5364,  ...,  0.2624,  0.2967,  0.1939],
          [ 0.5022,  0.4166,  0.5193,  ...,  0.2111,  0.2967,  0.3994],
          [ 0.3138,  0.2111,  0.5022,  ...,  0.3994,  0.3823,  0.4166]],

         [[-0.1975,  0.1176,  0.2752,  ..., -0.0924, -0.1800, -0.1800],
          [-0.4076, -0.2500,  0.0301,  ...,  0.0301,  0.0126, -0.0574],
          [-0.4426, -0.6877, -0.2850,  ...,  0.2227,  0.2227,  0.0651],
          ...,
          [-0.0399, -0.0749, -0.0749,  ..., -0.1450, -0.1800, -0.3550],
          [ 0.0476, -0.0924, -0.0574,  ..., -0.2850, -0.3200, -0.2500],
          [-0.0049, -0.2500, -0.0749,  ..., -0.1800, -0.1275, -0.0924]],

         [[-0.6890, -0.4101, -0.3927,  ..., -0.5844, -0.6018, -0.5844],
          [-0.8284, -0.

Data:  tensor([[[[ 1.4440,  1.4612,  1.4612,  ...,  2.2489,  2.2489,  2.2489],
          [ 1.4612,  1.4783,  1.4783,  ...,  2.2489,  2.2489,  2.2489],
          [ 1.4440,  1.4783,  1.4783,  ...,  2.2489,  2.2489,  2.2489],
          ...,
          [ 2.2147,  2.1804,  2.2147,  ...,  1.0159,  1.0159,  1.0331],
          [ 2.2318,  2.1975,  2.2489,  ...,  1.0159,  1.0159,  1.0673],
          [ 2.2318,  2.2489,  2.2318,  ...,  0.9988,  1.0159,  1.0502]],

         [[-0.4776, -0.4951, -0.4951,  ...,  2.0609,  2.0609,  2.0784],
          [-0.4601, -0.4776, -0.4776,  ...,  2.0434,  2.0434,  2.0434],
          [-0.4776, -0.4776, -0.4776,  ...,  2.0434,  2.0434,  2.0609],
          ...,
          [ 2.3410,  2.3410,  2.3060,  ..., -0.5651, -0.5651, -0.5826],
          [ 2.3410,  2.3235,  2.3585,  ..., -0.5651, -0.5651, -0.5476],
          [ 2.3585,  2.4111,  2.3936,  ..., -0.5301, -0.5301, -0.5826]],

         [[ 0.4962,  0.4614,  0.4788,  ...,  2.4831,  2.4831,  2.5006],
          [ 0.4788,  0.

Data:  tensor([[[[ 0.1768,  0.2453,  0.3138,  ..., -1.0219, -1.3130, -1.5357],
          [ 0.2282,  0.2967,  0.3652,  ...,  1.7865,  1.7180,  1.5468],
          [ 0.2282,  0.2967,  0.3823,  ...,  2.0948,  2.0777,  2.0948],
          ...,
          [-0.0287,  0.0227,  0.0741,  ...,  0.2282,  0.2111,  0.2282],
          [-0.0287,  0.0398,  0.1254,  ...,  0.2282,  0.2453,  0.2624],
          [ 0.0569,  0.0912,  0.1254,  ...,  0.2453,  0.2624,  0.2796]],

         [[ 0.7129,  0.7829,  0.7654,  ..., -0.7577, -0.9853, -1.1429],
          [ 0.7129,  0.7304,  0.7129,  ...,  2.0259,  1.9559,  1.7633],
          [ 0.6954,  0.7304,  0.7304,  ...,  2.3761,  2.3585,  2.3585],
          ...,
          [ 0.3627,  0.4153,  0.4678,  ..., -0.2500, -0.2675, -0.2500],
          [ 0.3627,  0.4328,  0.5203,  ..., -0.2500, -0.2325, -0.2150],
          [ 0.4503,  0.4853,  0.5203,  ..., -0.2325, -0.2150, -0.1975]],

         [[ 1.0539,  1.0539,  0.9668,  ..., -0.0964, -0.3230, -0.4624],
          [ 1.0714,  1.

Data:  tensor([[[[ 1.8208,  1.8208,  1.8208,  ...,  1.7009,  1.7009,  1.6838],
          [ 1.8208,  1.8208,  1.8208,  ...,  1.6838,  1.6667,  1.6667],
          [ 1.8208,  1.8208,  1.8208,  ...,  1.6838,  1.6667,  1.6667],
          ...,
          [-0.6623, -0.6623, -0.6109,  ..., -0.7993, -0.8164, -0.8335],
          [-0.6452, -0.5938, -0.5767,  ..., -0.8164, -0.8164, -0.8507],
          [-0.6452, -0.5938, -0.4568,  ..., -0.7993, -0.7993, -0.8335]],

         [[ 2.0434,  2.0434,  2.0434,  ...,  1.8508,  1.8508,  1.8333],
          [ 2.0434,  2.0434,  2.0434,  ...,  1.8333,  1.8158,  1.8158],
          [ 2.0434,  2.0434,  2.0434,  ...,  1.8508,  1.8333,  1.8333],
          ...,
          [ 0.7479,  0.7129,  0.6779,  ...,  0.3803,  0.3627,  0.3452],
          [ 0.7654,  0.7479,  0.6429,  ...,  0.3627,  0.3627,  0.3277],
          [ 0.7479,  0.6604,  0.6254,  ...,  0.3803,  0.3803,  0.3452]],

         [[ 2.4134,  2.4134,  2.4134,  ...,  2.1520,  2.1520,  2.1346],
          [ 2.4134,  2.

Data:  tensor([[[[-1.7412, -1.7412, -1.7240,  ..., -1.6555, -1.6384, -1.6384],
          [-1.7240, -1.7069, -1.7412,  ..., -1.6898, -1.6727, -1.6555],
          [-1.7412, -1.7240, -1.7240,  ..., -1.6898, -1.6727, -1.6727],
          ...,
          [-1.0562, -1.0904, -1.1589,  ..., -0.8678, -0.8678, -0.8507],
          [-1.0904, -1.1247, -1.1247,  ..., -0.8849, -0.9192, -0.9020],
          [-1.0904, -1.1075, -1.1418,  ..., -0.8849, -0.9192, -0.9192]],

         [[-1.4930, -1.4930, -1.5105,  ..., -1.3880, -1.3704, -1.3704],
          [-1.4930, -1.4930, -1.4930,  ..., -1.4230, -1.4055, -1.3880],
          [-1.4930, -1.4930, -1.4755,  ..., -1.4055, -1.3880, -1.4055],
          ...,
          [-0.8627, -0.8978, -0.9328,  ..., -0.6527, -0.6527, -0.6527],
          [-0.9153, -0.9503, -0.9328,  ..., -0.6702, -0.6702, -0.6352],
          [-0.8978, -0.9328, -0.9678,  ..., -0.6877, -0.7052, -0.6702]],

         [[-1.3164, -1.3164, -1.3164,  ..., -1.1944, -1.1596, -1.1596],
          [-1.3164, -1.

Data:  tensor([[[[-0.6452, -1.0562, -1.0904,  ...,  1.3070,  1.2214,  1.2043],
          [-0.4568, -1.0048, -1.2445,  ...,  0.9474,  0.8618,  0.8961],
          [-0.1828, -0.8849, -1.1589,  ...,  0.5878,  0.5193,  0.5878],
          ...,
          [-1.4843, -1.4672, -1.3987,  ..., -1.3302, -1.3302, -1.3130],
          [-1.4329, -1.3987, -1.3473,  ..., -1.2445, -1.2445, -1.2617],
          [-1.3815, -1.3815, -1.3302,  ..., -1.2103, -1.2445, -1.2445]],

         [[-0.2500, -1.0028, -1.4230,  ...,  1.5007,  1.5007,  1.4657],
          [ 0.0476, -0.7402, -1.3179,  ...,  1.2556,  1.2556,  1.2381],
          [ 0.1352, -0.7577, -1.2654,  ...,  0.9230,  0.8529,  0.8880],
          ...,
          [-1.7556, -1.7731, -1.7731,  ..., -1.3880, -1.3704, -1.3704],
          [-1.7556, -1.7556, -1.7731,  ..., -1.3529, -1.3354, -1.3529],
          [-1.7381, -1.7556, -1.7731,  ..., -1.3179, -1.3529, -1.3529]],

         [[ 0.1302, -0.5321, -0.9504,  ...,  1.1062,  1.0714,  1.0365],
          [ 0.3219, -0.

Data:  tensor([[[[ 0.7248,  0.7248,  0.7933,  ...,  0.4337,  0.4166,  0.4337],
          [ 0.6906,  0.6906,  0.7591,  ...,  0.4679,  0.3994,  0.3652],
          [ 0.7248,  0.7591,  0.7762,  ...,  0.3652,  0.3652,  0.3994],
          ...,
          [-2.1008, -2.0152, -1.9809,  ...,  1.6153,  1.2899,  1.2043],
          [-2.0837, -2.0323, -1.9809,  ...,  1.6324,  1.3413,  1.2043],
          [-1.9980, -2.0665, -2.0837,  ...,  1.6324,  1.4954,  1.2214]],

         [[ 0.6078,  0.6078,  0.6779,  ...,  0.2927,  0.2752,  0.2927],
          [ 0.5728,  0.5728,  0.6429,  ...,  0.3277,  0.2577,  0.2402],
          [ 0.6078,  0.6429,  0.6604,  ...,  0.2927,  0.2927,  0.3102],
          ...,
          [-2.0007, -1.9482, -1.8957,  ...,  1.5182,  1.1681,  1.0455],
          [-2.0182, -1.9657, -1.9132,  ...,  1.5182,  1.2206,  1.0280],
          [-1.9657, -2.0182, -2.0357,  ...,  1.5182,  1.3606,  1.0280]],

         [[ 0.1476,  0.1476,  0.2173,  ..., -0.0790, -0.0964, -0.0790],
          [ 0.1128,  0.

Data:  tensor([[[[ 0.8447,  0.8447,  0.8447,  ...,  0.9132,  0.9132,  0.9646],
          [ 0.8447,  0.8447,  0.8618,  ...,  0.9132,  0.9132,  0.9646],
          [ 0.8447,  0.8618,  0.8789,  ...,  0.9132,  0.9132,  0.9646],
          ...,
          [ 0.5707,  0.8447,  1.0331,  ...,  0.4679,  0.3994,  0.5878],
          [ 0.9132,  0.7933,  1.1358,  ...,  0.8276,  0.6392,  0.6392],
          [ 1.1358,  0.6221,  0.9646,  ...,  0.3309,  0.5878,  0.7762]],

         [[ 2.0959,  2.0959,  2.0959,  ...,  2.1835,  2.1835,  2.1660],
          [ 2.0959,  2.0959,  2.1134,  ...,  2.1835,  2.1835,  2.1660],
          [ 2.0959,  2.1134,  2.1310,  ...,  2.1835,  2.1835,  2.1660],
          ...,
          [ 0.9055,  1.2381,  1.4482,  ...,  0.7304,  0.7304,  1.0805],
          [ 1.2031,  1.1155,  1.5182,  ...,  1.0805,  0.9755,  1.1681],
          [ 1.3782,  0.8704,  1.2906,  ...,  0.5553,  0.9580,  1.3782]],

         [[ 2.6226,  2.6226,  2.6051,  ...,  2.6051,  2.6051,  2.6226],
          [ 2.6051,  2.

Data:  tensor([[[[ 2.2318,  2.2318,  2.2318,  ...,  2.2318,  2.2318,  2.2318],
          [ 2.2318,  2.2318,  2.2318,  ...,  2.2318,  2.2318,  2.2318],
          [ 2.2318,  2.2318,  2.2318,  ...,  2.2318,  2.2318,  2.2318],
          ...,
          [ 2.2318,  2.2318,  2.2318,  ...,  2.2489,  2.2318,  2.2318],
          [ 2.2318,  2.2318,  2.2318,  ...,  2.2489,  2.2318,  2.2318],
          [ 2.2318,  2.2318,  2.2318,  ...,  2.2489,  2.2318,  2.2318]],

         [[ 2.4111,  2.4111,  2.4111,  ...,  2.4111,  2.4111,  2.4111],
          [ 2.4111,  2.4111,  2.4111,  ...,  2.4111,  2.4111,  2.4111],
          [ 2.4111,  2.4111,  2.4111,  ...,  2.4111,  2.4111,  2.4111],
          ...,
          [ 2.4111,  2.4111,  2.4111,  ...,  2.4286,  2.4111,  2.4111],
          [ 2.4111,  2.4111,  2.4111,  ...,  2.4286,  2.4111,  2.4111],
          [ 2.4111,  2.4111,  2.4111,  ...,  2.4286,  2.4111,  2.4111]],

         [[ 2.6226,  2.6226,  2.6226,  ...,  2.6226,  2.6226,  2.6226],
          [ 2.6226,  2.

Data:  tensor([[[[-1.2617, -1.2103, -1.1589,  ..., -1.8439, -1.8439, -1.8439],
          [-1.3644, -1.3130, -1.2617,  ..., -1.8268, -1.8439, -1.8439],
          [-1.4158, -1.3644, -1.3302,  ..., -1.8439, -1.8610, -1.8610],
          ...,
          [-0.9534, -0.9534, -0.9534,  ...,  0.9646,  0.9474,  0.9646],
          [-1.0904, -1.1075, -1.1418,  ...,  0.9988,  0.9988,  1.0159],
          [-1.1932, -1.2274, -1.2788,  ...,  1.0502,  1.0502,  1.0159]],

         [[-1.4755, -1.4580, -1.4230,  ..., -1.7031, -1.7206, -1.7206],
          [-1.5455, -1.5105, -1.4580,  ..., -1.7031, -1.7206, -1.7206],
          [-1.5805, -1.5280, -1.4930,  ..., -1.7381, -1.7381, -1.7381],
          ...,
          [-0.9503, -0.8978, -0.8978,  ...,  0.5203,  0.5028,  0.4853],
          [-1.0378, -1.0203, -1.0553,  ...,  0.5903,  0.5553,  0.5378],
          [-1.1253, -1.1253, -1.1779,  ...,  0.6429,  0.6078,  0.5378]],

         [[-1.6127, -1.6302, -1.5953,  ..., -1.4733, -1.4733, -1.4733],
          [-1.6302, -1.

       device='cuda:0')
targets:  tensor([102,  41,  52,  35,  28,  31,  95,   5], device='cuda:0')
Predictions:  tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[ 0.8447,  0.8447,  0.8447,  ...,  0.4166,  0.3823,  0.3823],
          [ 0.8961,  0.8789,  0.8789,  ...,  0.4337,  0.4166,  0.3823],
          [ 0.9646,  0.9646,  0.9646,  ...,  0.4508,  0.4337,  0.4166],
          ...,
          [ 1.1187,  1.1358,  1.0673,  ..., -0.5082, -0.5596, -0.5253],
          [ 1.1700,  1.0844,  0.9132,  ..., -0.5082, -0.5938, -0.4568],
          [ 1.0331,  1.0331,  1.0331,  ..., -0.4739, -0.6109, -0.4911]],

         [[ 1.2906,  1.2906,  1.2906,  ...,  1.2031,  1.2206,  1.2381],
          [ 1.3256,  1.3256,  1.3256,  ...,  1.2031,  1.2206,  1.2206],
          [ 1.3606,  1.3606,  1.3606,  ...,  1.2031,  1.2206,  1.2206],
          ...,
          [ 1.1155,  1.1506,  1.0630,  ..., -1.3704, -1.4230, -1.4055],
     

Data:  tensor([[[[-0.1314, -0.2171,  0.1426,  ..., -0.3369, -0.4911, -0.5253],
          [-0.0116, -0.3369,  0.2624,  ...,  0.2111, -0.0287,  0.0227],
          [ 0.1254, -0.3541,  0.3309,  ...,  0.3309,  0.3309,  0.2453],
          ...,
          [ 0.4337, -0.0801,  0.4337,  ..., -0.6623, -1.4500, -1.9980],
          [ 0.7248, -0.0629,  0.2453,  ..., -0.1314, -1.1418, -1.8097],
          [ 0.0056, -0.4911, -0.9363,  ..., -1.1075, -1.0048, -1.4158]],

         [[ 0.0126, -0.0749,  0.2927,  ...,  0.0476, -0.0574, -0.0399],
          [ 0.1352, -0.1975,  0.4153,  ...,  0.6254,  0.4153,  0.5028],
          [ 0.2752, -0.2150,  0.4853,  ...,  0.7129,  0.7479,  0.6954],
          ...,
          [ 0.5553,  0.0651,  0.6429,  ..., -0.5826, -1.3880, -1.9482],
          [ 0.8529,  0.0826,  0.4328,  ..., -0.0399, -1.0728, -1.7556],
          [ 0.1176, -0.3550, -0.7927,  ..., -1.0378, -0.9328, -1.3529]],

         [[-0.7587, -0.8458, -0.4798,  ..., -0.9156, -1.0550, -1.0201],
          [-0.6367, -0.

Data:  tensor([[[[-0.5424, -0.1143, -0.0801,  ...,  0.2796,  0.3994,  0.4508],
          [-0.7650, -0.6452, -0.2513,  ...,  0.0398,  0.1426,  0.2796],
          [-0.7822, -0.8164, -0.5596,  ..., -0.0629,  0.0056,  0.0398],
          ...,
          [-1.4329, -1.4500, -1.2617,  ..., -0.7137, -0.7479, -0.7650],
          [-1.4329, -1.4843, -1.4329,  ..., -0.7308, -0.7137, -0.7822],
          [-1.4672, -1.4672, -1.4843,  ..., -0.7993, -0.7479, -0.8164]],

         [[-0.2850,  0.1176,  0.1352,  ...,  0.3452,  0.4853,  0.5728],
          [-0.5126, -0.3901, -0.0049,  ...,  0.1877,  0.3102,  0.4503],
          [-0.5826, -0.6176, -0.3375,  ...,  0.1352,  0.2052,  0.2402],
          ...,
          [-1.3704, -1.4230, -1.2479,  ..., -0.7752, -0.8102, -0.8277],
          [-1.3704, -1.4405, -1.4055,  ..., -0.7927, -0.7752, -0.8452],
          [-1.4055, -1.3880, -1.4230,  ..., -0.8627, -0.8102, -0.8803]],

         [[ 0.1651,  0.5311,  0.4788,  ...,  0.6879,  0.8099,  0.8622],
          [ 0.0256,  0.

Data:  tensor([[[[ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          ...,
          [ 0.4508,  0.3481,  0.5364,  ...,  1.2728,  1.0331,  0.8789],
          [ 0.6049,  0.5536,  0.6906,  ...,  1.3242,  1.1872,  1.1529],
          [ 0.6392,  0.7933,  0.6221,  ...,  1.3070,  1.0844,  1.1015]],

         [[ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          ...,
          [ 1.2381,  1.0805,  1.5707,  ...,  2.0084,  1.7283,  1.8158],
          [ 1.5357,  1.4832,  1.6933,  ...,  2.1134,  2.0609,  2.1485],
          [ 1.6933,  1.8859,  1.7108,  ...,  2.2010,  2.0609,  2.0434]],

         [[ 2.6400,  2.6400,  2.6400,  ...,  2.6400,  2.6400,  2.6400],
          [ 2.6400,  2.

Data:  tensor([[[[-1.4843, -1.6213, -1.6042,  ..., -1.3130, -1.3130, -1.3130],
          [-1.5528, -1.7240, -1.4672,  ..., -1.2959, -1.3644, -1.3302],
          [-1.5185, -1.5357, -1.4158,  ..., -1.5185, -1.6042, -1.2959],
          ...,
          [-1.4843, -1.5185, -1.3815,  ..., -1.5185, -1.2274, -1.1418],
          [-1.1760, -1.4158, -1.4500,  ..., -1.2617, -1.2445, -1.1760],
          [-1.4158, -1.4843, -1.5699,  ..., -1.4500, -1.2959, -1.4158]],

         [[-0.4776, -0.5651, -0.5476,  ..., -0.2325, -0.2675, -0.2675],
          [-0.5301, -0.6702, -0.4076,  ..., -0.2150, -0.3200, -0.3025],
          [-0.5126, -0.4951, -0.3550,  ..., -0.4426, -0.5651, -0.2675],
          ...,
          [-0.6352, -0.6702, -0.5301,  ..., -0.6877, -0.3901, -0.3550],
          [-0.3200, -0.5651, -0.6001,  ..., -0.3901, -0.4076, -0.3901],
          [-0.5651, -0.6352, -0.7227,  ..., -0.5826, -0.4426, -0.5826]],

         [[-1.3687, -1.4907, -1.4907,  ..., -1.2990, -1.2990, -1.3164],
          [-1.4210, -1.

Data:  tensor([[[[-0.9877, -1.0219, -0.9192,  ..., -1.1418, -1.2959, -1.2959],
          [-0.8164, -1.0562, -1.0904,  ..., -1.1247, -1.1418, -1.2103],
          [-0.6452, -0.9877, -1.1075,  ..., -1.3130, -1.1247, -1.0219],
          ...,
          [-0.6452, -0.5253, -0.3369,  ..., -1.2617, -1.4672, -1.5185],
          [-0.6965, -0.5596, -0.3541,  ..., -1.6555, -1.7412, -1.6898],
          [-0.5767, -0.5253, -0.5082,  ..., -1.8268, -1.8439, -1.8439]],

         [[-0.3550, -0.3550, -0.3550,  ..., -0.4251, -0.6001, -0.7052],
          [-0.1975, -0.3200, -0.3725,  ..., -0.5651, -0.4601, -0.5476],
          [-0.0924, -0.2675, -0.3375,  ..., -0.8452, -0.5476, -0.4251],
          ...,
          [-0.2150, -0.2150, -0.0399,  ..., -0.6001, -0.7577, -0.8277],
          [-0.2325, -0.2325, -0.0924,  ..., -1.1779, -1.1779, -1.1078],
          [-0.1800, -0.1800, -0.2325,  ..., -1.3880, -1.3880, -1.3529]],

         [[-1.1944, -1.1596, -1.0201,  ..., -1.0724, -1.1770, -1.2816],
          [-0.8633, -1.

Data:  tensor([[[[ 0.2453,  0.3994,  0.1597,  ..., -0.5938, -0.6109, -0.9877],
          [-0.1657, -0.5938, -0.0801,  ..., -0.5082, -0.5424, -0.5767],
          [ 0.5022,  0.1426, -0.5082,  ..., -0.7993, -0.6794, -0.5424],
          ...,
          [-0.6109, -0.5253, -1.2274,  ..., -1.1589, -1.0904, -0.8849],
          [-1.3815, -0.7479, -0.6281,  ..., -1.2788, -1.2274, -1.0219],
          [-1.0048, -0.8164, -0.5767,  ..., -1.0904, -0.9877, -1.1932]],

         [[ 0.7304,  0.7479,  0.3277,  ..., -0.0574, -0.0049, -0.5301],
          [ 0.2927, -0.3200,  0.1001,  ...,  0.0301,  0.0651, -0.0749],
          [ 0.8529,  0.5203, -0.2500,  ..., -0.2850, -0.0574, -0.0049],
          ...,
          [-0.1099, -0.0049, -0.6877,  ..., -0.6352, -0.4251, -0.1275],
          [-1.0378, -0.3725, -0.0924,  ..., -0.8277, -0.6352, -0.3901],
          [-0.4951, -0.1450,  0.1877,  ..., -0.6527, -0.4076, -0.5826]],

         [[-0.8110, -0.9156, -1.1247,  ..., -0.6193, -0.4624, -0.9504],
          [-1.1247, -1.

Data:  tensor([[[[-0.3369, -0.2856, -0.1999,  ..., -1.9467, -2.0152, -1.9638],
          [-0.3369, -0.2856, -0.1999,  ..., -1.9809, -1.9980, -1.9638],
          [-0.3198, -0.2856, -0.1999,  ..., -1.9980, -1.9638, -1.9467],
          ...,
          [-0.7993, -0.8849, -0.8678,  ..., -0.5596, -0.5253, -0.5253],
          [-0.8335, -0.9020, -0.8507,  ..., -0.6281, -0.6109, -0.5767],
          [-0.8335, -0.8849, -0.8507,  ..., -0.5938, -0.5767, -0.5767]],

         [[-0.4601, -0.4076, -0.3200,  ..., -1.7731, -1.8606, -1.8256],
          [-0.4601, -0.4076, -0.3200,  ..., -1.8431, -1.8782, -1.8606],
          [-0.4426, -0.4076, -0.3200,  ..., -1.8957, -1.8606, -1.8606],
          ...,
          [-1.2654, -1.3704, -1.3704,  ..., -0.6352, -0.6001, -0.6001],
          [-1.3004, -1.3704, -1.3529,  ..., -0.7052, -0.6877, -0.6527],
          [-1.2479, -1.3354, -1.3354,  ..., -0.6702, -0.6527, -0.6527]],

         [[-0.4624, -0.4101, -0.3230,  ..., -1.5604, -1.6302, -1.5953],
          [-0.4624, -0.

Data:  tensor([[[[ 1.1872,  0.5878,  0.0741,  ...,  0.0056,  0.1083,  0.7933],
          [ 0.8447,  0.1254, -0.6794,  ...,  0.6049,  1.0331,  1.7180],
          [ 0.6563,  0.2282, -0.0629,  ...,  1.2214,  1.6324,  2.0434],
          ...,
          [-0.0801,  0.0056, -0.4911,  ...,  0.0912,  0.0227,  0.0056],
          [-0.3198, -0.0629, -0.1999,  ...,  0.2967,  0.2453,  0.1083],
          [-0.3198,  0.0741,  0.2282,  ...,  0.5022,  0.2967, -0.0629]],

         [[ 1.3957,  0.7829,  0.2752,  ..., -0.7227, -0.3725,  0.5028],
          [ 1.0630,  0.3277, -0.4951,  ...,  0.0826,  0.8704,  1.7633],
          [ 0.8704,  0.4153,  0.1176,  ...,  1.0105,  1.6933,  2.2710],
          ...,
          [ 0.6254,  0.6429,  0.0651,  ...,  0.2402,  0.1702,  0.1527],
          [ 0.3803,  0.5903,  0.3803,  ...,  0.4503,  0.3978,  0.2577],
          [ 0.3803,  0.7479,  0.8529,  ...,  0.6604,  0.4503,  0.0826]],

         [[ 1.4374,  0.7925,  0.2348,  ..., -0.5844, -0.1661,  0.7751],
          [ 1.0714,  0.

       device='cuda:0')
targets:  tensor([ 87,  37,   4,  71,  46,  73, 113,  37], device='cuda:0')
Predictions:  tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[-0.6965, -0.5767, -0.5253,  ..., -0.4226, -0.5253, -0.4739],
          [-0.7137, -0.6794, -0.6452,  ..., -0.4568, -0.4568, -0.4054],
          [-0.6965, -0.7479, -0.6794,  ..., -0.3712, -0.2513, -0.1999],
          ...,
          [-1.1418, -1.2274, -1.2445,  ..., -1.2788, -1.1589, -0.9534],
          [-1.0048, -1.1075, -1.1075,  ..., -1.4158, -1.2103, -1.0904],
          [-0.9020, -1.0219, -1.0219,  ..., -1.0562, -1.1760, -1.3130]],

         [[-0.2325, -0.1099, -0.0574,  ..., -0.3200, -0.4076, -0.3200],
          [-0.2500, -0.2150, -0.1800,  ..., -0.3375, -0.3200, -0.2500],
          [-0.2325, -0.2675, -0.1975,  ..., -0.2150, -0.0924, -0.0399],
          ...,
          [-1.3354, -1.4055, -1.4405,  ..., -0.9678, -0.8627, -0.6702],
     

Data:  tensor([[[[ 1.0844,  1.0844,  1.0331,  ..., -0.6623, -0.6452, -0.6109],
          [ 1.0159,  0.9303,  0.9132,  ..., -0.6452, -0.6281, -0.5767],
          [ 0.9303,  0.8961,  0.8789,  ..., -0.5596, -0.5424, -0.5253],
          ...,
          [ 0.0398,  0.3309,  0.4679,  ...,  0.7591,  0.5536,  0.6221],
          [ 0.5536,  0.6049,  0.3481,  ...,  0.1597,  0.1426,  0.3823],
          [ 0.5193,  0.7591,  0.7248,  ...,  0.2453,  0.1426,  0.3309]],

         [[ 0.7829,  0.7654,  0.8179,  ..., -0.1450, -0.1450, -0.1099],
          [ 0.6429,  0.5728,  0.6429,  ..., -0.1450, -0.1275, -0.0924],
          [ 0.5553,  0.5378,  0.5903,  ..., -0.0574, -0.0399, -0.0399],
          ...,
          [ 0.2402,  0.5378,  0.6779,  ...,  0.8704,  0.6604,  0.7304],
          [ 0.7479,  0.8179,  0.6254,  ...,  0.2577,  0.2402,  0.4853],
          [ 0.7129,  0.9755,  1.0105,  ...,  0.3452,  0.2402,  0.4328]],

         [[ 0.9842,  0.9668,  0.9319,  ..., -0.6541, -0.6193, -0.5321],
          [ 0.7925,  0.

       device='cuda:0')
targets:  tensor([ 57,  77,  57,  37, 110,  81,  58,  34], device='cuda:0')
Predictions:  tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[ 0.3138,  0.2624,  0.1597,  ...,  1.0673,  1.1358,  1.0844],
          [ 0.3481,  0.2967,  0.1254,  ...,  1.0331,  1.0159,  0.9303],
          [ 0.3823,  0.3309,  0.1426,  ...,  1.1015,  1.0502,  0.8447],
          ...,
          [ 0.0227, -0.0629, -0.0629,  ...,  0.6392,  0.5878,  0.3138],
          [ 0.4166,  0.2796,  0.2967,  ...,  0.4166,  0.5022,  0.4851],
          [ 0.4166,  0.4166,  0.4166,  ...,  0.1597,  0.1768,  0.3138]],

         [[ 0.4503,  0.3978,  0.2927,  ...,  1.2556,  1.3256,  1.2731],
          [ 0.4853,  0.4328,  0.2577,  ...,  1.2206,  1.2031,  1.1155],
          [ 0.5203,  0.4678,  0.2752,  ...,  1.2906,  1.2381,  1.0280],
          ...,
          [ 0.1352,  0.0476,  0.0476,  ...,  0.7829,  0.7304,  0.4503],
     

Data:  tensor([[[[-2.0665, -2.0665, -2.0494,  ..., -0.5596, -0.5596, -0.5596],
          [-2.0494, -2.0837, -2.0665,  ..., -0.5767, -0.5767, -0.5767],
          [-2.0152, -2.0152, -2.0152,  ..., -0.5253, -0.5253, -0.5253],
          ...,
          [-1.7412, -1.7412, -1.7412,  ..., -0.6281, -0.6281, -0.6623],
          [-1.7583, -1.7412, -1.7412,  ..., -0.6452, -0.6281, -0.6452],
          [-1.7583, -1.7412, -1.7412,  ..., -0.6794, -0.6794, -0.6794]],

         [[-1.9832, -1.9832, -1.9657,  ..., -0.3725, -0.3725, -0.3725],
          [-1.9657, -2.0007, -1.9832,  ..., -0.3901, -0.3901, -0.3901],
          [-1.9482, -1.9482, -1.9482,  ..., -0.3375, -0.3375, -0.3375],
          ...,
          [-1.6506, -1.6506, -1.6506,  ..., -0.5651, -0.5651, -0.6001],
          [-1.6681, -1.6506, -1.6506,  ..., -0.5476, -0.5301, -0.5476],
          [-1.6681, -1.6506, -1.6506,  ..., -0.5826, -0.5826, -0.5826]],

         [[-1.6999, -1.7173, -1.6999,  ..., -0.1312, -0.1312, -0.1312],
          [-1.6824, -1.

Data:  tensor([[[[ 1.5297,  1.5297,  1.5125,  ...,  1.0159,  1.1015,  1.2043],
          [ 1.5468,  1.5297,  1.5468,  ...,  0.8276,  0.8447,  0.9303],
          [ 1.5468,  1.5468,  1.5639,  ...,  0.8447,  0.8104,  0.9303],
          ...,
          [ 0.0912,  0.0912,  0.1254,  ...,  0.7419,  0.6906,  0.6221],
          [ 0.0569,  0.0912,  0.1254,  ...,  0.6563,  0.7591,  0.5193],
          [ 0.0569,  0.0912,  0.1254,  ...,  0.7933,  0.7419,  0.5022]],

         [[ 1.6933,  1.6933,  1.6758,  ...,  0.8529,  0.9055,  0.9930],
          [ 1.6933,  1.6933,  1.6933,  ...,  0.6429,  0.6604,  0.7304],
          [ 1.6933,  1.6933,  1.7108,  ...,  0.6779,  0.6604,  0.7479],
          ...,
          [-0.0749, -0.0749, -0.0399,  ...,  0.5553,  0.4503,  0.3978],
          [-0.1099, -0.0749, -0.0399,  ...,  0.4853,  0.5903,  0.3452],
          [-0.1275, -0.0749, -0.0399,  ...,  0.5903,  0.5728,  0.3102]],

         [[ 1.9080,  1.9080,  1.8905,  ...,  0.6356,  0.6879,  0.8099],
          [ 1.8731,  1.

Data:  tensor([[[[-1.7069, -1.7925, -1.8610,  ..., -1.4843, -1.5014, -1.5528],
          [-1.7069, -1.7925, -1.8610,  ..., -1.4672, -1.4672, -1.5014],
          [-1.7069, -1.7925, -1.8439,  ..., -1.5528, -1.5699, -1.5870],
          ...,
          [ 0.1597,  0.1254,  0.0569,  ...,  0.3481,  0.3481,  0.2967],
          [ 0.1597,  0.1939,  0.1768,  ...,  0.4508,  0.4166,  0.4166],
          [ 0.1426,  0.1426,  0.2282,  ...,  0.4337,  0.4679,  0.5193]],

         [[-1.6506, -1.7206, -1.7731,  ..., -1.4930, -1.4930, -1.5280],
          [-1.6506, -1.7206, -1.7731,  ..., -1.4405, -1.4405, -1.4580],
          [-1.6506, -1.7206, -1.7556,  ..., -1.4930, -1.4930, -1.4930],
          ...,
          [ 0.0826,  0.0476, -0.0224,  ...,  0.2402,  0.2402,  0.1877],
          [ 0.0826,  0.1176,  0.1001,  ...,  0.3452,  0.3102,  0.3102],
          [ 0.0301,  0.0301,  0.1352,  ...,  0.3452,  0.3627,  0.4153]],

         [[-1.4036, -1.4733, -1.5430,  ..., -1.3164, -1.3339, -1.3687],
          [-1.4036, -1.

Data:  tensor([[[[-2.0152, -2.0665, -1.9980,  ..., -1.7240, -1.6898, -1.6213],
          [-1.9809, -2.0837, -2.1179,  ..., -1.8097, -1.7240, -1.6727],
          [-1.9638, -2.0494, -2.0665,  ..., -1.9467, -1.8610, -1.7583],
          ...,
          [ 0.3481,  0.4166,  0.2453,  ...,  0.3309,  0.1597,  0.1768],
          [ 0.1426,  0.1939,  0.0056,  ...,  0.3138, -0.0972,  0.0569],
          [-0.0458,  0.1597,  0.0398,  ...,  0.3652, -0.3883, -0.2856]],

         [[-1.7556, -1.8256, -1.7206,  ..., -1.0728, -1.0203, -1.0203],
          [-1.7381, -1.9307, -1.9132,  ..., -1.1779, -1.0903, -1.0903],
          [-1.7206, -1.8081, -1.8081,  ..., -1.4580, -1.3354, -1.2654],
          ...,
          [ 1.2731,  1.3431,  1.2031,  ...,  1.1681,  1.0805,  1.1155],
          [ 1.0630,  1.1155,  0.9405,  ...,  1.1856,  0.8354,  0.9930],
          [ 0.8179,  1.0630,  0.9580,  ...,  1.2556,  0.5378,  0.6429]],

         [[-1.7173, -1.7522, -1.6824,  ..., -1.4733, -1.4036, -1.3513],
          [-1.6650, -1.

Data:  tensor([[[[-0.7993, -0.7308, -0.6452,  ..., -1.7925, -1.7754, -1.7412],
          [-0.7479, -0.7308, -0.6109,  ..., -1.7412, -1.7240, -1.7754],
          [-0.6623, -0.6281, -0.5767,  ..., -1.7925, -1.8097, -1.7925],
          ...,
          [-1.7412, -1.7240, -1.7069,  ...,  0.6563,  0.6563,  0.5707],
          [-1.7583, -1.7240, -1.7412,  ...,  0.6563,  0.6221,  0.5878],
          [-1.7240, -1.6898, -1.7412,  ...,  0.9303,  0.5707,  0.6392]],

         [[-0.8102, -0.8452, -0.7577,  ..., -1.7381, -1.7031, -1.6856],
          [-0.8102, -0.8102, -0.6702,  ..., -1.7206, -1.7031, -1.6856],
          [-0.7927, -0.7752, -0.6702,  ..., -1.7031, -1.7206, -1.6856],
          ...,
          [-1.6506, -1.6856, -1.6681,  ..., -0.0924, -0.0749, -0.1099],
          [-1.7206, -1.7031, -1.7381,  ..., -0.0399, -0.0049, -0.0399],
          [-1.7206, -1.7031, -1.7381,  ...,  0.4678, -0.0049, -0.0399]],

         [[-1.0550, -1.0550, -1.0027,  ..., -1.6127, -1.6302, -1.6476],
          [-1.0724, -1.

Data:  tensor([[[[ 0.8789,  0.8447,  1.1700,  ...,  1.2899,  1.4098,  1.4612],
          [ 1.2043,  0.8618,  0.9132,  ...,  1.3070,  1.3070,  1.2214],
          [ 1.3927,  1.0502,  0.7419,  ...,  1.2899,  1.2214,  1.2214],
          ...,
          [ 0.2111,  0.1597,  0.0227,  ...,  0.8961,  1.0502,  1.4783],
          [ 0.0912,  0.0912,  0.0227,  ...,  0.7762,  0.7077,  1.2728],
          [ 0.0056,  0.0741,  0.0741,  ...,  1.8037,  1.9578,  1.9749]],

         [[ 0.8004,  0.7654,  1.0980,  ...,  0.8880,  1.0280,  1.1155],
          [ 1.1331,  0.7829,  0.8179,  ...,  0.9755,  0.9930,  0.9405],
          [ 1.3431,  0.9930,  0.6604,  ...,  1.0280,  0.9930,  1.0105],
          ...,
          [ 0.3277,  0.2927,  0.1352,  ...,  0.7829,  0.9580,  1.4307],
          [ 0.2577,  0.2577,  0.1702,  ...,  0.7304,  0.6604,  1.2206],
          [ 0.2227,  0.2927,  0.2927,  ...,  1.8683,  1.9909,  1.9909]],

         [[ 0.7228,  0.6705,  0.9842,  ...,  0.7925,  0.9319,  1.0365],
          [ 1.0365,  0.

Data:  tensor([[[[ 2.1975,  2.2147,  2.2318,  ..., -1.1075, -1.0904, -1.1247],
          [ 2.2318,  2.2147,  2.2318,  ..., -1.2445, -1.2617, -1.2788],
          [ 2.2489,  2.2147,  2.1975,  ..., -1.3815, -1.3815, -1.3987],
          ...,
          [-0.0629, -0.3369, -0.6965,  ..., -1.1760, -1.1760, -1.3302],
          [-0.1999, -0.5596, -0.9534,  ..., -1.2274, -1.3644, -1.5528],
          [-0.2684, -0.6965, -1.1075,  ..., -1.5014, -1.5870, -1.7069]],

         [[ 2.4111,  2.4111,  2.4286,  ..., -1.2829, -1.2654, -1.2829],
          [ 2.4286,  2.4111,  2.4286,  ..., -1.3704, -1.3704, -1.3880],
          [ 2.4111,  2.4111,  2.4286,  ..., -1.4930, -1.4755, -1.4930],
          ...,
          [-0.2500, -0.5126, -0.8627,  ..., -0.9678, -0.9678, -1.1429],
          [-0.3725, -0.7402, -1.1253,  ..., -1.0378, -1.1779, -1.3704],
          [-0.4426, -0.8627, -1.2654,  ..., -1.3004, -1.3880, -1.5280]],

         [[ 2.5703,  2.6400,  2.6400,  ..., -1.1421, -1.1247, -1.1421],
          [ 2.5180,  2.

Data:  tensor([[[[ 1.2728,  1.3070,  1.1529,  ...,  0.7419,  0.7248,  0.7248],
          [ 1.2385,  1.2557,  1.1187,  ...,  0.7762,  0.7591,  0.7762],
          [ 1.2043,  1.1872,  1.1015,  ...,  0.9132,  0.8447,  0.8276],
          ...,
          [-0.2171, -0.4739, -0.1999,  ...,  0.1597,  0.4508,  0.2796],
          [-1.0562, -0.6109,  0.0056,  ...,  0.4166,  0.2967,  0.0398],
          [-1.0219, -0.4397,  0.0569,  ...,  0.4166,  0.0227, -0.1486]],

         [[ 1.9384,  1.9734,  1.8158,  ...,  1.5357,  1.4832,  1.4832],
          [ 1.9034,  1.9209,  1.7808,  ...,  1.5532,  1.5182,  1.5182],
          [ 1.8859,  1.8683,  1.7633,  ...,  1.6758,  1.6057,  1.5707],
          ...,
          [ 0.2927,  0.0301,  0.2752,  ...,  0.5028,  0.9580,  0.8529],
          [-0.4951, -0.0224,  0.5378,  ...,  0.8004,  0.8179,  0.6078],
          [-0.3725,  0.2577,  0.6779,  ...,  0.8529,  0.5553,  0.3978]],

         [[ 1.2108,  1.2457,  1.0365,  ..., -0.0441,  0.0605,  0.1302],
          [ 1.1062,  1.

Data:  tensor([[[[-1.6384, -1.5185, -1.6898,  ..., -0.1828, -0.0801, -0.4739],
          [-1.6898, -1.3473, -1.6384,  ..., -0.6109, -0.1999, -0.3712],
          [-1.3302, -1.2103, -1.6727,  ..., -0.7993, -0.6794, -0.6281],
          ...,
          [-1.7925, -1.2445, -1.5357,  ..., -0.1999, -0.4568, -0.7479],
          [-1.7412, -1.3815, -1.7412,  ..., -0.4397, -0.4226, -1.0048],
          [-1.7069, -1.3987, -1.6898,  ..., -0.5596, -0.1828, -0.7308]],

         [[-1.3529, -1.1954, -1.3880,  ..., -0.0049,  0.1702, -0.1625],
          [-1.4230, -1.1604, -1.4055,  ..., -0.2675,  0.0826, -0.0224],
          [-1.2304, -1.1078, -1.4405,  ..., -0.4426, -0.3375, -0.3025],
          ...,
          [-1.5280, -0.9503, -1.3880,  ...,  0.3277, -0.0224, -0.4076],
          [-1.4930, -1.1779, -1.5455,  ...,  0.0826,  0.0476, -0.6877],
          [-1.4755, -1.1253, -1.5455,  ..., -0.0049,  0.3452, -0.4076]],

         [[-1.4907, -1.3861, -1.5953,  ..., -1.0724, -0.8981, -1.2293],
          [-1.6127, -1.

       device='cuda:0')
targets:  tensor([89, 29, 68, 23, 55, 58,  1, 45], device='cuda:0')
Predictions:  tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[-0.8507, -0.7479, -0.7822,  ..., -0.5082,  0.0398,  0.4851],
          [-1.5699, -1.6042, -1.5870,  ..., -0.3541,  0.0741,  0.7591],
          [-0.7137, -1.0048, -1.1932,  ..., -0.1657, -0.0629,  0.5193],
          ...,
          [-1.3130, -1.3473, -1.3644,  ...,  0.3652,  0.6563,  0.5878],
          [-1.2445, -1.3302, -1.2788,  ...,  0.3823,  0.6906,  0.7077],
          [-1.1932, -1.2445, -1.2445,  ...,  0.4166,  0.4851,  0.8276]],

         [[-0.3550, -0.2500, -0.2500,  ...,  0.0476,  0.5378,  0.9230],
          [-1.1253, -1.1253, -1.0903,  ...,  0.1352,  0.5028,  1.1331],
          [-0.4251, -0.7052, -0.8452,  ...,  0.3277,  0.3452,  0.8704],
          ...,
          [-0.6527, -0.6877, -0.7227,  ...,  0.1352,  0.4853,  0.4503],
          [-0

Data:  tensor([[[[ 2.2489,  2.2489,  2.2489,  ...,  1.1872,  1.1529,  1.1187],
          [ 2.2489,  2.2489,  2.2489,  ...,  1.1015,  1.1015,  1.0844],
          [ 2.2489,  2.2489,  2.2489,  ...,  1.0331,  1.0159,  0.9988],
          ...,
          [ 0.4337,  0.6049,  0.7933,  ...,  0.5536,  0.6221,  0.6392],
          [ 0.3823,  0.5707,  0.7591,  ...,  0.6049,  0.6221,  0.5707],
          [ 0.3138,  0.4851,  0.6563,  ...,  0.5707,  0.5707,  0.5707]],

         [[ 2.4286,  2.4286,  2.4286,  ...,  1.6408,  1.6057,  1.5707],
          [ 2.4286,  2.4286,  2.4286,  ...,  1.5532,  1.5357,  1.5182],
          [ 2.4286,  2.4286,  2.4286,  ...,  1.4832,  1.4482,  1.4307],
          ...,
          [ 0.9055,  1.0455,  1.2381,  ...,  0.9230,  0.9930,  1.0105],
          [ 0.8529,  1.0280,  1.2031,  ...,  0.9755,  0.9930,  0.9405],
          [ 0.8004,  0.9755,  1.1331,  ...,  0.9580,  0.9580,  0.9580]],

         [[ 2.6400,  2.6400,  2.6400,  ...,  0.6705,  0.6182,  0.6008],
          [ 2.6400,  2.

Data:  tensor([[[[ 1.4612,  1.4612,  1.4783,  ...,  1.6153,  1.6153,  1.6153],
          [ 1.4612,  1.4612,  1.4612,  ...,  1.6153,  1.6153,  1.6153],
          [ 1.4954,  1.4954,  1.4954,  ...,  1.5982,  1.5982,  1.5982],
          ...,
          [ 1.5125,  1.5125,  1.5125,  ...,  1.5982,  1.5982,  1.5982],
          [ 1.5125,  1.5125,  1.5125,  ...,  1.5982,  1.5982,  1.5982],
          [ 1.5125,  1.5125,  1.5125,  ...,  1.5982,  1.5982,  1.5982]],

         [[ 1.6933,  1.6933,  1.7108,  ...,  1.8333,  1.8508,  1.8508],
          [ 1.6933,  1.6933,  1.6933,  ...,  1.8333,  1.8508,  1.8508],
          [ 1.6583,  1.6583,  1.6933,  ...,  1.8333,  1.8333,  1.8333],
          ...,
          [ 1.7458,  1.7458,  1.7458,  ...,  1.8333,  1.8333,  1.8333],
          [ 1.7458,  1.7458,  1.7458,  ...,  1.8333,  1.8333,  1.8333],
          [ 1.7458,  1.7458,  1.7458,  ...,  1.8333,  1.8333,  1.8333]],

         [[ 1.9603,  1.9603,  1.9777,  ...,  2.1171,  2.1171,  2.1171],
          [ 1.9603,  1.

Data:  tensor([[[[-0.4397, -0.5424, -1.0048,  ...,  0.5707,  0.3138, -1.0219],
          [-0.4397, -0.6623, -0.9534,  ...,  0.3994,  0.0741, -0.9192],
          [-0.4054, -0.5767, -0.5767,  ...,  0.0912, -0.0458, -0.7650],
          ...,
          [ 0.3994,  0.3481, -0.7993,  ...,  0.9817,  1.1015,  0.6049],
          [ 0.7591,  0.1597, -0.7479,  ...,  0.0912,  0.2624,  0.9303],
          [ 0.2967,  0.4337,  0.5707,  ...,  0.0912,  0.0398,  0.6221]],

         [[-0.1450, -0.1800, -0.2325,  ...,  1.3606,  1.0630, -0.4776],
          [-0.1099, -0.3375, -0.2500,  ...,  1.2906,  0.8880, -0.3550],
          [-0.0924, -0.2325,  0.0126,  ...,  0.9755,  0.7654, -0.1625],
          ...,
          [ 0.7829,  1.0805,  0.1352,  ...,  0.9755,  1.0280,  0.4678],
          [ 1.2906,  0.7129, -0.2850,  ..., -0.0574,  0.2052,  1.1856],
          [ 0.8529,  0.8354,  0.9405,  ...,  0.0301,  0.1176,  1.0280]],

         [[-0.5147, -0.6367, -0.7936,  ...,  0.6531,  0.8099, -0.7761],
          [-0.5147, -0.

Data:  tensor([[[[ 0.0056,  0.0398,  0.0912,  ...,  1.8722,  1.9064,  1.8722],
          [ 0.6734,  0.5878,  0.5364,  ...,  1.8722,  1.9064,  1.8722],
          [ 1.4783,  1.3755,  1.3584,  ...,  1.8722,  1.9064,  1.8893],
          ...,
          [-0.6794, -0.6623, -0.6623,  ...,  0.0569, -0.1143, -0.0629],
          [-0.6965, -0.6965, -0.6965,  ...,  0.1254, -0.0972, -0.1143],
          [-0.7308, -0.7308, -0.7137,  ...,  0.1254, -0.0287, -0.1143]],

         [[-0.3375, -0.3550, -0.3025,  ...,  2.2010,  2.2360,  2.2010],
          [ 0.3452,  0.2577,  0.1877,  ...,  2.1835,  2.2185,  2.2010],
          [ 1.1856,  1.0980,  1.0630,  ...,  2.1660,  2.2010,  2.1835],
          ...,
          [-1.7556, -1.7381, -1.7206,  ..., -0.6001, -0.7752, -0.7402],
          [-1.7906, -1.8081, -1.7731,  ..., -0.5476, -0.7752, -0.7927],
          [-1.8081, -1.8431, -1.7906,  ..., -0.5476, -0.7052, -0.7927]],

         [[-0.1835, -0.1835, -0.1487,  ...,  2.5703,  2.5877,  2.5180],
          [ 0.4265,  0.

Data:  tensor([[[[-0.3027, -0.7822, -0.6623,  ..., -0.9363, -0.5596, -0.7137],
          [-0.3369, -0.7822, -1.0390,  ..., -0.7993, -0.4226, -0.5082],
          [-0.5082, -0.5767, -0.6794,  ..., -0.6965, -0.8335, -0.6794],
          ...,
          [-1.4158, -1.3473, -1.2788,  ..., -1.2617, -1.4843, -1.5357],
          [-1.3473, -1.2959, -1.4158,  ..., -0.8507, -1.1418, -1.3987],
          [-1.5185, -1.3302, -1.3473,  ..., -1.2617, -1.2103, -1.3473]],

         [[-0.6527, -1.1779, -1.0553,  ..., -1.1429, -0.7227, -1.0028],
          [-0.6001, -0.9853, -1.2829,  ..., -0.8803, -0.5301, -0.7402],
          [-0.7577, -0.7752, -0.9153,  ..., -0.8277, -1.0028, -0.8803],
          ...,
          [-1.4930, -1.4055, -1.3179,  ..., -1.2654, -1.4930, -1.5105],
          [-1.4055, -1.3880, -1.5280,  ..., -0.9153, -1.2129, -1.5105],
          [-1.5980, -1.3880, -1.4055,  ..., -1.3704, -1.3004, -1.4755]],

         [[-0.9853, -1.4210, -1.2467,  ..., -1.3513, -1.0201, -1.3687],
          [-0.8981, -1.

Data:  tensor([[[[ 1.5982,  1.5468,  1.0673,  ...,  1.5810,  1.6667,  1.8208],
          [ 1.5639,  1.5297,  0.9988,  ...,  1.8722,  1.6324,  1.8379],
          [ 1.5982,  1.4954,  1.4783,  ...,  1.5982,  1.6667,  1.9064],
          ...,
          [ 0.9132,  0.8961,  0.6734,  ..., -0.6109, -1.2274, -1.7583],
          [ 0.7591,  0.6734,  0.3823,  ..., -0.6452, -1.1932, -1.6384],
          [ 0.5536,  0.7762,  0.5878,  ..., -0.7822, -0.8335, -1.2274]],

         [[ 1.4132,  1.4132,  0.9230,  ...,  1.5182,  1.5182,  1.6933],
          [ 1.4482,  1.4657,  0.9055,  ...,  1.7983,  1.5007,  1.7633],
          [ 1.4832,  1.3782,  1.3606,  ...,  1.4832,  1.5182,  1.8508],
          ...,
          [ 0.7304,  0.7479,  0.5378,  ..., -0.7752, -1.3704, -1.7731],
          [ 0.5378,  0.4853,  0.1877,  ..., -0.8102, -1.2829, -1.6155],
          [ 0.3978,  0.6078,  0.4153,  ..., -0.9503, -0.9503, -1.2129]],

         [[ 1.0191,  1.0365,  0.5659,  ...,  0.6705,  0.7925,  1.0714],
          [ 1.1062,  1.

Data:  tensor([[[[ 1.1872,  1.1872,  1.1872,  ..., -1.6042, -1.5357, -1.5185],
          [ 1.1872,  1.1872,  1.1872,  ..., -1.5528, -1.5870, -1.5528],
          [ 1.1872,  1.1872,  1.1872,  ..., -1.5528, -1.5699, -1.5528],
          ...,
          [ 1.6153,  1.6495,  1.5639,  ...,  1.2557,  1.2557,  1.2557],
          [ 1.5982,  1.6838,  1.6667,  ...,  1.3070,  1.3070,  1.2043],
          [ 1.6838,  1.6495,  1.5810,  ...,  1.2557,  1.2043,  1.2728]],

         [[ 1.2556,  1.2556,  1.2556,  ..., -1.7906, -1.7381, -1.7381],
          [ 1.2556,  1.2556,  1.2556,  ..., -1.7381, -1.7906, -1.7556],
          [ 1.2556,  1.2556,  1.2556,  ..., -1.7381, -1.7556, -1.7381],
          ...,
          [ 1.5357,  1.5707,  1.4832,  ...,  1.0805,  1.0805,  1.0980],
          [ 1.5182,  1.6057,  1.5882,  ...,  1.1331,  1.1331,  1.0455],
          [ 1.6057,  1.5707,  1.5007,  ...,  1.0805,  1.0280,  1.1155]],

         [[ 1.3677,  1.3677,  1.3677,  ..., -1.5604, -1.5256, -1.5430],
          [ 1.3677,  1.

Data:  tensor([[[[-1.8610, -1.8439, -1.8268,  ..., -1.7583, -1.7583, -1.7069],
          [-1.8439, -1.8439, -1.8268,  ..., -1.7069, -1.7412, -1.6555],
          [-1.8439, -1.8439, -1.8268,  ..., -1.5357, -1.6384, -1.5870],
          ...,
          [ 0.6906,  0.7419,  0.7077,  ...,  0.7419,  0.5364,  0.6221],
          [ 0.6392,  0.5878,  0.5707,  ...,  0.6563,  0.7762,  0.6734],
          [ 0.2282,  0.2624,  0.4679,  ...,  0.5707,  0.7762,  0.6734]],

         [[-1.7731, -1.7381, -1.7031,  ..., -1.5980, -1.5630, -1.5455],
          [-1.7381, -1.7206, -1.7031,  ..., -1.5630, -1.5805, -1.5280],
          [-1.7206, -1.7206, -1.6856,  ..., -1.4230, -1.4930, -1.4755],
          ...,
          [ 0.7479,  0.7654,  0.6954,  ...,  0.7829,  0.5728,  0.6254],
          [ 0.6254,  0.5903,  0.5728,  ...,  0.6429,  0.8354,  0.6429],
          [ 0.3978,  0.3803,  0.5553,  ...,  0.5903,  0.7654,  0.6429]],

         [[-1.5430, -1.5081, -1.4907,  ..., -1.4210, -1.4036, -1.4036],
          [-1.5256, -1.

targets:  tensor([101,  98, 115, 128,  99,   6,  25, 111], device='cuda:0')
Predictions:  tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[-1.6555, -1.7069, -1.7754,  ..., -0.8507, -1.1589, -1.2617],
          [-1.5699, -1.6213, -1.6898,  ..., -0.8507, -1.1932, -1.2788],
          [-1.5357, -1.5699, -1.6555,  ..., -0.9363, -1.2103, -1.3130],
          ...,
          [-1.0733, -1.1075, -1.0733,  ..., -1.6727, -1.6213, -1.6727],
          [-1.1589, -1.0733, -1.1760,  ..., -1.6042, -1.5185, -1.5014],
          [-1.0733, -1.2445, -1.2788,  ..., -1.5014, -1.4672, -1.4843]],

         [[-1.5105, -1.4755, -1.4930,  ..., -1.1779, -1.4580, -1.4580],
          [-1.3529, -1.3179, -1.3529,  ..., -1.1779, -1.4580, -1.4755],
          [-1.3004, -1.3004, -1.3529,  ..., -1.2829, -1.4580, -1.4930],
          ...,
          [-1.3179, -1.3529, -1.3354,  ..., -1.5630, -1.5280, -1.5455],
          [-1.3704, -1.3529, 

Data:  tensor([[[[ 2.0263,  2.0263,  2.0434,  ...,  1.0673,  1.2043,  1.3413],
          [ 2.0263,  2.0263,  2.0434,  ...,  1.2557,  1.3755,  1.4783],
          [ 2.0263,  2.0263,  2.0434,  ...,  1.4440,  1.5297,  1.5468],
          ...,
          [ 1.7009,  1.7009,  1.7009,  ...,  0.8789,  0.8447,  0.8104],
          [ 1.7009,  1.7009,  1.7009,  ...,  0.8447,  0.8618,  0.8276],
          [ 1.7009,  1.7009,  1.7009,  ...,  0.7419,  0.8447,  0.7933]],

         [[ 1.7108,  1.7108,  1.7458,  ..., -0.0049,  0.1352,  0.2402],
          [ 1.7108,  1.7108,  1.7458,  ...,  0.1702,  0.2752,  0.3627],
          [ 1.7108,  1.7108,  1.7458,  ...,  0.3803,  0.4328,  0.4328],
          ...,
          [ 1.3606,  1.3606,  1.3606,  ...,  1.1155,  1.0805,  1.0455],
          [ 1.3606,  1.3606,  1.3606,  ...,  1.0805,  1.0980,  1.0630],
          [ 1.3606,  1.3606,  1.3606,  ...,  0.9755,  1.0805,  1.0280]],

         [[ 1.2282,  1.2282,  1.2108,  ..., -0.6367, -0.4973, -0.3753],
          [ 1.2282,  1.

Data:  tensor([[[[ 0.6906,  0.5193,  0.3309,  ...,  0.4679,  0.4679,  0.4508],
          [ 0.3994,  0.0741, -0.2513,  ...,  0.3652,  0.3652,  0.3652],
          [-0.0287, -0.4226, -0.7650,  ...,  0.3823,  0.3823,  0.3823],
          ...,
          [-1.3815, -1.3987, -1.4158,  ...,  0.3481,  0.2967,  0.3823],
          [-1.2959, -1.4158, -1.4843,  ...,  0.3994,  0.1597,  0.2282],
          [-1.2617, -1.4329, -1.5699,  ...,  0.3481,  0.1254, -0.2513]],

         [[ 0.5553,  0.3978,  0.2227,  ..., -0.2850, -0.2850, -0.2850],
          [ 0.3102, -0.0049, -0.3025,  ..., -0.3200, -0.3200, -0.3200],
          [-0.0924, -0.4601, -0.7752,  ..., -0.2500, -0.2500, -0.2500],
          ...,
          [-1.2129, -1.1954, -1.0903,  ...,  0.6604,  0.6078,  0.6954],
          [-1.1779, -1.2654, -1.2479,  ...,  0.7129,  0.4678,  0.5378],
          [-1.1779, -1.3354, -1.3704,  ...,  0.6604,  0.4328,  0.0476]],

         [[ 0.4962,  0.2871,  0.0779,  ..., -0.6890, -0.6890, -0.6890],
          [ 0.2871, -0.

Data:  tensor([[[[ 0.4166,  0.3994,  0.3994,  ..., -0.2342, -0.1657, -0.0801],
          [ 0.3994,  0.3823,  0.4166,  ..., -0.1828, -0.0972, -0.0629],
          [ 0.3994,  0.3823,  0.4166,  ..., -0.1486, -0.0629, -0.0801],
          ...,
          [ 1.8550,  1.8722,  1.9235,  ...,  1.3242,  1.3413,  1.3584],
          [ 1.8550,  1.8722,  1.9407,  ...,  1.3242,  1.2899,  1.3242],
          [ 1.8208,  1.8550,  1.9064,  ...,  1.3242,  1.3584,  1.4098]],

         [[ 0.5028,  0.4853,  0.5203,  ...,  0.1352,  0.2052,  0.2577],
          [ 0.5203,  0.5028,  0.5553,  ...,  0.1702,  0.2402,  0.2402],
          [ 0.5203,  0.5028,  0.5378,  ...,  0.1702,  0.2577,  0.2227],
          ...,
          [ 1.6583,  1.7633,  1.8333,  ...,  1.0630,  1.0805,  1.0980],
          [ 1.7108,  1.8158,  1.8859,  ...,  1.0455,  0.9930,  1.0455],
          [ 1.8158,  1.8683,  1.9034,  ...,  0.9930,  1.0280,  1.0980]],

         [[ 0.1825,  0.1651,  0.1825,  ..., -0.1487, -0.0964, -0.0790],
          [ 0.1651,  0.

Data:  tensor([[[[ 0.2111,  0.3309,  0.7762,  ...,  1.1872,  0.9988,  1.1015],
          [ 0.0227,  0.1939,  0.6563,  ...,  1.4098,  1.3755,  1.6667],
          [ 0.3309,  0.4679,  0.7933,  ...,  1.3242,  1.3755,  1.8208],
          ...,
          [ 1.2728,  1.3070,  1.3413,  ...,  1.2214,  1.2385,  1.3755],
          [ 1.3755,  1.3755,  1.3413,  ...,  1.1872,  1.2214,  1.2557],
          [ 1.5639,  1.5297,  1.3755,  ...,  1.1700,  1.1015,  1.1358]],

         [[ 0.6078,  0.6779,  1.0980,  ...,  1.4832,  1.3256,  1.4307],
          [ 0.5203,  0.5903,  0.9580,  ...,  1.7283,  1.6933,  1.9909],
          [ 0.8529,  0.8880,  1.0630,  ...,  1.6583,  1.6933,  2.0959],
          ...,
          [ 1.2556,  1.3957,  1.4307,  ...,  0.8880,  0.9930,  1.1155],
          [ 1.3256,  1.4657,  1.5532,  ...,  0.7654,  0.8704,  0.9405],
          [ 1.2556,  1.3431,  1.3256,  ...,  0.8704,  0.7829,  0.7829]],

         [[ 1.4722,  1.5071,  1.8731,  ...,  1.6988,  1.6291,  1.8034],
          [ 1.6117,  1.

Data:  tensor([[[[ 0.2453,  0.3138,  0.3652,  ..., -0.0801, -0.0972, -0.0801],
          [ 0.1083,  0.1939,  0.2282,  ...,  0.0569, -0.0116, -0.0287],
          [ 0.0912,  0.1597,  0.2111,  ...,  0.2453,  0.1768,  0.1426],
          ...,
          [-1.1589, -1.0733, -0.9192,  ...,  0.1083,  0.1939, -0.2684],
          [-1.0390, -0.8678, -0.6794,  ...,  0.3481,  0.1597, -0.1999],
          [-1.0904, -1.0390, -0.9192,  ...,  0.0056, -0.1657, -0.1828]],

         [[ 0.6779,  0.7479,  0.8004,  ...,  0.3102,  0.2577,  0.2752],
          [ 0.5553,  0.6429,  0.6779,  ...,  0.4503,  0.3452,  0.3277],
          [ 0.4678,  0.5378,  0.5903,  ...,  0.6078,  0.5378,  0.5028],
          ...,
          [-0.8102, -0.7052, -0.5476,  ...,  0.5553,  0.6429,  0.1702],
          [-0.6702, -0.4951, -0.3025,  ...,  0.8004,  0.6078,  0.2402],
          [-0.7227, -0.6702, -0.5476,  ...,  0.4503,  0.2752,  0.2577]],

         [[-0.4973, -0.4275, -0.3753,  ..., -0.9504, -0.9504, -0.8981],
          [-0.6541, -0.

Data:  tensor([[[[ 2.2489,  2.2318,  2.2147,  ...,  0.5364,  0.5022,  0.5022],
          [ 2.2489,  2.2489,  2.2318,  ...,  0.5193,  0.5536,  0.6049],
          [ 2.0092,  2.1975,  2.2318,  ...,  0.3823,  0.3994,  0.4508],
          ...,
          [ 1.6153,  0.3138, -1.0733,  ..., -1.3644, -1.7754, -1.9124],
          [ 0.6049, -0.9877, -1.8953,  ..., -1.4329, -1.7754, -1.8782],
          [-0.4911, -1.5357, -1.7925,  ..., -1.3987, -1.6042, -1.6213]],

         [[ 2.4286,  2.4111,  2.3936,  ...,  0.2577,  0.2227,  0.1877],
          [ 2.4286,  2.4286,  2.4111,  ...,  0.2402,  0.2577,  0.2927],
          [ 2.2010,  2.3936,  2.4286,  ...,  0.1176,  0.1001,  0.1352],
          ...,
          [-1.3004, -1.5630, -1.7031,  ..., -1.7731, -1.6856, -1.7206],
          [-1.3179, -1.5805, -1.8431,  ..., -1.5455, -1.4580, -1.4755],
          [-1.4405, -1.4755, -1.7206,  ..., -1.2304, -1.1779, -1.3004]],

         [[ 2.6226,  2.6226,  2.6226,  ...,  0.0082,  0.0431,  0.1128],
          [ 2.5877,  2.

       device='cuda:0')
targets:  tensor([122,  14,  79,  47,  46, 108,  59,  62], device='cuda:0')
Predictions:  tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[ 0.8618,  0.8961,  0.8789,  ...,  1.3070,  1.2557,  1.2728],
          [ 0.8961,  0.8961,  0.8618,  ...,  1.3584,  1.3413,  1.3242],
          [ 0.8961,  0.9132,  0.9303,  ...,  1.2899,  1.2899,  1.2899],
          ...,
          [ 0.5536,  0.4337,  0.3994,  ..., -0.8849, -0.9192, -0.9363],
          [ 0.5707,  0.4508,  0.4508,  ..., -0.8335, -0.8849, -0.8678],
          [ 0.6049,  0.5022,  0.5022,  ..., -0.7993, -0.8849, -0.8507]],

         [[ 0.9580,  0.9755,  0.9580,  ...,  1.3081,  1.2381,  1.2731],
          [ 0.9755,  0.9930,  0.9580,  ...,  1.3606,  1.3256,  1.3081],
          [ 0.9580,  0.9930,  1.0280,  ...,  1.3256,  1.2906,  1.2731],
          ...,
          [ 0.5728,  0.4853,  0.4678,  ..., -0.7577, -0.7927, -0.8102],
     

Data:  tensor([[[[ 1.4440,  1.4440,  1.4440,  ...,  0.0912,  0.1597,  0.1083],
          [ 1.4440,  1.4440,  1.4440,  ..., -0.2342,  0.0056,  0.0569],
          [ 1.4440,  1.4440,  1.4440,  ..., -0.6281, -0.3883, -0.3198],
          ...,
          [-0.5253, -0.5767, -0.5424,  ...,  1.6838,  1.6838,  1.5982],
          [-0.5767, -0.5938, -0.6623,  ...,  1.6153,  1.6153,  1.5810],
          [-0.6452, -0.5938, -0.7137,  ...,  1.6153,  1.6153,  1.6153]],

         [[ 1.3606,  1.3606,  1.3957,  ..., -0.2675, -0.1975, -0.2500],
          [ 1.3606,  1.3606,  1.3957,  ..., -0.6352, -0.3901, -0.3200],
          [ 1.3606,  1.3606,  1.3957,  ..., -1.0728, -0.8102, -0.7227],
          ...,
          [-0.6176, -0.6702, -0.6176,  ...,  1.6758,  1.6758,  1.6057],
          [-0.6702, -0.6877, -0.7402,  ...,  1.6057,  1.6057,  1.5882],
          [-0.7402, -0.6877, -0.7927,  ...,  1.6057,  1.6057,  1.6408]],

         [[ 0.8971,  0.8971,  0.9145,  ..., -1.1247, -1.0724, -1.1073],
          [ 0.8971,  0.

Data:  tensor([[[[-0.1999, -0.5767, -0.6965,  ..., -0.3712, -0.5424, -0.3369],
          [-0.1486, -0.4739, -0.6452,  ...,  0.0569, -0.6794, -0.9192],
          [-0.0801, -0.3883, -0.4911,  ...,  0.6049, -0.1999, -0.8849],
          ...,
          [-0.0458, -0.7650, -1.1418,  ..., -1.0390, -1.3815, -1.6727],
          [ 0.1426, -0.4226, -1.0733,  ..., -0.9534, -1.1247, -1.6042],
          [-0.0458, -0.0972, -0.3027,  ..., -0.9877, -1.0733, -1.5870]],

         [[-0.0924, -0.5126, -0.7752,  ...,  0.1527, -0.0049,  0.2227],
          [-0.0399, -0.4251, -0.7227,  ...,  0.5203, -0.2150, -0.4251],
          [ 0.0301, -0.3200, -0.5651,  ...,  0.9930,  0.2227, -0.4776],
          ...,
          [ 0.8354,  0.0826, -0.3200,  ..., -0.2150, -0.5476, -0.8452],
          [ 0.9930,  0.4153, -0.2500,  ..., -0.1275, -0.2850, -0.7752],
          [ 0.7479,  0.7304,  0.5553,  ..., -0.1625, -0.2325, -0.7577]],

         [[-0.2184, -0.6367, -0.9330,  ..., -0.8110, -0.9678, -0.7413],
          [-0.1661, -0.

Data:  tensor([[[[-0.1657, -0.4054, -0.6109,  ..., -1.2445, -1.2274, -1.3473],
          [-0.7308, -0.7479, -0.7993,  ..., -1.1418, -1.0904, -1.3130],
          [-1.0048, -1.0733, -1.1932,  ..., -0.7993, -0.9020, -1.1418],
          ...,
          [-0.5938, -0.5767, -0.4568,  ...,  0.1426,  0.1597,  0.1426],
          [-0.7650, -0.7650, -0.7137,  ..., -0.2171, -0.1999, -0.2513],
          [-0.6623, -0.6623, -0.6794,  ..., -0.4226, -0.4568, -0.4397]],

         [[-1.1078, -1.2829, -1.3880,  ..., -1.6506, -1.6331, -1.6331],
          [-1.3704, -1.3179, -1.2829,  ..., -1.7206, -1.6331, -1.7556],
          [-1.3179, -1.3704, -1.4230,  ..., -1.5805, -1.6331, -1.7556],
          ...,
          [-0.5126, -0.4951, -0.3725,  ...,  0.2577,  0.2752,  0.2752],
          [-0.6702, -0.6702, -0.6352,  ..., -0.1099, -0.0924, -0.1450],
          [-0.5126, -0.5126, -0.5301,  ..., -0.2850, -0.3025, -0.3025]],

         [[-0.7761, -0.9330, -1.0201,  ..., -1.3164, -1.4210, -1.4036],
          [-1.0201, -0.

Data:  tensor([[[[ 0.3652,  0.1939,  0.4166,  ..., -0.8164, -0.8507, -0.7650],
          [ 0.1254,  0.0398,  0.3138,  ..., -0.7822, -0.8335, -0.7993],
          [ 0.2111,  0.3309,  0.4508,  ..., -0.6965, -0.7822, -0.8335],
          ...,
          [-0.1657, -0.2513, -0.2856,  ...,  0.4679,  0.3652,  0.2111],
          [-0.2513, -0.3369, -0.5424,  ...,  0.3994,  0.2796,  0.1426],
          [-0.4226, -0.5596, -0.9020,  ...,  0.2624,  0.1768,  0.1083]],

         [[ 0.8004,  0.5903,  0.8004,  ..., -0.7402, -0.7052, -0.5476],
          [ 0.5903,  0.4678,  0.7129,  ..., -0.7227, -0.7052, -0.6001],
          [ 0.6779,  0.7829,  0.8880,  ..., -0.6527, -0.6527, -0.6176],
          ...,
          [ 0.5378,  0.4678,  0.3803,  ...,  0.5553,  0.5203,  0.4153],
          [ 0.3978,  0.3452,  0.1176,  ...,  0.5203,  0.4503,  0.3627],
          [ 0.2052,  0.0826, -0.2675,  ...,  0.3452,  0.3102,  0.3102]],

         [[ 0.3219,  0.0953,  0.2696,  ..., -0.6890, -0.7587, -0.7064],
          [ 0.0431, -0.

       device='cuda:0')
targets:  tensor([25, 78, 13, 13, 22, 97, 56, 91], device='cuda:0')
Predictions:  tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[-0.0801, -0.0801,  0.1426,  ...,  0.0227,  0.0398,  0.0398],
          [-0.1143, -0.0972,  0.1939,  ...,  0.0569,  0.1083,  0.1254],
          [-0.2171, -0.3198,  0.1083,  ...,  0.0741,  0.0227,  0.0227],
          ...,
          [ 1.6153,  1.5125,  1.3070,  ...,  0.9303,  1.0159,  1.1700],
          [ 1.6324,  1.5810,  1.4954,  ...,  1.0502,  1.1015,  1.0673],
          [ 1.7180,  1.5468,  1.2214,  ...,  0.8789,  0.8961,  0.8447]],

         [[ 0.2227,  0.1527,  0.2927,  ...,  0.2052,  0.2052,  0.1877],
          [ 0.2227,  0.1702,  0.3627,  ...,  0.2052,  0.2577,  0.2227],
          [ 0.1877,  0.0476,  0.3277,  ...,  0.1352,  0.1702, -0.0399],
          ...,
          [ 1.1506,  1.0105,  0.7654,  ...,  1.2381,  1.2906,  1.3782],
          [ 1

Data:  tensor([[[[-1.4843, -1.5357, -1.3473,  ...,  1.1700,  0.5364,  0.4679],
          [-1.3130, -1.5185, -1.4843,  ...,  1.1358,  0.5536,  0.4851],
          [-1.2788, -1.5357, -1.4500,  ...,  1.1015,  0.5536,  0.3823],
          ...,
          [ 1.4954,  1.4098,  1.2043,  ...,  1.6667,  1.8037,  1.7694],
          [ 1.6153,  1.4440,  1.2557,  ...,  1.7865,  1.8208,  1.7352],
          [ 0.8789,  1.1187,  1.6153,  ...,  1.7352,  1.8037,  1.8037]],

         [[-1.4755, -1.5805, -1.5105,  ...,  1.1331,  0.0126, -0.4951],
          [-1.3004, -1.5630, -1.6506,  ...,  1.1331,  0.0476, -0.4776],
          [-1.3354, -1.5980, -1.5980,  ...,  1.0980,  0.1702, -0.3901],
          ...,
          [ 1.8333,  1.7458,  1.5357,  ...,  1.8508,  1.9734,  1.9384],
          [ 1.9384,  1.7633,  1.5707,  ...,  1.9384,  1.9559,  1.8683],
          [ 1.1856,  1.4307,  1.9209,  ...,  1.8508,  1.9209,  1.9034]],

         [[-1.4559, -1.5604, -1.4733,  ...,  1.2282, -0.4973, -1.5779],
          [-1.2990, -1.

Data:  tensor([[[[-0.5253, -0.5596, -0.5596,  ..., -0.3369, -0.3198, -0.3198],
          [-0.5253, -0.5424, -0.5596,  ..., -0.3198, -0.3198, -0.3027],
          [-0.5253, -0.5424, -0.5424,  ..., -0.3027, -0.3027, -0.3027],
          ...,
          [-1.4500, -1.4329, -1.4158,  ..., -0.3883, -0.3883, -0.3712],
          [-1.4500, -1.4672, -1.4672,  ..., -0.3369, -0.3369, -0.3883],
          [-1.4500, -1.4843, -1.4843,  ..., -0.2856, -0.3198, -0.3883]],

         [[-0.8277, -0.8277, -0.8277,  ..., -0.6527, -0.6352, -0.6352],
          [-0.8277, -0.8277, -0.8277,  ..., -0.6352, -0.6352, -0.6176],
          [-0.8102, -0.8102, -0.8102,  ..., -0.6176, -0.6176, -0.6176],
          ...,
          [-0.6176, -0.6702, -0.6877,  ..., -0.5651, -0.5651, -0.5476],
          [-0.6527, -0.6176, -0.6352,  ..., -0.5651, -0.5826, -0.6176],
          [-0.6702, -0.6001, -0.6001,  ..., -0.5476, -0.5826, -0.6702]],

         [[-1.1944, -1.2641, -1.2641,  ..., -1.1421, -1.1247, -1.1247],
          [-1.1944, -1.

Data:  tensor([[[[-0.1314, -0.0972, -0.0458,  ...,  0.1768,  0.2111,  0.2453],
          [-0.0972, -0.0801, -0.0458,  ...,  0.2624,  0.2967,  0.2967],
          [-0.0458, -0.0458, -0.0116,  ...,  0.2967,  0.2967,  0.3138],
          ...,
          [-0.0116,  0.0227,  0.0056,  ...,  0.0227,  0.0569,  0.0056],
          [ 0.0227,  0.0398,  0.0056,  ...,  0.0056, -0.0116, -0.0458],
          [ 0.0056, -0.0116, -0.0458,  ..., -0.0629, -0.1143, -0.1143]],

         [[ 0.0126,  0.0301,  0.0826,  ..., -0.2150, -0.1800, -0.1625],
          [-0.0224, -0.0224,  0.0301,  ..., -0.1625, -0.1275, -0.1275],
          [-0.0574, -0.0749, -0.0399,  ..., -0.1275, -0.1275, -0.1099],
          ...,
          [ 0.3803,  0.4153,  0.3978,  ..., -0.1450, -0.1099, -0.1450],
          [ 0.4153,  0.4328,  0.3978,  ..., -0.1625, -0.1450, -0.1800],
          [ 0.3978,  0.3803,  0.3452,  ..., -0.2150, -0.2500, -0.2500]],

         [[ 0.3742,  0.3916,  0.4091,  ..., -0.3404, -0.3055, -0.2707],
          [ 0.2348,  0.

Data:  tensor([[[[-1.2103, -1.2103, -1.1075,  ..., -1.0390, -1.1418, -1.1589],
          [-1.2617, -1.1932, -1.1932,  ..., -1.1075, -1.1075, -1.1075],
          [-1.1932, -1.1932, -1.1589,  ..., -1.0390, -1.0904, -1.1075],
          ...,
          [-1.8610, -1.8610, -1.8782,  ..., -1.5870, -1.6042, -1.6213],
          [-1.8268, -1.8610, -1.8610,  ..., -1.6213, -1.5870, -1.5870],
          [-1.8268, -1.8439, -1.8268,  ..., -1.5699, -1.5870, -1.5870]],

         [[-0.9328, -0.9853, -0.9153,  ..., -0.9153, -0.9678, -0.9328],
          [-0.9678, -0.9678, -0.9853,  ..., -0.9678, -0.9328, -0.8978],
          [-0.9328, -0.9853, -0.9678,  ..., -0.8978, -0.9153, -0.9153],
          ...,
          [-1.6856, -1.6856, -1.6856,  ..., -1.4930, -1.4580, -1.4405],
          [-1.6506, -1.6856, -1.6856,  ..., -1.5280, -1.4580, -1.4230],
          [-1.6856, -1.7031, -1.6856,  ..., -1.4930, -1.4580, -1.4230]],

         [[-0.6541, -0.6715, -0.5844,  ..., -0.5321, -0.6193, -0.6193],
          [-0.6715, -0.

       device='cuda:0')
targets:  tensor([106,  59, 107,   9,  13,  45, 125,  57], device='cuda:0')
Predictions:  tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[ 1.4098,  1.4269,  1.3927,  ..., -0.4568, -0.5596, -0.6452],
          [ 1.3755,  1.3927,  1.3927,  ..., -0.4226, -0.5253, -0.6281],
          [ 1.4612,  1.4783,  1.4783,  ..., -0.4226, -0.5424, -0.6452],
          ...,
          [-0.7308, -0.7308, -0.6965,  ..., -1.6555, -1.6555, -1.6384],
          [-0.6965, -0.7137, -0.6794,  ..., -1.2959, -1.3130, -1.3302],
          [-0.6452, -0.6623, -0.6452,  ..., -0.8164, -0.6965, -0.5253]],

         [[ 1.6408,  1.6583,  1.6232,  ...,  0.1877,  0.1352,  0.1001],
          [ 1.6057,  1.6232,  1.6232,  ...,  0.1877,  0.1352,  0.0826],
          [ 1.6933,  1.7108,  1.7108,  ...,  0.1877,  0.1176,  0.0651],
          ...,
          [-0.2500, -0.2500, -0.2150,  ..., -1.2654, -1.2654, -1.3004],
     

Data:  tensor([[[[ 1.4440,  1.4098,  1.3927,  ...,  1.8550,  1.8379,  1.8208],
          [ 1.1872,  1.1529,  1.1187,  ...,  1.9064,  1.8550,  1.8037],
          [ 1.0673,  1.0502,  1.0502,  ...,  1.9064,  1.8379,  1.7865],
          ...,
          [ 0.0569,  0.0912,  0.2111,  ..., -0.5596, -0.4911, -0.4054],
          [-0.0287,  0.0227,  0.2111,  ..., -0.5424, -0.4397, -0.4226],
          [-0.1143, -0.0116,  0.0912,  ..., -0.5596, -0.4739, -0.4226]],

         [[ 1.8683,  1.8508,  1.7983,  ...,  1.8683,  1.8333,  1.8333],
          [ 1.5532,  1.5182,  1.4482,  ...,  1.8683,  1.8158,  1.8158],
          [ 1.3431,  1.3431,  1.3431,  ...,  1.8508,  1.8333,  1.8158],
          ...,
          [-0.1625, -0.1275,  0.0301,  ..., -0.8803, -0.8452, -0.8102],
          [-0.2150, -0.1275,  0.0651,  ..., -0.8978, -0.7927, -0.8102],
          [-0.3200, -0.2675, -0.1099,  ..., -0.9328, -0.8102, -0.7927]],

         [[ 2.0648,  2.0474,  2.0648,  ...,  1.9254,  1.8905,  1.8905],
          [ 1.6988,  1.

Data:  tensor([[[[ 0.0056,  0.1939,  0.1939,  ...,  0.4337,  0.2453,  0.7077],
          [ 0.1083,  0.1597,  0.6734,  ...,  0.5022,  0.2967,  0.1768],
          [ 0.3481,  0.4851,  0.5536,  ...,  0.7762,  0.7077,  0.3309],
          ...,
          [-1.2788, -1.1589, -1.0390,  ..., -0.5082, -0.4397, -0.4397],
          [-1.2617, -1.1589, -1.0904,  ..., -0.2171, -0.4226, -0.4226],
          [-1.3130, -1.0219, -1.0904,  ..., -0.0801, -0.4911, -0.7137]],

         [[ 0.4853,  0.7304,  0.7654,  ...,  1.0805,  1.0455,  1.4657],
          [ 0.6254,  0.7129,  1.1506,  ...,  1.0805,  0.9930,  0.9055],
          [ 0.9405,  1.1331,  1.1856,  ...,  1.3782,  1.3606,  1.1155],
          ...,
          [-1.1604, -1.0903, -0.9678,  ...,  0.0126,  0.0651, -0.1099],
          [-1.1429, -1.0728, -0.9853,  ...,  0.1702, -0.0749, -0.1275],
          [-1.2304, -1.0028, -1.0728,  ...,  0.2752, -0.2325, -0.3901]],

         [[-0.6367, -0.3404, -0.4450,  ..., -0.7761, -0.8633, -0.8110],
          [-0.4275, -0.

Data:  tensor([[[[ 2.1975,  2.2318,  2.2147,  ..., -0.7479, -0.7822, -0.7479],
          [ 2.1633,  2.1975,  2.1975,  ..., -0.7650, -0.7650, -0.6794],
          [ 2.1633,  2.1975,  2.2147,  ..., -0.6965, -0.8507, -0.8164],
          ...,
          [ 1.6838,  1.7180,  1.7523,  ...,  1.8550,  1.8208,  1.8379],
          [ 1.6667,  1.7694,  1.8037,  ...,  1.8893,  1.8379,  1.8550],
          [ 1.7180,  1.7523,  1.7694,  ...,  1.8722,  1.8893,  1.9064]],

         [[ 2.2360,  2.3060,  2.3410,  ..., -1.2829, -1.3179, -1.2829],
          [ 2.2535,  2.3060,  2.3410,  ..., -1.3354, -1.3880, -1.3529],
          [ 2.2535,  2.3235,  2.3410,  ..., -1.3529, -1.3704, -1.3529],
          ...,
          [ 1.9209,  1.9209,  1.9209,  ...,  2.0784,  2.0609,  2.0784],
          [ 1.9384,  1.9384,  1.9559,  ...,  2.0959,  2.0784,  2.0609],
          [ 1.9034,  1.8859,  1.9384,  ...,  2.0784,  2.0959,  2.1134]],

         [[ 2.3960,  2.4483,  2.5180,  ..., -1.3513, -1.3861, -1.3861],
          [ 2.3960,  2.

Data:  tensor([[[[-0.3369, -0.3369, -0.3369,  ..., -0.4911, -0.4568, -0.1999],
          [-0.3883, -0.3712, -0.4226,  ..., -1.3815, -1.2445, -0.9363],
          [-0.4568, -0.4226, -0.5082,  ..., -1.3815, -1.2617, -1.1932],
          ...,
          [-1.6042, -1.1760, -1.3644,  ..., -1.4500, -1.3644, -1.3130],
          [-1.4329, -0.8335, -0.9705,  ..., -1.4843, -1.3644, -1.2959],
          [-1.1075, -0.8164, -0.6452,  ..., -1.4672, -1.3473, -1.2959]],

         [[ 0.2227,  0.1702,  0.1001,  ..., -0.8102, -0.7927, -0.8277],
          [ 0.1527,  0.0476, -0.0574,  ..., -1.2829, -1.2129, -1.1429],
          [ 0.0301, -0.1625, -0.3025,  ..., -1.4405, -1.4580, -1.4580],
          ...,
          [-1.1954, -0.7227, -0.8102,  ..., -0.7577, -0.8102, -0.8452],
          [-0.9853, -0.3725, -0.3901,  ..., -0.8277, -0.8452, -0.8102],
          [-0.5476, -0.2500, -0.0224,  ..., -0.8627, -0.8277, -0.8277]],

         [[-0.2881, -0.5321, -0.8807,  ..., -0.5147, -0.4624, -0.2707],
          [-0.3055, -0.

Data:  tensor([[[[-0.9020, -0.9020, -0.8507,  ..., -1.0219, -1.0733, -1.0733],
          [-0.9877, -0.9363, -0.8849,  ..., -1.1589, -1.2445, -1.0904],
          [-0.9363, -0.9192, -0.8507,  ..., -1.1075, -1.0048, -0.9705],
          ...,
          [-0.8849, -0.7822, -0.8507,  ..., -1.0048, -1.1418, -1.4500],
          [-0.8678, -1.1589, -1.2274,  ..., -0.8849, -1.3987, -1.5185],
          [-1.1589, -1.2445, -1.1075,  ..., -1.0562, -1.6042, -1.3130]],

         [[-0.5476, -0.5651, -0.5126,  ..., -0.6527, -0.7052, -0.7052],
          [-0.6176, -0.5651, -0.5126,  ..., -0.7927, -0.8803, -0.7227],
          [-0.6176, -0.6001, -0.5476,  ..., -0.7402, -0.6352, -0.6001],
          ...,
          [-0.6176, -0.5126, -0.5826,  ..., -0.7402, -0.8803, -1.1954],
          [-0.6001, -0.8978, -0.9678,  ..., -0.6176, -1.1429, -1.2654],
          [-0.8978, -0.9853, -0.8277,  ..., -0.7927, -1.3529, -1.0553]],

         [[-1.2293, -1.2293, -1.1944,  ..., -1.4036, -1.4559, -1.4559],
          [-1.2816, -1.

In [1]:
import os
import numpy as np
from torchvision import datasets, transforms
import torch
import torch.nn as nn
import torch.nn.functional as F

## Define transforms, Datasets, batch_size, and Dataloaders

## While I have used various combinations of augmentation, I eventualy removed the RandomRotations and RandomResizedCrops
## after implementing batch normalization, which aocording to the paper introducing the concept, relieves the need
## for these kinds of augmentation. 
## There was no difference in training results whether using the additional augementation measures or not.

train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.Resize(250),
                                      transforms.RandomCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                           std=[0.229, 0.224, 0.225])
                                     ])

test_transform = transforms.Compose([transforms.Resize(250),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], 
                                                          [0.229, 0.224, 0.225])
                                    ])


data_dir = 'dogImages/'
train_dir = os.path.join(data_dir, 'train/')
valid_dir = os.path.join(data_dir, 'valid/')
test_dir = os.path.join(data_dir, 'test/')

train_data = datasets.ImageFolder(train_dir, transform=train_transform)
valid_data = datasets.ImageFolder(valid_dir, transform=test_transform)
test_data = datasets.ImageFolder(test_dir, transform=test_transform)

print('Num Training imgs: ', len(train_data))
print('Num Validation imgs: ', len(valid_data))
print('Num Test imgs: ', len(test_data))

batch_size = 8
num_workers = 0

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)


loaders_scratch = {'train': train_loader, 
                   'valid': valid_loader, 
                   'test': test_loader}

classes = train_data.classes

Num Training imgs:  6680
Num Validation imgs:  835
Num Test imgs:  836


In [None]:
print(classes)

Note: Should these photos really be normalized with these mean and standard deviation values if it distorts their color this way? With this normalization, pixel values go well beyond the -1, 1 range that I'd think would be optimal.

Normalizing with mean = [0.5, 0.5, 0.5] and std= [0.5, 0.5, 0.5] results in images that look much more natural - and with pixel values that are actually in the range from -1 to 0. 

I have tried both ways, and my network has failed to train either way, but just a thought...

## Architecture 1: 
* 5 Conv Layers: 32 to 512 channels - all with 3x3 kernels, stride=1 and padding=1
* Batch Normalization and ReLU activation on each Convolutional Layer
* MaxPool(2,2) after each Convolutional Layer
* 3 fully connected layers with ReLU applied to fc1 and fc2
* Dropout with probability of 50% applied after 2nd, 3rd, and 4th convolutional layer (before maxpool layers) and after 1st and 2nd fc layer

* CrossEntropyLoss (applied as log_softmax on 3rd fc layer and NLLLoss as criterion
* Adam optimizer with lr=0.0025
* Result: after 10 training epochs:

Epoch: 1 	Training Loss: 5.363417 	Validation Loss: 4.876212

Epoch: 2 	Training Loss: 4.888124 	Validation Loss: 4.870528

Epoch: 3 	Training Loss: 4.909029 	Validation Loss: 4.868946

Epoch: 4 	Training Loss: 4.878525 	Validation Loss: 4.867552

Epoch: 5 	Training Loss: 4.878072 	Validation Loss: 4.867843

Epoch: 6 	Training Loss: 4.872267 	Validation Loss: 4.868251

Epoch: 7 	Training Loss: 4.866718 	Validation Loss: 4.868318

Epoch: 8 	Training Loss: 4.867238 	Validation Loss: 4.867780

Epoch: 9 	Training Loss: 4.866567 	Validation Loss: 4.868664

Epoch: 10 	Training Loss: 4.866593 	Validation Loss: 4.868259

In [4]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.batchnorm2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.batchnorm3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.batchnorm4 = nn.BatchNorm2d(256)
        self.conv5 = nn.Conv2d(256, 512, 3, padding=1)
        self.batchnorm5 = nn.BatchNorm2d(512)
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.fc1 = nn.Linear(512*7*7, 1000, bias=True)
        self.fc2 = nn.Linear(1000, 500, bias=True)
        self.fc3 = nn.Linear(500, 133, bias=True)
        self.dropout = nn.Dropout(p=0.5)
        
    def forward(self, x):
        x = F.relu(self.batchnorm1(self.conv1(x)))
        x = self.pool(x)
        x = F.relu(self.batchnorm2(self.conv2(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = F.relu(self.batchnorm3(self.conv3(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = F.relu(self.batchnorm4(self.conv4(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = F.relu(self.batchnorm5(self.conv5(x)))
        x = self.pool(x)
        
        x = x.view(-1, 512*7*7)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = F.log_softmax(self.fc3(x), dim=1)
        return x

#-#-# You do NOT have to modify the code below this line. #-#-#

# instantiate the CNN
model_scratch = Net()

# move tensors to GPU if CUDA is available
use_cuda = torch.cuda.is_available()
if use_cuda:
    model_scratch.cuda()

In [5]:
import torch.optim as optim

### TODO: select loss function
criterion_scratch = nn.NLLLoss()

### TODO: select optimizer
optimizer_scratch = optim.Adam(model_scratch.parameters(), lr=0.0025)

In [None]:
# the following import is required for training to be robust to truncated images
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


def train(n_epochs, loaders, model, optimizer, criterion, use_cuda, save_path):
    """returns trained model"""
    # initialize tracker for minimum validation loss
    valid_loss_min = np.Inf 
    
    for epoch in range(1, n_epochs+1):

        # initialize variables to monitor training and validation loss
        train_loss = 0.0
        valid_loss = 0.0
        ###################
        # train the model #
        ###################
        model.train()
        for batch_idx, (data, target) in enumerate(loaders['train']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            ## find the loss and update the model parameters accordingly
            optimizer.zero_grad()
            output = model(data)  #output is tensor of shape([batch_size, num_classes]) where largest value is the prediction
            loss = criterion(output, target) # loss is the cross-entropy loss which measures how far the prediction is from the actual target
            loss.backward()  # calculating the gradients for all operations
            optimizer.step() #performing gradient descent step
            train_loss += loss.item()

    
        ######################    
        # validate the model #
        ######################
        model.eval()
        for batch_idx, (data, target) in enumerate(loaders['valid']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            with torch.no_grad():
                output = model(data)
                loss = criterion(output, target)
                valid_loss += loss.item()
            
        train_loss = train_loss / len(train_loader)
        valid_loss = valid_loss / len(valid_loader)
        
            
        # print training/validation statistics 
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            epoch, 
            train_loss,
            valid_loss
            ))
        
        
        ## TODO: save the model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print('Validation Loss Decreased. Saving model')
            torch.save(model.state_dict(), save_path)
            valid_loss_min = valid_loss
    # return trained model
    return model

# train the model
model_scratch = train(10, loaders_scratch, model_scratch, optimizer_scratch, 
                      criterion_scratch, use_cuda, 'model_scratch.pt')

## Architecture 2:
* Same convolutional layers, batch norm, activation, and max pool layers as above
* only 2 fully connected layers 
* dropout at .3
* changed loss function to actual nn.CrossEntropyLoss (and removed non-linearity from last fully connected layer)
* learning rate increased to .025
* Same training/validation loop as above

### Results after 10 training epochs:
Epoch: 1 	Training Loss: 20.446863 	Validation Loss: 4.880278

Epoch: 2 	Training Loss: 4.894538 	Validation Loss: 4.889751

Epoch: 3 	Training Loss: 4.894520 	Validation Loss: 4.891614

Epoch: 4 	Training Loss: 4.894344 	Validation Loss: 4.887182

Epoch: 5 	Training Loss: 4.925039 	Validation Loss: 4.884497

Epoch: 6 	Training Loss: 4.895196 	Validation Loss: 4.884095

Epoch: 7 	Training Loss: 4.901975 	Validation Loss: 4.882802

Epoch: 8 	Training Loss: 4.895335 	Validation Loss: 4.886300

Epoch: 9 	Training Loss: 4.894879 	Validation Loss: 4.884101

Epoch: 10 	Training Loss: 4.894126 	Validation Loss: 4.884886


In [5]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.batchnorm2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.batchnorm3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.batchnorm4 = nn.BatchNorm2d(256)
        self.conv5 = nn.Conv2d(256, 512, 3, padding=1)
        self.batchnorm5 = nn.BatchNorm2d(512)
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.fc1 = nn.Linear(512*7*7, 1000, bias=True)
        self.fc2 = nn.Linear(1000, 133, bias=True)
#         self.fc3 = nn.Linear(500, 133, bias=True)
        self.dropout = nn.Dropout(p=0.3)
        
    def forward(self, x):
        x = F.relu(self.batchnorm1(self.conv1(x)))
        x = self.pool(x)
        x = F.relu(self.batchnorm2(self.conv2(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = F.relu(self.batchnorm3(self.conv3(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = F.relu(self.batchnorm4(self.conv4(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = F.relu(self.batchnorm5(self.conv5(x)))
        x = self.pool(x)
        
        x = x.view(-1, 512*7*7)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

#-#-# You do NOT have to modify the code below this line. #-#-#

# instantiate the CNN
model_scratch = Net()

# move tensors to GPU if CUDA is available
use_cuda = torch.cuda.is_available()
if use_cuda:
    model_scratch.cuda()

In [8]:
import torch.optim as optim

### TODO: select loss function
criterion_scratch = nn.CrossEntropyLoss()

### TODO: select optimizer
optimizer_scratch = optim.Adam(model_scratch.parameters(), lr=0.025)

In [9]:
# the following import is required for training to be robust to truncated images
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


def train(n_epochs, loaders, model, optimizer, criterion, use_cuda, save_path):
    """returns trained model"""
    # initialize tracker for minimum validation loss
    valid_loss_min = np.Inf 
    
    for epoch in range(1, n_epochs+1):

        # initialize variables to monitor training and validation loss
        train_loss = 0.0
        valid_loss = 0.0
        ###################
        # train the model #
        ###################
        model.train()
        for batch_idx, (data, target) in enumerate(loaders['train']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            ## find the loss and update the model parameters accordingly
            optimizer.zero_grad()
            output = model(data) #get predictions
            loss = criterion(output, target) #calulate loss
            loss.backward()  # calculate the gradients
            optimizer.step() # perform optimization step
            train_loss += loss.item()

    
        ######################    
        # validate the model #
        ######################
        model.eval()
        for batch_idx, (data, target) in enumerate(loaders['valid']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            with torch.no_grad():
                output = model(data)
                loss = criterion(output, target)
                valid_loss += loss.item()
            
        train_loss = train_loss / len(train_loader)
        valid_loss = valid_loss / len(valid_loader)
        
            
        # print training/validation statistics 
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            epoch, 
            train_loss,
            valid_loss
            ))
        
        
        ## TODO: save the model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print('Validation Loss Decreased. Saving model')
            torch.save(model.state_dict(), save_path)
            valid_loss_min = valid_loss
    # return trained model
    return model

# train the model
model_scratch = train(10, loaders_scratch, model_scratch, optimizer_scratch, 
                      criterion_scratch, use_cuda, 'model_scratch.pt')

Epoch: 1 	Training Loss: 20.446863 	Validation Loss: 4.880278
Validation Loss Decreased. Saving model
Epoch: 2 	Training Loss: 4.894538 	Validation Loss: 4.889751
Epoch: 3 	Training Loss: 4.894520 	Validation Loss: 4.891614
Epoch: 4 	Training Loss: 4.894344 	Validation Loss: 4.887182
Epoch: 5 	Training Loss: 4.925039 	Validation Loss: 4.884497
Epoch: 6 	Training Loss: 4.895196 	Validation Loss: 4.884095
Epoch: 7 	Training Loss: 4.901975 	Validation Loss: 4.882802
Epoch: 8 	Training Loss: 4.895335 	Validation Loss: 4.886300
Epoch: 9 	Training Loss: 4.894879 	Validation Loss: 4.884101
Epoch: 10 	Training Loss: 4.894126 	Validation Loss: 4.884886


The fact that I consistently get one epoch where the training loss shows obvious progress followed by a bunch of minor shifts back and forth seems to indicate a problem with my training loop - as though the gradients get turned off in the validation loop, and are never turned back on. I will run one more version only changing my architecture, then will run each of these three models again with a change to my training code (that seems like it should be unnecessary) to ensure that the gradients are turned back on during each training loop.

## Architecture 3:
* Batch Normalization switched to apply After ReLU on each convolutional layer
* all else remains the same as above

### Result after 10 training epochs:
Epoch: 1 	Training Loss: 22.869417 	Validation Loss: 4.881092

Epoch: 2 	Training Loss: 4.894224 	Validation Loss: 4.884323

Epoch: 3 	Training Loss: 4.912950 	Validation Loss: 5.126093

Epoch: 4 	Training Loss: 4.897206 	Validation Loss: 4.928959

Epoch: 5 	Training Loss: 4.895872 	Validation Loss: 5.007126

Epoch: 6 	Training Loss: 4.892983 	Validation Loss: 5.116531

Epoch: 7 	Training Loss: 4.893521 	Validation Loss: 5.319191

Epoch: 8 	Training Loss: 4.897356 	Validation Loss: 4.882479

Epoch: 9 	Training Loss: 4.894560 	Validation Loss: 4.885508

Epoch: 10 	Training Loss: 4.894811 	Validation Loss: 4.889231


In [3]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.batchnorm2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.batchnorm3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.batchnorm4 = nn.BatchNorm2d(256)
        self.conv5 = nn.Conv2d(256, 512, 3, padding=1)
        self.batchnorm5 = nn.BatchNorm2d(512)
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.fc1 = nn.Linear(512*7*7, 1000, bias=True)
        self.fc2 = nn.Linear(1000, 133, bias=True)
#         self.fc3 = nn.Linear(500, 133, bias=True)
        self.dropout = nn.Dropout(p=0.3)
        
    def forward(self, x):
        x = self.batchnorm1(F.relu(self.conv1(x)))
        x = self.pool(x)
        x = self.batchnorm2(F.relu(self.conv2(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = self.batchnorm3(F.relu(self.conv3(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = self.batchnorm4(F.relu(self.conv4(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = self.batchnorm5(F.relu(self.conv5(x)))
        x = self.pool(x)
        
        x = x.view(-1, 512*7*7)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

#-#-# You do NOT have to modify the code below this line. #-#-#

# instantiate the CNN
model_scratch = Net()

# move tensors to GPU if CUDA is available
use_cuda = torch.cuda.is_available()
if use_cuda:
    model_scratch.cuda()

In [4]:
import torch.optim as optim

### TODO: select loss function
criterion_scratch = nn.CrossEntropyLoss()

### TODO: select optimizer
optimizer_scratch = optim.Adam(model_scratch.parameters(), lr=0.025)

In [5]:
# the following import is required for training to be robust to truncated images
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


def train(n_epochs, loaders, model, optimizer, criterion, use_cuda, save_path):
    """returns trained model"""
    # initialize tracker for minimum validation loss
    valid_loss_min = np.Inf 
    
    for epoch in range(1, n_epochs+1):

        # initialize variables to monitor training and validation loss
        train_loss = 0.0
        valid_loss = 0.0
        ###################
        # train the model #
        ###################
        model.train()
        for batch_idx, (data, target) in enumerate(loaders['train']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            ## find the loss and update the model parameters accordingly
            optimizer.zero_grad()
            output = model(data) #get predictions
            loss = criterion(output, target) #calulate loss
            loss.backward()  # calculate the gradients
            optimizer.step() # perform optimization step
            train_loss += loss.item()

    
        ######################    
        # validate the model #
        ######################
        model.eval()
        for batch_idx, (data, target) in enumerate(loaders['valid']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            with torch.no_grad():
                output = model(data)
                loss = criterion(output, target)
                valid_loss += loss.item()
            
        train_loss = train_loss / len(train_loader)
        valid_loss = valid_loss / len(valid_loader)
        
            
        # print training/validation statistics 
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            epoch, 
            train_loss,
            valid_loss
            ))
        
        
        ## TODO: save the model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print('Validation Loss Decreased. Saving model')
            torch.save(model.state_dict(), save_path)
            valid_loss_min = valid_loss
    # return trained model
    return model

# train the model
model_scratch = train(10, loaders_scratch, model_scratch, optimizer_scratch, 
                      criterion_scratch, use_cuda, 'model_scratch.pt')

Epoch: 1 	Training Loss: 22.869417 	Validation Loss: 4.881092
Validation Loss Decreased. Saving model
Epoch: 2 	Training Loss: 4.894224 	Validation Loss: 4.884323
Epoch: 3 	Training Loss: 4.912950 	Validation Loss: 5.126093
Epoch: 4 	Training Loss: 4.897206 	Validation Loss: 4.928959
Epoch: 5 	Training Loss: 4.895872 	Validation Loss: 5.007126
Epoch: 6 	Training Loss: 4.892983 	Validation Loss: 5.116531
Epoch: 7 	Training Loss: 4.893521 	Validation Loss: 5.319191
Epoch: 8 	Training Loss: 4.897356 	Validation Loss: 4.882479
Epoch: 9 	Training Loss: 4.894560 	Validation Loss: 4.885508
Epoch: 10 	Training Loss: 4.894811 	Validation Loss: 4.889231


## Architecture 3 with modified training loop:
* training loop modifeid to place all training steps inside `with torch.enable_grad():` loop
* learning rate decreased to 0.0025 to closer approximate what has been reported as optimal
* directly accessing train_loader and valid_loader, rather than accessing through the loaders dict.
* iterating through train_loader and valid_loader directly, rather than enumerating them. (I don't use the batch index for any reason anyway)

### Results after 10 training epochs:
Epoch: 1 	Training Loss: 6.145714 	Validation Loss: 317.693266

Epoch: 2 	Training Loss: 4.881616 	Validation Loss: 149.638213

Epoch: 3 	Training Loss: 4.881738 	Validation Loss: 76.856270

Epoch: 4 	Training Loss: 4.868647 	Validation Loss: 10.119001

Epoch: 5 	Training Loss: 4.870256 	Validation Loss: 175.708004

Epoch: 6 	Training Loss: 4.885781 	Validation Loss: 5.842024

Epoch: 7 	Training Loss: 4.867848 	Validation Loss: 5.502650

Epoch: 8 	Training Loss: 4.866408 	Validation Loss: 5.315064

Epoch: 9 	Training Loss: 4.866485 	Validation Loss: 5.140727

Epoch: 10 	Training Loss: 4.866341 	Validation Loss: 5.599622


It's at least showing something different than the versions before it - only in the validation loops however. 
The training loss shows the same as before: 1 good step, followed by bouncing around in the 4.8 range.

The validation loss showed progress for a few iterations, before exploding in Epoch 5, then dropping back down to the 5.x level in epoch 6 after which it just bounced around in the 5.x range without progress.

I will keep this checkpoint saved in the model_scratch.pt file and will rename the checkpoint for the next attempts.

Testing shows accuracy of 1%. Predictions seem to start as random, then later settle as class 4 for every image.

In [3]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.batchnorm2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.batchnorm3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.batchnorm4 = nn.BatchNorm2d(256)
        self.conv5 = nn.Conv2d(256, 512, 3, padding=1)
        self.batchnorm5 = nn.BatchNorm2d(512)
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.fc1 = nn.Linear(512*7*7, 1000, bias=True)
        self.fc2 = nn.Linear(1000, 133, bias=True)
#         self.fc3 = nn.Linear(500, 133, bias=True)
        self.dropout = nn.Dropout(p=0.3)
        
    def forward(self, x):
        x = self.batchnorm1(F.relu(self.conv1(x)))
        x = self.pool(x)
        x = self.batchnorm2(F.relu(self.conv2(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = self.batchnorm3(F.relu(self.conv3(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = self.batchnorm4(F.relu(self.conv4(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = self.batchnorm5(F.relu(self.conv5(x)))
        x = self.pool(x)
        
        x = x.view(-1, 512*7*7)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

#-#-# You do NOT have to modify the code below this line. #-#-#

# instantiate the CNN
model_scratch = Net()

# move tensors to GPU if CUDA is available
use_cuda = torch.cuda.is_available()
if use_cuda:
    model_scratch.cuda()

In [3]:
import torch.optim as optim

### TODO: select loss function
criterion_scratch = nn.CrossEntropyLoss()

### TODO: select optimizer
optimizer_scratch = optim.Adam(model_scratch.parameters(), lr=0.025)

In [4]:
# the following import is required for training to be robust to truncated images
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


def train(n_epochs, loaders, model, optimizer, criterion, use_cuda, save_path):
    """returns trained model"""
    # initialize tracker for minimum validation loss
    valid_loss_min = np.Inf 
    
    for epoch in range(1, n_epochs+1):

        # initialize variables to monitor training and validation loss
        train_loss = 0.0
        valid_loss = 0.0
        ###################
        # train the model #
        ###################
        model.train()
        for data, target in train_loader:
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            ## find the loss and update the model parameters accordingly
            with torch.enable_grad():
                optimizer.zero_grad()
                output = model(data) #get predictions
                loss = criterion(output, target) #calulate loss
                loss.backward()  # calculate the gradients
                optimizer.step() # perform optimization step
                train_loss += loss.item()

    
        ######################    
        # validate the model #
        ######################
        model.eval()
        for data, target in valid_loader:
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            with torch.no_grad():
                output = model(data)
                loss = criterion(output, target)
                valid_loss += loss.item()
            
        train_loss = train_loss / len(train_loader)
        valid_loss = valid_loss / len(valid_loader)
        
            
        # print training/validation statistics 
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            epoch, 
            train_loss,
            valid_loss
            ))
        
        
        ## TODO: save the model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print('Validation Loss Decreased. Saving model')
            torch.save(model.state_dict(), save_path)
            valid_loss_min = valid_loss
    # return trained model
    return model

# train the model
model_scratch = train(30, loaders_scratch, model_scratch, optimizer_scratch, 
                      criterion_scratch, use_cuda, 'model_scratch_3.pt')

Epoch: 1 	Training Loss: 21.230738 	Validation Loss: 95.244353
Validation Loss Decreased. Saving model
Epoch: 2 	Training Loss: 4.893078 	Validation Loss: 24.413463
Validation Loss Decreased. Saving model
Epoch: 3 	Training Loss: 4.896302 	Validation Loss: 1088.203040
Epoch: 4 	Training Loss: 4.894740 	Validation Loss: 33.911446
Epoch: 5 	Training Loss: 4.895293 	Validation Loss: 1834.588041
Epoch: 6 	Training Loss: 4.897745 	Validation Loss: 1198.642593
Epoch: 7 	Training Loss: 4.896015 	Validation Loss: 1604.381039
Epoch: 8 	Training Loss: 4.894294 	Validation Loss: 3534.684728
Epoch: 9 	Training Loss: 4.893817 	Validation Loss: 2232.242461
Epoch: 10 	Training Loss: 4.893525 	Validation Loss: 2071.671996
Epoch: 11 	Training Loss: 4.894349 	Validation Loss: 1978.222531
Epoch: 12 	Training Loss: 4.896489 	Validation Loss: 1439.038064
Epoch: 13 	Training Loss: 4.894585 	Validation Loss: 2151.704347
Epoch: 14 	Training Loss: 4.895108 	Validation Loss: 152.453516
Epoch: 15 	Training Loss:

## Architecture 1 again, with modified training loop:
* Using same training/validation code as that used with Architecture 1, but adding the `torch.enable_grad():` line to encompass the training loop
* also changed the average loss calculation prior to printing from using `train_loss/len(train_loader)` to using `train_loss/len(loaders['train'])`, which I believe should be insignificant.

### Results after 10 epochs
Epoch: 1 	Training Loss: 5.625305 	Validation Loss: 4.871825

Epoch: 2 	Training Loss: 4.876564 	Validation Loss: 4.869084

Epoch: 3 	Training Loss: 4.872355 	Validation Loss: 4.871287

Epoch: 4 	Training Loss: 4.870470 	Validation Loss: 4.868716

Epoch: 5 	Training Loss: 4.869439 	Validation Loss: 4.870064

Epoch: 6 	Training Loss: 4.867907 	Validation Loss: 4.868820

Epoch: 7 	Training Loss: 4.867843 	Validation Loss: 4.867972

Epoch: 8 	Training Loss: 4.867246 	Validation Loss: 4.868295

Epoch: 9 	Training Loss: 4.866786 	Validation Loss: 4.868938

Epoch: 10 	Training Loss: 4.878639 	Validation Loss: 4.867290

No better than any other. 1 seemingly good step, followed by bouncing around in 4.8 range as usual.
Test shows accuracy of 1% - with every single image receiving a prediction of class 4.

In [5]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.batchnorm2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.batchnorm3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.batchnorm4 = nn.BatchNorm2d(256)
        self.conv5 = nn.Conv2d(256, 512, 3, padding=1)
        self.batchnorm5 = nn.BatchNorm2d(512)
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.fc1 = nn.Linear(512*7*7, 1000, bias=True)
        self.fc2 = nn.Linear(1000, 500, bias=True)
        self.fc3 = nn.Linear(500, 133, bias=True)
        self.dropout = nn.Dropout(p=0.5)
        
    def forward(self, x):
        x = F.relu(self.batchnorm1(self.conv1(x)))
        x = self.pool(x)
        x = F.relu(self.batchnorm2(self.conv2(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = F.relu(self.batchnorm3(self.conv3(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = F.relu(self.batchnorm4(self.conv4(x)))
        x = self.dropout(x)
        x = self.pool(x)
        x = F.relu(self.batchnorm5(self.conv5(x)))
        x = self.pool(x)
        
        x = x.view(-1, 512*7*7)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = F.log_softmax(self.fc3(x), dim=1)
        return x

#-#-# You do NOT have to modify the code below this line. #-#-#

# instantiate the CNN
model_scratch = Net()

# move tensors to GPU if CUDA is available
use_cuda = torch.cuda.is_available()
if use_cuda:
    model_scratch.cuda()

In [6]:
import torch.optim as optim

### TODO: select loss function
criterion_scratch = nn.NLLLoss()

### TODO: select optimizer
optimizer_scratch = optim.Adam(model_scratch.parameters(), lr=0.0025)

In [7]:
# the following import is required for training to be robust to truncated images
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


def train(n_epochs, loaders, model, optimizer, criterion, use_cuda, save_path):
    """returns trained model"""
    # initialize tracker for minimum validation loss
    valid_loss_min = np.Inf 
    
    for epoch in range(1, n_epochs+1):

        # initialize variables to monitor training and validation loss
        train_loss = 0.0
        valid_loss = 0.0
        ###################
        # train the model #
        ###################
        model.train()
        for batch_idx, (data, target) in enumerate(loaders['train']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            ## find the loss and update the model parameters accordingly
            with torch.enable_grad():
                optimizer.zero_grad()
                output = model(data)  #output is tensor of shape([batch_size, num_classes]) where largest value is the prediction
                loss = criterion(output, target) # loss is the cross-entropy loss which measures how far the prediction is from the actual target
                loss.backward()  # calculating the gradients for all operations
                optimizer.step() #performing gradient descent step
                train_loss += loss.item()

    
        ######################    
        # validate the model #
        ######################
        model.eval()
        for batch_idx, (data, target) in enumerate(loaders['valid']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            with torch.no_grad():
                output = model(data)
                loss = criterion(output, target)
                valid_loss += loss.item()
            
        train_loss = train_loss / len(loaders['train'])
        valid_loss = valid_loss / len(loaders['valid'])
        
            
        # print training/validation statistics 
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            epoch, 
            train_loss,
            valid_loss
            ))
        
        
        ## TODO: save the model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print('Validation Loss Decreased. Saving model')
            torch.save(model.state_dict(), save_path)
            valid_loss_min = valid_loss
    # return trained model
    return model

# train the model
model_scratch = train(10, loaders_scratch, model_scratch, optimizer_scratch, 
                      criterion_scratch, use_cuda, 'model_scratch_2.pt')

Epoch: 1 	Training Loss: 5.625305 	Validation Loss: 4.871825
Validation Loss Decreased. Saving model
Epoch: 2 	Training Loss: 4.876564 	Validation Loss: 4.869084
Validation Loss Decreased. Saving model
Epoch: 3 	Training Loss: 4.872355 	Validation Loss: 4.871287
Epoch: 4 	Training Loss: 4.870470 	Validation Loss: 4.868716
Validation Loss Decreased. Saving model
Epoch: 5 	Training Loss: 4.869439 	Validation Loss: 4.870064
Epoch: 6 	Training Loss: 4.867907 	Validation Loss: 4.868820
Epoch: 7 	Training Loss: 4.867843 	Validation Loss: 4.867972
Validation Loss Decreased. Saving model
Epoch: 8 	Training Loss: 4.867246 	Validation Loss: 4.868295
Epoch: 9 	Training Loss: 4.866786 	Validation Loss: 4.868938
Epoch: 10 	Training Loss: 4.878639 	Validation Loss: 4.867290
Validation Loss Decreased. Saving model


## Test function

In [6]:
def test(loaders, model, criterion, use_cuda):

    # monitor test loss and accuracy
    test_loss = 0.0
    correct = 0.0
    total = 0.0
    model.eval()
    
    for batch_idx, (data, target) in enumerate(loaders['test']):
        # move to GPU
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        with torch.no_grad():
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            print('Data: ', data)
            print('target: ', target)
            # calculate the loss
            loss = criterion(output, target)
            # update average test loss 
            test_loss = test_loss + ((1 / (batch_idx + 1)) * (loss.data - test_loss))
            # convert output probabilities to predicted class
            pred = output.data.max(1, keepdim=True)[1]
            print(pred)
            # compare predictions to true label
            correct += np.sum(np.squeeze(pred.eq(target.data.view_as(pred))).cpu().numpy())
            total += data.size(0)
            
    print('Test Loss: {:.6f}\n'.format(test_loss))

    print('\nTest Accuracy: %2d%% (%2d/%2d)' % (
        100. * correct / total, correct, total))

# call test function    
test(loaders_scratch, model_scratch, criterion_scratch, use_cuda)

Data:  tensor([[[[-1.4672, -1.4329, -1.3644,  ..., -1.4843, -1.5014, -1.5185],
          [-1.5014, -1.4329, -1.3130,  ..., -1.4672, -1.4843, -1.4843],
          [-1.5185, -1.4500, -1.3302,  ..., -1.4329, -1.4500, -1.4843],
          ...,
          [ 2.0777,  2.0777,  2.0605,  ..., -1.3815, -1.3644, -1.3473],
          [ 2.0777,  2.0605,  2.0605,  ..., -1.3473, -1.3815, -1.3815],
          [ 2.0777,  2.0434,  2.0434,  ..., -1.3815, -1.4158, -1.3987]],

         [[-0.9853, -0.9503, -0.8978,  ..., -0.9678, -0.9853, -1.0028],
          [-1.0028, -0.9328, -0.8452,  ..., -0.9853, -0.9678, -0.9678],
          [-0.9853, -0.9153, -0.8277,  ..., -0.9678, -0.9853, -0.9853],
          ...,
          [ 1.4657,  1.4657,  1.4482,  ..., -1.0378, -1.0203, -1.0378],
          [ 1.4657,  1.4307,  1.4307,  ..., -0.9853, -1.0028, -1.0203],
          [ 1.4482,  1.4132,  1.4132,  ..., -1.0203, -1.0378, -1.0378]],

         [[-1.4907, -1.5081, -1.4733,  ..., -1.2990, -1.3164, -1.3339],
          [-1.5256, -1.

Data:  tensor([[[[-1.3815, -1.2788, -1.3302,  ...,  1.0844,  0.3994, -0.1999],
          [-1.3987, -1.4158, -1.3473,  ...,  1.0159,  0.9817,  0.3994],
          [-0.8507, -1.5014, -1.5014,  ...,  0.8618,  0.8276,  0.4679],
          ...,
          [ 1.3755,  0.7762, -0.5938,  ...,  0.0569,  0.8618,  1.3584],
          [ 1.4783,  0.9303, -0.0629,  ..., -1.5185, -1.1247, -0.1828],
          [ 1.6838,  0.7933,  0.5364,  ..., -1.3987, -0.9534, -0.0972]],

         [[-0.7752, -0.6352, -0.7927,  ...,  1.3081,  0.4678, -0.2150],
          [-0.9328, -0.9503, -0.9503,  ...,  1.3081,  1.2031,  0.5903],
          [-0.4776, -1.1429, -1.1253,  ...,  1.0980,  1.1856,  0.8179],
          ...,
          [ 0.9405,  0.3627, -1.0203,  ..., -0.5301,  0.2402,  0.6954],
          [ 1.0455,  0.4678, -0.4951,  ..., -1.6681, -1.4580, -0.5826],
          [ 1.3256,  0.4153,  0.1001,  ..., -1.5455, -1.1429, -0.3725]],

         [[-1.3861, -1.3164, -1.4210,  ...,  0.4265, -0.1835, -0.8110],
          [-1.2641, -1.

Data:  tensor([[[[-0.9705, -0.9534, -0.9534,  ...,  1.0331,  0.9988,  1.0502],
          [-1.0219, -1.0219, -1.0390,  ...,  0.9474,  0.9132,  0.9646],
          [-0.9192, -0.9363, -0.9705,  ...,  0.8961,  0.8789,  0.9132],
          ...,
          [ 1.3755,  1.3927,  1.3927,  ..., -0.9020, -0.7822, -0.8507],
          [ 1.3755,  1.3927,  1.3927,  ..., -0.8507, -0.7650, -0.9363],
          [ 1.3755,  1.3927,  1.3927,  ..., -0.7822, -0.7993, -1.0390]],

         [[-0.6877, -0.6702, -0.6702,  ...,  0.1702,  0.1527,  0.2052],
          [-0.7402, -0.7402, -0.7577,  ...,  0.0826,  0.0476,  0.1001],
          [-0.6352, -0.6527, -0.6877,  ...,  0.0301,  0.0126,  0.0301],
          ...,
          [ 1.5357,  1.5532,  1.5532,  ..., -0.9678, -0.8627, -0.9678],
          [ 1.5357,  1.5532,  1.5532,  ..., -0.9503, -0.8627, -1.0553],
          [ 1.5357,  1.5532,  1.5532,  ..., -0.8803, -0.9153, -1.1604]],

         [[-0.3055, -0.2881, -0.2881,  ..., -0.8458, -0.8807, -0.8458],
          [-0.3404, -0.

Data:  tensor([[[[ 1.6324,  0.0056,  0.0056,  ..., -1.0733, -0.9877, -1.0048],
          [ 1.5639, -0.0801, -0.0801,  ..., -1.0904, -1.0390, -1.0048],
          [ 1.2899, -0.0801, -0.0287,  ..., -1.0904, -1.0562, -1.0048],
          ...,
          [-0.9534, -0.7308, -0.9363,  ..., -1.5528, -1.0733, -0.9705],
          [-0.7993, -0.8507, -0.7137,  ..., -1.2274, -1.1418, -1.2445],
          [-0.5767, -0.9363, -0.7479,  ..., -1.2445, -1.2274, -1.2959]],

         [[ 1.9384,  0.2402,  0.1527,  ..., -0.2850, -0.1975, -0.1975],
          [ 1.8859,  0.1877,  0.0651,  ..., -0.2850, -0.2325, -0.1975],
          [ 1.6057,  0.1877,  0.1352,  ..., -0.2850, -0.2500, -0.1975],
          ...,
          [-0.5126, -0.3200, -0.4251,  ..., -1.3704, -0.8627, -0.7052],
          [-0.1975, -0.3725, -0.3200,  ..., -1.0203, -0.8627, -0.8978],
          [ 0.1352, -0.4251, -0.4251,  ..., -1.0028, -0.8452, -0.8627]],

         [[ 1.8557,  0.0779,  0.2522,  ...,  0.9494,  1.0539,  1.0539],
          [ 1.7685,  0.

       device='cuda:0')
Output:  tensor([[-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        ...,
        [-1.4882e+04, -1.0445e+04, -1.5580e+04,  ..., -1.2266e+04,
         -1.1713e+04, -1.5692e+04],
        [-1.1380e+05, -6.4257e+04, -1.1777e+05,  ..., -8.9054e+04,
         -9.6022e+04, -1.1970e+05],
        [-3.3350e+01, -3.1345e+01, -3.3692e+01,  ...,  7.1913e+00,
         -2.4132e+01, -4.0143e+01]], device='cuda:0')
target:  tensor([122,  91,  71,  37,   5, 110,  58,  18], device='cuda:0')
tensor([[ 4],
        [ 4],
        [ 4],
        [56],
        [ 4],
        [56],
        [56],
        [47]], device='cuda:0')
Data:  tensor([[[[-0.7479, -0.7479, -0.7308,  ...,  0.5878,  0.7248,  0.7762],
          [-0.7479, -0.7650, -0.78

Data:  tensor([[[[-0.7822, -0.7650, -0.5253,  ..., -0.7479, -1.1589, -1.5014],
          [-0.8678, -0.8335, -0.6794,  ..., -0.7308, -1.0904, -1.4843],
          [-0.8164, -0.9192, -0.8678,  ..., -0.9363, -1.2445, -1.5870],
          ...,
          [-1.0390, -1.0733, -1.0733,  ...,  2.2489,  2.2489,  2.2489],
          [-1.0390, -1.0733, -1.0733,  ...,  2.2489,  2.2489,  2.2489],
          [-1.0048, -1.0390, -1.0390,  ...,  2.2489,  2.2489,  2.2489]],

         [[-1.0728, -1.0028, -0.7402,  ..., -1.2829, -1.4930, -1.6506],
          [-1.1954, -1.1429, -0.9328,  ..., -1.3004, -1.4930, -1.6856],
          [-1.1954, -1.2304, -1.1429,  ..., -1.4755, -1.5980, -1.7381],
          ...,
          [-1.3880, -1.4230, -1.4230,  ...,  2.3585,  2.3585,  2.3410],
          [-1.3529, -1.3880, -1.3880,  ...,  2.3410,  2.3410,  2.3585],
          [-1.3004, -1.3354, -1.3354,  ...,  2.3235,  2.3410,  2.3410]],

         [[-1.7696, -1.7173, -1.5081,  ..., -1.3513, -1.5430, -1.6999],
          [-1.7870, -1.

Data:  tensor([[[[-1.7754, -1.7925, -1.8097,  ..., -1.8439, -1.8439, -1.8439],
          [-1.7754, -1.7925, -1.8097,  ..., -1.8439, -1.8268, -1.8268],
          [-1.7925, -1.8097, -1.8097,  ..., -1.8782, -1.8439, -1.8439],
          ...,
          [ 0.2282,  0.1939,  0.1768,  ..., -1.7583, -1.7583, -1.7069],
          [ 0.1939,  0.1597,  0.1597,  ..., -1.6898, -1.5014, -1.2103],
          [ 0.2282,  0.1939,  0.1939,  ..., -1.0048, -0.8335, -0.7479]],

         [[-1.6856, -1.7031, -1.7206,  ..., -1.7556, -1.7556, -1.7556],
          [-1.6681, -1.7031, -1.7206,  ..., -1.7556, -1.7381, -1.7381],
          [-1.6681, -1.6856, -1.7031,  ..., -1.7556, -1.7556, -1.7556],
          ...,
          [-0.4076, -0.4426, -0.4601,  ..., -1.6856, -1.6681, -1.5630],
          [-0.4426, -0.4776, -0.4776,  ..., -1.4230, -1.1779, -0.8627],
          [-0.4076, -0.4426, -0.4426,  ..., -0.5826, -0.4251, -0.4076]],

         [[-1.2467, -1.2641, -1.2816,  ..., -1.3164, -1.3164, -1.3164],
          [-1.2293, -1.

       device='cuda:0')
target:  tensor([ 52,  46, 101,  23,  10,  16,  86,  38], device='cuda:0')
tensor([[12],
        [ 4],
        [ 4],
        [92],
        [91],
        [ 4],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[ 1.8379,  1.8037,  1.7523,  ...,  1.9407,  2.1804,  2.0263],
          [ 1.8550,  1.8379,  1.7694,  ...,  1.8550,  2.1119,  1.8893],
          [ 1.8379,  1.8037,  1.7352,  ...,  2.0434,  2.2147,  2.0605],
          ...,
          [-1.7069, -1.5870, -1.6042,  ..., -1.2274, -1.3473, -1.3644],
          [-1.6555, -1.5699, -1.6042,  ..., -1.2103, -1.3302, -1.2959],
          [-1.6727, -1.6213, -1.5357,  ..., -1.0733, -1.3302, -1.3302]],

         [[-0.8277, -0.8627, -0.9153,  ...,  0.5553,  0.8880, -0.1800],
          [-0.8102, -0.8277, -0.8978,  ...,  0.6429,  0.8354, -0.2325],
          [-0.8452, -0.8627, -0.9153,  ...,  1.0805,  1.1506,  0.0651],
          ...,
          [-1.6155, -1.4930, -1.5105,  ..., -1.1429, -1.2654, -1.3004],
          [-

       device='cuda:0')
Output:  tensor([[-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-7.5935e+03, -5.4035e+03, -7.8402e+03,  ..., -5.8923e+03,
         -5.1977e+03, -7.9365e+03],
        [-7.8138e+01, -8.6861e+01, -7.0207e+01,  ..., -2.5228e+01,
         -3.1447e+01, -8.0144e+01],
        ...,
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00]], device='cuda:0')
target:  tensor([ 64,  94, 132,  75,  60,  88,  19,  40], device='cuda:0')
tensor([[  4],
        [ 56],
        [ 17],
        [ 56],
        [117],
        [  4],
        [  4],
        [  4]], device='cuda:0')
Data:  tensor([[[[ 1.9578,  1.9749,  1.9920,  ...,  0.0398,  0.0398, -0.2171],
          [ 1.9920,  1.940

Data:  tensor([[[[ 1.9749,  1.6153,  1.2385,  ...,  1.1529,  0.7419,  0.0056],
          [ 1.3242,  0.6221, -0.1314,  ...,  1.1529,  0.7591,  0.0398],
          [ 0.6734, -0.0629, -0.8678,  ...,  1.1015,  0.7077,  0.0227],
          ...,
          [-1.1760, -1.0562, -0.8849,  ..., -1.6384, -1.6898, -1.7583],
          [-1.2274, -1.2103, -1.0904,  ..., -1.6384, -1.6898, -1.7583],
          [-1.3644, -1.3473, -1.2445,  ..., -1.6384, -1.6898, -1.7583]],

         [[-0.1800, -0.2500, -0.3200,  ...,  0.3803,  0.1702, -0.2500],
          [-0.5476, -0.8102, -0.9328,  ...,  0.4153,  0.2227, -0.1625],
          [-0.8277, -1.0728, -1.0553,  ...,  0.3978,  0.2577, -0.1800],
          ...,
          [ 0.9230,  1.0105,  1.0980,  ..., -0.4426, -0.4076, -0.4251],
          [ 0.9230,  0.9405,  0.9930,  ..., -0.4426, -0.4076, -0.4251],
          [ 0.9580,  0.9930,  0.9755,  ..., -0.4426, -0.4076, -0.4251]],

         [[-0.6193, -0.3230, -0.0092,  ...,  0.1128, -0.0615, -0.0267],
          [-0.4624, -0.

       device='cuda:0')
Output:  tensor([[-1.3824e+01, -1.3779e+01, -1.4548e+01,  ..., -1.4196e+01,
         -1.7875e+01, -1.5151e+01],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-2.7612e+03, -1.3576e+03, -2.8441e+03,  ..., -2.1809e+03,
         -1.4459e+03, -2.8185e+03],
        ...,
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-3.2922e+02, -1.1065e+02, -3.4237e+02,  ..., -3.1951e+02,
         -2.4153e+02, -3.7864e+02],
        [-1.6786e+02, -1.0512e+02, -1.5799e+02,  ..., -1.4235e+02,
         -9.6991e+01, -1.6062e+02]], device='cuda:0')
target:  tensor([ 68, 107,  60,  37,  77,  90,  85,  13], device='cuda:0')
tensor([[ 31],
        [  4],
        [ 12],
        [ 91],
        [  4],
        [  4],
        [117],
        [  4]], device='cuda:0')
Data:  tensor([[[[ 0.4679,  0.3994,  0.2967,  ...,  0.8961,  0.8961,  0.8961],
          [ 0.4508,  0.399

       device='cuda:0')
Output:  tensor([[-1.8425e+03, -1.5280e+03, -1.8880e+03,  ..., -1.5618e+03,
         -1.2515e+03, -1.8961e+03],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        ...,
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-2.2234e+02, -1.3643e+02, -2.4167e+02,  ..., -2.0555e+02,
         -1.6334e+02, -2.2952e+02],
        [-2.2921e+00, -2.1033e+00, -1.4818e+00,  ..., -2.7789e+00,
         -2.9601e+00, -2.5029e+00]], device='cuda:0')
target:  tensor([91, 57, 24, 28, 31, 20,  8, 22], device='cuda:0')
tensor([[11],
        [ 4],
        [ 4],
        [ 4],
        [95],
        [ 4],
        [31],
        [95]], device='cuda:0')
Data:  tensor([[[[-0.1314, -0.0972, -0.0458,  ...,  0.1768,  0.2111,  0.2453],
          [-0.0972, -0.0801, -0.0458,  ...

       device='cuda:0')
Output:  tensor([[-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-1.5965e+04, -9.0211e+03, -1.6543e+04,  ..., -1.1550e+04,
         -1.2281e+04, -1.6778e+04],
        [-3.5685e+03, -1.9344e+03, -3.7428e+03,  ..., -2.3414e+03,
         -2.5976e+03, -3.6704e+03],
        ...,
        [-3.5678e+02, -3.2400e+02, -3.4538e+02,  ..., -3.0910e+02,
         -3.3036e+02, -3.6313e+02],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00]], device='cuda:0')
target:  tensor([ 29, 131,  56,  47, 103,  91,  15,  30], device='cuda:0')
tensor([[ 4],
        [36],
        [82],
        [ 4],
        [11],
        [45],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[ 0.0227,  0.0227,  0.1083,  ...,  1.2043,  1.3413,  1.2899],
          [-0.0287,  0.0227,  0.09

       device='cuda:0')
Output:  tensor([[-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-2.3218e+02, -2.1092e+02, -2.2200e+02,  ..., -2.2180e+02,
         -1.5739e+02, -2.3586e+02],
        ...,
        [-3.1696e+01, -3.1766e+01, -3.3870e+01,  ..., -3.1770e+01,
         -4.0294e+01, -3.3836e+01],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00]], device='cuda:0')
target:  tensor([ 68,  46, 122,  65,  64,  34, 124,  62], device='cuda:0')
tensor([[ 4],
        [ 4],
        [91],
        [ 4],
        [ 4],
        [31],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[-0.5938, -0.5424, -0.5253,  ..., -0.0801, -0.0972, -0.0972],
          [-0.5424, -0.5253, -0.50

       device='cuda:0')
Output:  tensor([[-6.2506e+03, -3.1444e+03, -6.3219e+03,  ..., -4.7841e+03,
         -3.8777e+03, -6.3548e+03],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        ...,
        [-7.7880e-01, -6.4979e-01, -4.4476e-01,  ..., -1.3689e+00,
         -1.5116e+00, -1.5142e+00],
        [-2.0678e+01, -2.1695e+01, -2.1240e+01,  ..., -2.1676e+01,
         -1.1239e+01, -2.2399e+01],
        [-5.8642e+00, -5.5251e+00, -5.6286e+00,  ...,  6.6731e-01,
         -4.9937e+00, -7.6853e+00]], device='cuda:0')
target:  tensor([127,  96,  76,  80, 118,  32,  57,  95], device='cuda:0')
tensor([[36],
        [ 4],
        [ 4],
        [31],
        [ 4],
        [ 4],
        [12],
        [47]], device='cuda:0')
Data:  tensor([[[[-0.4739, -0.4226, -0.4739,  ..., -1.2274, -1.2788, -1.2274],
          [-0.0972, -0.1143, -0.37

       device='cuda:0')
Output:  tensor([[-0.4912, -0.3603, -0.1338,  ..., -1.0861, -1.1508, -1.2135],
        [-0.4912, -0.3603, -0.1338,  ..., -1.0861, -1.1508, -1.2135],
        [-0.4912, -0.3603, -0.1338,  ..., -1.0861, -1.1508, -1.2135],
        ...,
        [-0.4912, -0.3603, -0.1338,  ..., -1.0861, -1.1508, -1.2135],
        [-0.4912, -0.3603, -0.1338,  ..., -1.0861, -1.1508, -1.2135],
        [-0.4912, -0.3603, -0.1338,  ..., -1.0861, -1.1508, -1.2135]],
       device='cuda:0')
target:  tensor([14, 55, 79, 23, 23, 92, 68, 14], device='cuda:0')
tensor([[ 4],
        [ 4],
        [ 4],
        [56],
        [ 4],
        [ 4],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[ 0.4851,  0.4508,  0.8276,  ..., -1.4158, -1.1932, -1.0390],
          [ 0.8104,  1.0159,  0.6734,  ..., -1.2959, -1.0904, -0.9534],
          [ 1.0844,  0.8618,  0.8276,  ..., -1.1932, -0.8849, -0.6794],
          ...,
          [-0.9534, -0.8335, -0.7308,  ..., -0.4397, -0.5596, -0.5082],
  

       device='cuda:0')
Output:  tensor([[-2.1166e+04, -1.2596e+04, -2.1961e+04,  ..., -1.5455e+04,
         -1.6838e+04, -2.2198e+04],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        ...,
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-1.1727e+04, -7.0697e+03, -1.2264e+04,  ..., -1.0439e+04,
         -9.2816e+03, -1.2306e+04]], device='cuda:0')
target:  tensor([ 88,  61,  41,  78,  59,  20, 128,  69], device='cuda:0')
tensor([[36],
        [ 4],
        [ 4],
        [ 4],
        [ 4],
        [ 4],
        [ 4],
        [56]], device='cuda:0')
Data:  tensor([[[[ 1.1872,  1.1872,  1.1872,  ..., -1.6042, -1.5357, -1.5185],
          [ 1.1872,  1.1872,  1.18

Data:  tensor([[[[-0.8335, -0.8849, -0.9020,  ..., -0.9705, -0.9534, -0.9363],
          [-0.7137, -0.7308, -0.7479,  ..., -1.0390, -1.0390, -1.0219],
          [-0.6965, -0.7308, -0.7479,  ..., -1.1075, -1.1247, -1.1075],
          ...,
          [-1.6898, -1.5699, -1.6213,  ..., -0.9705, -0.7650, -1.1589],
          [-0.3712, -0.1143, -0.6965,  ..., -1.0904, -0.9877, -1.2617],
          [-0.6281, -0.2513, -0.9020,  ..., -0.8335, -1.3644, -1.2445]],

         [[-0.7227, -0.7752, -0.7752,  ..., -0.7752, -0.7752, -0.7752],
          [-0.7927, -0.7752, -0.7402,  ..., -0.7577, -0.7752, -0.7927],
          [-0.8277, -0.8102, -0.7577,  ..., -0.7927, -0.8102, -0.8102],
          ...,
          [-1.4405, -1.4405, -1.5805,  ..., -0.6176, -0.2850, -0.7227],
          [-0.0924,  0.0476, -0.6527,  ..., -0.6877, -0.5301, -0.9328],
          [-0.4076, -0.0399, -0.8803,  ..., -0.2850, -0.8627, -1.0028]],

         [[ 0.2348,  0.2348,  0.2173,  ...,  0.2173,  0.2871,  0.2871],
          [ 0.3393,  0.

Data:  tensor([[[[-1.9124, -1.9467, -2.0152,  ...,  1.5297,  1.2557,  0.8618],
          [-1.9124, -1.9467, -1.9980,  ...,  1.5125,  1.3927,  1.1015],
          [-1.9467, -1.9467, -1.9638,  ...,  1.4612,  1.4783,  1.2899],
          ...,
          [-1.2445, -1.2274, -1.2103,  ...,  0.7933,  0.8618,  0.9817],
          [-1.3130, -1.3473, -1.3473,  ...,  0.8104,  0.7591,  0.7933],
          [-1.4672, -1.4672, -1.5014,  ...,  0.7591,  0.7933,  0.8447]],

         [[-1.8081, -1.8431, -1.8957,  ...,  0.8529,  0.5728,  0.3102],
          [-1.8081, -1.8431, -1.8957,  ...,  0.8704,  0.7479,  0.5203],
          [-1.8431, -1.8431, -1.8606,  ...,  0.8704,  0.8704,  0.6604],
          ...,
          [-0.9853, -0.9678, -1.0028,  ...,  0.9580,  1.0105,  1.1331],
          [-1.0728, -1.1078, -1.1779,  ...,  0.9930,  0.9055,  0.9405],
          [-1.2654, -1.2829, -1.3529,  ...,  0.9405,  0.9405,  0.9930]],

         [[-1.5430, -1.5604, -1.6476,  ...,  0.8099,  0.5311,  0.3568],
          [-1.5256, -1.

Data:  tensor([[[[-0.0287, -0.0287,  0.0056,  ..., -0.5596, -0.5596, -0.5424],
          [-0.2684, -0.2513, -0.2342,  ..., -0.1314, -0.0972, -0.0801],
          [-0.5596, -0.5253, -0.5253,  ...,  0.3823,  0.4166,  0.4166],
          ...,
          [ 1.4783,  1.5297,  1.6667,  ...,  1.5810,  1.5639,  1.5297],
          [ 1.6495,  1.6324,  1.7523,  ...,  1.7865,  1.7523,  1.7009],
          [ 1.6495,  1.6838,  1.7009,  ...,  1.8379,  1.8208,  1.8208]],

         [[ 0.3627,  0.3452,  0.3452,  ..., -0.1800, -0.1800, -0.1625],
          [ 0.1176,  0.1001,  0.1001,  ...,  0.2227,  0.2577,  0.2752],
          [-0.1800, -0.1975, -0.1975,  ...,  0.6954,  0.7304,  0.7479],
          ...,
          [ 1.4482,  1.4832,  1.6232,  ...,  1.5357,  1.5182,  1.4832],
          [ 1.6232,  1.6057,  1.6933,  ...,  1.7458,  1.7108,  1.6583],
          [ 1.6057,  1.6408,  1.6583,  ...,  1.8333,  1.8158,  1.7633]],

         [[ 0.0256, -0.0790, -0.1487,  ..., -0.0615, -0.0441, -0.0267],
          [-0.0441, -0.

Data:  tensor([[[[ 0.1083,  0.1083,  0.1254,  ..., -0.2684, -0.2856, -0.3027],
          [ 0.0398,  0.0741,  0.0912,  ..., -0.2513, -0.2513, -0.2684],
          [ 0.0569,  0.0741,  0.0912,  ..., -0.1828, -0.1999, -0.2342],
          ...,
          [-0.8678, -0.8678, -1.0048,  ..., -0.9705, -0.9877, -1.1418],
          [-0.8678, -1.0048, -1.1760,  ..., -0.9363, -1.0562, -1.2445],
          [-0.9534, -1.1247, -1.3473,  ..., -0.9192, -1.0904, -1.2788]],

         [[-0.0399, -0.0224, -0.0224,  ...,  0.0301,  0.0301,  0.0126],
          [-0.0574, -0.0224, -0.0049,  ...,  0.0651,  0.0651,  0.0476],
          [-0.0399, -0.0049, -0.0049,  ...,  0.1352,  0.1176,  0.0826],
          ...,
          [-0.4426, -0.3901, -0.4601,  ..., -0.4426, -0.4076, -0.4951],
          [-0.4601, -0.5126, -0.6176,  ..., -0.5126, -0.5126, -0.6176],
          [-0.5301, -0.6352, -0.8277,  ..., -0.5126, -0.5476, -0.6527]],

         [[-0.0615, -0.0615, -0.0964,  ..., -0.3753, -0.4101, -0.4624],
          [-0.1138, -0.

Output:  tensor([[ -0.4912,  -0.3603,  -0.1338,  ...,  -1.0861,  -1.1508,  -1.2135],
        [ -0.4912,  -0.3603,  -0.1338,  ...,  -1.0861,  -1.1508,  -1.2135],
        [-37.2346, -41.5316, -34.7677,  ..., -41.3823, -28.1277, -39.2393],
        ...,
        [-45.8188, -49.0267, -46.9897,  ..., -47.8729, -12.4137, -48.8440],
        [ -0.4912,  -0.3603,  -0.1338,  ...,  -1.0861,  -1.1508,  -1.2135],
        [ -0.4912,  -0.3603,  -0.1338,  ...,  -1.0861,  -1.1508,  -1.2135]],
       device='cuda:0')
target:  tensor([ 80,  74,  41,  92, 130,  43,  47,  31], device='cuda:0')
tensor([[ 4],
        [ 4],
        [17],
        [11],
        [11],
        [12],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[-0.1999, -0.1657, -0.1143,  ..., -0.1486, -0.1143, -0.0972],
          [-0.1657, -0.1486, -0.1143,  ..., -0.1657, -0.1314, -0.0972],
          [-0.1828, -0.1314, -0.0629,  ..., -0.1486, -0.1657, -0.1314],
          ...,
          [-1.4500, -1.4329, -1.3815,  ..., -1.5870, -

Data:  tensor([[[[ 1.6495,  2.1462,  1.9578,  ...,  2.0777,  2.1290,  2.0263],
          [ 1.9920,  2.0605,  2.1462,  ...,  1.2728,  1.6495,  1.8208],
          [ 1.7352,  1.5125,  1.5982,  ...,  1.9578,  1.7523,  1.9064],
          ...,
          [-1.7069, -1.7412, -1.8097,  ...,  0.1768, -0.0287,  0.1083],
          [-1.7240, -1.7754, -1.7754,  ...,  0.6049,  0.7419,  0.4337],
          [-1.7754, -1.8097, -1.7925,  ...,  0.6563,  0.2796, -0.1143]],

         [[ 1.7983,  2.2535,  2.1310,  ...,  2.2360,  2.2535,  2.0959],
          [ 2.2360,  2.3060,  2.3060,  ...,  1.4657,  1.7633,  1.8683],
          [ 2.0959,  1.9034,  1.8859,  ...,  2.1660,  1.9384,  1.9909],
          ...,
          [-1.6331, -1.6856, -1.7731,  ...,  0.3978,  0.2227,  0.3978],
          [-1.6681, -1.7206, -1.7381,  ...,  0.8179,  1.0105,  0.7129],
          [-1.7206, -1.7556, -1.7556,  ...,  0.8354,  0.5028,  0.1001]],

         [[ 0.6879,  1.4025,  0.8448,  ...,  1.7163,  1.8731,  1.6814],
          [ 0.7751,  0.

Data:  tensor([[[[ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          ...,
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489]],

         [[ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          ...,
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286]],

         [[ 2.6400,  2.6400,  2.6400,  ...,  2.6400,  2.6400,  2.6400],
          [ 2.6400,  2.

Data:  tensor([[[[-1.1932, -1.5185, -1.2274,  ..., -1.4500, -1.3473, -1.2445],
          [-1.1932, -1.5014, -1.1760,  ..., -1.5014, -1.5357, -1.3815],
          [-1.3473, -1.4500, -1.1247,  ..., -1.5185, -1.4843, -1.4843],
          ...,
          [-0.4568, -0.4739, -0.5767,  ..., -0.7822, -0.4911, -0.4226],
          [-0.3027, -0.3369, -0.3883,  ..., -0.2171, -0.1143,  0.1254],
          [-0.5424, -0.8164, -1.1075,  ...,  0.5707,  0.0741, -0.3369]],

         [[-1.1078, -1.4230, -1.1253,  ..., -1.3529, -1.2304, -1.1078],
          [-1.1078, -1.4055, -1.0728,  ..., -1.4055, -1.4055, -1.2479],
          [-1.2654, -1.3704, -1.0203,  ..., -1.4230, -1.3529, -1.3529],
          ...,
          [-0.6527, -0.6176, -0.6702,  ..., -1.0553, -0.6702, -0.5476],
          [-0.4601, -0.4776, -0.4951,  ..., -0.4951, -0.3550, -0.0749],
          [-0.6877, -0.9503, -1.2304,  ...,  0.3452, -0.1800, -0.5651]],

         [[-0.9678, -1.3513, -1.1073,  ..., -1.2816, -1.1944, -1.1073],
          [-0.9678, -1.

Data:  tensor([[[[ 0.5022,  0.5193,  0.4337,  ...,  0.6221,  0.8276,  0.7933],
          [ 0.3652,  0.4679,  0.3481,  ...,  0.7591,  0.8961,  0.9303],
          [ 0.3823,  0.3309,  0.3823,  ...,  0.9817,  0.8961,  0.9988],
          ...,
          [-0.6623, -0.6623, -0.6452,  ...,  1.7180,  1.7009,  1.6838],
          [-0.7308, -0.7137, -0.6965,  ...,  1.7009,  1.7180,  1.7180],
          [-0.7137, -0.7137, -0.6794,  ...,  1.6838,  1.7180,  1.7352]],

         [[ 0.4503,  0.3627,  0.1001,  ...,  0.3102,  0.4678,  0.3627],
          [ 0.3102,  0.2752, -0.0224,  ...,  0.4853,  0.6078,  0.5728],
          [ 0.2052, -0.0049, -0.1275,  ...,  0.7829,  0.6779,  0.7829],
          ...,
          [-0.7052, -0.7052, -0.6702,  ...,  1.8333,  1.8508,  1.8508],
          [-0.7752, -0.7752, -0.7227,  ...,  1.8333,  1.8508,  1.8683],
          [-0.7927, -0.7752, -0.7052,  ...,  1.8508,  1.8859,  1.8859]],

         [[ 0.6879,  0.5485,  0.1999,  ...,  0.4614,  0.6356,  0.5136],
          [ 0.5659,  0.

Data:  tensor([[[[ 1.1872,  0.9132,  0.8789,  ...,  1.3584,  1.4612,  1.5468],
          [ 0.9817,  0.7591,  0.8276,  ...,  0.9474,  1.1015,  1.3070],
          [ 0.7419,  0.6563,  0.7591,  ...,  0.7077,  0.8447,  1.0331],
          ...,
          [-0.0458, -0.0972, -0.0801,  ..., -0.2171, -0.0287,  0.3309],
          [-0.1999, -0.1999, -0.1828,  ..., -0.1143, -0.0458,  0.5707],
          [-0.1486, -0.2342, -0.2171,  ...,  0.0912, -0.0116,  0.5364]],

         [[ 1.5882,  1.4307,  1.4657,  ...,  1.6583,  1.7108,  1.8333],
          [ 1.5707,  1.3782,  1.4482,  ...,  1.3957,  1.5007,  1.7283],
          [ 1.4132,  1.3606,  1.4132,  ...,  1.2556,  1.3256,  1.6057],
          ...,
          [ 0.3978,  0.3277,  0.3102,  ..., -0.2325, -0.0574,  0.2577],
          [ 0.3277,  0.3102,  0.3102,  ..., -0.0924, -0.0399,  0.5378],
          [ 0.4503,  0.3627,  0.3803,  ...,  0.1176, -0.0049,  0.5203]],

         [[ 1.0539,  0.7228,  0.7925,  ...,  1.1062,  1.2282,  1.4025],
          [ 0.9319,  0.

Data:  tensor([[[[-0.5253, -0.4054, -0.5938,  ..., -0.5938, -0.6794, -0.6109],
          [-0.5424, -0.5082, -0.8164,  ..., -0.3712, -0.8335, -0.9363],
          [-0.6452, -0.4739, -0.6794,  ...,  0.1254, -0.0116,  0.1597],
          ...,
          [-0.7822, -0.3369, -0.2856,  ..., -0.9020, -0.7479, -0.4568],
          [-0.5424, -0.3369, -0.4911,  ..., -0.3883, -0.3027, -0.1657],
          [-0.3541, -0.3883, -0.6452,  ..., -0.6623, -0.3883, -0.3027]],

         [[-0.7052, -0.5651, -0.7752,  ..., -0.7927, -0.8277, -0.7402],
          [-0.7227, -0.6702, -0.9853,  ..., -0.5126, -0.9503, -1.0553],
          [-0.8277, -0.6176, -0.8102,  ...,  0.0126, -0.0749,  0.0826],
          ...,
          [-1.0728, -0.6176, -0.5651,  ..., -1.2129, -1.0203, -0.7227],
          [-0.8277, -0.6176, -0.7752,  ..., -0.6527, -0.5476, -0.4251],
          [-0.6352, -0.6702, -0.9153,  ..., -0.8627, -0.5826, -0.5301]],

         [[-0.7587, -0.6018, -0.7413,  ..., -0.7587, -0.8110, -0.7064],
          [-0.7413, -0.

Data:  tensor([[[[ 1.4269,  1.2899,  0.9303,  ..., -0.5767, -0.4739, -0.3712],
          [ 1.4098,  1.3242,  1.1187,  ..., -0.5938, -0.2684, -0.2856],
          [ 1.4269,  1.4954,  1.1529,  ..., -0.4226, -0.3712, -0.4568],
          ...,
          [ 1.9920,  2.1975,  2.2489,  ..., -0.6109, -0.6452, -0.6623],
          [ 1.9235,  2.1633,  2.2489,  ..., -0.6794, -0.6623, -0.6794],
          [ 1.8379,  2.1290,  2.2489,  ..., -0.6794, -0.6452, -0.6623]],

         [[ 1.3957,  1.3256,  1.0805,  ..., -0.0224,  0.0826,  0.1702],
          [ 1.3957,  1.3606,  1.2381,  ..., -0.0574,  0.2402,  0.1877],
          [ 1.3256,  1.4832,  1.2206,  ...,  0.1176,  0.1527,  0.0476],
          ...,
          [ 2.1835,  2.3761,  2.4286,  ...,  0.0651,  0.0476,  0.0301],
          [ 2.1134,  2.3410,  2.4111,  ..., -0.0049,  0.0126,  0.0126],
          [ 2.0784,  2.3410,  2.4286,  ..., -0.0049,  0.0301,  0.0126]],

         [[ 1.0365,  1.0191,  0.6531,  ..., -0.9156, -0.6367, -0.3753],
          [ 1.0539,  1.

Data:  tensor([[[[-0.4911, -0.3541, -0.3369,  ...,  2.0092,  2.0777,  2.0092],
          [-0.1999, -0.0629, -0.1143,  ...,  1.9749,  2.0948,  1.9749],
          [-0.1143, -0.0458, -0.0972,  ...,  2.0263,  2.0092,  1.8550],
          ...,
          [ 0.3138,  0.3138,  0.2967,  ..., -0.3027, -0.1828, -0.0287],
          [ 0.2624,  0.2967,  0.2796,  ..., -0.4568,  0.4679,  0.2111],
          [ 0.2111,  0.2624,  0.2796,  ..., -0.4911,  0.2282,  0.4166]],

         [[-0.5301, -0.3725, -0.3550,  ...,  2.1660,  2.2535,  2.1835],
          [-0.2325, -0.0574, -0.1275,  ...,  2.1310,  2.2885,  2.1660],
          [-0.1275, -0.0399, -0.0924,  ...,  2.1835,  2.2010,  2.0434],
          ...,
          [ 0.2227,  0.2577,  0.2402,  ..., -0.3375, -0.1625, -0.0224],
          [ 0.1877,  0.2402,  0.2402,  ..., -0.4776,  0.5378,  0.2227],
          [ 0.1527,  0.2052,  0.2402,  ..., -0.5126,  0.2752,  0.3978]],

         [[-0.4275, -0.2707, -0.2184,  ...,  2.4831,  2.5354,  2.4657],
          [-0.1312,  0.

Data:  tensor([[[[ 1.9578,  1.9064,  1.9749,  ...,  1.8037,  1.5468,  0.1426],
          [ 2.0434,  2.0605,  2.1119,  ...,  2.0263,  1.6495, -0.2171],
          [ 0.7077,  0.6906,  0.8447,  ...,  2.1119,  1.4954, -0.3027],
          ...,
          [ 0.6049,  1.0331,  1.0673,  ...,  0.1768,  0.6049,  0.0741],
          [ 0.6392,  1.0159,  1.1872,  ...,  0.6563,  0.8104,  0.5878],
          [ 0.7762,  1.0159,  0.8789,  ...,  0.4508,  0.4679,  0.4508]],

         [[ 2.0609,  1.9909,  2.0784,  ...,  1.9734,  1.7633,  0.3803],
          [ 2.2010,  2.2185,  2.2535,  ...,  2.2185,  1.8508, -0.0924],
          [ 0.9055,  0.9405,  1.1331,  ...,  2.3585,  1.6583, -0.2325],
          ...,
          [ 0.5728,  0.9405,  0.9230,  ...,  0.5903,  1.0630,  0.5728],
          [ 0.6779,  0.8880,  1.0280,  ...,  1.1155,  1.2556,  0.9930],
          [ 0.8704,  1.0105,  0.8529,  ...,  0.9405,  0.7479,  0.6254]],

         [[ 2.3088,  2.2566,  2.3088,  ...,  2.1694,  2.0300,  0.8274],
          [ 2.3611,  2.

Data:  tensor([[[[-1.2274, -1.1247, -1.0904,  ..., -0.4568, -0.5082, -0.5767],
          [-1.1418, -1.0904, -1.1760,  ..., -0.4568, -0.4568, -0.3712],
          [-1.1589, -1.0904, -1.1589,  ..., -0.4226, -0.3369, -0.2684],
          ...,
          [ 0.1254,  0.1939,  0.1597,  ...,  0.0912,  0.0056, -0.2856],
          [ 0.2111,  0.2453,  0.1597,  ...,  0.1254,  0.0741, -0.1143],
          [ 0.2111,  0.2796,  0.1939,  ..., -0.3883, -0.0458,  0.0056]],

         [[-1.0553, -1.0553, -1.0028,  ..., -0.3550, -0.3901, -0.4426],
          [-1.0553, -1.0028, -1.0203,  ..., -0.2850, -0.3200, -0.2500],
          [-1.0378, -1.0378, -1.0553,  ..., -0.2500, -0.2150, -0.1099],
          ...,
          [ 0.2402,  0.3452,  0.3102,  ...,  0.2402,  0.1352, -0.2500],
          [ 0.3627,  0.3978,  0.2927,  ...,  0.2752,  0.2402, -0.0049],
          [ 0.3627,  0.4153,  0.3452,  ..., -0.3550,  0.1001,  0.1352]],

         [[-0.8284, -0.8110, -0.7587,  ..., -0.1661, -0.2010, -0.2532],
          [-0.7936, -0.

Data:  tensor([[[[ 1.4440,  1.4098,  1.3927,  ...,  1.8550,  1.8379,  1.8208],
          [ 1.1872,  1.1529,  1.1187,  ...,  1.9064,  1.8550,  1.8037],
          [ 1.0673,  1.0502,  1.0502,  ...,  1.9064,  1.8379,  1.7865],
          ...,
          [ 0.0569,  0.0912,  0.2111,  ..., -0.5596, -0.4911, -0.4054],
          [-0.0287,  0.0227,  0.2111,  ..., -0.5424, -0.4397, -0.4226],
          [-0.1143, -0.0116,  0.0912,  ..., -0.5596, -0.4739, -0.4226]],

         [[ 1.8683,  1.8508,  1.7983,  ...,  1.8683,  1.8333,  1.8333],
          [ 1.5532,  1.5182,  1.4482,  ...,  1.8683,  1.8158,  1.8158],
          [ 1.3431,  1.3431,  1.3431,  ...,  1.8508,  1.8333,  1.8158],
          ...,
          [-0.1625, -0.1275,  0.0301,  ..., -0.8803, -0.8452, -0.8102],
          [-0.2150, -0.1275,  0.0651,  ..., -0.8978, -0.7927, -0.8102],
          [-0.3200, -0.2675, -0.1099,  ..., -0.9328, -0.8102, -0.7927]],

         [[ 2.0648,  2.0474,  2.0648,  ...,  1.9254,  1.8905,  1.8905],
          [ 1.6988,  1.

Data:  tensor([[[[-0.7479, -0.8335, -1.1418,  ...,  2.0605,  2.1290,  2.1633],
          [-0.8849, -0.9192, -1.1760,  ...,  2.1462,  2.1633,  2.1633],
          [-1.0219, -1.0219, -1.0733,  ...,  2.0777,  2.1633,  2.1975],
          ...,
          [ 0.0227, -0.0972,  0.1426,  ...,  0.3994,  0.5364,  0.4166],
          [ 0.4679,  0.1597, -0.0972,  ...,  0.3481,  0.4679,  0.4337],
          [ 0.3138,  0.0056, -0.3541,  ...,  0.2967,  0.3481,  0.4508]],

         [[ 0.0651, -0.0399, -0.3550,  ...,  2.2010,  2.2710,  2.3235],
          [-0.0749, -0.1099, -0.3901,  ...,  2.2885,  2.3060,  2.3235],
          [-0.2150, -0.2150, -0.3025,  ...,  2.2010,  2.3235,  2.3585],
          ...,
          [ 0.5203,  0.4328,  0.6779,  ..., -0.2150, -0.0749, -0.1975],
          [ 0.9755,  0.6779,  0.4328,  ..., -0.2675, -0.1450, -0.1800],
          [ 0.8179,  0.5028,  0.1352,  ..., -0.3200, -0.2675, -0.1625]],

         [[-1.0724, -1.1770, -1.5256,  ...,  2.1694,  2.2740,  2.3263],
          [-1.2119, -1.

Data:  tensor([[[[ 1.2728,  1.2557,  1.2385,  ..., -0.9705, -1.1075, -1.7412],
          [ 1.2728,  1.2728,  1.2728,  ..., -0.9534, -1.0733, -1.7069],
          [ 1.2899,  1.2899,  1.2728,  ..., -0.9363, -1.0562, -1.6042],
          ...,
          [ 0.4679,  0.4166,  0.3652,  ...,  0.3138,  0.3481,  0.3823],
          [ 0.4337,  0.4166,  0.3652,  ...,  0.3138,  0.3481,  0.3652],
          [ 0.4166,  0.3994,  0.3652,  ...,  0.2967,  0.3481,  0.3652]],

         [[ 1.2731,  1.2556,  1.2381,  ..., -0.9503, -1.0728, -1.6856],
          [ 1.2731,  1.2731,  1.2731,  ..., -0.8978, -1.0203, -1.6506],
          [ 1.2906,  1.2906,  1.2731,  ..., -0.8978, -1.0028, -1.5280],
          ...,
          [ 0.3452,  0.3277,  0.2927,  ...,  0.3978,  0.4328,  0.4678],
          [ 0.3452,  0.3277,  0.3102,  ...,  0.4153,  0.4503,  0.4678],
          [ 0.3277,  0.3102,  0.2927,  ...,  0.3978,  0.4503,  0.4678]],

         [[ 1.3851,  1.3502,  1.3328,  ..., -0.7936, -0.8807, -1.4733],
          [ 1.4025,  1.

Data:  tensor([[[[-0.8507, -0.7308, -0.5938,  ..., -0.1828, -0.0287, -0.1314],
          [-0.7822, -0.8507, -0.7308,  ..., -0.1657, -0.0287, -0.1486],
          [-0.7137, -0.8164, -0.7650,  ...,  0.0056, -0.1657, -0.2856],
          ...,
          [-1.1075, -1.0904, -1.4672,  ..., -0.8678, -0.7993, -1.0733],
          [-1.3815, -1.4158, -1.5185,  ..., -1.0048, -0.7822, -0.8849],
          [-1.1760, -1.0390, -0.9363,  ..., -0.9705, -0.9877, -0.9192]],

         [[-0.4251, -0.3375, -0.2325,  ..., -0.0224,  0.0301, -0.0399],
          [-0.3550, -0.4601, -0.3901,  ...,  0.0651,  0.1352,  0.0126],
          [-0.3550, -0.4601, -0.3901,  ...,  0.2752,  0.0651, -0.0574],
          ...,
          [-0.6702, -0.7402, -1.1779,  ..., -0.4776, -0.3901, -0.7752],
          [-0.8277, -0.8803, -1.0378,  ..., -0.6176, -0.3550, -0.5826],
          [-0.6352, -0.5301, -0.4426,  ..., -0.6001, -0.6176, -0.6877]],

         [[-0.3753, -0.3055, -0.1487,  ...,  0.1128,  0.1999,  0.1476],
          [-0.2707, -0.

Data:  tensor([[[[-0.2513,  0.1597,  0.3652,  ..., -0.9020, -0.7479, -0.6281],
          [-0.1657, -0.2684, -0.1999,  ..., -0.9705, -0.7822, -0.7137],
          [ 0.2453,  0.2453, -0.0287,  ..., -0.9363, -0.8164, -0.7308],
          ...,
          [ 1.2899,  1.1700,  1.1529,  ...,  0.0227, -0.0629, -0.2513],
          [ 1.1872,  1.2385,  1.1700,  ...,  0.0398, -0.0629, -0.0116],
          [ 1.1529,  1.0502,  1.1529,  ...,  0.0056, -0.0116,  0.1426]],

         [[ 0.0651,  0.4503,  0.5903,  ..., -0.8452, -0.6877, -0.5826],
          [ 0.1176, -0.0399, -0.0049,  ..., -0.8627, -0.6877, -0.6352],
          [ 0.4853,  0.4503,  0.2227,  ..., -0.8102, -0.6877, -0.6527],
          ...,
          [ 0.6779,  0.5203,  0.5378,  ...,  0.6954,  0.5903,  0.3277],
          [ 0.5728,  0.5903,  0.6254,  ...,  0.7304,  0.6078,  0.5903],
          [ 0.5378,  0.4153,  0.5728,  ...,  0.6954,  0.7129,  0.8179]],

         [[ 0.6531,  1.1237,  1.3502,  ..., -0.5321, -0.4450, -0.4275],
          [ 0.7576,  0.

       device='cuda:0')
Output:  tensor([[-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-8.5657e+04, -4.7644e+04, -8.9429e+04,  ..., -6.6673e+04,
         -7.4977e+04, -9.0960e+04],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        ...,
        [-4.4161e+00, -1.4468e+00, -3.7769e+00,  ..., -4.7341e+00,
         -5.3595e+00, -4.4009e+00],
        [-5.2899e-01, -3.3132e-01, -1.7916e-01,  ..., -1.1206e+00,
         -1.1952e+00, -1.2497e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00]], device='cuda:0')
target:  tensor([ 98,  34,   0, 105,   0,  48,  84,  15], device='cuda:0')
tensor([[ 4],
        [56],
        [ 4],
        [ 4],
        [77],
        [36],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[-1.4672, -1.4672, -1.5014,  ..., -1.3987, -1.3644, -1.4329],
          [-1.4158, -1.4329, -1.46

       device='cuda:0')
Output:  tensor([[-1.1919e+01,  8.4027e+00, -1.3842e+01,  ..., -1.1526e+01,
         -1.4573e+01, -1.2153e+01],
        [-1.6331e+02, -1.4517e+02, -1.7258e+02,  ..., -1.1268e+02,
         -1.2284e+02, -1.9358e+02],
        [-1.7369e+02, -8.4624e+01, -1.7607e+02,  ..., -2.2381e+01,
         -1.1206e+02, -1.9876e+02],
        ...,
        [-1.4213e+04, -7.1443e+03, -1.4751e+04,  ..., -1.0960e+04,
         -1.0513e+04, -1.4822e+04],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00]], device='cuda:0')
target:  tensor([114,  35,  93,  58,  50,  83,  23, 120], device='cuda:0')
tensor([[47],
        [97],
        [47],
        [ 4],
        [12],
        [36],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[-1.9124, -1.7412, -1.0562,  ..., -1.2103, -0.9192, -0.5424],
          [-1.0904, -1.3130, -1.79

        [ 4]], device='cuda:0')
Data:  tensor([[[[-1.7925, -1.8268, -1.8610,  ..., -0.9192, -1.1075, -1.1589],
          [-1.8610, -1.8610, -1.8782,  ..., -0.9534, -1.0390, -1.1075],
          [-1.8782, -1.8782, -1.8610,  ..., -0.3541, -0.4739, -0.5767],
          ...,
          [-0.5082, -0.9363, -1.2274,  ..., -0.2171, -0.1486, -0.0287],
          [-1.0904, -1.1589, -1.2274,  ..., -0.1314, -0.1143, -0.0458],
          [-1.1760, -1.1760, -1.1589,  ..., -0.1999, -0.1143, -0.1486]],

         [[-1.8431, -1.8782, -1.9132,  ..., -0.7402, -0.9153, -0.9678],
          [-1.9132, -1.9132, -1.9307,  ..., -0.8803, -0.9153, -0.9503],
          [-1.9307, -1.9307, -1.9132,  ..., -0.4601, -0.5301, -0.5826],
          ...,
          [-0.7052, -1.0028, -1.2129,  ..., -0.3375, -0.2500, -0.1625],
          [-1.2304, -1.1604, -1.1604,  ..., -0.3550, -0.3025, -0.2325],
          [-1.2304, -1.1078, -1.0378,  ..., -0.5126, -0.3375, -0.3725]],

         [[-1.6476, -1.6824, -1.7173,  ..., -0.5147, -0.6890, -

Data:  tensor([[[[-1.3473, -1.3130, -1.3130,  ..., -0.5938, -0.6623, -0.7822],
          [-1.3987, -1.3644, -1.3473,  ..., -0.5767, -0.6109, -0.7137],
          [-1.5014, -1.5014, -1.4843,  ..., -0.5424, -0.5767, -0.6794],
          ...,
          [-0.5596, -0.5596, -0.4568,  ..., -0.2171, -0.1657, -0.1486],
          [-0.4054, -0.4911, -0.5767,  ..., -0.2342, -0.2171, -0.1486],
          [-0.3712, -0.4226, -0.4226,  ..., -0.2856, -0.1999, -0.1486]],

         [[-0.9853, -0.9503, -0.9503,  ..., -0.3200, -0.3901, -0.5126],
          [-1.0378, -1.0028, -0.9853,  ..., -0.3025, -0.3375, -0.4426],
          [-1.1604, -1.1429, -1.1253,  ..., -0.2500, -0.2850, -0.3901],
          ...,
          [-0.2850, -0.2850, -0.1800,  ...,  0.0826,  0.1352,  0.1527],
          [-0.1275, -0.2150, -0.3025,  ...,  0.0651,  0.0826,  0.1527],
          [-0.0924, -0.1450, -0.1450,  ...,  0.0126,  0.1001,  0.1527]],

         [[-1.4907, -1.4733, -1.4733,  ..., -1.0550, -1.1421, -1.2467],
          [-1.5256, -1.

Data:  tensor([[[[-1.5870, -1.5699, -1.4158,  ..., -1.4500, -1.4672, -1.3302],
          [-1.6384, -1.6042, -1.5014,  ..., -1.3987, -1.3987, -1.2788],
          [-1.6213, -1.6042, -1.5014,  ..., -1.3815, -1.3302, -1.1932],
          ...,
          [ 0.7762,  0.0056, -0.5596,  ...,  1.2728,  1.3584,  1.4440],
          [ 0.6734,  0.3994,  0.0741,  ...,  0.9988,  1.4440,  1.4612],
          [ 0.6049,  0.4508,  0.8618,  ...,  0.9646,  1.0502,  1.5125]],

         [[-1.4405, -1.5280, -1.5805,  ..., -1.2829, -1.3179, -1.2829],
          [-1.4580, -1.4930, -1.5280,  ..., -1.2829, -1.3004, -1.3004],
          [-1.4755, -1.4755, -1.5105,  ..., -1.2654, -1.3004, -1.3704],
          ...,
          [ 1.1155,  0.2752, -0.4251,  ...,  1.6232,  1.6933,  1.7458],
          [ 0.9230,  0.6429,  0.2927,  ...,  1.2206,  1.7458,  1.8333],
          [ 0.9055,  0.6779,  1.2381,  ...,  1.1681,  1.2556,  1.8333]],

         [[-1.3513, -1.4210, -1.4384,  ..., -1.1770, -1.2119, -1.1944],
          [-1.3687, -1.

Data:  tensor([[[[-2.0837, -2.0837, -2.0837,  ..., -1.5699, -1.5699, -1.5528],
          [-2.0837, -2.0837, -2.0837,  ..., -1.5528, -1.5528, -1.5357],
          [-2.0837, -2.0837, -2.0837,  ..., -1.5699, -1.5528, -1.5357],
          ...,
          [-0.3198, -0.2684, -0.2171,  ..., -2.0837, -2.0837, -2.0837],
          [-0.3369, -0.2856, -0.2171,  ..., -2.0837, -2.0837, -2.0837],
          [-0.3541, -0.3027, -0.2342,  ..., -2.0837, -2.0837, -2.0837]],

         [[-1.9832, -1.9657, -1.9657,  ..., -2.0007, -2.0182, -2.0007],
          [-1.9657, -1.9657, -1.9657,  ..., -2.0182, -2.0182, -2.0007],
          [-1.9657, -1.9657, -1.9657,  ..., -2.0182, -2.0182, -2.0007],
          ...,
          [-0.9678, -0.9153, -0.8803,  ..., -1.9657, -1.9657, -1.9657],
          [-0.9678, -0.9328, -0.8803,  ..., -1.9657, -1.9657, -1.9657],
          [-0.9853, -0.9503, -0.8978,  ..., -1.9657, -1.9657, -1.9657]],

         [[-1.7870, -1.7870, -1.7870,  ..., -1.7522, -1.7696, -1.7696],
          [-1.7870, -1.

Data:  tensor([[[[-1.1932, -1.0390, -0.8507,  ..., -0.0801,  0.1083,  0.0227],
          [-1.3130, -1.0904, -0.8849,  ..., -0.0116,  0.2111,  0.0398],
          [-1.3815, -1.3473, -1.2617,  ...,  0.0569,  0.2453,  0.1768],
          ...,
          [ 1.5125,  1.3927,  1.2043,  ...,  2.1119,  2.0263,  1.9064],
          [ 1.0331,  1.2728,  1.4440,  ...,  1.9920,  1.9749,  1.9407],
          [ 0.8789,  1.0673,  1.0844,  ...,  1.9578,  1.9235,  1.9235]],

         [[-0.7752, -0.6001, -0.4426,  ...,  0.3627,  0.5903,  0.5203],
          [-0.8978, -0.6176, -0.3901,  ...,  0.4328,  0.6954,  0.5553],
          [-1.0203, -0.9328, -0.7927,  ...,  0.4503,  0.6254,  0.5728],
          ...,
          [ 0.7829,  0.7129,  0.5203,  ...,  1.8683,  1.8158,  1.6933],
          [ 0.2052,  0.5028,  0.8004,  ...,  1.7458,  1.7983,  1.6232],
          [ 0.0476,  0.2227,  0.2577,  ...,  1.6758,  1.7108,  1.6232]],

         [[-1.3164, -1.2990, -1.1073,  ..., -0.5147, -0.6193, -0.8110],
          [-1.3861, -1.

Data:  tensor([[[[ 0.2624,  0.2796,  0.3652,  ..., -1.5870, -1.6213, -1.6042],
          [ 0.3823,  0.3138,  0.5022,  ..., -1.5870, -1.5870, -1.5870],
          [ 0.7933,  0.7591,  0.6906,  ..., -1.6042, -1.6042, -1.5870],
          ...,
          [-0.2171,  0.0398,  0.1254,  ..., -1.0562, -1.1247, -1.4672],
          [-0.0116, -0.0116,  0.1939,  ..., -0.9534, -1.1418, -1.3644],
          [ 0.2111,  0.0741,  0.1426,  ..., -1.0904, -1.3815, -1.3473]],

         [[-0.7402, -0.6702, -0.5651,  ..., -1.0203, -1.0378, -1.0203],
          [-0.6001, -0.6001, -0.1975,  ..., -1.0028, -1.0028, -1.0028],
          [ 0.2927,  0.2227,  0.2227,  ..., -1.0028, -1.0203, -1.0203],
          ...,
          [-0.6702, -0.4951, -0.4076,  ..., -1.0728, -1.1604, -1.0728],
          [-0.4601, -0.5476, -0.3375,  ..., -1.0553, -1.1954, -1.1429],
          [-0.2150, -0.4426, -0.3550,  ..., -1.1429, -1.3704, -1.1954]],

         [[-0.7413, -0.6890, -0.5495,  ..., -0.5147, -0.5670, -0.5844],
          [-0.6018, -0.

Data:  tensor([[[[ 0.7591,  0.7933,  0.8618,  ...,  1.5639,  1.5982,  1.5297],
          [ 0.8104,  0.7762,  0.8104,  ...,  1.4612,  1.5639,  1.5468],
          [ 0.7077,  0.6906,  0.7419,  ...,  1.5297,  1.4954,  1.4612],
          ...,
          [-0.7822, -0.6965, -0.6281,  ..., -0.6452, -0.7308, -0.5424],
          [-0.6452, -0.8849, -0.8507,  ..., -1.0733, -0.5082, -0.7822],
          [-0.5596, -0.9534, -0.9534,  ..., -0.8164, -0.7308,  0.1254]],

         [[ 0.6954,  0.7304,  0.8004,  ...,  1.7633,  1.7983,  1.7458],
          [ 0.7479,  0.7129,  0.7479,  ...,  1.6758,  1.7458,  1.7458],
          [ 0.6779,  0.6779,  0.7304,  ...,  1.7108,  1.6758,  1.6408],
          ...,
          [-0.1800, -0.0574, -0.0049,  ..., -0.3375, -0.5301, -0.3550],
          [ 0.0826,  0.0651, -0.0224,  ..., -0.8627, -0.3200, -0.5301],
          [ 0.1527, -0.0399, -0.1450,  ..., -0.5476, -0.4601,  0.4678]],

         [[-0.8458, -0.8284, -0.7761,  ..., -0.4275, -0.3404, -0.4101],
          [-0.8633, -0.

Data:  tensor([[[[ 0.5878,  0.3823,  0.1426,  ...,  1.2214,  1.2214,  1.1015],
          [ 0.5022,  0.4508,  0.3481,  ...,  1.3755,  1.2385,  1.2728],
          [ 0.6563,  0.5193,  0.4851,  ...,  1.3070,  1.1529,  1.2043],
          ...,
          [-0.1143,  0.1939,  0.0227,  ..., -0.0458, -0.0287,  0.0398],
          [ 0.1426,  0.1597, -0.3883,  ...,  0.0741,  0.0912,  0.1768],
          [ 0.3994,  0.0398, -0.7650,  ...,  0.1254,  0.1083,  0.0569]],

         [[ 0.5553,  0.3627,  0.1527,  ...,  1.0630,  1.0805,  0.9580],
          [ 0.4678,  0.4328,  0.3627,  ...,  1.2381,  1.0980,  1.1331],
          [ 0.6779,  0.5378,  0.4678,  ...,  1.1681,  1.0105,  1.0805],
          ...,
          [-0.3200, -0.0049, -0.1625,  ...,  0.1352,  0.1527,  0.2227],
          [-0.0049, -0.0049, -0.5476,  ...,  0.2577,  0.2752,  0.3627],
          [ 0.2927, -0.0924, -0.8978,  ...,  0.3102,  0.2927,  0.2402]],

         [[ 0.7576,  0.5311,  0.2173,  ...,  1.1411,  1.1585,  1.0365],
          [ 0.6531,  0.

Data:  tensor([[[[ 0.3481,  0.6221,  1.3584,  ...,  1.0331,  1.0502,  1.2385],
          [ 0.4508,  0.7933,  1.4440,  ...,  1.1358,  1.2385,  1.4783],
          [ 0.6563,  0.8447,  1.4098,  ...,  1.1872,  1.4612,  1.6667],
          ...,
          [-0.6281, -0.6109, -0.5938,  ...,  0.2453,  0.2967, -0.5938],
          [-0.7822, -0.7137, -0.6623,  ...,  0.3481,  0.3481, -0.4911],
          [-1.2274, -0.9877, -0.8507,  ...,  0.3994,  0.2796, -0.5767]],

         [[-0.1975, -0.0049,  0.4153,  ...,  0.1001,  0.1176,  0.2402],
          [-0.1975,  0.0301,  0.3978,  ...,  0.1702,  0.2577,  0.4503],
          [-0.1099, -0.0224,  0.3277,  ...,  0.1877,  0.3978,  0.5203],
          ...,
          [-0.0749, -0.0224, -0.0049,  ...,  0.7129,  0.8179, -0.0224],
          [-0.2325, -0.1099, -0.0574,  ...,  0.8354,  0.8529,  0.0301],
          [-0.7927, -0.4776, -0.3025,  ...,  0.8880,  0.8354, -0.0399]],

         [[-1.8044, -1.6824, -1.4907,  ..., -1.3339, -1.3687, -1.3861],
          [-1.7870, -1.

Output:  tensor([[-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        ...,
        [-2.4216e+02, -2.2838e+02, -2.3539e+02,  ..., -1.9437e+02,
         -1.9346e+02, -2.3819e+02],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-1.2146e+01, -6.3443e+00, -1.1655e+01,  ..., -1.8509e+01,
         -4.9737e+00, -9.6163e+00]], device='cuda:0')
target:  tensor([115,  26,  67,  77,  63,  93,  56,   9], device='cuda:0')
tensor([[ 4],
        [ 4],
        [ 4],
        [ 4],
        [11],
        [56],
        [ 4],
        [56]], device='cuda:0')
Data:  tensor([[[[ 0.8789,  0.8961,  0.9646,  ...,  0.7248,  0.6734,  0.6049],
          [ 0.9132,  0.9303,  0.9132,  ...,  0.7419,  0.63

       device='cuda:0')
Output:  tensor([[-3.7659e+03, -2.6140e+03, -3.9907e+03,  ..., -3.0529e+03,
         -2.5148e+03, -3.9070e+03],
        [-2.3150e+01, -2.4260e+01, -2.3704e+01,  ...,  3.7073e-02,
         -1.9689e+01, -2.8550e+01],
        [-3.3682e+01, -3.2216e+01, -3.1766e+01,  ..., -3.7947e+01,
         -2.9657e+01, -3.5885e+01],
        ...,
        [-6.2257e+01, -5.7204e+01, -6.3930e+01,  ..., -4.5129e+01,
         -2.2639e+01, -7.0466e+01],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00]], device='cuda:0')
target:  tensor([109,  46,  78,  82,  11,   5,  61, 130], device='cuda:0')
tensor([[11],
        [97],
        [12],
        [ 4],
        [ 4],
        [12],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[ 0.2967,  0.3138,  0.2967,  ...,  0.3652,  0.3823,  0.4166],
          [ 0.1939,  0.2282,  0.21

       device='cuda:0')
Output:  tensor([[-2.8053e+01, -2.2658e+01, -3.0095e+01,  ..., -3.8609e+01,
         -2.4948e+01, -2.9054e+01],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-7.2726e+02, -5.7617e+02, -7.4983e+02,  ..., -5.0895e+02,
         -3.3109e+02, -7.2676e+02],
        ...,
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.6625e+02, -2.6911e+02, -4.5784e+02,  ..., -4.0768e+02,
         -3.2209e+02, -4.6348e+02],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00]], device='cuda:0')
target:  tensor([38, 46,  2, 56,  5,  1,  1, 78], device='cuda:0')
tensor([[66],
        [ 4],
        [11],
        [ 4],
        [ 4],
        [ 4],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[ 0.5878,  0.4679,  0.5878,  ...,  0.8104,  0.7248,  0.6049],
          [ 0.4679,  0.5364,  0.7077,  ...

Data:  tensor([[[[ 0.3652,  0.6906,  0.3138,  ...,  1.3070,  1.3070,  1.7180],
          [ 0.7077,  0.4508,  0.3309,  ...,  0.9988,  1.2385,  1.6495],
          [ 0.3823,  0.0398, -0.1657,  ...,  0.8447,  1.0673,  1.5982],
          ...,
          [ 0.4508,  0.1083,  0.4851,  ...,  0.4166, -0.1828, -0.6794],
          [ 1.3755,  1.0673,  0.5878,  ...,  0.0912, -0.5767, -0.8335],
          [ 1.3070,  1.0844,  0.6221,  ..., -0.1657, -0.2513,  0.0227]],

         [[ 0.7654,  0.9930,  0.6254,  ...,  1.4832,  1.4832,  1.7983],
          [ 0.9755,  0.6429,  0.6078,  ...,  1.3081,  1.4482,  1.7983],
          [ 0.5028,  0.1527,  0.0651,  ...,  1.1331,  1.3081,  1.7808],
          ...,
          [ 0.8529,  0.4853,  0.6429,  ...,  0.6954,  0.4853, -0.0049],
          [ 1.5182,  1.1506,  0.6429,  ...,  0.7304,  0.1352, -0.2500],
          [ 1.5357,  1.3256,  0.7479,  ...,  0.3452,  0.3627,  0.5028]],

         [[ 0.4265,  0.8797,  0.3742,  ...,  1.5594,  1.4374,  1.8905],
          [ 0.7576,  0.

Data:  tensor([[[[ 1.3070,  1.3413,  1.3927,  ...,  0.9817,  0.9132,  1.2728],
          [ 1.3070,  1.3070,  1.3927,  ...,  1.0844,  1.0331,  1.6838],
          [ 1.2728,  1.2557,  1.3755,  ...,  1.1529,  1.1529,  1.9235],
          ...,
          [-0.4397, -0.5253, -0.2171,  ..., -0.8335, -0.7822, -0.5767],
          [-1.3815, -1.2788, -0.5767,  ..., -1.0219, -0.8507, -0.6623],
          [-1.2617, -1.6727, -0.4568,  ..., -1.0904, -0.7993, -0.5938]],

         [[ 1.7808,  1.8158,  1.7983,  ...,  1.1506,  1.1856,  1.7108],
          [ 1.7808,  1.7983,  1.7983,  ...,  1.2731,  1.3782,  2.1310],
          [ 1.7808,  1.7983,  1.7983,  ...,  1.3957,  1.6408,  2.2360],
          ...,
          [-0.1450, -0.2150,  0.0826,  ..., -0.1975, -0.0574,  0.2402],
          [-0.9328, -0.7402, -0.1975,  ..., -0.0924,  0.1352,  0.3277],
          [-0.8803, -1.3354, -0.1625,  ..., -0.2500,  0.1702,  0.4153]],

         [[ 2.1171,  2.1520,  2.1171,  ...,  1.4722,  1.5245,  1.9254],
          [ 2.1171,  2.

         -1.1508e+00, -1.2135e+00]], device='cuda:0')
target:  tensor([  2,   9,  48, 110,  44,   3,  72,  67], device='cuda:0')
tensor([[ 4],
        [17],
        [ 4],
        [56],
        [ 4],
        [ 4],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[-1.8439, -1.9124, -1.9124,  ...,  0.4508,  0.3652,  0.3309],
          [-1.8268, -1.8782, -1.8782,  ...,  1.3242,  1.2043,  0.9988],
          [-1.7754, -1.8610, -1.8782,  ...,  1.6495,  1.4098,  1.2214],
          ...,
          [-1.8953, -1.7754, -1.6727,  ..., -1.8268, -1.8610, -1.7925],
          [-1.7925, -1.7754, -1.6898,  ..., -1.8439, -1.8097, -1.7754],
          [-1.6898, -1.7583, -1.7240,  ..., -1.8097, -1.8097, -1.8268]],

         [[-1.7556, -1.8256, -1.8256,  ..., -1.0903, -1.1954, -1.2479],
          [-1.7381, -1.7906, -1.7906,  ..., -0.6001, -0.7052, -0.8102],
          [-1.6856, -1.7731, -1.7906,  ..., -0.5126, -0.7052, -0.7752],
          ...,
          [-1.9307, -1.8431, -1.7556,  ..., -1.9132, -

        [47]], device='cuda:0')
Data:  tensor([[[[-1.3815, -1.3130, -1.5185,  ..., -0.4911, -0.8507, -1.2788],
          [-0.9020, -1.4158, -1.5528,  ..., -0.8507, -1.3644, -1.1589],
          [-1.1589, -1.2445, -1.0904,  ..., -1.0562, -1.2788, -1.1589],
          ...,
          [-1.5870, -1.7240, -1.7412,  ..., -1.6384, -1.5528, -1.5014],
          [-1.6042, -1.6727, -1.9295,  ..., -1.5185, -1.4672, -1.3473],
          [-1.7240, -1.6384, -1.7240,  ..., -1.4329, -1.3644, -1.6898]],

         [[-0.6001, -0.5126, -0.6176,  ...,  0.2052, -0.0224, -0.3901],
          [-0.0924, -0.6176, -0.6877,  ..., -0.0049, -0.5301, -0.2150],
          [-0.3550, -0.6001, -0.5126,  ..., -0.0924, -0.4251, -0.1800],
          ...,
          [-0.8452, -1.1954, -1.4230,  ..., -0.9678, -0.8102, -0.6352],
          [-0.7927, -0.9678, -1.4055,  ..., -0.7402, -0.4951, -0.4251],
          [-1.0378, -0.8803, -1.0553,  ..., -0.5126, -0.4251, -1.0028]],

         [[-1.1073, -1.0376, -1.2293,  ..., -0.1835, -0.7064, -

         -1.1508e+00, -1.2135e+00]], device='cuda:0')
target:  tensor([ 74,  19,  93,  28,  51, 111, 105,  40], device='cuda:0')
tensor([[ 4],
        [ 4],
        [ 4],
        [ 4],
        [ 4],
        [ 4],
        [97],
        [ 4]], device='cuda:0')
Data:  tensor([[[[-0.8164, -1.1932, -0.5082,  ..., -0.9877, -0.7822, -0.5082],
          [-0.8849, -1.1589,  0.1083,  ..., -0.6623,  0.0569, -0.2171],
          [-0.0629, -0.5082, -0.1486,  ..., -0.5938, -0.1486, -0.5596],
          ...,
          [ 0.5707, -0.2342, -0.2684,  ...,  0.2111,  0.0056,  0.4337],
          [ 0.5878,  0.1768, -0.0458,  ...,  0.1939,  0.3481,  0.3994],
          [ 0.4679,  0.1939, -0.1314,  ..., -0.2856,  0.1597,  0.3481]],

         [[-0.2325, -0.5826,  0.1527,  ..., -0.3550, -0.1450,  0.1352],
          [-0.3200, -0.5826,  0.7654,  ..., -0.0574,  0.6954,  0.4503],
          [ 0.4328,  0.0126,  0.4153,  ...,  0.0126,  0.4678,  0.0651],
          ...,
          [ 1.2731,  0.2927,  0.0651,  ...,  0.8004,  

Output:  tensor([[-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-7.2940e+00, -6.8996e+00, -7.0908e+00,  ...,  1.1339e+00,
         -6.0163e+00, -9.4075e+00],
        ...,
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-1.7860e+03, -9.6729e+02, -1.8290e+03,  ..., -1.4794e+03,
         -1.2011e+03, -1.9402e+03],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00]], device='cuda:0')
target:  tensor([ 87,  84,  80,  81,  47, 103, 111, 112], device='cuda:0')
tensor([[ 4],
        [ 4],
        [47],
        [ 4],
        [ 4],
        [ 4],
        [36],
        [ 4]], device='cuda:0')
Data:  tensor([[[[-0.9877, -1.0048, -1.0390,  ...,  1.8722,  1.9235,  1.8893],
          [-1.0048, -1.0219, -1.0390,  ...,  1.8550,  1.88

       device='cuda:0')
Output:  tensor([[-8.3458e+01, -8.2403e+01, -8.0205e+01,  ..., -8.1607e+01,
         -7.0915e+01, -8.8351e+01],
        [-3.9459e+04, -2.3761e+04, -4.0054e+04,  ..., -2.9740e+04,
         -3.0771e+04, -4.1294e+04],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        ...,
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00]], device='cuda:0')
target:  tensor([  3,  86,   4,  16, 129,  96,  45,  41], device='cuda:0')
tensor([[45],
        [36],
        [ 4],
        [ 1],
        [ 4],
        [ 4],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[ 1.9749,  1.8208,  1.7352,  ..., -0.7137, -0.6965, -0.6623],
          [ 1.9920,  1.8379,  1.73

       device='cuda:0')
Output:  tensor([[-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-2.9818e+00, -2.8670e+00, -2.8265e+00,  ..., -3.5351e+00,
         -4.2750e+00, -3.8173e+00],
        [-1.0628e+01, -8.0053e+00, -1.0224e+01,  ..., -1.0558e+01,
         -7.4952e+00, -1.1195e+01],
        ...,
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-9.2642e+03, -5.5258e+03, -9.6424e+03,  ..., -6.6056e+03,
         -6.8373e+03, -9.9631e+03]], device='cuda:0')
target:  tensor([114,  21,  94,  28, 100,  47,  79,  52], device='cuda:0')
tensor([[  4],
        [ 31],
        [ 66],
        [  4],
        [117],
        [  4],
        [  4],
        [ 36]], device='cuda:0')
Data:  tensor([[[[ 0.1768,  0.3481,  0.2967,  ...,  2.1804,  2.1975,  2.2147],
          [ 0.1939,  0.211

       device='cuda:0')
Output:  tensor([[-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-3.1571e+04, -1.9730e+04, -3.3075e+04,  ..., -2.6054e+04,
         -2.6455e+04, -3.3464e+04],
        ...,
        [-6.6468e+04, -3.7280e+04, -6.8842e+04,  ..., -5.2155e+04,
         -5.5263e+04, -6.9931e+04],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00]], device='cuda:0')
target:  tensor([ 29,  56, 114,  25,  63, 100,  52,  82], device='cuda:0')
tensor([[ 4],
        [ 4],
        [56],
        [ 4],
        [ 4],
        [56],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[-1.9638, -1.9467, -1.8439,  ..., -0.3027, -0.1999, -0.2513],
          [-1.9980, -1.9809, -1.80

       device='cuda:0')
Output:  tensor([[-17.7942,   3.4168, -20.3005,  ..., -16.0664, -17.4659, -16.8615],
        [ -0.4912,  -0.3603,  -0.1338,  ...,  -1.0861,  -1.1508,  -1.2135],
        [ -0.4912,  -0.3603,  -0.1338,  ...,  -1.0861,  -1.1508,  -1.2135],
        ...,
        [ -0.4912,  -0.3603,  -0.1338,  ...,  -1.0861,  -1.1508,  -1.2135],
        [ -0.4912,  -0.3603,  -0.1338,  ...,  -1.0861,  -1.1508,  -1.2135],
        [ -0.4912,  -0.3603,  -0.1338,  ...,  -1.0861,  -1.1508,  -1.2135]],
       device='cuda:0')
target:  tensor([ 70,  68,  62,  75,  13,  85,  51, 108], device='cuda:0')
tensor([[ 47],
        [  4],
        [  4],
        [117],
        [  4],
        [  4],
        [  4],
        [  4]], device='cuda:0')
Data:  tensor([[[[ 0.7419,  0.7419,  0.6906,  ...,  0.7248,  0.7419,  0.7248],
          [ 1.2043,  1.1700,  1.1015,  ...,  0.7419,  0.7591,  0.7248],
          [ 1.5297,  1.4954,  1.4954,  ...,  0.7419,  0.7762,  0.7591],
          ...,
          [-0.0116, -0

Data:  tensor([[[[-1.2617, -1.2445, -1.2445,  ..., -1.1247, -1.1247, -1.1075],
          [-1.2617, -1.2445, -1.2445,  ..., -1.1247, -1.1247, -1.1075],
          [-1.2617, -1.2445, -1.2445,  ..., -1.1247, -1.1247, -1.1075],
          ...,
          [ 0.6049,  0.6221,  0.5536,  ..., -0.5938, -0.6281, -0.9534],
          [ 0.5364,  0.5364,  0.4166,  ..., -0.6109, -0.7137, -0.8335],
          [ 0.6049,  0.5193,  0.4337,  ..., -0.5424, -0.8164, -0.6281]],

         [[-0.6001, -0.5826, -0.5826,  ..., -0.4601, -0.4601, -0.4426],
          [-0.5651, -0.5476, -0.5476,  ..., -0.4601, -0.4601, -0.4426],
          [-0.5651, -0.5476, -0.5476,  ..., -0.4601, -0.4601, -0.4426],
          ...,
          [ 0.3978,  0.4678,  0.4678,  ...,  0.6604,  0.7304,  0.4853],
          [ 0.4853,  0.5728,  0.4853,  ...,  0.4503,  0.4328,  0.4153],
          [ 0.5903,  0.5903,  0.5553,  ...,  0.5028,  0.3452,  0.6078]],

         [[-0.6367, -0.6193, -0.6193,  ..., -0.4624, -0.4624, -0.4450],
          [-0.6193, -0.

       device='cuda:0')
Output:  tensor([[-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        ...,
        [-6.2608e+01, -4.2057e+01, -6.6829e+01,  ..., -5.7886e+01,
         -6.5316e+01, -6.2665e+01],
        [-2.6189e+03, -1.9722e+03, -2.7572e+03,  ..., -2.1220e+03,
         -1.4521e+03, -2.7066e+03],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00]], device='cuda:0')
target:  tensor([ 42,  76,  54,  55,  96,   4, 115,  89], device='cuda:0')
tensor([[  4],
        [  4],
        [  4],
        [117],
        [  4],
        [ 91],
        [ 11],
        [  4]], device='cuda:0')
Data:  tensor([[[[-1.7412, -1.7583, -1.7754,  ..., -1.8097, -1.7925, -1.8268],
          [-1.7583, -1.775

       device='cuda:0')
Output:  tensor([[-5.8164e+02, -4.2070e+02, -6.1974e+02,  ..., -4.4948e+02,
         -4.5058e+02, -6.1156e+02],
        [-1.7500e+01, -1.7479e+01, -1.8523e+01,  ..., -1.7811e+01,
         -2.2487e+01, -1.8995e+01],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        ...,
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00]], device='cuda:0')
target:  tensor([112,  33, 105, 129,  74,  12,  28,  86], device='cuda:0')
tensor([[ 41],
        [ 31],
        [  4],
        [117],
        [ 56],
        [  4],
        [  4],
        [  4]], device='cuda:0')
Data:  tensor([[[[-2.0152, -1.9638, -1.9638,  ..., -1.8610, -1.8610, -1.7754],
          [-1.9980, -1.946

       device='cuda:0')
target:  tensor([39, 20,  4,  0, 22, 16, 35, 60], device='cuda:0')
tensor([[ 4],
        [ 4],
        [ 4],
        [ 4],
        [17],
        [ 4],
        [ 4],
        [31]], device='cuda:0')
Data:  tensor([[[[-2.0665, -2.0665, -2.0494,  ..., -1.8610, -1.8610, -1.8439],
          [-2.0494, -2.0494, -2.0494,  ..., -1.8610, -1.8439, -1.8439],
          [-2.0323, -2.0494, -2.0494,  ..., -1.8439, -1.8439, -1.8268],
          ...,
          [ 0.5193,  0.4679,  0.4508,  ..., -0.8164, -0.8164, -0.8335],
          [ 0.5193,  0.4508,  0.4508,  ..., -0.7137, -0.7479, -0.7993],
          [ 0.5536,  0.4851,  0.4508,  ..., -0.1828, -0.2171, -0.3027]],

         [[-1.9482, -1.9482, -1.9307,  ..., -1.7906, -1.7906, -1.7731],
          [-1.9307, -1.9307, -1.9307,  ..., -1.7906, -1.7731, -1.7731],
          [-1.9132, -1.9307, -1.9307,  ..., -1.7731, -1.7731, -1.7556],
          ...,
          [ 0.2227,  0.1352,  0.1001,  ..., -1.1779, -1.1954, -1.1779],
          [ 0.2402, 

Data:  tensor([[[[ 0.4166,  0.4166,  0.3138,  ..., -0.6452, -0.6965, -0.8507],
          [ 0.8618,  0.8276,  0.7591,  ..., -0.7822, -0.8335, -0.8849],
          [ 1.4783,  1.4612,  1.4440,  ..., -0.8335, -0.8678, -0.9534],
          ...,
          [ 1.0844,  1.0844,  1.0844,  ...,  0.5364,  0.5707,  0.6734],
          [ 1.0331,  1.0673,  1.0673,  ...,  0.7591,  0.7762,  0.8276],
          [ 1.0502,  1.0331,  1.0159,  ...,  0.7762,  0.7762,  0.7591]],

         [[ 0.4678,  0.4328,  0.2927,  ..., -1.3004, -1.3704, -1.4930],
          [ 0.4853,  0.4678,  0.3978,  ..., -1.3880, -1.4405, -1.5105],
          [ 0.6779,  0.6604,  0.6604,  ..., -1.4580, -1.5105, -1.5805],
          ...,
          [ 0.3627,  0.3627,  0.3627,  ..., -0.4601, -0.5301, -0.4951],
          [ 0.3277,  0.3277,  0.3452,  ..., -0.4076, -0.4251, -0.4426],
          [ 0.3452,  0.3102,  0.2927,  ..., -0.4776, -0.4776, -0.4951]],

         [[ 0.3045,  0.2871,  0.1651,  ..., -1.5604, -1.5256, -1.6302],
          [ 0.1651,  0.

Data:  tensor([[[[-1.0219, -0.9534, -1.2617,  ...,  0.9303,  1.3070,  1.0159],
          [-0.9877, -1.1932, -1.4843,  ...,  1.2385,  1.5297,  1.1529],
          [-0.5938, -1.0219, -1.5699,  ...,  0.9817,  1.2385,  0.9817],
          ...,
          [-0.9020, -0.7479, -0.6794,  ..., -0.9363, -0.8164, -0.8164],
          [-0.5596, -0.3883, -0.3541,  ..., -0.7479, -0.6623, -0.6623],
          [-0.3883, -0.2513, -0.2684,  ..., -0.5938, -0.5596, -0.5424]],

         [[-0.9503, -0.9153, -1.1429,  ...,  1.2381,  1.5707,  1.2381],
          [-0.8978, -1.1429, -1.3179,  ...,  1.5707,  1.7808,  1.3431],
          [-0.4951, -0.9678, -1.4055,  ...,  1.3081,  1.4832,  1.1506],
          ...,
          [-0.5651, -0.4426, -0.3901,  ..., -0.3550, -0.2675, -0.2850],
          [-0.1800,  0.0126,  0.0301,  ..., -0.2325, -0.1800, -0.1800],
          [ 0.0126,  0.1702,  0.1702,  ..., -0.1450, -0.1450, -0.1099]],

         [[-0.8633, -0.8458, -1.0724,  ...,  0.6531,  1.2805,  1.1062],
          [-0.8633, -1.

Data:  tensor([[[[ 0.8447,  0.8447,  0.8447,  ...,  0.4166,  0.3823,  0.3823],
          [ 0.8961,  0.8789,  0.8789,  ...,  0.4337,  0.4166,  0.3823],
          [ 0.9646,  0.9646,  0.9646,  ...,  0.4508,  0.4337,  0.4166],
          ...,
          [ 1.1187,  1.1358,  1.0673,  ..., -0.5082, -0.5596, -0.5253],
          [ 1.1700,  1.0844,  0.9132,  ..., -0.5082, -0.5938, -0.4568],
          [ 1.0331,  1.0331,  1.0331,  ..., -0.4739, -0.6109, -0.4911]],

         [[ 1.2906,  1.2906,  1.2906,  ...,  1.2031,  1.2206,  1.2381],
          [ 1.3256,  1.3256,  1.3256,  ...,  1.2031,  1.2206,  1.2206],
          [ 1.3606,  1.3606,  1.3606,  ...,  1.2031,  1.2206,  1.2206],
          ...,
          [ 1.1155,  1.1506,  1.0630,  ..., -1.3704, -1.4230, -1.4055],
          [ 1.2206,  1.1155,  0.8880,  ..., -1.3529, -1.3880, -1.3529],
          [ 1.0455,  1.0105,  0.9755,  ..., -1.3179, -1.4055, -1.3529]],

         [[ 1.7337,  1.7337,  1.7337,  ...,  1.8208,  1.8208,  1.8383],
          [ 1.7511,  1.

Data:  tensor([[[[-0.1999, -0.2342, -0.3369,  ..., -1.7069, -1.6213, -1.5528],
          [-0.3541, -0.4397, -0.6281,  ..., -1.7240, -1.6727, -1.6042],
          [-0.7308, -0.8507, -1.0048,  ..., -1.6898, -1.6042, -1.5357],
          ...,
          [ 1.8379,  1.8550,  1.8722,  ...,  0.8447,  0.9646,  0.9303],
          [ 1.9064,  1.9064,  1.9064,  ...,  1.4269,  1.3927,  1.3070],
          [ 1.9235,  1.8893,  1.8893,  ...,  1.5125,  1.4440,  1.3755]],

         [[ 0.5728,  0.5903,  0.5203,  ..., -1.5105, -1.3529, -1.1954],
          [ 0.2927,  0.2227,  0.0651,  ..., -1.5805, -1.4405, -1.2829],
          [-0.1800, -0.3200, -0.4776,  ..., -1.5280, -1.3704, -1.2129],
          ...,
          [ 1.8158,  1.8333,  1.8508,  ...,  0.9580,  1.0805,  1.0280],
          [ 1.8859,  1.8859,  1.8859,  ...,  1.5532,  1.5182,  1.4132],
          [ 1.9034,  1.8683,  1.8683,  ...,  1.6408,  1.5707,  1.4832]],

         [[ 0.1476,  0.1128, -0.0092,  ..., -1.4907, -1.3861, -1.2816],
          [ 0.0256, -0.

Data:  tensor([[[[ 1.9749,  1.9407,  1.9578,  ...,  1.9578,  1.9578,  1.9578],
          [ 1.9407,  1.9578,  1.9749,  ...,  1.9578,  1.9578,  1.9578],
          [ 1.9749,  1.9920,  1.9235,  ...,  1.9578,  1.9578,  1.9578],
          ...,
          [ 1.9407,  1.9749,  1.9578,  ...,  1.9407,  1.9578,  1.9749],
          [ 1.9407,  1.9749,  1.9407,  ...,  1.9407,  1.9578,  1.9578],
          [ 1.9407,  1.9749,  1.9749,  ...,  1.9749,  1.9749,  1.9407]],

         [[ 2.1485,  2.1134,  2.1310,  ...,  2.1310,  2.1310,  2.1310],
          [ 2.1134,  2.1310,  2.1485,  ...,  2.1310,  2.1310,  2.1310],
          [ 2.1485,  2.1660,  2.0959,  ...,  2.1310,  2.1310,  2.1310],
          ...,
          [ 2.1134,  2.1485,  2.1310,  ...,  2.1134,  2.1310,  2.1485],
          [ 2.1134,  2.1485,  2.1134,  ...,  2.1134,  2.1310,  2.1310],
          [ 2.1134,  2.1485,  2.1485,  ...,  2.1485,  2.1485,  2.1134]],

         [[ 2.3611,  2.3263,  2.3437,  ...,  2.3437,  2.3437,  2.3437],
          [ 2.3263,  2.

Data:  tensor([[[[ 0.1254,  0.1254,  0.2624,  ..., -0.0629,  0.0741,  0.1426],
          [-0.0458,  0.0741, -0.0972,  ..., -0.1143, -0.0116,  0.2453],
          [-0.4739,  0.0056,  0.3823,  ..., -0.0458,  0.0227,  0.1254],
          ...,
          [ 0.4166,  0.0569,  0.2796,  ..., -0.0116,  0.3481,  0.5364],
          [ 0.1939,  0.2453,  0.4337,  ...,  0.7077,  0.4851,  0.1597],
          [-0.1486,  0.1426,  0.1426,  ...,  1.0331,  0.4508,  0.0912]],

         [[-0.2850, -0.1800,  0.1001,  ..., -0.5126, -0.3025, -0.1450],
          [-0.3901, -0.2500, -0.2675,  ..., -0.5826, -0.4776, -0.0924],
          [-0.6527, -0.1975,  0.1702,  ..., -0.6001, -0.4601, -0.1450],
          ...,
          [-0.1099, -0.4776, -0.1099,  ..., -0.1975,  0.4678,  0.5028],
          [-0.2325, -0.1800, -0.0574,  ...,  0.6779,  0.3803, -0.0224],
          [-0.5301, -0.2850, -0.1975,  ...,  1.0630,  0.3102, -0.0924]],

         [[-0.4798, -0.2358,  0.1476,  ..., -0.8633, -0.7238, -0.7587],
          [-0.5670, -0.

Data:  tensor([[[[-1.6555, -1.7069, -1.7754,  ..., -0.8507, -1.1589, -1.2617],
          [-1.5699, -1.6213, -1.6898,  ..., -0.8507, -1.1932, -1.2788],
          [-1.5357, -1.5699, -1.6555,  ..., -0.9363, -1.2103, -1.3130],
          ...,
          [-1.0733, -1.1075, -1.0733,  ..., -1.6727, -1.6213, -1.6727],
          [-1.1589, -1.0733, -1.1760,  ..., -1.6042, -1.5185, -1.5014],
          [-1.0733, -1.2445, -1.2788,  ..., -1.5014, -1.4672, -1.4843]],

         [[-1.5105, -1.4755, -1.4930,  ..., -1.1779, -1.4580, -1.4580],
          [-1.3529, -1.3179, -1.3529,  ..., -1.1779, -1.4580, -1.4755],
          [-1.3004, -1.3004, -1.3529,  ..., -1.2829, -1.4580, -1.4930],
          ...,
          [-1.3179, -1.3529, -1.3354,  ..., -1.5630, -1.5280, -1.5455],
          [-1.3704, -1.3529, -1.4580,  ..., -1.5455, -1.4755, -1.4580],
          [-1.3179, -1.4580, -1.4580,  ..., -1.4405, -1.3880, -1.3880]],

         [[-1.0201, -1.1073, -1.1596,  ..., -1.2119, -1.5430, -1.3861],
          [-0.8633, -0.

Data:  tensor([[[[-2.1179, -2.1179, -2.1179,  ..., -2.0837, -2.0837, -2.0837],
          [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1008, -2.0837],
          [-2.1179, -2.1179, -2.1179,  ..., -2.1008, -2.1008, -2.0837],
          ...,
          [-0.6794, -0.6281, -0.6281,  ...,  0.0227,  0.0569,  0.1939],
          [-0.7308, -0.6794, -0.6794,  ...,  0.0227,  0.0569,  0.0569],
          [-0.8164, -0.7479, -0.6452,  ...,  0.0398,  0.0912,  0.1254]],

         [[-2.0357, -2.0357, -2.0357,  ..., -2.0007, -2.0182, -2.0182],
          [-2.0357, -2.0357, -2.0357,  ..., -2.0182, -2.0182, -2.0357],
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0182],
          ...,
          [-0.9503, -0.8627, -0.8277,  ..., -0.4251, -0.4076, -0.2500],
          [-0.9503, -0.8978, -0.8978,  ..., -0.3901, -0.3725, -0.3725],
          [-0.9678, -0.9503, -0.8978,  ..., -0.3901, -0.3550, -0.3550]],

         [[-1.8044, -1.8044, -1.8044,  ..., -1.8044, -1.8044, -1.8044],
          [-1.8044, -1.

       device='cuda:0')
Output:  tensor([[-4.5175e+02, -3.8342e+02, -4.2334e+02,  ..., -3.3986e+02,
         -2.7619e+02, -4.3015e+02],
        [-1.1141e+03, -6.3974e+02, -1.2074e+03,  ..., -8.3201e+02,
         -7.8294e+02, -1.2291e+03],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        ...,
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00],
        [-7.1847e+03, -4.3439e+03, -7.3547e+03,  ..., -5.9866e+03,
         -5.0664e+03, -7.2794e+03],
        [-4.9118e-01, -3.6031e-01, -1.3381e-01,  ..., -1.0861e+00,
         -1.1508e+00, -1.2135e+00]], device='cuda:0')
target:  tensor([117,  12,  39,  14,  45,  26, 112,  15], device='cuda:0')
tensor([[11],
        [36],
        [ 4],
        [ 4],
        [ 4],
        [ 4],
        [56],
        [ 4]], device='cuda:0')
Data:  tensor([[[[ 1.7009,  1.7180,  1.7009,  ...,  0.9474,  0.9303,  0.9132],
          [ 1.7009,  1.7180,  1.70

Data:  tensor([[[[-0.8849, -0.9020, -0.8849,  ...,  1.8722,  1.8722,  1.9920],
          [-0.8849, -0.9020, -0.8849,  ...,  1.9235,  1.9920,  2.0434],
          [-0.8678, -0.8849, -0.8678,  ...,  1.9578,  1.9749,  1.8893],
          ...,
          [-1.3815, -1.3644, -1.6213,  ..., -1.9124, -1.8782, -1.8782],
          [-1.2274, -1.4672, -1.8097,  ..., -1.8782, -1.8782, -1.8782],
          [-1.2103, -1.7069, -1.7754,  ..., -1.8439, -1.9124, -1.8782]],

         [[-0.5126, -0.5301, -0.5301,  ...,  2.0959,  2.0959,  2.2185],
          [-0.5126, -0.5301, -0.5301,  ...,  2.1485,  2.2185,  2.2710],
          [-0.4951, -0.5126, -0.5126,  ...,  2.1835,  2.2010,  2.1134],
          ...,
          [-0.6352, -0.6527, -0.9503,  ..., -1.5805, -1.5630, -1.6681],
          [-0.4426, -0.7402, -1.1779,  ..., -1.5980, -1.6155, -1.7206],
          [-0.4426, -0.9853, -1.1429,  ..., -1.5630, -1.6506, -1.7206]],

         [[-0.2881, -0.3404, -0.3753,  ...,  1.7685,  1.7685,  1.8905],
          [-0.2881, -0.

Data:  tensor([[[[-0.1486, -0.2342, -0.1828,  ..., -1.6384, -1.5014, -1.3987],
          [-0.1486, -0.2342, -0.1828,  ..., -1.6213, -1.5185, -1.4500],
          [-0.1486, -0.2342, -0.1657,  ..., -1.6213, -1.5528, -1.4843],
          ...,
          [-1.3644, -1.4158, -1.4329,  ..., -1.4329, -1.3473, -1.3815],
          [-1.1418, -1.2445, -1.3987,  ..., -1.1075, -1.0733, -1.3130],
          [-1.0390, -1.1247, -1.3815,  ..., -1.0390, -0.9363, -1.1760]],

         [[ 0.1352,  0.0126,  0.0126,  ..., -1.5980, -1.4405, -1.3354],
          [ 0.1352,  0.0126,  0.0126,  ..., -1.5805, -1.4580, -1.3880],
          [ 0.1352,  0.0126,  0.0301,  ..., -1.5805, -1.4930, -1.4230],
          ...,
          [-0.6176, -0.7052, -0.7227,  ..., -0.7752, -0.6877, -0.7052],
          [-0.3901, -0.5301, -0.7052,  ..., -0.4951, -0.4601, -0.6877],
          [-0.2850, -0.4426, -0.7052,  ..., -0.4601, -0.3725, -0.6001]],

         [[ 0.1302, -0.0441, -0.0790,  ..., -1.5256, -1.4210, -1.3513],
          [ 0.1302, -0.

Data:  tensor([[[[-1.6384, -1.5185, -1.6898,  ..., -0.1828, -0.0801, -0.4739],
          [-1.6898, -1.3473, -1.6384,  ..., -0.6109, -0.1999, -0.3712],
          [-1.3302, -1.2103, -1.6727,  ..., -0.7993, -0.6794, -0.6281],
          ...,
          [-1.7925, -1.2445, -1.5357,  ..., -0.1999, -0.4568, -0.7479],
          [-1.7412, -1.3815, -1.7412,  ..., -0.4397, -0.4226, -1.0048],
          [-1.7069, -1.3987, -1.6898,  ..., -0.5596, -0.1828, -0.7308]],

         [[-1.3529, -1.1954, -1.3880,  ..., -0.0049,  0.1702, -0.1625],
          [-1.4230, -1.1604, -1.4055,  ..., -0.2675,  0.0826, -0.0224],
          [-1.2304, -1.1078, -1.4405,  ..., -0.4426, -0.3375, -0.3025],
          ...,
          [-1.5280, -0.9503, -1.3880,  ...,  0.3277, -0.0224, -0.4076],
          [-1.4930, -1.1779, -1.5455,  ...,  0.0826,  0.0476, -0.6877],
          [-1.4755, -1.1253, -1.5455,  ..., -0.0049,  0.3452, -0.4076]],

         [[-1.4907, -1.3861, -1.5953,  ..., -1.0724, -0.8981, -1.2293],
          [-1.6127, -1.

Data:  tensor([[[[-0.3369, -0.2856, -0.1999,  ..., -1.9467, -2.0152, -1.9638],
          [-0.3369, -0.2856, -0.1999,  ..., -1.9809, -1.9980, -1.9638],
          [-0.3198, -0.2856, -0.1999,  ..., -1.9980, -1.9638, -1.9467],
          ...,
          [-0.7993, -0.8849, -0.8678,  ..., -0.5596, -0.5253, -0.5253],
          [-0.8335, -0.9020, -0.8507,  ..., -0.6281, -0.6109, -0.5767],
          [-0.8335, -0.8849, -0.8507,  ..., -0.5938, -0.5767, -0.5767]],

         [[-0.4601, -0.4076, -0.3200,  ..., -1.7731, -1.8606, -1.8256],
          [-0.4601, -0.4076, -0.3200,  ..., -1.8431, -1.8782, -1.8606],
          [-0.4426, -0.4076, -0.3200,  ..., -1.8957, -1.8606, -1.8606],
          ...,
          [-1.2654, -1.3704, -1.3704,  ..., -0.6352, -0.6001, -0.6001],
          [-1.3004, -1.3704, -1.3529,  ..., -0.7052, -0.6877, -0.6527],
          [-1.2479, -1.3354, -1.3354,  ..., -0.6702, -0.6527, -0.6527]],

         [[-0.4624, -0.4101, -0.3230,  ..., -1.5604, -1.6302, -1.5953],
          [-0.4624, -0.

Data:  tensor([[[[-2.0152, -1.9467, -1.9124,  ...,  1.5639,  1.5297,  1.4954],
          [-2.0152, -1.9638, -1.9295,  ...,  1.6324,  1.5639,  1.5125],
          [-1.9467, -1.9638, -1.9467,  ...,  1.6153,  1.5468,  1.5297],
          ...,
          [ 1.3070,  1.3070,  1.3070,  ...,  1.2385,  1.2214,  1.2043],
          [ 1.3070,  1.2899,  1.2899,  ...,  1.2385,  1.2214,  1.1872],
          [ 1.2728,  1.2557,  1.2557,  ...,  1.2385,  1.2214,  1.1872]],

         [[-1.5280, -1.5455, -1.5805,  ...,  0.9230,  0.8354,  0.7479],
          [-1.4580, -1.5105, -1.5455,  ...,  0.9580,  0.8529,  0.7829],
          [-1.4230, -1.5105, -1.5280,  ...,  0.9405,  0.8354,  0.8004],
          ...,
          [ 1.3957,  1.3957,  1.3957,  ...,  1.4307,  1.4132,  1.3957],
          [ 1.3957,  1.3782,  1.3782,  ...,  1.4307,  1.4132,  1.3782],
          [ 1.3606,  1.3431,  1.3431,  ...,  1.4307,  1.4132,  1.3782]],

         [[-0.6193, -0.7238, -0.8284,  ...,  1.0017,  0.8448,  0.6879],
          [-0.5495, -0.

Data:  tensor([[[[-1.8610, -1.8610, -1.8268,  ..., -1.9467, -1.8610, -1.8439],
          [-1.8439, -1.8268, -1.8439,  ..., -1.8782, -1.8097, -1.8097],
          [-1.7583, -1.7925, -1.8097,  ..., -1.8610, -1.8097, -1.8268],
          ...,
          [ 0.3652,  0.3652,  0.1768,  ...,  0.4679,  0.4508,  0.4851],
          [ 0.1426,  0.2111,  0.1254,  ...,  0.5364,  0.4508,  0.4337],
          [ 0.2111,  0.3994,  0.2796,  ...,  0.5707,  0.4851,  0.4679]],

         [[-1.7556, -1.7906, -1.7731,  ..., -1.8957, -1.8256, -1.7381],
          [-1.7731, -1.7731, -1.7906,  ..., -1.8957, -1.8081, -1.7031],
          [-1.7381, -1.7906, -1.7556,  ..., -1.8606, -1.7906, -1.7206],
          ...,
          [ 0.6429,  0.6254,  0.4853,  ...,  0.7479,  0.7479,  0.7479],
          [ 0.3627,  0.4328,  0.3978,  ...,  0.7654,  0.7129,  0.6779],
          [ 0.3627,  0.6078,  0.6078,  ...,  0.8354,  0.7479,  0.7129]],

         [[-1.6476, -1.6127, -1.6650,  ..., -1.7173, -1.6476, -1.5779],
          [-1.6650, -1.

Data:  tensor([[[[-1.7754, -1.7925, -1.8782,  ..., -0.3369, -0.5424, -0.5938],
          [-1.6727, -1.7754, -1.8439,  ...,  0.0569, -0.3027, -0.4568],
          [-1.7069, -1.7754, -1.8439,  ...,  0.5364,  0.1939,  0.0741],
          ...,
          [-0.8849, -1.2959, -1.5699,  ..., -1.5357, -0.6965, -1.0733],
          [-0.5767, -1.2274, -1.5357,  ..., -1.2617, -1.1247, -1.2274],
          [-0.3541, -0.9020, -1.7754,  ..., -1.0048, -1.1075, -1.1075]],

         [[-1.8782, -1.8957, -1.9832,  ..., -0.0399, -0.1099, -0.0399],
          [-1.7731, -1.8782, -1.9482,  ...,  0.2402,  0.0476,  0.0126],
          [-1.7906, -1.8606, -1.9307,  ...,  0.3277,  0.1702,  0.3277],
          ...,
          [-0.3200, -0.7402, -1.0378,  ..., -0.9503, -0.1450, -0.5826],
          [ 0.0301, -0.6527, -1.0028,  ..., -0.6352, -0.5126, -0.7052],
          [ 0.3277, -0.2675, -1.2304,  ..., -0.3025, -0.4426, -0.5126]],

         [[-1.5779, -1.5953, -1.6824,  ..., -0.4973, -0.6193, -0.6541],
          [-1.4907, -1.

Data:  tensor([[[[ 0.6906,  0.6906,  0.6734,  ...,  0.5536,  0.5364,  0.5022],
          [ 0.6906,  0.7077,  0.7077,  ...,  0.5364,  0.5364,  0.5364],
          [ 0.7248,  0.7591,  0.7419,  ...,  0.5364,  0.5022,  0.4851],
          ...,
          [-0.1486, -0.1314, -0.1657,  ..., -0.5082, -0.5253, -0.5253],
          [-0.1657, -0.1657, -0.1657,  ..., -0.5082, -0.4911, -0.4739],
          [-0.1657, -0.1657, -0.1486,  ..., -0.5082, -0.4739, -0.4568]],

         [[ 0.6078,  0.6078,  0.5903,  ...,  0.4503,  0.4503,  0.4153],
          [ 0.6078,  0.6254,  0.6254,  ...,  0.4328,  0.4503,  0.4503],
          [ 0.6429,  0.6779,  0.6604,  ...,  0.4328,  0.4153,  0.3978],
          ...,
          [-0.2675, -0.2675, -0.3025,  ..., -0.5651, -0.5826, -0.5826],
          [-0.2675, -0.3025, -0.3025,  ..., -0.5651, -0.5476, -0.5301],
          [-0.2675, -0.2850, -0.2675,  ..., -0.5476, -0.5301, -0.4951]],

         [[ 0.6879,  0.6705,  0.6531,  ...,  0.5136,  0.5136,  0.4788],
          [ 0.6879,  0.

Data:  tensor([[[[-0.2856, -0.4226, -0.5424,  ..., -0.5253, -0.4226, -0.2684],
          [-0.1486, -0.2684, -0.3712,  ..., -0.5767, -0.4226, -0.2513],
          [-0.0801, -0.1828, -0.2171,  ..., -0.5082, -0.3712, -0.1999],
          ...,
          [-0.5938, -0.4568, -0.3198,  ..., -0.6109, -0.6109, -0.5424],
          [-0.7822, -0.7308, -0.8507,  ..., -0.6109, -0.6281, -0.5596],
          [-0.7822, -0.6452, -1.0048,  ..., -0.5767, -0.5938, -0.5938]],

         [[-0.2850, -0.3375, -0.4776,  ..., -0.4951, -0.4251, -0.3025],
          [-0.1450, -0.1975, -0.3375,  ..., -0.5476, -0.4251, -0.2850],
          [-0.0924, -0.1099, -0.1975,  ..., -0.4951, -0.3901, -0.2150],
          ...,
          [-0.4251, -0.3725, -0.3200,  ...,  0.1702,  0.1702,  0.2227],
          [-0.5126, -0.5476, -0.7402,  ...,  0.1877,  0.1527,  0.2227],
          [-0.4776, -0.3725, -0.7577,  ...,  0.2052,  0.1702,  0.1352]],

         [[-0.8458, -0.9678, -1.1596,  ..., -0.9504, -0.9156, -0.7761],
          [-0.7238, -0.

Data:  tensor([[[[ 0.3994,  0.8276,  0.8104,  ...,  0.5536,  0.7419,  0.7077],
          [ 0.5878,  0.7933,  0.7933,  ...,  0.2453,  0.4508,  0.5878],
          [ 0.3994,  0.5707,  0.7419,  ...,  0.2111,  0.1939,  0.4337],
          ...,
          [-0.2513, -0.2856,  0.3481,  ...,  0.6221,  0.5364,  0.0741],
          [ 0.0056, -0.3883,  0.0056,  ..., -0.1828, -0.3541, -0.1828],
          [-0.1143, -0.4226, -0.0801,  ..., -0.4739, -0.2856, -0.4911]],

         [[ 0.8704,  0.9755,  0.8354,  ...,  0.9230,  0.9405,  0.7829],
          [ 1.0105,  1.0805,  0.9580,  ...,  0.6604,  0.7654,  0.6954],
          [ 0.8004,  0.8880,  0.8880,  ...,  0.3277,  0.4153,  0.5553],
          ...,
          [-0.0399, -0.1625,  0.3627,  ...,  0.0126,  0.0651, -0.3375],
          [ 0.2227, -0.1099,  0.1001,  ..., -0.4951, -0.5651, -0.4426],
          [-0.0049, -0.1099,  0.1527,  ..., -0.5651, -0.3901, -0.4076]],

         [[-0.6367, -0.5321, -0.5670,  ..., -0.5495, -0.4450, -0.5321],
          [-0.5147, -0.

Data:  tensor([[[[ 0.0227,  0.0569,  0.1083,  ..., -0.2513, -0.2513, -0.2171],
          [-0.0116,  0.0741,  0.1254,  ..., -0.2856, -0.2342, -0.2856],
          [-0.0287,  0.0398,  0.0569,  ..., -0.3027, -0.3027, -0.3369],
          ...,
          [ 0.4337,  0.3994,  0.1254,  ...,  0.5022,  0.4679,  0.5536],
          [ 1.1358,  0.8104,  0.2796,  ...,  0.5878,  0.3823,  0.7591],
          [ 1.0673,  1.0502,  0.8789,  ...,  0.7419,  0.4851,  0.5536]],

         [[ 0.0476,  0.0826,  0.1001,  ...,  0.0126,  0.0126,  0.0476],
          [ 0.0301,  0.0826,  0.1176,  ..., -0.0224,  0.0301, -0.0224],
          [ 0.0126,  0.0826,  0.1001,  ..., -0.0399, -0.0399, -0.0749],
          ...,
          [ 0.4328,  0.4328,  0.2227,  ...,  0.8179,  0.8179,  0.9405],
          [ 1.0980,  0.8004,  0.3627,  ...,  0.9055,  0.7304,  1.1331],
          [ 0.9755,  1.0280,  0.9230,  ...,  1.0455,  0.8354,  0.9580]],

         [[-1.2293, -1.1944, -1.1596,  ..., -1.1073, -1.1073, -1.0724],
          [-1.2467, -1.

Data:  tensor([[[[-0.4397, -0.3712, -0.2856,  ...,  1.7523,  1.8550,  1.9407],
          [-0.2684, -0.1828, -0.1314,  ...,  1.7523,  1.8550,  1.9407],
          [-0.0801,  0.0227,  0.0569,  ...,  1.7523,  1.8550,  1.9407],
          ...,
          [ 0.3652,  0.3481,  0.3481,  ...,  0.4679,  0.5022,  0.5022],
          [ 0.3309,  0.3138,  0.3309,  ...,  0.4679,  0.4851,  0.4851],
          [ 0.3138,  0.2967,  0.2967,  ...,  0.4679,  0.4679,  0.4679]],

         [[-1.0378, -1.0203, -0.9853,  ...,  1.2381,  1.2906,  1.3256],
          [-0.8803, -0.8627, -0.8277,  ...,  1.2381,  1.2906,  1.3256],
          [-0.7052, -0.6527, -0.6527,  ...,  1.2381,  1.2906,  1.3256],
          ...,
          [ 1.0455,  1.0280,  1.0280,  ...,  1.0980,  1.1331,  1.1331],
          [ 1.0105,  0.9930,  1.0105,  ...,  1.0980,  1.1155,  1.1155],
          [ 0.9930,  0.9755,  0.9755,  ...,  1.0980,  1.0980,  1.0980]],

         [[-0.9330, -0.9504, -0.8981,  ...,  0.5485,  0.3742,  0.1476],
          [-0.8110, -0.

Output:  tensor([[ -0.4912,  -0.3603,  -0.1338,  ...,  -1.0861,  -1.1508,  -1.2135],
        [-60.7477, -13.9784, -65.0344,  ..., -57.5465, -68.1473, -65.7572],
        [ -0.4912,  -0.3603,  -0.1338,  ...,  -1.0861,  -1.1508,  -1.2135],
        ...,
        [ -0.4912,  -0.3603,  -0.1338,  ...,  -1.0861,  -1.1508,  -1.2135],
        [-69.4640,  -5.4692, -71.6080,  ..., -63.0048, -70.2237, -59.9931],
        [-20.8908,  15.2818, -24.6035,  ..., -19.7211, -25.1101, -20.7410]],
       device='cuda:0')
target:  tensor([ 28,   0,  86,  43,  78, 108,   1, 109], device='cuda:0')
tensor([[  4],
        [ 31],
        [  4],
        [ 47],
        [ 36],
        [  4],
        [106],
        [ 47]], device='cuda:0')
Data:  tensor([[[[ 1.2728,  1.3070,  1.1529,  ...,  0.7419,  0.7248,  0.7248],
          [ 1.2385,  1.2557,  1.1187,  ...,  0.7762,  0.7591,  0.7762],
          [ 1.2043,  1.1872,  1.1015,  ...,  0.9132,  0.8447,  0.8276],
          ...,
          [-0.2171, -0.4739, -0.1999,  ...,  0

target:  tensor([ 32, 123,  35,  82,  72, 111,  63,  81], device='cuda:0')
tensor([[ 4],
        [ 4],
        [ 4],
        [ 4],
        [ 4],
        [36],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[ 1.1700,  1.3413,  1.1187,  ...,  0.4851,  0.7762,  0.7762],
          [ 1.2557,  1.2385,  1.0331,  ...,  0.3823,  0.5707,  0.5878],
          [ 1.0502,  1.2385,  0.8789,  ..., -0.0458,  0.1768,  0.3994],
          ...,
          [-0.1143, -0.2171, -0.0458,  ..., -0.2856, -0.4397, -0.6109],
          [-0.0458,  0.0056,  0.0227,  ..., -0.4397, -0.3883, -0.3883],
          [-0.2513, -0.0972, -0.0287,  ..., -0.6965, -0.5424, -0.3541]],

         [[ 1.5882,  1.7458,  1.5007,  ...,  0.8354,  1.2206,  1.2906],
          [ 1.6583,  1.6408,  1.3957,  ...,  0.7129,  1.0980,  1.2731],
          [ 1.4482,  1.6057,  1.2381,  ...,  0.3102,  0.7304,  1.1155],
          ...,
          [ 0.3978,  0.2577,  0.3803,  ...,  0.3803,  0.0826, -0.2325],
          [ 0.3803,  0.3978,  0.3803

         -0.6527, -0.5687, -1.0861, -1.1508, -1.2135]], device='cuda:0')
target:  tensor([103,  43, 132,   6], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4]], device='cuda:0')
Test Loss: 1345.552979


Test Accuracy:  0% ( 7/836)


### Running architecture 1 again, but directly accessing global variables inside of train function

#### Result:
Again, predicting the same class for every image (class 4 - seems to be a favorite)

In [6]:
# the following import is required for training to be robust to truncated images
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


def train(n_epochs, loaders, model, optimizer, criterion, use_cuda, save_path):
    global train_loader
    global valid_loader
    global model_scratch
    global optimizer_scratch
    global criterion_scratch
    """returns trained model"""
    # initialize tracker for minimum validation loss
    valid_loss_min = np.Inf 
    
    for epoch in range(1, n_epochs+1):

        # initialize variables to monitor training and validation loss
        train_loss = 0.0
        valid_loss = 0.0
        ###################
        # train the model #
        ###################
        model_scratch.train()
        for data, target in train_loader:
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            ## find the loss and update the model parameters accordingly
            with torch.enable_grad():
                optimizer_scratch.zero_grad()
                output = model_scratch(data) #get predictions
                loss = criterion_scratch(output, target) #calulate loss
                loss.backward()  # calculate the gradients
                optimizer_scratch.step() # perform optimization step
                train_loss += loss.item()

    
        ######################    
        # validate the model #
        ######################
        model_scratch.eval()
        for data, target in valid_loader:
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            with torch.no_grad():
                output = model_scratch(data)
                loss = criterion_scratch(output, target)
                valid_loss += loss.item()
            
        train_loss = train_loss / len(train_loader)
        valid_loss = valid_loss / len(valid_loader)
        
            
        # print training/validation statistics 
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            epoch, 
            train_loss,
            valid_loss
            ))
        
        
        ## TODO: save the model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print('Validation Loss Decreased. Saving model')
            torch.save(model_scratch.state_dict(), save_path)
            valid_loss_min = valid_loss
    # return trained model
    return model_scratch

# train the model
trained_model = train(20, loaders_scratch, model_scratch, optimizer_scratch, 
                      criterion_scratch, use_cuda, 'arch_1_global_train_variables.pt')

Epoch: 1 	Training Loss: 5.593780 	Validation Loss: 4.872661
Validation Loss Decreased. Saving model
Epoch: 2 	Training Loss: 4.880664 	Validation Loss: 4.870785
Validation Loss Decreased. Saving model
Epoch: 3 	Training Loss: 4.872185 	Validation Loss: 4.868692
Validation Loss Decreased. Saving model
Epoch: 4 	Training Loss: 4.870217 	Validation Loss: 4.868035
Validation Loss Decreased. Saving model
Epoch: 5 	Training Loss: 4.868786 	Validation Loss: 4.868854
Epoch: 6 	Training Loss: 4.867970 	Validation Loss: 4.868792
Epoch: 7 	Training Loss: 4.887500 	Validation Loss: 4.871854
Epoch: 8 	Training Loss: 4.867568 	Validation Loss: 4.868095
Epoch: 9 	Training Loss: 4.866680 	Validation Loss: 4.869761
Epoch: 10 	Training Loss: 4.866748 	Validation Loss: 4.869246
Epoch: 11 	Training Loss: 4.866648 	Validation Loss: 4.868374
Epoch: 12 	Training Loss: 4.866472 	Validation Loss: 4.868917
Epoch: 13 	Training Loss: 4.866339 	Validation Loss: 4.867692
Validation Loss Decreased. Saving model
Epo

### Test function accessing global variables, rather than parameters

In [9]:
def test(loaders, model, criterion, use_cuda):
    global test_loader
    global model_scratch
    global trained_model
    global criterion_scratch
    
    # monitor test loss and accuracy
    test_loss = 0.0
    correct = 0.0
    total = 0.0
    model_scratch.eval()
    
    for data, target in test_loader:
        # move to GPU
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        with torch.no_grad():
            # forward pass: compute predicted outputs by passing inputs to the model
            output = trained_model(data)
            print('Output: ', output)
            print('target: ', target)
            # calculate the loss
            loss = criterion_scratch(output, target)
            # update average test loss 
#             test_loss = test_loss + ((1 / (batch_idx + 1)) * (loss.data - test_loss))
            test_loss += loss.item()
            # convert output probabilities to predicted class
            pred = output.data.max(1, keepdim=True)[1]
            print(pred)
            # compare predictions to true label
            correct += np.sum(np.squeeze(pred.eq(target.data.view_as(pred))).cpu().numpy())
            total += data.size(0)
            
    print('Test Loss: {:.6f}\n'.format(test_loss/len(test_loader)))

    print('\nTest Accuracy: %2d%% (%2d/%2d)' % (
        100. * correct / total, correct, total))

# call test function    
test(loaders_scratch, model_scratch, criterion_scratch, use_cuda)

Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034]],
       device='cuda:0')
target:  tensor([127,  59,   7,  48,  26,  13, 106,  51], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548

Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034]],
       device='cuda:0')
target:  tensor([28, 91, 46, 30, 43, 89, 26, 74], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., 

Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034]],
       device='cuda:0')
target:  tensor([105,  21,  56,  96,  11,  38, 119,  16], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548

Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034]],
       device='cuda:0')
target:  tensor([ 71,  40,  34, 122,  17,  42,  47,  79], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548

Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034]],
       device='cuda:0')
target:  tensor([122, 117,  26,  96,  25,  72,   1,  13], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548

Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034]],
       device='cuda:0')
target:  tensor([15, 93, 67, 51, 67, 57, 23, 87], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., 

Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034]],
       device='cuda:0')
target:  tensor([ 47,  21,   8,  16,  54, 103,  72, 102], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548

Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034]],
       device='cuda:0')
target:  tensor([ 62,  19,  40, 118, 108,  56,   1,  63], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Output:  tensor([[-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        ...,
        [-4.6683, -4.7427, -4.8548,  ..., -5.4137, -5.5457, -5.4034],
        [-4.6683, -4.7427, -4.8548

## Test

In [8]:
# load the model that got the best validation accuracy
model_scratch.load_state_dict(torch.load('model_scratch.pt'))

<All keys matched successfully>

In [9]:
# call test function    
test(loaders_scratch, model_scratch, criterion_scratch, use_cuda)

Data:  tensor([[[[ 0.0912,  0.1254,  0.1939,  ..., -0.0972, -0.0629, -0.3198],
          [ 0.1426,  0.1597,  0.2111,  ..., -0.1657, -0.0801, -0.3369],
          [ 0.1083,  0.1254,  0.1768,  ..., -0.0801,  0.0056, -0.1657],
          ...,
          [-0.9705, -0.8678, -0.6452,  ..., -1.1589, -1.0904, -1.0219],
          [-1.0733, -0.9705, -0.7822,  ..., -1.1247, -1.0904, -1.0562],
          [-0.9020, -0.7650, -0.7308,  ..., -1.0733, -1.0904, -1.0904]],

         [[ 1.0105,  1.0280,  1.0280,  ...,  0.3102,  0.3978,  0.2227],
          [ 1.1155,  1.1331,  1.1331,  ...,  0.2927,  0.4503,  0.2752],
          [ 1.0455,  1.0630,  1.0980,  ...,  0.3627,  0.5378,  0.4328],
          ...,
          [ 0.0126,  0.0301,  0.2227,  ..., -0.1099, -0.0574, -0.0049],
          [-0.1099, -0.0749,  0.0651,  ..., -0.0749, -0.0749, -0.0574],
          [-0.0574,  0.0126,  0.0126,  ..., -0.0399, -0.0749, -0.1099]],

         [[ 0.2522,  0.2696,  0.2871,  ..., -0.3578, -0.2707, -0.4450],
          [ 0.3219,  0.

Data:  tensor([[[[-0.9192, -0.7822, -0.5767,  ...,  0.2796,  0.2453,  0.1426],
          [-0.9020, -0.7822, -0.6794,  ...,  0.2282,  0.1939,  0.1939],
          [-0.9192, -0.7822, -0.7137,  ...,  0.1939,  0.2624,  0.2796],
          ...,
          [-0.4397, -0.5767, -0.8507,  ..., -0.4911, -0.3712, -0.4568],
          [-0.4054, -0.5253, -0.9192,  ..., -0.3541, -0.5767, -0.2513],
          [-0.4397, -0.2342, -0.7137,  ..., -0.2513, -0.2856, -0.1828]],

         [[-0.9153, -0.7927, -0.5651,  ...,  0.3452,  0.2577,  0.1877],
          [-0.8978, -0.7577, -0.6352,  ...,  0.2577,  0.2052,  0.2052],
          [-0.8627, -0.7227, -0.6527,  ...,  0.2402,  0.2927,  0.3102],
          ...,
          [-0.4601, -0.5826, -0.8803,  ..., -0.5126, -0.4076, -0.4776],
          [-0.4426, -0.5301, -0.9328,  ..., -0.3375, -0.6001, -0.2500],
          [-0.4601, -0.2325, -0.7227,  ..., -0.2325, -0.2675, -0.1450]],

         [[-0.9678, -0.8458, -0.6193,  ...,  0.3916,  0.3393,  0.2696],
          [-0.9504, -0.

Data:  tensor([[[[ 1.5297,  1.6324,  1.7523,  ...,  1.2214, -0.1143, -0.7308],
          [ 1.4612,  1.6838,  1.6153,  ...,  0.9303, -0.3712, -0.7479],
          [ 1.5468,  1.6153,  1.3242,  ...,  0.9474, -0.3712, -0.7308],
          ...,
          [ 0.1768,  0.1939, -0.2513,  ...,  0.8789,  1.1187,  1.3070],
          [-0.0458, -0.0458, -0.1314,  ...,  0.1939,  0.0056,  0.4166],
          [-0.1657, -0.1143,  0.0398,  ...,  0.5536,  0.1254, -0.2171]],

         [[ 1.9384,  2.0434,  2.1485,  ...,  1.5007,  0.1702, -0.3901],
          [ 1.8333,  2.0609,  1.9909,  ...,  1.2031, -0.0924, -0.4601],
          [ 1.9209,  1.9909,  1.6933,  ...,  1.2381, -0.0749, -0.4251],
          ...,
          [ 0.2752,  0.2927, -0.1625,  ...,  0.6429,  0.9055,  1.1506],
          [ 0.0476,  0.0476, -0.0399,  ..., -0.0049, -0.1625,  0.2927],
          [-0.0749, -0.0224,  0.1352,  ...,  0.4328,  0.0301, -0.2850]],

         [[ 2.4134,  2.4657,  2.5180,  ...,  2.0997,  0.7925,  0.2173],
          [ 2.3437,  2.

Data:  tensor([[[[ 0.9474,  1.0159,  1.0331,  ..., -1.6384, -1.5699, -1.6042],
          [ 0.9646,  1.0844,  1.0159,  ..., -1.6042, -1.5185, -1.5528],
          [ 0.9988,  1.1358,  1.0331,  ..., -1.5870, -1.5699, -1.5699],
          ...,
          [ 2.2489,  2.2489,  2.2489,  ...,  2.1119,  1.9235,  2.1290],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2318,  2.2318,  2.2147],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2147,  2.2318,  2.2318]],

         [[ 1.1155,  1.2731,  1.3431,  ..., -1.4230, -1.4405, -1.4405],
          [ 1.1331,  1.3256,  1.3431,  ..., -1.4230, -1.4405, -1.4755],
          [ 1.1506,  1.3606,  1.3606,  ..., -1.4755, -1.4755, -1.4580],
          ...,
          [ 2.4286,  2.4286,  2.4286,  ...,  2.2010,  1.9034,  2.2360],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4111,  2.4111],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4111,  2.4111,  2.4111]],

         [[ 1.2282,  1.5420,  1.6291,  ..., -1.1944, -1.1770, -1.1421],
          [ 1.2457,  1.

Data:  tensor([[[[-0.5424, -0.5767, -0.5938,  ..., -1.1932, -1.2103, -1.0904],
          [-0.4911, -0.5424, -0.5767,  ..., -1.1589, -1.1932, -1.0562],
          [-0.4054, -0.4568, -0.4739,  ..., -1.1760, -1.1760, -1.0562],
          ...,
          [ 0.8276,  0.7762,  0.6734,  ...,  0.3481,  0.3138,  0.2967],
          [ 0.6563,  0.6563,  0.6221,  ...,  0.3481,  0.3481,  0.3309],
          [ 0.3994,  0.4337,  0.4337,  ...,  0.2453,  0.2796,  0.3138]],

         [[-1.0028, -1.0378, -1.0203,  ..., -1.7381, -1.8081, -1.7206],
          [-0.9503, -1.0028, -1.0028,  ..., -1.7031, -1.7906, -1.6856],
          [-0.8627, -0.9153, -0.8978,  ..., -1.7031, -1.7556, -1.6856],
          ...,
          [ 1.0630,  1.0455,  0.9580,  ...,  0.4678,  0.4328,  0.4153],
          [ 0.8704,  0.8704,  0.8704,  ...,  0.4853,  0.4678,  0.4678],
          [ 0.6078,  0.6429,  0.6604,  ...,  0.3978,  0.4328,  0.4503]],

         [[-0.9330, -0.9330, -0.8807,  ..., -1.6650, -1.7347, -1.6824],
          [-0.8807, -0.

Data:  tensor([[[[-1.9809, -1.9809, -1.9809,  ..., -1.3644, -1.4158, -1.4500],
          [-2.0152, -2.0152, -2.0152,  ..., -1.3987, -1.4672, -1.5185],
          [-2.0152, -2.0152, -1.9809,  ..., -1.4158, -1.4843, -1.5185],
          ...,
          [-1.6213, -1.6898, -1.6727,  ..., -0.7479, -0.5938, -0.5424],
          [-1.7069, -1.5528, -1.4500,  ..., -0.7308, -0.4911, -0.4739],
          [-1.5870, -1.3473, -1.3302,  ..., -0.7308, -0.6281, -0.5596]],

         [[-1.8782, -1.8782, -1.8957,  ..., -1.2304, -1.2654, -1.3179],
          [-1.9132, -1.9132, -1.9132,  ..., -1.2654, -1.3004, -1.3529],
          [-1.9307, -1.9482, -1.8957,  ..., -1.2829, -1.3354, -1.3529],
          ...,
          [-1.7031, -1.7906, -1.7731,  ..., -0.8452, -0.7052, -0.6352],
          [-1.8081, -1.6331, -1.5280,  ..., -0.7577, -0.5301, -0.4601],
          [-1.6681, -1.4230, -1.3880,  ..., -0.7052, -0.5826, -0.5126]],

         [[-1.6476, -1.6302, -1.5953,  ..., -1.0898, -1.1247, -1.1596],
          [-1.6302, -1.

Data:  tensor([[[[ 1.2557,  1.5639,  1.6667,  ...,  1.5982,  1.6153,  1.6153],
          [ 1.1529,  1.4783,  1.5639,  ...,  1.5468,  1.5810,  1.5639],
          [ 1.2214,  1.4783,  1.5468,  ...,  1.4612,  1.4612,  1.4612],
          ...,
          [ 0.9132,  1.1872,  0.9646,  ..., -0.0801, -0.9705, -0.8849],
          [ 1.5982,  1.1872,  1.3927,  ..., -0.9020, -0.4911, -0.8507],
          [ 1.4269,  1.4612,  1.7180,  ..., -1.3130, -0.7479, -0.7308]],

         [[ 1.7458,  1.9384,  1.9209,  ...,  1.9209,  1.9909,  1.9734],
          [ 1.6057,  1.8508,  1.8683,  ...,  1.8333,  1.9734,  1.9384],
          [ 1.5357,  1.7458,  1.7983,  ...,  1.7808,  1.8859,  1.8859],
          ...,
          [ 1.0980,  1.4482,  1.2031,  ...,  0.6604, -0.3200, -0.1975],
          [ 1.8333,  1.4832,  1.6583,  ..., -0.2150,  0.2227, -0.1275],
          [ 1.6583,  1.7108,  1.9384,  ..., -0.7227, -0.0224,  0.0126]],

         [[ 0.0431,  0.5485,  0.7576,  ...,  0.8099,  0.7576,  0.8622],
          [ 0.2696,  0.

Data:  tensor([[[[-0.1486, -0.1314, -0.1486,  ...,  1.1358,  1.1700,  1.2043],
          [-0.1828, -0.1486, -0.1143,  ...,  1.1358,  1.1872,  1.2385],
          [-0.1999, -0.1486, -0.1314,  ...,  1.1872,  1.2385,  1.2385],
          ...,
          [-0.6965, -0.7137, -0.7650,  ..., -0.1143, -0.1314, -0.1657],
          [-0.7308, -0.7479, -0.7822,  ..., -0.1486, -0.1486, -0.1828],
          [-0.7308, -0.7479, -0.7993,  ..., -0.1999, -0.2171, -0.2342]],

         [[-0.5651, -0.5476, -0.5651,  ...,  0.6429,  0.6779,  0.7129],
          [-0.6001, -0.5651, -0.5301,  ...,  0.6779,  0.7129,  0.7479],
          [-0.6176, -0.5651, -0.5476,  ...,  0.7304,  0.7479,  0.7654],
          ...,
          [-0.9153, -0.9328, -0.9678,  ..., -0.5826, -0.5826, -0.6176],
          [-0.9328, -0.9328, -0.9678,  ..., -0.6001, -0.6176, -0.6527],
          [-0.9328, -0.9328, -0.9853,  ..., -0.6176, -0.6527, -0.6702]],

         [[-0.6541, -0.6367, -0.6541,  ...,  0.5485,  0.5834,  0.6182],
          [-0.6715, -0.

Data:  tensor([[[[-0.7479, -0.8335, -1.1418,  ...,  2.0605,  2.1290,  2.1633],
          [-0.8849, -0.9192, -1.1760,  ...,  2.1462,  2.1633,  2.1633],
          [-1.0219, -1.0219, -1.0733,  ...,  2.0777,  2.1633,  2.1975],
          ...,
          [ 0.0227, -0.0972,  0.1426,  ...,  0.3994,  0.5364,  0.4166],
          [ 0.4679,  0.1597, -0.0972,  ...,  0.3481,  0.4679,  0.4337],
          [ 0.3138,  0.0056, -0.3541,  ...,  0.2967,  0.3481,  0.4508]],

         [[ 0.0651, -0.0399, -0.3550,  ...,  2.2010,  2.2710,  2.3235],
          [-0.0749, -0.1099, -0.3901,  ...,  2.2885,  2.3060,  2.3235],
          [-0.2150, -0.2150, -0.3025,  ...,  2.2010,  2.3235,  2.3585],
          ...,
          [ 0.5203,  0.4328,  0.6779,  ..., -0.2150, -0.0749, -0.1975],
          [ 0.9755,  0.6779,  0.4328,  ..., -0.2675, -0.1450, -0.1800],
          [ 0.8179,  0.5028,  0.1352,  ..., -0.3200, -0.2675, -0.1625]],

         [[-1.0724, -1.1770, -1.5256,  ...,  2.1694,  2.2740,  2.3263],
          [-1.2119, -1.

Data:  tensor([[[[-0.1999, -0.1999, -0.2171,  ..., -0.9020, -0.9363, -0.9534],
          [-0.2684, -0.1828, -0.1999,  ..., -0.0801, -0.1828, -0.2171],
          [-0.2171, -0.2342, -0.1828,  ...,  0.6563,  0.7248,  0.8104],
          ...,
          [-2.0152, -1.9124, -1.9980,  ..., -1.8953, -1.8097, -1.8610],
          [-2.0323, -1.9295, -1.9295,  ..., -1.7583, -1.7412, -1.8097],
          [-2.0323, -1.9809, -1.9124,  ..., -1.7925, -1.8782, -1.8953]],

         [[-0.1099, -0.1099, -0.1275,  ..., -1.0553, -1.0728, -1.0728],
          [-0.1800, -0.0924, -0.1099,  ..., -0.4951, -0.5826, -0.6352],
          [-0.1275, -0.1450, -0.0924,  ...,  0.5028,  0.5553,  0.6604],
          ...,
          [-1.6506, -1.5280, -1.5980,  ..., -1.4580, -1.3704, -1.4230],
          [-1.6331, -1.5280, -1.5280,  ..., -1.3179, -1.2829, -1.3529],
          [-1.6331, -1.5805, -1.5105,  ..., -1.3529, -1.4405, -1.4755]],

         [[ 0.1999,  0.1999,  0.1825,  ..., -0.9156, -0.9156, -0.9504],
          [ 0.1302,  0.

Data:  tensor([[[[-0.3369, -0.2684, -0.2342,  ..., -1.0219, -0.8678, -0.8164],
          [-0.3369, -0.3027, -0.3027,  ..., -1.4672, -1.2274, -1.1075],
          [-0.3883, -0.3883, -0.4054,  ..., -1.6555, -1.5357, -1.3987],
          ...,
          [-0.9192, -0.9705, -0.9877,  ..., -1.1418, -1.1589, -1.1760],
          [-0.9363, -0.9363, -0.9534,  ..., -1.1760, -1.2103, -1.1760],
          [-0.9534, -0.9020, -0.9020,  ..., -1.2103, -1.2445, -1.1589]],

         [[-0.5826, -0.6352, -0.6702,  ..., -1.3179, -1.1779, -1.1253],
          [-0.7752, -0.7752, -0.8277,  ..., -1.5280, -1.3354, -1.2829],
          [-0.8277, -0.8627, -0.9328,  ..., -1.7556, -1.6681, -1.5980],
          ...,
          [-1.5280, -1.5805, -1.5980,  ..., -1.6681, -1.6856, -1.7031],
          [-1.5455, -1.5455, -1.5630,  ..., -1.7031, -1.7381, -1.7031],
          [-1.5805, -1.5105, -1.5105,  ..., -1.7381, -1.7731, -1.6856]],

         [[-0.4101, -0.4450, -0.4973,  ..., -1.0027, -0.8458, -0.7936],
          [-0.7413, -0.

       device='cuda:0')
Output:  tensor([[ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        ...,
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [-0.1759,  0.0336, -0.0069,  ..., -0.7462, -0.9685, -0.6770]],
       device='cuda:0')
target:  tensor([ 28,  70,  86, 125, 114,  28,  24,  71], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[ 1.1872,  1.2385,  1.2385,  ...,  0.7762,  0.9646,  0.8961],
          [ 1.1187,  1.0331,  1.1529,  ...,  0.8961,  0.9303,  0.8447],
          [ 1.1872,  1.1015,  1.0844,  ...,  0.8276,  0.8618,  0.7591],
          ...,
          [ 0.8447,  1.0331,  0.8961,  ...,  0.5707,  0.5364,  0.4337],
  

       device='cuda:0')
Output:  tensor([[ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        ...,
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493]],
       device='cuda:0')
target:  tensor([100, 103, 123,   2, 119,  78,  49,  78], device='cuda:0')
tensor([[ 4],
        [ 4],
        [ 4],
        [52],
        [ 4],
        [ 4],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[ 0.5707,  0.5707,  0.5536,  ...,  2.0948,  2.1462,  2.1975],
          [ 0.5022,  0.5022,  0.5022,  ...,  2.0092,  2.0777,  2.1119],
          [ 0.4337,  0.4166,  0.4166,  ...,  1.9235,  2.0092,  2.0605],
          ...,
          [-0.4739,  0.0398,  0.1939,  ...,  0.2282,  0.4851,  0.6

Data:  tensor([[[[-0.8335, -0.8849, -0.9020,  ..., -0.9705, -0.9534, -0.9363],
          [-0.7137, -0.7308, -0.7479,  ..., -1.0390, -1.0390, -1.0219],
          [-0.6965, -0.7308, -0.7479,  ..., -1.1075, -1.1247, -1.1075],
          ...,
          [-1.6898, -1.5699, -1.6213,  ..., -0.9705, -0.7650, -1.1589],
          [-0.3712, -0.1143, -0.6965,  ..., -1.0904, -0.9877, -1.2617],
          [-0.6281, -0.2513, -0.9020,  ..., -0.8335, -1.3644, -1.2445]],

         [[-0.7227, -0.7752, -0.7752,  ..., -0.7752, -0.7752, -0.7752],
          [-0.7927, -0.7752, -0.7402,  ..., -0.7577, -0.7752, -0.7927],
          [-0.8277, -0.8102, -0.7577,  ..., -0.7927, -0.8102, -0.8102],
          ...,
          [-1.4405, -1.4405, -1.5805,  ..., -0.6176, -0.2850, -0.7227],
          [-0.0924,  0.0476, -0.6527,  ..., -0.6877, -0.5301, -0.9328],
          [-0.4076, -0.0399, -0.8803,  ..., -0.2850, -0.8627, -1.0028]],

         [[ 0.2348,  0.2348,  0.2173,  ...,  0.2173,  0.2871,  0.2871],
          [ 0.3393,  0.

Data:  tensor([[[[ 1.1529,  1.2214,  1.2557,  ..., -1.7754, -1.7754, -1.7925],
          [ 1.1529,  0.8447,  0.6906,  ..., -1.7754, -1.8439, -1.9124],
          [ 0.3309,  0.3823,  1.0331,  ..., -1.9467, -2.0152, -1.9809],
          ...,
          [ 1.3584,  1.3927,  1.3584,  ..., -1.6727, -1.6384, -1.6384],
          [ 1.3755,  1.3755,  1.3755,  ..., -1.6213, -1.6042, -1.6384],
          [ 1.2899,  1.3755,  1.3927,  ..., -1.5699, -1.5870, -1.6213]],

         [[ 1.1155,  1.2731,  1.3431,  ..., -1.6681, -1.6681, -1.6856],
          [ 1.2731,  0.9230,  0.7129,  ..., -1.6681, -1.7381, -1.8256],
          [ 0.1352,  0.0476,  0.7304,  ..., -1.8606, -1.9307, -1.8957],
          ...,
          [ 1.3606,  1.4132,  1.4132,  ..., -1.4930, -1.5105, -1.5455],
          [ 1.3782,  1.3782,  1.4132,  ..., -1.4755, -1.5280, -1.5455],
          [ 1.3081,  1.3782,  1.3957,  ..., -1.4405, -1.4930, -1.5105]],

         [[ 0.9668,  1.1237,  1.2108,  ..., -1.4036, -1.4036, -1.4210],
          [ 0.9842,  0.

Data:  tensor([[[[-0.4054, -0.3198, -0.3712,  ..., -1.1418, -1.2959, -1.3815],
          [-0.7650, -0.7308, -0.5767,  ..., -0.9192, -1.0733, -1.1932],
          [-0.6623, -0.7308, -0.5938,  ..., -0.6794, -0.8164, -1.0048],
          ...,
          [-0.6965, -0.7822, -0.7137,  ...,  1.2728,  1.6838,  1.5982],
          [-0.7308, -0.7822, -0.5767,  ...,  0.6734,  1.4098,  1.5982],
          [-0.7479, -0.5767, -0.4054,  ...,  0.0741,  0.6563,  1.3755]],

         [[-1.1078, -1.0203, -1.0553,  ..., -1.4230, -1.4755, -1.5105],
          [-1.4405, -1.4055, -1.2479,  ..., -1.2654, -1.3179, -1.3179],
          [-1.3004, -1.3704, -1.2304,  ..., -1.0728, -1.1078, -1.1429],
          ...,
          [-0.6001, -0.6877, -0.6176,  ...,  0.7829,  1.2556,  1.2381],
          [-0.6352, -0.6877, -0.4776,  ..., -0.0224,  0.8529,  1.1856],
          [-0.6527, -0.4776, -0.3025,  ..., -0.8102, -0.0574,  0.8880]],

         [[-1.3164, -1.2467, -1.3164,  ..., -1.3861, -1.4210, -1.4036],
          [-1.6476, -1.

       device='cuda:0')
Output:  tensor([[ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        ...,
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2053,  0.0922, -0.0137,  ..., -0.5548, -0.6618, -0.5522]],
       device='cuda:0')
target:  tensor([ 97,  96,  14, 111,  25,  23,  85,  95], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[-1.4843, -1.6213, -1.6384,  ...,  0.9132,  1.3584,  1.3070],
          [-1.5185, -1.6384, -1.6213,  ...,  0.6049,  1.3584,  1.3927],
          [-1.5870, -1.6898, -1.6555,  ...,  0.8618,  1.3927,  1.5639],
          ...,
          [-1.7583, -1.8953, -1.9124,  ..., -0.4568, -0.7137, -0.3883],
  

       device='cuda:0')
Output:  tensor([[ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        ...,
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493]],
       device='cuda:0')
target:  tensor([115,   4,  33, 108,  64,  25,  95,  38], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[-0.3541, -0.2684, -0.1486,  ..., -0.8164, -0.5082, -0.5253],
          [-0.4054, -0.2171,  0.0398,  ..., -0.7822, -0.4568, -0.5253],
          [-0.5424, -0.2513,  0.0912,  ..., -0.7308, -0.4568, -0.5424],
          ...,
          [ 1.1015,  1.0673,  1.1015,  ...,  0.7077,  0.4851,  0.5193],
  

       device='cuda:0')
Output:  tensor([[ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.1355,  0.0560, -0.0056,  ..., -0.6164, -0.7112, -0.5219],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        ...,
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493]],
       device='cuda:0')
target:  tensor([ 35,  81, 128,   6, 110,  39,  36,  22], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[ 0.2282,  0.2282,  0.3652,  ..., -1.4158, -1.4500, -1.5870],
          [ 0.3138,  0.2624,  0.3823,  ..., -1.0562, -1.1418, -1.2959],
          [ 0.3823,  0.2796,  0.3652,  ..., -0.7993, -0.9363, -1.0733],
          ...,
          [-0.4568, -0.4739, -0.5082,  ..., -0.3883, -0.4911, -0.5938],
  

Data:  tensor([[[[-0.4054, -0.3541, -0.2342,  ..., -1.5014, -0.4054, -0.0972],
          [-0.4226, -0.1314,  0.0569,  ..., -1.1932, -0.1999, -0.1999],
          [-0.0629,  0.0569,  0.2453,  ..., -0.8849, -0.2171, -0.1828],
          ...,
          [-0.4054, -0.4739, -0.6965,  ..., -0.1999, -0.4739, -0.3198],
          [-0.5253, -0.4054, -0.6965,  ..., -0.0116, -0.4054, -0.7479],
          [-1.0904, -0.9534, -1.1760,  ..., -0.2171,  0.1768, -0.2171]],

         [[ 0.1877,  0.2052,  0.2752,  ..., -1.0378,  0.1877,  0.5903],
          [ 0.1702,  0.3978,  0.5378,  ..., -0.7227,  0.3803,  0.4853],
          [ 0.5553,  0.5903,  0.6779,  ..., -0.4076,  0.3803,  0.5028],
          ...,
          [ 0.3102,  0.2577,  0.0476,  ..., -0.1625, -0.4776, -0.3550],
          [ 0.2052,  0.3452,  0.0826,  ...,  0.0476, -0.3901, -0.7752],
          [-0.3025, -0.1450, -0.3550,  ..., -0.1275,  0.2402, -0.1975]],

         [[-0.8981, -0.8284, -0.6193,  ..., -1.2816, -0.2881, -0.0615],
          [-0.8633, -0.

Data:  tensor([[[[ 0.3309,  0.2624,  0.1254,  ..., -0.0629, -0.1486, -0.0972],
          [ 0.1768,  0.1254,  0.0227,  ..., -0.0287, -0.0458,  0.0569],
          [ 0.0398,  0.0398, -0.0116,  ...,  0.0398, -0.0116,  0.0569],
          ...,
          [ 0.7248,  1.0159,  0.8276,  ..., -0.3369, -0.3712, -0.4911],
          [ 0.5707,  1.1700,  0.9817,  ..., -0.5767, -0.4911, -0.4054],
          [ 0.5193,  0.9474,  0.9303,  ..., -0.6794, -0.5424, -0.3883]],

         [[ 0.7829,  0.7129,  0.5728,  ...,  0.4328,  0.3978,  0.4678],
          [ 0.6954,  0.6604,  0.5728,  ...,  0.4853,  0.5028,  0.6078],
          [ 0.6078,  0.6078,  0.5553,  ...,  0.5553,  0.5203,  0.5903],
          ...,
          [ 0.8880,  1.0630,  0.7829,  ...,  0.3102,  0.2752,  0.1877],
          [ 0.7304,  1.2381,  0.9580,  ...,  0.0651,  0.1702,  0.2577],
          [ 0.6429,  0.9930,  0.9055,  ..., -0.0749,  0.1001,  0.2402]],

         [[ 0.7576,  0.6879,  0.5311,  ...,  0.3568,  0.2696,  0.3045],
          [ 0.6008,  0.

Data:  tensor([[[[-0.0458, -0.3541, -0.8335,  ..., -0.8507, -0.9877, -1.0562],
          [-0.7308, -0.9705, -0.8849,  ..., -0.9877, -1.0562, -0.9192],
          [-1.1760, -0.9705, -0.9877,  ..., -0.7993, -1.0562, -0.9534],
          ...,
          [ 0.5707,  0.0912, -0.5596,  ..., -0.5253, -0.5938, -0.7822],
          [ 0.2453, -0.5253, -1.0219,  ...,  0.3652,  0.3309,  0.0912],
          [-0.3541, -0.8678, -0.9705,  ..., -0.0287, -0.0458,  0.1254]],

         [[ 0.8354,  0.5553,  0.0651,  ...,  0.0651, -0.0924, -0.0924],
          [ 0.1527, -0.0924, -0.0574,  ...,  0.0126, -0.0574,  0.0826],
          [-0.3550, -0.1975, -0.2500,  ...,  0.1176, -0.0924,  0.0301],
          ...,
          [ 1.5357,  0.9405,  0.0651,  ...,  0.4153,  0.3277,  0.1527],
          [ 1.1506,  0.2052, -0.4426,  ...,  1.1331,  1.1331,  0.9405],
          [ 0.5378,  0.0301, -0.1450,  ...,  0.8179,  0.8179,  1.0630]],

         [[-0.1661, -0.4275, -0.8284,  ..., -0.7413, -0.8981, -0.9678],
          [-0.6541, -0.

Data:  tensor([[[[-1.6555, -1.7069, -1.7754,  ..., -0.8507, -1.1589, -1.2617],
          [-1.5699, -1.6213, -1.6898,  ..., -0.8507, -1.1932, -1.2788],
          [-1.5357, -1.5699, -1.6555,  ..., -0.9363, -1.2103, -1.3130],
          ...,
          [-1.0733, -1.1075, -1.0733,  ..., -1.6727, -1.6213, -1.6727],
          [-1.1589, -1.0733, -1.1760,  ..., -1.6042, -1.5185, -1.5014],
          [-1.0733, -1.2445, -1.2788,  ..., -1.5014, -1.4672, -1.4843]],

         [[-1.5105, -1.4755, -1.4930,  ..., -1.1779, -1.4580, -1.4580],
          [-1.3529, -1.3179, -1.3529,  ..., -1.1779, -1.4580, -1.4755],
          [-1.3004, -1.3004, -1.3529,  ..., -1.2829, -1.4580, -1.4930],
          ...,
          [-1.3179, -1.3529, -1.3354,  ..., -1.5630, -1.5280, -1.5455],
          [-1.3704, -1.3529, -1.4580,  ..., -1.5455, -1.4755, -1.4580],
          [-1.3179, -1.4580, -1.4580,  ..., -1.4405, -1.3880, -1.3880]],

         [[-1.0201, -1.1073, -1.1596,  ..., -1.2119, -1.5430, -1.3861],
          [-0.8633, -0.

       device='cuda:0')
Output:  tensor([[ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        ...,
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493]],
       device='cuda:0')
target:  tensor([ 55,  30,  78,  15, 117,  49, 106,  11], device='cuda:0')
tensor([[ 4],
        [ 4],
        [ 4],
        [79],
        [ 4],
        [ 4],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[ 1.8722,  1.8379,  1.7865,  ...,  2.2489,  2.2489,  2.2489],
          [ 1.8208,  1.7865,  1.7694,  ...,  2.2489,  2.2318,  2.2318],
          [ 1.8550,  1.7523,  1.7523,  ...,  2.2318,  2.2489,  2.2318],
          ...,
          [-1.0562, -0.8849, -0.7993,  ..., -0.5767, -0.2513, -0.2

       device='cuda:0')
Output:  tensor([[  0.2058,   0.0920,  -0.0145,  ...,  -0.5479,  -0.6645,  -0.5493],
        [ -7.7292,  -7.9591,  -9.2863,  ..., -13.3161,  -6.2616,  -4.0299],
        [  0.2058,   0.0920,  -0.0145,  ...,  -0.5479,  -0.6645,  -0.5493],
        ...,
        [  0.2058,   0.0920,  -0.0145,  ...,  -0.5479,  -0.6645,  -0.5493],
        [  0.2058,   0.0920,  -0.0145,  ...,  -0.5479,  -0.6645,  -0.5493],
        [  0.1881,   0.0974,   0.0140,  ...,  -0.7852,  -0.5701,  -0.6498]],
       device='cuda:0')
target:  tensor([ 45,  39,  48,  66, 112,  97,  21,  86], device='cuda:0')
tensor([[ 4],
        [ 5],
        [ 4],
        [ 4],
        [26],
        [ 4],
        [ 4],
        [26]], device='cuda:0')
Data:  tensor([[[[-0.7650, -0.8164, -0.8507,  ...,  1.2557,  1.2557,  1.2385],
          [-0.8164, -0.8678, -0.9020,  ...,  2.0263,  2.0434,  2.0605],
          [-0.8507, -0.9192, -0.9363,  ...,  2.2318,  2.2489,  2.2489],
          ...,
          [-1.6213, -1.6042, -

       device='cuda:0')
target:  tensor([ 7,  2, 86, 48, 37, 29, 43, 57], device='cuda:0')
tensor([[ 4],
        [ 4],
        [ 4],
        [ 4],
        [10],
        [ 4],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[-1.7240, -1.7754, -1.7583,  ..., -1.7925, -1.3815, -0.5938],
          [-1.8268, -1.8097, -1.7583,  ..., -1.7925, -1.3302, -0.8849],
          [-1.8097, -1.8097, -1.7925,  ..., -1.7925, -1.6384, -1.5699],
          ...,
          [-0.8164, -0.7308, -0.7137,  ..., -1.8782, -1.8610, -1.8610],
          [-0.8849, -0.7993, -0.4226,  ..., -1.9467, -1.8782, -1.8953],
          [-0.4911, -0.9877, -0.9534,  ..., -1.9467, -1.9124, -1.8953]],

         [[-1.5455, -1.7031, -1.6331,  ..., -1.7031, -1.0378,  0.0826],
          [-1.7731, -1.7556, -1.6856,  ..., -1.7206, -0.9503, -0.4076],
          [-1.7731, -1.7381, -1.7381,  ..., -1.7381, -1.4230, -1.3354],
          ...,
          [-0.0924, -0.0049,  0.0476,  ..., -1.8431, -1.8081, -1.8081],
          [-0.1625, 

Data:  tensor([[[[-1.1760, -1.2617, -1.3815,  ..., -1.3302, -1.2959, -1.4843],
          [-1.0390, -1.2445, -1.4500,  ..., -1.0562, -1.3815, -1.8268],
          [-1.2274, -1.3302, -1.4500,  ..., -1.0048, -1.3473, -1.9295],
          ...,
          [-0.0458, -0.1314, -0.0287,  ...,  0.3481,  0.0398,  0.1939],
          [ 0.1426, -0.1657,  0.1768,  ...,  0.6563,  0.3994,  0.0227],
          [ 0.2967,  0.2624,  0.4337,  ...,  1.0502,  0.7933,  0.1426]],

         [[-0.4076, -0.5301, -0.7052,  ..., -0.9678, -0.9853, -1.2129],
          [-0.2325, -0.5476, -0.8102,  ..., -0.6176, -1.0378, -1.5805],
          [-0.4776, -0.6702, -0.8277,  ..., -0.4601, -0.9328, -1.6856],
          ...,
          [ 0.8529,  0.8704,  0.9405,  ...,  0.2402, -0.0399,  0.1877],
          [ 0.9055,  0.8004,  1.1506,  ...,  0.3627,  0.2227,  0.0651],
          [ 1.0805,  1.1856,  1.3957,  ...,  0.6604,  0.5203,  0.0651]],

         [[-1.2293, -1.3339, -1.3513,  ..., -1.0027, -0.9678, -1.1596],
          [-1.2119, -1.

Data:  tensor([[[[-1.7754, -1.6555, -1.3130,  ..., -1.8782, -1.3987, -1.3130],
          [-1.3815, -1.5870, -1.7240,  ..., -1.8610, -1.7754, -1.4672],
          [-1.3987, -1.6213, -1.7583,  ..., -1.8953, -1.8439, -1.3130],
          ...,
          [-1.1075, -0.9877, -0.9877,  ..., -1.7754, -1.6727, -1.6213],
          [-0.9534, -0.6623, -0.9534,  ..., -1.3302, -1.4329, -1.7069],
          [-1.2103, -0.9363, -1.3130,  ..., -1.0390, -1.1760, -1.4500]],

         [[-1.3354, -1.0903, -0.6702,  ..., -1.5805, -1.1429, -1.1078],
          [-0.8452, -0.9503, -1.0378,  ..., -1.5105, -1.4930, -1.2129],
          [-0.7927, -0.8978, -0.9678,  ..., -1.5455, -1.5280, -1.0553],
          ...,
          [-0.4251, -0.3025, -0.3200,  ..., -1.0028, -0.8803, -0.8277],
          [-0.1800,  0.1176, -0.1975,  ..., -0.5301, -0.6352, -0.9153],
          [-0.3550, -0.0924, -0.4776,  ..., -0.2325, -0.3725, -0.6527]],

         [[-1.5953, -1.4036, -1.0027,  ..., -1.6999, -1.2119, -1.1247],
          [-1.1421, -1.

Data:  tensor([[[[-1.4329, -1.4158, -1.3644,  ..., -1.5528, -1.5528, -1.5185],
          [-1.3987, -1.3987, -1.3815,  ..., -1.5357, -1.5357, -1.5014],
          [-1.3815, -1.3644, -1.3987,  ..., -1.5357, -1.5014, -1.4843],
          ...,
          [ 0.2453,  0.2111,  0.1597,  ...,  0.1768,  0.2282,  0.1768],
          [ 0.2624,  0.2453,  0.1597,  ...,  0.2282,  0.1768,  0.1254],
          [ 0.2282,  0.2282,  0.1426,  ...,  0.1426,  0.1597,  0.1768]],

         [[-1.6856, -1.6681, -1.6155,  ..., -1.8081, -1.8256, -1.8081],
          [-1.6506, -1.6506, -1.6331,  ..., -1.8431, -1.8431, -1.8081],
          [-1.6331, -1.6155, -1.6506,  ..., -1.8606, -1.8081, -1.7906],
          ...,
          [-0.3200, -0.3375, -0.3725,  ..., -0.4601, -0.4076, -0.4426],
          [-0.3200, -0.3375, -0.4076,  ..., -0.3725, -0.4251, -0.4251],
          [-0.3550, -0.3550, -0.4251,  ..., -0.4426, -0.4426, -0.3901]],

         [[-1.6127, -1.5953, -1.5430,  ..., -1.7347, -1.7522, -1.7173],
          [-1.5779, -1.

Data:  tensor([[[[ 0.4166,  0.2796, -0.1314,  ...,  0.0741, -0.0458,  0.0227],
          [ 0.2967,  0.2111, -0.1999,  ..., -0.0287,  0.0056,  0.0741],
          [ 0.1083, -0.1143, -0.5082,  ...,  0.0912,  0.1083,  0.1083],
          ...,
          [ 0.4508,  0.1426,  0.2967,  ...,  0.1254,  0.1597, -0.0458],
          [-0.1657,  0.0227,  0.3481,  ...,  0.0227,  0.3994,  0.4508],
          [ 0.0912, -0.1828, -0.0972,  ...,  0.0227,  0.4508,  0.3823]],

         [[ 1.2556,  1.1155,  0.5378,  ...,  1.2906,  1.2906,  1.3431],
          [ 1.1155,  1.0280,  0.4153,  ...,  1.1681,  1.2906,  1.3606],
          [ 0.8529,  0.5028, -0.1450,  ...,  1.2731,  1.2906,  1.2731],
          ...,
          [ 1.6758,  1.1331,  1.3431,  ...,  1.2731,  1.3431,  0.8529],
          [ 0.8704,  0.9405,  1.2556,  ...,  1.1506,  1.5707,  1.4832],
          [ 1.1506,  0.8004,  0.9930,  ...,  1.1331,  1.3957,  1.3081]],

         [[ 0.1825,  0.2348, -0.2881,  ..., -0.0790, -0.3404, -0.0615],
          [ 0.1999,  0.

Data:  tensor([[[[-0.0972, -0.1314, -0.1657,  ..., -1.2959, -1.2788, -1.2445],
          [-0.1486, -0.1143, -0.0287,  ..., -1.4158, -1.3815, -1.2617],
          [-0.2342, -0.2342, -0.1314,  ..., -1.3473, -1.2617, -1.2445],
          ...,
          [-0.4397,  0.1426,  0.4679,  ...,  0.0056, -0.2342, -0.4226],
          [-0.1657, -0.3541,  0.3138,  ..., -0.0801,  0.3138,  0.3138],
          [ 0.0741,  0.0569, -0.0458,  ..., -0.8678, -0.2171,  0.3138]],

         [[ 1.0805,  1.0630,  1.0805,  ..., -0.1975, -0.1625, -0.1099],
          [ 1.0630,  1.1155,  1.2206,  ..., -0.3550, -0.2850, -0.1450],
          [ 1.0105,  1.0105,  1.1155,  ..., -0.2850, -0.1975, -0.1450],
          ...,
          [ 0.7304,  1.3782,  1.6758,  ...,  1.2381,  0.8354,  0.2927],
          [ 1.1155,  0.9055,  1.6057,  ...,  1.1856,  1.4307,  1.3606],
          [ 1.3431,  1.3256,  1.2906,  ...,  0.3803,  0.7829,  1.5182]],

         [[ 0.4962,  0.3916,  0.3045,  ..., -0.7587, -0.7413, -0.7413],
          [ 0.4265,  0.

Data:  tensor([[[[-1.7069, -1.6727, -1.6555,  ..., -1.3815, -1.3473, -1.3302],
          [-1.6727, -1.6384, -1.6384,  ..., -1.3644, -1.3644, -1.3644],
          [-1.6555, -1.6042, -1.6213,  ..., -1.3473, -1.3473, -1.3302],
          ...,
          [-0.1143, -0.0801, -0.0116,  ...,  0.8961,  0.8447,  0.7762],
          [-0.1314, -0.0287, -0.0116,  ...,  0.8104,  0.7248,  0.7077],
          [-0.0801, -0.0287, -0.0287,  ...,  0.8276,  0.7762,  0.7419]],

         [[-1.7906, -1.7381, -1.7206,  ..., -1.5805, -1.5805, -1.5805],
          [-1.7381, -1.7031, -1.7031,  ..., -1.5630, -1.5630, -1.5980],
          [-1.7381, -1.7381, -1.7381,  ..., -1.5805, -1.5630, -1.5630],
          ...,
          [-0.2850, -0.2150, -0.1275,  ...,  0.7129,  0.6779,  0.6078],
          [-0.2675, -0.1800, -0.1450,  ...,  0.6604,  0.5903,  0.5728],
          [-0.2325, -0.1975, -0.1975,  ...,  0.6779,  0.6254,  0.5903]],

         [[-1.5953, -1.5953, -1.5953,  ..., -1.4733, -1.4384, -1.4210],
          [-1.5779, -1.

target:  tensor([ 11,  51, 102,  13,  37,  50,  15,  15], device='cuda:0')
tensor([[79],
        [ 4],
        [ 4],
        [17],
        [ 4],
        [ 4],
        [ 4],
        [52]], device='cuda:0')
Data:  tensor([[[[-0.5253, -0.4054, -0.5938,  ..., -0.5938, -0.6794, -0.6109],
          [-0.5424, -0.5082, -0.8164,  ..., -0.3712, -0.8335, -0.9363],
          [-0.6452, -0.4739, -0.6794,  ...,  0.1254, -0.0116,  0.1597],
          ...,
          [-0.7822, -0.3369, -0.2856,  ..., -0.9020, -0.7479, -0.4568],
          [-0.5424, -0.3369, -0.4911,  ..., -0.3883, -0.3027, -0.1657],
          [-0.3541, -0.3883, -0.6452,  ..., -0.6623, -0.3883, -0.3027]],

         [[-0.7052, -0.5651, -0.7752,  ..., -0.7927, -0.8277, -0.7402],
          [-0.7227, -0.6702, -0.9853,  ..., -0.5126, -0.9503, -1.0553],
          [-0.8277, -0.6176, -0.8102,  ...,  0.0126, -0.0749,  0.0826],
          ...,
          [-1.0728, -0.6176, -0.5651,  ..., -1.2129, -1.0203, -0.7227],
          [-0.8277, -0.6176, -0.7752

        [4]], device='cuda:0')
Data:  tensor([[[[-0.7308, -0.0972, -0.4739,  ..., -0.0287,  0.0741,  0.1254],
          [-0.4226, -0.4397, -0.9363,  ...,  0.0398, -0.3198, -0.5596],
          [-0.5767, -1.1589, -1.3473,  ..., -0.2171,  0.3652, -0.0972],
          ...,
          [-0.7650, -1.0562, -1.2617,  ..., -1.2103, -0.8849, -0.7479],
          [-0.8164, -0.5253, -0.4911,  ..., -0.9363, -0.5767, -0.6109],
          [-1.2445, -0.8164, -0.1314,  ..., -0.7822, -0.6452, -0.6281]],

         [[-0.5301,  0.2052, -0.0224,  ..., -0.1099,  0.1001,  0.2577],
          [ 0.0126,  0.0126, -0.5826,  ..., -0.0749, -0.2850, -0.4426],
          [-0.1450, -0.8452, -1.1429,  ..., -0.2325,  0.2577, -0.1975],
          ...,
          [-0.2500, -0.5651, -0.8102,  ..., -0.7402, -0.3725, -0.2325],
          [-0.3550, -0.0224,  0.0651,  ..., -0.4601, -0.0924, -0.1625],
          [-0.8277, -0.3901,  0.3978,  ..., -0.2850, -0.1625, -0.1800]],

         [[-0.3578,  0.3916,  0.0605,  ..., -0.6541, -0.2532, -0

target:  tensor([ 85,  33,  59,  89,  36, 126, 112,  48], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [0],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[-1.0219, -0.7993, -0.6281,  ...,  0.3309,  0.3994,  0.3823],
          [-0.5253, -0.3027, -0.2513,  ...,  0.2796,  0.3481,  0.3652],
          [-0.3541, -0.1828, -0.2171,  ...,  0.1939,  0.2453,  0.3309],
          ...,
          [-0.6794, -0.7137, -0.5767,  ...,  0.2624,  0.2967,  0.2453],
          [-0.6281, -0.6965, -0.7650,  ...,  0.3309,  0.2796,  0.3309],
          [-0.4739, -0.4397, -0.4226,  ...,  0.4337,  0.3138,  0.1597]],

         [[ 0.3102,  0.5028,  0.5553,  ..., -0.3725, -0.3901, -0.4076],
          [ 0.6954,  0.8704,  0.8704,  ..., -0.4426, -0.4426, -0.4251],
          [ 0.6604,  0.8004,  0.7479,  ..., -0.4951, -0.4951, -0.4251],
          ...,
          [-0.6702, -0.7577, -0.6176,  ...,  0.3102,  0.3627,  0.3102],
          [-0.5826, -0.6702, -0.7577,  ..., 

Data:  tensor([[[[-0.0116,  0.0741,  0.0056,  ...,  0.6049,  0.2282,  0.1083],
          [ 0.0741,  0.3994,  0.2111,  ...,  0.6563,  0.3481,  0.2111],
          [-0.0801, -0.0458, -0.1143,  ...,  0.4679, -0.0287, -0.4054],
          ...,
          [-0.4226, -0.5424, -0.7822,  ..., -0.4568, -0.5082, -0.6281],
          [-0.2513, -0.2684, -0.4739,  ..., -0.9363, -0.9877, -0.7650],
          [-0.2856, -0.1657, -0.3541,  ..., -0.9877, -1.1075, -1.0048]],

         [[ 0.6078,  0.5553,  0.4153,  ...,  1.2381,  0.8354,  0.7829],
          [ 0.6078,  0.8880,  0.6254,  ...,  1.2206,  1.0105,  0.9755],
          [ 0.3978,  0.4678,  0.4503,  ...,  1.0280,  0.6078,  0.3102],
          ...,
          [ 0.1176,  0.2752, -0.0399,  ...,  0.1527,  0.1527,  0.0301],
          [ 0.2927,  0.5378,  0.2402,  ..., -0.3375, -0.2675, -0.1099],
          [ 0.2402,  0.4678,  0.3102,  ..., -0.2325, -0.2850, -0.3550]],

         [[-0.5670, -0.4798, -0.5321,  ...,  0.0605, -0.4101, -0.6715],
          [-0.6367, -0.

Data:  tensor([[[[-1.1075, -1.0733, -0.9877,  ..., -1.0904,  0.3309,  1.5125],
          [-1.2788, -1.2103, -1.0048,  ..., -1.3130, -0.3541,  1.2385],
          [-0.9020, -1.1418, -1.0904,  ..., -1.3302, -1.1247, -0.2856],
          ...,
          [-0.4397, -0.5253, -0.6794,  ...,  0.2453,  0.2624, -0.1143],
          [-0.4054, -0.5767, -0.4226,  ..., -0.0458, -0.0629, -0.2342],
          [-0.3198, -0.5596, -0.5767,  ...,  0.0056,  0.0227, -0.1486]],

         [[-0.7402, -0.7052, -0.6176,  ..., -1.2129,  0.2577,  1.4832],
          [-0.9153, -0.8627, -0.6352,  ..., -1.3704, -0.3901,  1.2381],
          [-0.5476, -0.7927, -0.7402,  ..., -1.3354, -1.1253, -0.3200],
          ...,
          [-0.0924, -0.1450, -0.2675,  ...,  0.9405,  0.9580,  0.5028],
          [-0.0924, -0.1625,  0.0476,  ...,  0.6604,  0.6078,  0.3803],
          [-0.0224, -0.1450, -0.1275,  ...,  0.7129,  0.7129,  0.4678]],

         [[-0.9330, -0.8633, -0.7936,  ..., -1.1770,  0.3219,  1.6117],
          [-1.0376, -0.

Output:  tensor([[ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        ...,
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493]],
       device='cuda:0')
target:  tensor([117,  36,  75, 111,  12,  14,  19,  95], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[ 0.3994,  0.8276,  0.8104,  ...,  0.5536,  0.7419,  0.7077],
          [ 0.5878,  0.7933,  0.7933,  ...,  0.2453,  0.4508,  0.5878],
          [ 0.3994,  0.5707,  0.7419,  ...,  0.2111,  0.1939,  0.4337],
          ...,
          [-0.2513, -0.2856,  0.3481,  ...,  0.6221,  0.5364,  0.0741],
          [ 0.0056, -0.388

         -6.6454e-01, -5.4926e-01]], device='cuda:0')
target:  tensor([118,   3, 128,   9,  40, 102,  69,  60], device='cuda:0')
tensor([[  4],
        [  4],
        [  4],
        [  4],
        [  4],
        [  4],
        [106],
        [  4]], device='cuda:0')
Data:  tensor([[[[-2.0665, -2.0665, -2.0494,  ..., -1.8610, -1.8610, -1.8439],
          [-2.0494, -2.0494, -2.0494,  ..., -1.8610, -1.8439, -1.8439],
          [-2.0323, -2.0494, -2.0494,  ..., -1.8439, -1.8439, -1.8268],
          ...,
          [ 0.5193,  0.4679,  0.4508,  ..., -0.8164, -0.8164, -0.8335],
          [ 0.5193,  0.4508,  0.4508,  ..., -0.7137, -0.7479, -0.7993],
          [ 0.5536,  0.4851,  0.4508,  ..., -0.1828, -0.2171, -0.3027]],

         [[-1.9482, -1.9482, -1.9307,  ..., -1.7906, -1.7906, -1.7731],
          [-1.9307, -1.9307, -1.9307,  ..., -1.7906, -1.7731, -1.7731],
          [-1.9132, -1.9307, -1.9307,  ..., -1.7731, -1.7731, -1.7556],
          ...,
          [ 0.2227,  0.1352,  0.1001,  ..., -1

Data:  tensor([[[[-0.7822, -1.1932, -0.7822,  ..., -0.7650, -0.6452, -0.8164],
          [-0.5253, -0.8164, -0.6281,  ..., -1.0390, -0.8678, -0.7650],
          [-0.5082, -0.4397, -0.2513,  ..., -0.9363, -0.4739, -0.6109],
          ...,
          [-0.8507, -0.8335, -0.5082,  ..., -1.3130, -0.5253, -0.9192],
          [-0.8678, -1.1760, -0.9705,  ..., -1.3815, -0.6965, -0.6623],
          [-0.7822, -1.0390, -1.1589,  ..., -1.0733, -1.0048, -0.7822]],

         [[-0.0049, -0.3550, -0.0749,  ...,  0.0651,  0.1352, -0.1099],
          [ 0.2577,  0.0126,  0.1176,  ..., -0.2675, -0.1275, -0.0399],
          [ 0.2752,  0.3452,  0.4328,  ..., -0.2150,  0.2402,  0.0651],
          ...,
          [-0.5301, -0.5301, -0.1800,  ..., -0.7752,  0.1702, -0.2325],
          [-0.4426, -0.7927, -0.6352,  ..., -0.6001,  0.1352,  0.0826],
          [-0.2500, -0.4776, -0.7402,  ..., -0.0924, -0.1450, -0.0574]],

         [[-1.4384, -1.6824, -1.5256,  ..., -1.5604, -1.3513, -1.4384],
          [-1.1944, -1.

Data:  tensor([[[[ 1.7180,  2.0777,  2.2147,  ...,  0.4337,  0.2111,  0.3481],
          [ 1.6153,  1.9749,  2.1633,  ...,  0.6392,  0.5878,  0.3309],
          [ 1.8722,  2.0948,  2.1462,  ...,  0.6563,  0.8961,  0.3994],
          ...,
          [-0.5082, -0.4568, -0.3712,  ..., -1.4329, -1.3473, -1.5699],
          [ 0.0227,  0.1254,  0.2111,  ..., -1.5357, -1.3815, -1.4500],
          [ 0.2453,  0.4166,  0.5193,  ..., -1.5870, -1.5014, -1.4672]],

         [[ 1.1681,  1.6232,  2.1660,  ...,  0.9930,  0.8354,  0.8704],
          [ 1.1681,  1.5182,  1.9384,  ...,  1.0805,  0.9930,  0.8354],
          [ 1.6758,  1.8683,  1.9734,  ...,  1.1681,  1.2206,  0.7479],
          ...,
          [-0.6877, -0.6352, -0.5476,  ..., -1.3880, -1.3354, -1.5105],
          [-0.3025, -0.1275, -0.0049,  ..., -1.5105, -1.3354, -1.3880],
          [-0.0049,  0.2227,  0.3452,  ..., -1.5105, -1.4405, -1.4055]],

         [[ 1.0714,  1.6814,  2.2914,  ...,  0.5136,  0.4265,  0.5136],
          [ 1.0539,  1.

       device='cuda:0')
target:  tensor([52, 56, 77, 75, 47, 46, 50,  2], device='cuda:0')
tensor([[ 4],
        [ 4],
        [26],
        [ 4],
        [ 4],
        [ 4],
        [ 4],
        [ 4]], device='cuda:0')
Data:  tensor([[[[ 1.2899,  1.3070,  1.2899,  ...,  0.9646,  0.9817,  0.9988],
          [ 1.2728,  1.2728,  1.3070,  ...,  0.9646,  0.9474,  0.9474],
          [ 1.2728,  1.2728,  1.3070,  ...,  0.9817,  0.9817,  0.9817],
          ...,
          [ 1.2899,  1.2728,  1.2043,  ...,  0.3823, -0.1486, -0.2856],
          [ 1.2557,  1.2385,  1.2728,  ...,  0.3481,  0.0227,  0.2111],
          [ 1.2043,  1.1529,  1.0502,  ...,  0.2796,  0.2282,  0.7762]],

         [[ 1.3256,  1.3431,  1.3256,  ...,  0.8529,  0.8704,  0.8704],
          [ 1.3081,  1.3081,  1.3431,  ...,  0.8529,  0.8354,  0.8179],
          [ 1.3081,  1.3081,  1.3431,  ...,  0.8704,  0.8704,  0.8529],
          ...,
          [ 1.8333,  1.7633,  1.7283,  ...,  0.4678,  0.0126, -0.0049],
          [ 1.7283, 

Data:  tensor([[[[-1.2445, -1.2103, -1.1932,  ...,  2.0777,  2.0948,  2.0777],
          [-1.2959, -1.2445, -1.2617,  ...,  2.1633,  2.1975,  2.1462],
          [-1.2617, -1.2788, -1.2617,  ...,  1.8037,  1.8379,  1.8550],
          ...,
          [-0.2856, -0.2513, -0.2171,  ...,  0.4166,  0.4166,  0.3994],
          [-0.3369, -0.3541, -0.3712,  ...,  0.3309,  0.3481,  0.3309],
          [-0.3027, -0.3712, -0.3712,  ...,  0.2624,  0.2796,  0.2624]],

         [[-1.3004, -1.3354, -1.3354,  ...,  2.2360,  2.2535,  2.2360],
          [-1.3179, -1.3354, -1.3004,  ...,  2.3235,  2.3585,  2.3060],
          [-1.3354, -1.3004, -1.2654,  ...,  1.9384,  2.0084,  2.0259],
          ...,
          [-0.1450, -0.1099, -0.0574,  ...,  0.5903,  0.5903,  0.5728],
          [-0.2150, -0.1975, -0.1975,  ...,  0.5378,  0.5203,  0.5028],
          [-0.1975, -0.1975, -0.2150,  ...,  0.4678,  0.4328,  0.4328]],

         [[-1.4036, -1.4210, -1.3513,  ...,  2.5354,  2.5529,  2.5180],
          [-1.4559, -1.

Data:  tensor([[[[ 1.5810,  1.5982,  1.6324,  ..., -0.8849, -0.7993, -0.8164],
          [ 1.6495,  1.6324,  1.6667,  ..., -0.8507, -0.7822, -0.8507],
          [ 1.6495,  1.6495,  1.6495,  ..., -0.8507, -0.7822, -0.8164],
          ...,
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489]],

         [[ 1.6057,  1.6583,  1.6933,  ..., -0.4601, -0.4076, -0.4251],
          [ 1.6583,  1.7108,  1.7458,  ..., -0.4426, -0.4076, -0.4776],
          [ 1.6583,  1.7283,  1.7283,  ..., -0.4251, -0.4076, -0.4601],
          ...,
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286]],

         [[ 2.0300,  2.0300,  2.0648,  ...,  0.1651,  0.2348,  0.2173],
          [ 2.0648,  2.

Data:  tensor([[[[ 1.7865,  1.7694,  1.7694,  ..., -0.8678, -0.7479, -0.5082],
          [ 1.8037,  1.7865,  1.7865,  ..., -0.9877, -0.9363, -0.7650],
          [ 1.8037,  1.8037,  1.7865,  ..., -1.1075, -1.1075, -1.0390],
          ...,
          [ 1.4954,  1.4783,  1.4783,  ...,  1.1358,  1.3070,  1.2214],
          [ 1.4954,  1.4783,  1.4783,  ...,  0.8618,  0.9817,  1.1015],
          [ 1.4954,  1.4612,  1.4612,  ...,  1.0331,  1.2214,  1.3413]],

         [[ 1.9384,  1.9209,  1.9209,  ..., -0.9678, -0.8452, -0.6001],
          [ 1.9559,  1.9384,  1.9384,  ..., -1.0903, -1.0378, -0.8803],
          [ 1.9559,  1.9559,  1.9384,  ..., -1.2129, -1.2129, -1.1429],
          ...,
          [ 1.6583,  1.6408,  1.6408,  ...,  1.0630,  1.2731,  1.2031],
          [ 1.6583,  1.6408,  1.6408,  ...,  0.8004,  0.9405,  1.0980],
          [ 1.6583,  1.6232,  1.6232,  ...,  0.9755,  1.2031,  1.3782]],

         [[ 2.2566,  2.2391,  2.2391,  ..., -0.5670, -0.4450, -0.2184],
          [ 2.2740,  2.

Data:  tensor([[[[ 1.4612,  1.0331,  0.5878,  ...,  2.0777,  2.0605,  2.0948],
          [ 1.2214,  0.8104,  0.4851,  ...,  2.1119,  2.0948,  2.1290],
          [ 1.0159,  0.5878,  0.2624,  ...,  2.1633,  2.1462,  2.1462],
          ...,
          [ 2.0263,  2.0605,  2.0948,  ..., -0.7137, -0.2684,  0.4166],
          [ 1.9578,  2.0434,  2.1633,  ..., -0.6281, -0.2171, -0.0458],
          [ 1.9749,  2.0777,  2.1462,  ..., -0.5938, -0.4911, -0.2856]],

         [[ 2.3936,  2.3060,  2.1660,  ...,  1.8859,  1.8683,  1.8859],
          [ 2.3235,  2.2185,  2.1835,  ...,  1.9034,  1.8859,  1.8859],
          [ 2.2885,  2.1134,  2.0434,  ...,  1.9034,  1.8859,  1.8683],
          ...,
          [ 2.2360,  2.2185,  2.1310,  ..., -0.8978, -0.4426,  0.2402],
          [ 2.1310,  2.1310,  2.1134,  ..., -0.8102, -0.4076, -0.2325],
          [ 2.0434,  2.0434,  1.9734,  ..., -0.7927, -0.7052, -0.4951]],

         [[ 2.5529,  2.4657,  2.4483,  ...,  1.2282,  1.2108,  1.2457],
          [ 2.4831,  2.

       device='cuda:0')
Output:  tensor([[ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        ...,
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493],
        [ 0.2058,  0.0920, -0.0145,  ..., -0.5479, -0.6645, -0.5493]],
       device='cuda:0')
target:  tensor([ 51, 120,  74,  61,  41,  33,  87, 102], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[-1.3987, -1.6042, -1.4672,  ..., -0.6623, -0.7822, -0.9020],
          [-0.8678, -0.9877, -1.0733,  ..., -0.1828, -0.3883, -0.6452],
          [-0.1999, -0.1143, -0.3027,  ..., -0.4397, -0.3369, -0.3541],
          ...,
          [ 0.2624,  0.0398,  0.2796,  ..., -0.2342, -0.1486, -0.1486],
  

Data:  tensor([[[[ 1.4440,  1.4440,  1.4612,  ...,  1.3927,  1.3755,  1.3242],
          [ 1.4269,  1.4440,  1.4612,  ...,  1.3584,  1.3755,  1.3070],
          [ 1.4098,  1.4098,  1.4440,  ...,  1.3242,  1.3584,  1.3413],
          ...,
          [ 1.0673,  0.9817,  0.9817,  ..., -0.0458, -0.0458, -0.0629],
          [ 1.0844,  1.0159,  0.9646,  ..., -0.0629, -0.0458,  0.0056],
          [ 0.9817,  1.0159,  0.9988,  ..., -0.0458, -0.0116,  0.0398]],

         [[ 1.6232,  1.6232,  1.6408,  ...,  1.5182,  1.5007,  1.4482],
          [ 1.6057,  1.6232,  1.6408,  ...,  1.4832,  1.5007,  1.4307],
          [ 1.5882,  1.5882,  1.6232,  ...,  1.4482,  1.4832,  1.4657],
          ...,
          [ 1.0980,  1.0105,  1.0105,  ..., -0.0399, -0.0399, -0.0574],
          [ 1.1155,  1.0455,  0.9930,  ..., -0.0574, -0.0399,  0.0126],
          [ 1.0105,  1.0455,  1.0280,  ..., -0.0399, -0.0049,  0.0476]],

         [[ 1.7337,  1.7337,  1.7511,  ...,  1.5245,  1.5071,  1.4548],
          [ 1.7163,  1.

Data:  tensor([[[[ 0.5707,  0.5878,  0.6049,  ...,  0.6906,  0.6906,  0.6734],
          [ 0.5707,  0.6049,  0.5878,  ...,  0.7077,  0.6906,  0.6906],
          [ 0.5878,  0.5707,  0.5878,  ...,  0.6906,  0.6734,  0.6563],
          ...,
          [ 0.3652,  0.3309,  0.3138,  ..., -0.2684, -0.1999, -0.1314],
          [ 0.3138,  0.3138,  0.2967,  ..., -0.3712, -0.3198, -0.2684],
          [ 0.2624,  0.2453,  0.2624,  ..., -0.4054, -0.3712, -0.3198]],

         [[ 0.4328,  0.4503,  0.4678,  ...,  0.5903,  0.5728,  0.5903],
          [ 0.4328,  0.4678,  0.4503,  ...,  0.5903,  0.5728,  0.5903],
          [ 0.4503,  0.4153,  0.4503,  ...,  0.5903,  0.6078,  0.5728],
          ...,
          [ 0.4153,  0.3803,  0.3627,  ..., -0.0924, -0.0399,  0.0826],
          [ 0.3803,  0.3803,  0.3978,  ..., -0.2500, -0.1625, -0.0749],
          [ 0.3452,  0.3452,  0.3627,  ..., -0.3725, -0.2850, -0.2150]],

         [[ 0.3916,  0.4091,  0.4614,  ...,  0.6531,  0.6705,  0.6182],
          [ 0.3916,  0.

Data:  tensor([[[[ 0.1426,  0.1083,  0.0398,  ..., -1.6213, -1.6042, -1.6384],
          [ 0.1083,  0.0912,  0.0741,  ..., -1.6384, -1.6213, -1.6042],
          [ 0.1083,  0.0912,  0.1083,  ..., -1.6213, -1.6213, -1.5870],
          ...,
          [ 1.3070,  1.2728,  1.3242,  ...,  1.5125,  1.5125,  1.4783],
          [ 1.3242,  1.3927,  1.4098,  ...,  1.5125,  1.5468,  1.5297],
          [ 1.3584,  1.3927,  1.4098,  ...,  1.5297,  1.4783,  1.5297]],

         [[-0.5126, -0.4601, -0.5651,  ..., -1.5280, -1.5105, -1.5280],
          [-0.5126, -0.5126, -0.5301,  ..., -1.5455, -1.5280, -1.5105],
          [-0.5126, -0.5651, -0.5476,  ..., -1.4930, -1.4930, -1.4930],
          ...,
          [ 1.4482,  1.4307,  1.4832,  ...,  1.6758,  1.6583,  1.5707],
          [ 1.4832,  1.5532,  1.5707,  ...,  1.6583,  1.6583,  1.6583],
          [ 1.5007,  1.5532,  1.5707,  ...,  1.6933,  1.6057,  1.6933]],

         [[-1.0376, -1.1073, -1.2293,  ..., -1.3164, -1.2990, -1.3339],
          [-1.0898, -1.

Data:  tensor([[[[ 0.2453,  0.3138,  0.3652,  ..., -0.0801, -0.0972, -0.0801],
          [ 0.1083,  0.1939,  0.2282,  ...,  0.0569, -0.0116, -0.0287],
          [ 0.0912,  0.1597,  0.2111,  ...,  0.2453,  0.1768,  0.1426],
          ...,
          [-1.1589, -1.0733, -0.9192,  ...,  0.1083,  0.1939, -0.2684],
          [-1.0390, -0.8678, -0.6794,  ...,  0.3481,  0.1597, -0.1999],
          [-1.0904, -1.0390, -0.9192,  ...,  0.0056, -0.1657, -0.1828]],

         [[ 0.6779,  0.7479,  0.8004,  ...,  0.3102,  0.2577,  0.2752],
          [ 0.5553,  0.6429,  0.6779,  ...,  0.4503,  0.3452,  0.3277],
          [ 0.4678,  0.5378,  0.5903,  ...,  0.6078,  0.5378,  0.5028],
          ...,
          [-0.8102, -0.7052, -0.5476,  ...,  0.5553,  0.6429,  0.1702],
          [-0.6702, -0.4951, -0.3025,  ...,  0.8004,  0.6078,  0.2402],
          [-0.7227, -0.6702, -0.5476,  ...,  0.4503,  0.2752,  0.2577]],

         [[-0.4973, -0.4275, -0.3753,  ..., -0.9504, -0.9504, -0.8981],
          [-0.6541, -0.

Data:  tensor([[[[-1.0562, -0.9192, -0.8849,  ..., -0.8507, -0.8335, -0.9192],
          [-0.9705, -1.0048, -1.0733,  ..., -0.9877, -0.9877, -0.8678],
          [-1.0390, -0.9534, -0.8507,  ..., -0.9534, -0.8678, -0.9192],
          ...,
          [-1.2445, -1.2617, -1.2274,  ..., -0.8335, -0.6794, -0.4397],
          [-1.3130, -1.3302, -1.2445,  ..., -0.8678, -0.7650, -0.8335],
          [-1.1075, -1.1760, -1.1760,  ..., -0.6109, -0.5424, -0.6965]],

         [[-1.0203, -0.8803, -0.8452,  ..., -0.8102, -0.7927, -0.8803],
          [-0.9328, -0.9678, -1.0378,  ..., -0.9503, -0.9503, -0.8277],
          [-1.0028, -0.9153, -0.8102,  ..., -0.9153, -0.8277, -0.8803],
          ...,
          [-1.1604, -1.1779, -1.1429,  ..., -0.8803, -0.7227, -0.4776],
          [-1.2479, -1.2654, -1.1604,  ..., -0.9153, -0.7927, -0.8627],
          [-1.0728, -1.1253, -1.1429,  ..., -0.6001, -0.5126, -0.6702]],

         [[-0.8110, -0.6715, -0.6367,  ..., -0.5844, -0.5844, -0.6715],
          [-0.7238, -0.

Data:  tensor([[[[-1.4843, -1.4843, -1.5870,  ..., -1.2788, -1.2103, -1.1075],
          [-1.4843, -1.5528, -1.6555,  ..., -1.3130, -1.2274, -1.1075],
          [-1.5014, -1.6042, -1.7069,  ..., -1.3130, -1.2103, -1.0733],
          ...,
          [ 0.4679,  0.3481,  0.3994,  ...,  0.1768,  0.0912, -0.0287],
          [ 0.4679,  0.2453,  0.2453,  ...,  0.0741,  0.0227, -0.0458],
          [ 0.4679,  0.0741,  0.0227,  ...,  0.1939,  0.2282,  0.2111]],

         [[-0.8452, -0.8803, -1.0028,  ..., -1.3004, -1.2654, -1.1779],
          [-0.8452, -0.9328, -1.0728,  ..., -1.3354, -1.2654, -1.1779],
          [-0.8627, -0.9853, -1.1253,  ..., -1.3354, -1.2654, -1.1429],
          ...,
          [ 0.7304,  0.6254,  0.6779,  ...,  0.2577,  0.2577,  0.2052],
          [ 0.7479,  0.5378,  0.5378,  ...,  0.1527,  0.1702,  0.1877],
          [ 0.7829,  0.3803,  0.3277,  ...,  0.2752,  0.3978,  0.4853]],

         [[-1.3861, -1.4036, -1.4907,  ..., -0.9330, -0.8807, -0.7936],
          [-1.3513, -1.

Data:  tensor([[[[ 2.1804,  2.2147,  2.2489,  ...,  1.3070,  1.4954,  1.6324],
          [ 2.1633,  2.2147,  2.2318,  ...,  1.2557,  1.4440,  1.6153],
          [ 2.1290,  2.1804,  2.1975,  ...,  1.2043,  1.4098,  1.5810],
          ...,
          [ 1.4783,  1.3242,  0.8447,  ..., -0.0972, -0.1143, -0.1143],
          [ 1.4098,  1.2899,  0.8447,  ..., -0.1486, -0.1314, -0.1143],
          [ 1.3755,  1.2728,  0.8789,  ..., -0.1657, -0.1486, -0.1143]],

         [[ 2.3761,  2.3936,  2.3936,  ...,  1.4132,  1.4132,  1.2906],
          [ 2.3761,  2.3936,  2.3936,  ...,  1.3431,  1.3606,  1.2556],
          [ 2.3585,  2.3761,  2.3761,  ...,  1.2731,  1.3081,  1.2206],
          ...,
          [ 1.7983,  1.6232,  0.9055,  ..., -0.5126, -0.5301, -0.5301],
          [ 1.7108,  1.5532,  0.9055,  ..., -0.5651, -0.5476, -0.5301],
          [ 1.6583,  1.5532,  0.9405,  ..., -0.5826, -0.5651, -0.5301]],

         [[ 2.5703,  2.5354,  2.5180,  ...,  1.2631,  1.2631,  1.0539],
          [ 2.5877,  2.

Data:  tensor([[[[-1.5870, -1.6384, -1.5699,  ..., -0.3712, -0.4911, -0.5424],
          [-1.4500, -1.5014, -1.4843,  ..., -0.3712, -0.3027, -0.2684],
          [-1.3644, -1.4158, -1.4158,  ..., -0.4911, -0.2513, -0.1314],
          ...,
          [ 0.0398,  0.0398,  0.0398,  ...,  1.2043,  0.9303,  0.7591],
          [ 0.0227,  0.0227,  0.0227,  ...,  1.2385,  1.2385,  1.1187],
          [ 0.0227,  0.0227,  0.0227,  ...,  1.2043,  1.3927,  1.3413]],

         [[-1.3880, -1.4405, -1.3704,  ..., -0.4076, -0.5301, -0.5826],
          [-1.2479, -1.3004, -1.2829,  ..., -0.4076, -0.3375, -0.3025],
          [-1.1604, -1.2129, -1.2304,  ..., -0.5301, -0.2850, -0.1625],
          ...,
          [ 0.2052,  0.2052,  0.2052,  ...,  1.5882,  1.2906,  1.0280],
          [ 0.1877,  0.1877,  0.1877,  ...,  1.6758,  1.6583,  1.4657],
          [ 0.1877,  0.1877,  0.1877,  ...,  1.6758,  1.8508,  1.7108]],

         [[-1.1944, -1.2467, -1.1770,  ..., -0.1312, -0.2532, -0.3055],
          [-1.0550, -1.

Data:  tensor([[[[ 0.0227,  0.0056,  0.0398,  ...,  0.2111,  0.2453,  0.3652],
          [ 0.0398,  0.0227,  0.0056,  ...,  0.2967,  0.2967,  0.4166],
          [ 0.1426,  0.1426,  0.0056,  ...,  0.3309,  0.2796,  0.3994],
          ...,
          [-0.7137, -0.9705, -1.0219,  ...,  0.0569,  0.0569, -0.1828],
          [-0.6965, -0.9020, -1.0733,  ..., -0.1314,  0.0056, -0.2513],
          [-0.7137, -1.1247, -1.2103,  ..., -0.1999, -0.0287, -0.2342]],

         [[ 0.8880,  0.8529,  0.8354,  ...,  1.1155,  1.1506,  1.2381],
          [ 0.9055,  0.8704,  0.8179,  ...,  1.1856,  1.1856,  1.2906],
          [ 0.9580,  0.9580,  0.8179,  ...,  1.2206,  1.1681,  1.2556],
          ...,
          [ 0.1527, -0.0924, -0.0924,  ...,  0.9405,  0.9755,  0.7304],
          [ 0.0651, -0.1275, -0.2325,  ...,  0.7479,  0.9230,  0.6604],
          [-0.0049, -0.4776, -0.5651,  ...,  0.6779,  0.9230,  0.6604]],

         [[-0.1138, -0.0790, -0.0092,  ..., -0.0790, -0.0615,  0.0953],
          [-0.0615, -0.

Data:  tensor([[[[-1.7069, -1.7069, -1.7412,  ..., -1.8268, -1.8268, -1.8097],
          [-1.7240, -1.6898, -1.6898,  ..., -1.8097, -1.8097, -1.7925],
          [-1.7069, -1.7069, -1.6898,  ..., -1.7925, -1.7754, -1.7412],
          ...,
          [-1.5528, -1.5528, -1.5699,  ...,  0.9132,  0.8789,  0.7419],
          [-1.6042, -1.5870, -1.6213,  ...,  0.8961,  0.8447,  0.8104],
          [-1.6727, -1.7069, -1.7240,  ...,  0.7933,  0.8447,  0.8961]],

         [[-1.6155, -1.6155, -1.6506,  ..., -1.6681, -1.7031, -1.6856],
          [-1.6331, -1.5980, -1.5980,  ..., -1.6681, -1.6856, -1.6856],
          [-1.6155, -1.6155, -1.5980,  ..., -1.6506, -1.6506, -1.6155],
          ...,
          [-1.4230, -1.4405, -1.4405,  ...,  0.4853,  0.5203,  0.3452],
          [-1.4580, -1.4580, -1.4930,  ...,  0.4503,  0.4503,  0.3978],
          [-1.5105, -1.5455, -1.5630,  ...,  0.4853,  0.4678,  0.4328]],

         [[-1.4210, -1.4210, -1.4559,  ..., -1.4559, -1.4907, -1.4733],
          [-1.4384, -1.

Data:  tensor([[[[-0.5596, -0.5424, -0.4739,  ..., -1.0733, -1.0904, -1.2959],
          [-0.5082, -0.4911, -0.5082,  ..., -1.1760, -1.3302, -1.4329],
          [-0.5424, -0.5596, -0.5253,  ..., -1.2445, -1.2788, -1.4158],
          ...,
          [-0.8335, -0.7479, -0.3198,  ..., -1.4158, -1.2617, -1.3302],
          [-0.8335, -0.4397,  0.3823,  ..., -1.5699, -1.4500, -1.4329],
          [-0.5596, -0.5424, -0.5424,  ..., -1.3473, -1.4158, -1.5185]],

         [[-0.5476, -0.5651, -0.5126,  ..., -0.8978, -0.9853, -1.1954],
          [-0.5126, -0.5126, -0.5826,  ..., -0.9678, -1.1779, -1.3354],
          [-0.5476, -0.6001, -0.5651,  ..., -1.0203, -1.1078, -1.3179],
          ...,
          [-0.8627, -0.8102, -0.3725,  ..., -1.4405, -1.2654, -1.3179],
          [-0.8803, -0.5126,  0.3277,  ..., -1.5455, -1.4405, -1.4405],
          [-0.5826, -0.6001, -0.6001,  ..., -1.3004, -1.4580, -1.5805]],

         [[-0.6018, -0.5844, -0.4798,  ..., -0.6541, -0.7064, -0.9330],
          [-0.5844, -0.

Data:  tensor([[[[ 0.5022,  0.1939,  0.0741,  ..., -1.8097, -1.8268, -1.8268],
          [ 0.7933,  0.6221,  0.3652,  ..., -1.7925, -1.8268, -1.8268],
          [ 0.3994,  0.1597,  0.1426,  ..., -1.7925, -1.8268, -1.8268],
          ...,
          [ 0.3481,  0.2796,  0.1254,  ...,  0.9474,  1.0331,  1.0159],
          [ 0.3309,  0.4851,  0.3652,  ...,  0.8961,  0.9646,  0.9303],
          [-0.0116,  0.3309,  0.3309,  ...,  0.7591,  0.8961,  0.8104]],

         [[ 0.4678, -0.0574, -0.2675,  ..., -1.9657, -1.9832, -1.9832],
          [ 0.7654,  0.3803,  0.0301,  ..., -1.9657, -1.9832, -1.9832],
          [ 0.3627, -0.0749, -0.1800,  ..., -1.9832, -1.9832, -1.9832],
          ...,
          [-0.9853, -1.3004, -1.5630,  ..., -2.0007, -1.9832, -2.0182],
          [-0.7052, -0.8627, -1.1253,  ..., -2.0357, -2.0007, -2.0357],
          [-1.0903, -0.8978, -1.0028,  ..., -2.0357, -1.9657, -2.0357]],

         [[ 0.6705, -0.1138, -0.3927,  ..., -1.7870, -1.7696, -1.7696],
          [ 0.9668,  0.

Data:  tensor([[[[ 1.6838,  1.2385,  2.1975,  ..., -0.4911, -0.3883,  0.3994],
          [ 2.0777,  1.4098,  2.1633,  ..., -0.5253, -0.2513,  0.1939],
          [ 1.9749,  1.3413,  1.9578,  ..., -0.4568, -0.3027, -0.2342],
          ...,
          [-0.5596, -0.5253, -0.1657,  ..., -0.5767, -0.3027, -0.0801],
          [-0.8335, -0.6965, -0.5424,  ..., -0.7993, -0.7993, -0.2513],
          [-0.9877, -0.6794, -0.6452,  ..., -0.8507, -0.6623, -0.3883]],

         [[ 2.1134,  1.8683,  2.3936,  ..., -0.0574,  0.1001,  1.0105],
          [ 2.3235,  1.9034,  2.3585,  ..., -0.0574,  0.1877,  0.6429],
          [ 2.2535,  1.8158,  2.2185,  ...,  0.0126,  0.0126,  0.1176],
          ...,
          [-0.1800, -0.0574,  0.2577,  ..., -0.3025,  0.0651,  0.5553],
          [-0.2150, -0.1275,  0.0651,  ..., -0.4776, -0.2325,  0.4678],
          [-0.3375, -0.0749, -0.1625,  ..., -0.6702, -0.0574,  0.3803]],

         [[ 2.4657,  2.4308,  2.5877,  ..., -0.3055, -0.0441,  1.0365],
          [ 2.5703,  2.

       device='cuda:0')
target:  tensor([ 68, 107,  90,  60,  31,  31,  97,  90], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[ 0.1768,  0.1597,  0.0741,  ...,  1.1187,  0.1083, -0.1314],
          [ 0.1426,  0.1939,  0.0912,  ...,  1.1529,  0.0741, -0.0972],
          [ 0.0741,  0.2453,  0.2282,  ...,  1.1529, -0.0116, -0.2342],
          ...,
          [-0.1999, -0.0287, -0.0287,  ...,  0.0056,  0.0741,  0.0912],
          [-0.2684, -0.1828, -0.1999,  ...,  0.1768,  0.2796,  0.3481],
          [-0.2342, -0.1999, -0.2513,  ...,  0.6563,  0.1426,  0.3309]],

         [[ 0.5903,  0.5378,  0.4678,  ...,  1.3782,  0.5028,  0.3803],
          [ 0.5553,  0.5728,  0.4678,  ...,  1.4657,  0.5203,  0.4328],
          [ 0.4678,  0.6429,  0.6254,  ...,  1.4832,  0.4503,  0.3277],
          ...,
          [ 0.0301,  0.1702,  0.1352,  ...,  0.0651,  0.1527,  0.1527],
          [-0.0399, 

Data:  tensor([[[[ 1.2043,  0.3309, -0.5596,  ...,  1.5297,  1.5982,  1.6153],
          [ 1.0673,  0.3823, -0.4568,  ...,  1.6153,  1.7180,  1.7009],
          [ 1.2214,  0.7591, -0.0458,  ...,  1.7009,  1.6495,  1.7865],
          ...,
          [ 0.8618,  0.8961,  0.9303,  ...,  0.9646,  1.0673,  1.0502],
          [ 0.9646,  0.9303,  0.8961,  ...,  0.9817,  1.0673,  1.0673],
          [ 0.9474,  0.8789,  0.8447,  ...,  1.0502,  1.0502,  1.0159]],

         [[ 1.0805,  0.2227, -0.7052,  ...,  1.3606,  1.4307,  1.4657],
          [ 0.9405,  0.2752, -0.5826,  ...,  1.4132,  1.5182,  1.5007],
          [ 1.0980,  0.6604, -0.1625,  ...,  1.4657,  1.4307,  1.5707],
          ...,
          [ 1.0455,  1.0805,  1.1155,  ...,  0.8179,  0.9230,  0.9055],
          [ 1.1506,  1.1155,  1.0805,  ...,  0.8354,  0.9230,  0.9230],
          [ 1.1331,  1.0630,  1.0280,  ...,  0.9055,  0.9055,  0.8704]],

         [[ 0.4265, -0.2184, -0.9156,  ...,  0.9668,  1.0365,  1.0714],
          [ 0.2522, -0.

Data:  tensor([[[[-0.5938, -0.6623, -0.5253,  ...,  0.2967,  0.2624,  0.2624],
          [-0.7822, -0.9363, -0.6965,  ...,  0.3823,  0.3138,  0.0912],
          [-0.8849, -0.8678, -0.7993,  ...,  0.3823,  0.4166,  0.1254],
          ...,
          [-1.7069, -1.6555, -1.5357,  ...,  0.0398, -0.0972, -0.2856],
          [-1.7412, -1.6384, -1.5185,  ...,  0.1597,  0.0056, -0.2171],
          [-1.6898, -1.7240, -1.1247,  ...,  0.2453, -0.0287, -0.2684]],

         [[-0.5651, -0.5651, -0.5301,  ...,  0.7304,  0.6954,  0.6954],
          [-0.6702, -0.7927, -0.7052,  ...,  0.7129,  0.6254,  0.3978],
          [-0.8102, -0.7752, -0.8978,  ...,  1.0280,  1.0105,  0.6604],
          ...,
          [-1.5280, -1.4755, -1.3529,  ...,  0.1527, -0.0574, -0.2850],
          [-1.5805, -1.5280, -1.4055,  ...,  0.2052, -0.0574, -0.3025],
          [-1.6331, -1.5630, -1.0728,  ...,  0.2227, -0.1275, -0.3725]],

         [[-0.3230, -0.4275, -0.3927,  ...,  0.8448,  0.8448,  0.9319],
          [-0.5321, -0.

Data:  tensor([[[[-0.9705, -0.9534, -0.9534,  ...,  1.0331,  0.9988,  1.0502],
          [-1.0219, -1.0219, -1.0390,  ...,  0.9474,  0.9132,  0.9646],
          [-0.9192, -0.9363, -0.9705,  ...,  0.8961,  0.8789,  0.9132],
          ...,
          [ 1.3755,  1.3927,  1.3927,  ..., -0.9020, -0.7822, -0.8507],
          [ 1.3755,  1.3927,  1.3927,  ..., -0.8507, -0.7650, -0.9363],
          [ 1.3755,  1.3927,  1.3927,  ..., -0.7822, -0.7993, -1.0390]],

         [[-0.6877, -0.6702, -0.6702,  ...,  0.1702,  0.1527,  0.2052],
          [-0.7402, -0.7402, -0.7577,  ...,  0.0826,  0.0476,  0.1001],
          [-0.6352, -0.6527, -0.6877,  ...,  0.0301,  0.0126,  0.0301],
          ...,
          [ 1.5357,  1.5532,  1.5532,  ..., -0.9678, -0.8627, -0.9678],
          [ 1.5357,  1.5532,  1.5532,  ..., -0.9503, -0.8627, -1.0553],
          [ 1.5357,  1.5532,  1.5532,  ..., -0.8803, -0.9153, -1.1604]],

         [[-0.3055, -0.2881, -0.2881,  ..., -0.8458, -0.8807, -0.8458],
          [-0.3404, -0.

Data:  tensor([[[[ 1.4612,  1.4612,  1.4783,  ...,  1.6153,  1.6153,  1.6153],
          [ 1.4612,  1.4612,  1.4612,  ...,  1.6153,  1.6153,  1.6153],
          [ 1.4954,  1.4954,  1.4954,  ...,  1.5982,  1.5982,  1.5982],
          ...,
          [ 1.5125,  1.5125,  1.5125,  ...,  1.5982,  1.5982,  1.5982],
          [ 1.5125,  1.5125,  1.5125,  ...,  1.5982,  1.5982,  1.5982],
          [ 1.5125,  1.5125,  1.5125,  ...,  1.5982,  1.5982,  1.5982]],

         [[ 1.6933,  1.6933,  1.7108,  ...,  1.8333,  1.8508,  1.8508],
          [ 1.6933,  1.6933,  1.6933,  ...,  1.8333,  1.8508,  1.8508],
          [ 1.6583,  1.6583,  1.6933,  ...,  1.8333,  1.8333,  1.8333],
          ...,
          [ 1.7458,  1.7458,  1.7458,  ...,  1.8333,  1.8333,  1.8333],
          [ 1.7458,  1.7458,  1.7458,  ...,  1.8333,  1.8333,  1.8333],
          [ 1.7458,  1.7458,  1.7458,  ...,  1.8333,  1.8333,  1.8333]],

         [[ 1.9603,  1.9603,  1.9777,  ...,  2.1171,  2.1171,  2.1171],
          [ 1.9603,  1.

Data:  tensor([[[[-1.2617, -1.2959, -1.3130,  ..., -1.7412, -1.7583, -1.7583],
          [-1.2617, -1.2617, -1.2959,  ..., -1.7583, -1.7925, -1.6898],
          [-1.2445, -1.2617, -1.3130,  ..., -1.7583, -1.6555, -1.4158],
          ...,
          [-1.0390, -0.8678, -0.8678,  ..., -0.9020, -0.7137, -0.5938],
          [-0.7650, -0.6281, -0.7993,  ..., -0.9363, -0.7137, -0.8335],
          [-0.8507, -0.5767, -0.8507,  ..., -0.8164, -0.9534, -0.8678]],

         [[-1.0728, -1.1078, -1.0903,  ..., -1.4930, -1.5105, -1.4930],
          [-1.0728, -1.0903, -1.0903,  ..., -1.4930, -1.5280, -1.4230],
          [-1.0728, -1.0903, -1.1078,  ..., -1.4930, -1.3880, -1.1954],
          ...,
          [-0.6877, -0.4076, -0.4251,  ..., -0.6877, -0.4776, -0.3025],
          [-0.3725, -0.1800, -0.3550,  ..., -0.7577, -0.4426, -0.5126],
          [-0.5126, -0.2500, -0.4426,  ..., -0.5301, -0.5826, -0.5476]],

         [[-1.7870, -1.7870, -1.7522,  ..., -1.6824, -1.6999, -1.7347],
          [-1.7696, -1.

Data:  tensor([[[[-1.3473, -1.3302, -1.4500,  ...,  0.1254,  0.3309,  0.5022],
          [-1.3473, -1.4158, -1.4158,  ...,  0.1254,  0.2967,  0.3309],
          [-1.3644, -1.3130, -1.3302,  ...,  0.1939,  0.4166,  0.2624],
          ...,
          [-0.4739, -0.4226, -0.7993,  ...,  0.4166,  0.5536,  0.5707],
          [-0.4911, -1.0733, -1.5185,  ..., -0.3027, -0.0287,  0.3309],
          [-0.6109, -0.3712, -0.5253,  ..., -0.5082,  0.2796, -0.3369]],

         [[-1.0203, -1.0203, -1.1078,  ...,  0.6604,  0.7129,  0.7654],
          [-1.0203, -1.0378, -1.0553,  ...,  0.6254,  0.6779,  0.5903],
          [-1.0378, -0.8978, -0.9153,  ...,  0.6604,  0.8004,  0.5378],
          ...,
          [-0.0749, -0.0049, -0.4601,  ...,  0.2577,  0.4503,  0.5728],
          [-0.1099, -0.8627, -1.4755,  ..., -0.4426,  0.0476,  0.5378],
          [-0.5476, -0.2850, -0.4776,  ..., -0.1450,  0.7479, -0.0399]],

         [[-0.6193, -0.6018, -0.6890,  ..., -0.7064, -0.4798, -0.2707],
          [-0.6367, -0.

       device='cuda:0')
target:  tensor([104,  79,  15,  83,  98,   9, 105, 109], device='cuda:0')
tensor([[4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[-0.3883, -0.5424, -0.9020,  ..., -0.7137, -0.8164, -0.7308],
          [-0.3712, -0.4568, -0.7650,  ..., -0.2684, -0.4054, -0.5596],
          [-0.4739, -0.4054, -0.7993,  ..., -0.1828, -0.0287, -0.2513],
          ...,
          [-0.4568, -0.8507, -1.1589,  ..., -1.7069, -1.3644, -1.0390],
          [-0.7479, -0.5082, -0.5253,  ..., -1.2959, -0.5596, -0.0801],
          [-1.1247, -0.3198, -0.2513,  ..., -1.5014, -1.3473, -0.4397]],

         [[ 0.4153,  0.3102, -0.0224,  ..., -0.0574, -0.1975, -0.1099],
          [ 0.3452,  0.3452,  0.0826,  ...,  0.2927,  0.1352, -0.0049],
          [ 0.1877,  0.3452,  0.0301,  ...,  0.2927,  0.4678,  0.2402],
          ...,
          [-0.1099, -0.4426, -0.6527,  ..., -1.3529, -0.8978, -0.4776],
          [-0.3025, 

Data:  tensor([[[[ 0.9303,  0.9474,  0.9303,  ...,  0.3481,  0.3481,  0.3481],
          [ 0.8961,  0.9132,  0.8961,  ...,  0.3481,  0.3481,  0.3652],
          [ 0.8618,  0.8618,  0.8789,  ...,  0.3309,  0.3309,  0.3309],
          ...,
          [ 1.1358,  1.1358,  1.1358,  ...,  1.1529,  1.2385,  1.2385],
          [ 1.1187,  1.1187,  1.1187,  ...,  1.0502,  1.0673,  1.1187],
          [ 1.1358,  1.1015,  1.1187,  ...,  0.8789,  0.9988,  1.1529]],

         [[ 1.0980,  1.0805,  1.0805,  ...,  0.5378,  0.5378,  0.5378],
          [ 1.0630,  1.0455,  1.0630,  ...,  0.5378,  0.5203,  0.5378],
          [ 1.0280,  1.0280,  1.0455,  ...,  0.5378,  0.5378,  0.5203],
          ...,
          [ 0.8354,  0.8354,  0.8354,  ...,  0.4503,  0.4678,  0.4853],
          [ 0.8354,  0.8354,  0.8179,  ...,  0.3452,  0.3452,  0.3803],
          [ 0.8179,  0.8179,  0.8354,  ...,  0.1702,  0.2752,  0.3627]],

         [[ 1.2282,  1.2282,  1.2282,  ...,  0.7402,  0.7228,  0.7402],
          [ 1.1934,  1.

Data:  tensor([[[[ 1.7352,  1.9578,  2.0605,  ...,  0.2282,  0.8961,  2.1119],
          [ 1.7865,  1.8037,  1.9064,  ...,  0.5536,  1.0673,  2.1290],
          [ 1.8037,  1.8550,  1.8550,  ...,  0.7591,  1.0673,  2.0434],
          ...,
          [ 0.1254, -0.0629, -0.5938,  ..., -1.4158, -1.4500, -0.9192],
          [ 0.4679, -0.2342, -0.9534,  ..., -1.2445, -1.4500, -0.7650],
          [-0.4568, -0.9192, -0.7137,  ..., -0.5424, -0.8164, -0.4397]],

         [[ 1.6933,  1.9734,  2.1660,  ..., -0.3550,  0.3978,  1.7283],
          [ 1.6933,  1.8158,  2.0434,  ..., -0.1800,  0.3803,  1.5182],
          [ 1.7983,  1.8683,  1.9559,  ...,  0.0301,  0.3102,  1.2906],
          ...,
          [ 0.1001, -0.1450, -0.7227,  ..., -1.4405, -1.4580, -0.9153],
          [ 0.3627, -0.3725, -1.1604,  ..., -1.2479, -1.4405, -0.7577],
          [-0.6176, -1.1078, -0.9328,  ..., -0.4776, -0.7577, -0.3901]],

         [[ 1.9428,  2.2391,  2.4134,  ..., -0.6367, -0.0964,  1.0017],
          [ 1.9254,  2.

Data:  tensor([[[[-0.2856, -0.6794, -0.7137,  ..., -1.3130, -0.9534, -0.8335],
          [-0.6109, -1.2959, -1.2103,  ..., -1.4329, -1.3130, -1.6042],
          [-1.2617, -1.2103, -1.1075,  ..., -1.2445, -1.0562, -1.1589],
          ...,
          [ 0.6563,  0.6906,  0.4679,  ...,  0.7591,  1.0844,  1.3413],
          [ 0.7077,  0.3823,  0.2282,  ...,  1.3413,  1.2728,  0.9646],
          [ 0.7762,  0.4679,  0.0056,  ...,  1.4954,  1.5297,  1.3584]],

         [[-0.1800, -0.6001, -0.6176,  ..., -1.2479, -0.8803, -0.7577],
          [-0.4251, -1.1429, -1.0553,  ..., -1.3880, -1.2654, -1.5805],
          [-1.0903, -1.0378, -0.9678,  ..., -1.2129, -1.0203, -1.1253],
          ...,
          [ 0.8704,  0.8704,  0.6078,  ...,  0.9755,  1.3081,  1.5532],
          [ 0.9230,  0.5728,  0.3452,  ...,  1.6057,  1.5357,  1.1856],
          [ 0.9755,  0.6254,  0.1176,  ...,  1.8158,  1.8333,  1.6583]],

         [[ 0.1128, -0.2881, -0.3230,  ..., -1.0898, -0.7064, -0.5495],
          [-0.1487, -0.

Data:  tensor([[[[-0.2856, -0.1999, -0.0972,  ..., -0.2171, -0.2171, -0.1999],
          [-0.4568, -0.3712, -0.2513,  ..., -0.1486, -0.1657, -0.1828],
          [-0.5253, -0.4911, -0.4226,  ..., -0.2171, -0.2684, -0.3369],
          ...,
          [-1.5185, -1.3473, -1.6727,  ...,  0.8789,  0.8789,  0.8618],
          [-1.5870, -1.4672, -1.6898,  ...,  0.9132,  0.9132,  0.8789],
          [-1.5357, -1.5699, -1.6555,  ...,  0.9132,  0.9132,  0.8789]],

         [[ 0.7654,  0.8529,  0.9580,  ...,  0.9405,  0.9405,  0.9580],
          [ 0.6429,  0.7304,  0.8354,  ...,  0.9580,  0.9580,  0.9230],
          [ 0.6078,  0.6429,  0.7129,  ...,  0.8354,  0.7829,  0.7129],
          ...,
          [-1.3529, -1.1779, -1.5105,  ...,  0.9405,  0.9405,  0.9230],
          [-1.4230, -1.3004, -1.5280,  ...,  0.9755,  0.9755,  0.9405],
          [-1.3704, -1.4055, -1.4930,  ...,  0.9755,  0.9755,  0.9405]],

         [[ 1.8034,  1.8905,  1.9951,  ...,  2.0300,  2.0300,  2.0474],
          [ 1.6814,  1.

       device='cuda:0')
Output:  tensor([[-4.1459e+00, -6.0718e+00, -7.3357e+00,  ..., -7.9541e+00,
         -4.6591e+00, -2.5109e+00],
        [ 2.0585e-01,  9.1996e-02, -1.4494e-02,  ..., -5.4787e-01,
         -6.6454e-01, -5.4926e-01],
        [ 1.9361e-01,  9.5719e-02,  5.1338e-03,  ..., -7.1135e-01,
         -5.9948e-01, -6.1854e-01],
        ...,
        [ 2.0585e-01,  9.1996e-02, -1.4494e-02,  ..., -5.4787e-01,
         -6.6454e-01, -5.4926e-01],
        [ 2.0585e-01,  9.1996e-02, -1.4494e-02,  ..., -5.4787e-01,
         -6.6454e-01, -5.4926e-01],
        [ 2.0585e-01,  9.1996e-02, -1.4494e-02,  ..., -5.4787e-01,
         -6.6454e-01, -5.4926e-01]], device='cuda:0')
target:  tensor([ 82,  57,  40,  74,   1, 109,  62, 121], device='cuda:0')
tensor([[5],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4],
        [4]], device='cuda:0')
Data:  tensor([[[[-1.9124, -1.9467, -2.0152,  ...,  1.5297,  1.2557,  0.8618],
          [-1.9124, -1.9467, -1.9980,  ...

Data:  tensor([[[[ 0.3481,  0.3138,  0.6049,  ...,  0.1083,  0.1768, -0.7137],
          [-0.3712, -0.0287,  0.0056,  ..., -0.8507, -0.1828, -0.5082],
          [-0.5253,  0.0056, -0.2856,  ..., -0.7822, -0.2342, -0.0629],
          ...,
          [ 2.1804,  2.1804,  2.1804,  ...,  1.8208,  1.7694,  1.7694],
          [ 2.1804,  2.1804,  2.1804,  ...,  1.8379,  1.8208,  1.7694],
          [ 2.1633,  2.1804,  2.1804,  ...,  1.6324,  1.7009,  1.6838]],

         [[ 0.4853,  0.4853,  0.8179,  ...,  0.5378,  0.6429, -0.2325],
          [-0.2325,  0.1352,  0.2052,  ..., -0.5476,  0.2577,  0.0301],
          [-0.3725,  0.1877, -0.0924,  ..., -0.5476,  0.1527,  0.4503],
          ...,
          [ 2.3060,  2.3060,  2.3060,  ...,  1.5357,  1.5007,  1.5532],
          [ 2.3060,  2.3060,  2.3060,  ...,  1.5182,  1.5182,  1.5357],
          [ 2.2885,  2.3060,  2.3060,  ...,  1.4832,  1.5182,  1.4657]],

         [[ 0.3393,  0.3568,  0.7228,  ..., -0.0615, -0.0267, -0.8981],
          [-0.4101, -0.

Data:  tensor([[[[ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          ...,
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489]],

         [[ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          ...,
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286]],

         [[ 2.6400,  2.6400,  2.6400,  ...,  2.6400,  2.6400,  2.6400],
          [ 2.6400,  2.

       device='cuda:0')
Output:  tensor([[ 2.0585e-01,  9.1996e-02, -1.4494e-02,  ..., -5.4787e-01,
         -6.6454e-01, -5.4926e-01],
        [ 2.0585e-01,  9.1996e-02, -1.4494e-02,  ..., -5.4787e-01,
         -6.6454e-01, -5.4926e-01],
        [ 1.9628e-01,  9.4907e-02,  8.5020e-04,  ..., -6.7567e-01,
         -6.1368e-01, -6.0342e-01],
        ...,
        [ 2.0585e-01,  9.1996e-02, -1.4494e-02,  ..., -5.4787e-01,
         -6.6454e-01, -5.4926e-01],
        [ 2.0585e-01,  9.1996e-02, -1.4494e-02,  ..., -5.4787e-01,
         -6.6454e-01, -5.4926e-01],
        [-9.7186e-01, -2.5791e+00, -1.4065e+00,  ..., -3.1536e+00,
         -2.7706e+00, -1.3946e+00]], device='cuda:0')
target:  tensor([ 11,  66, 100,  72,  40,  48,  41,  67], device='cuda:0')
tensor([[ 4],
        [ 4],
        [ 4],
        [ 4],
        [ 4],
        [ 4],
        [ 4],
        [48]], device='cuda:0')
Data:  tensor([[[[-0.2342, -0.2342, -0.2513,  ..., -0.3712, -0.3712, -0.3369],
          [-0.1999, -0.2171, -0.23

Data:  tensor([[[[ 1.3584,  1.4098,  1.4269,  ..., -0.0116, -0.0116, -0.0287],
          [ 1.3242,  1.3584,  1.3927,  ..., -0.6794, -0.6623, -0.6452],
          [ 1.3413,  1.3755,  1.3584,  ..., -0.7822, -0.7137, -0.6623],
          ...,
          [-1.3815, -1.3302, -1.2788,  ...,  0.4851,  0.5364,  0.5364],
          [-1.3644, -1.3130, -1.2959,  ...,  0.4851,  0.5364,  0.5536],
          [-1.3130, -1.2788, -1.2959,  ...,  0.5193,  0.5364,  0.5193]],

         [[ 1.3782,  1.4307,  1.4657,  ...,  0.1527,  0.1527,  0.1352],
          [ 1.3431,  1.3606,  1.3606,  ..., -0.4076, -0.3725, -0.3550],
          [ 1.3256,  1.3606,  1.3431,  ..., -0.5651, -0.5126, -0.4251],
          ...,
          [-1.4405, -1.4230, -1.3704,  ...,  0.4678,  0.5028,  0.5028],
          [-1.4230, -1.3880, -1.3880,  ...,  0.4678,  0.5028,  0.5203],
          [-1.3704, -1.3529, -1.3880,  ...,  0.4853,  0.5028,  0.4853]],

         [[ 1.5071,  1.5594,  1.5768,  ...,  0.1476,  0.1476,  0.1651],
          [ 1.4374,  1.

Data:  tensor([[[[-1.7069, -1.8268, -1.6727,  ..., -1.2788, -1.5699, -1.6727],
          [-1.8782, -1.9638, -1.7069,  ..., -0.6623, -0.2342, -1.1247],
          [-1.8439, -1.8439, -1.7240,  ..., -1.2959, -0.2513, -0.4226],
          ...,
          [-1.0562, -1.0390, -1.1418,  ..., -0.0458, -0.1828,  0.0398],
          [-0.7650, -1.0219, -0.9705,  ...,  0.4166, -0.1999, -0.1657],
          [-0.8849, -1.0048, -0.6794,  ...,  0.8447,  0.3652, -0.0972]],

         [[-1.5105, -1.6155, -1.4755,  ..., -0.7577, -1.0903, -1.1954],
          [-1.6681, -1.7556, -1.4930,  ..., -0.1275,  0.3102, -0.6001],
          [-1.6331, -1.6331, -1.5105,  ..., -0.7927,  0.3102,  0.1527],
          ...,
          [-0.4601, -0.3901, -0.4601,  ...,  0.7129,  0.5728,  0.8004],
          [-0.1625, -0.3725, -0.3025,  ...,  1.1856,  0.5553,  0.5903],
          [-0.2850, -0.3725, -0.0224,  ...,  1.6232,  1.1331,  0.6604]],

         [[-1.5430, -1.6302, -1.4907,  ..., -1.5081, -1.5256, -1.5081],
          [-1.6999, -1.

Data:  tensor([[[[ 0.5364,  0.6221,  0.5364,  ...,  0.5022,  0.5536,  0.5364],
          [ 0.5707,  0.4679,  0.6906,  ...,  0.5022,  0.5536,  0.5022],
          [ 0.5536,  0.5707,  0.7591,  ...,  0.4851,  0.5707,  0.6049],
          ...,
          [ 0.7077,  0.7419,  0.7591,  ...,  0.5536,  0.5022,  0.5707],
          [ 0.7591,  0.7077,  0.6906,  ...,  0.5364,  0.4679,  0.5364],
          [ 0.7762,  0.7077,  0.6734,  ...,  0.4679,  0.4337,  0.4679]],

         [[ 0.6779,  0.7654,  0.6779,  ...,  0.6254,  0.6779,  0.6604],
          [ 0.7129,  0.6078,  0.8354,  ...,  0.6254,  0.6779,  0.6254],
          [ 0.6954,  0.7129,  0.8880,  ...,  0.6078,  0.6954,  0.7304],
          ...,
          [ 0.8354,  0.8704,  0.8880,  ...,  0.6604,  0.6078,  0.6779],
          [ 0.8880,  0.8354,  0.8179,  ...,  0.6429,  0.5728,  0.6429],
          [ 0.9055,  0.8354,  0.8004,  ...,  0.5728,  0.5378,  0.5728]],

         [[ 0.8971,  0.9842,  0.8971,  ...,  0.8099,  0.8622,  0.8448],
          [ 0.9319,  0.

### Ignore the code below this line
It's just a bunch of random pieces of code I've used at various times while trying to figure out what was going on.

In [None]:
# # Visualize the data

# import matplotlib.pyplot as plt
# %matplotlib inline

# dataiter = iter(train_loader)
# images, labels = dataiter.next()
# images = images.numpy()
# fig = plt.figure(figsize=(25,4))
# for idx in np.arange(6):
#     ax = fig.add_subplot(2, 6/2, idx+1, xticks=[], yticks=[])
#     images[idx] = images[idx]/2 +0.5
#     plt.imshow(np.transpose(images[idx], (1,2,0)))
#     ax.set_title(classes[labels[idx]])
    

In [8]:
# from PIL import ImageFile
# ImageFile.LOAD_TRUNCATED_IMAGES = True

# valid_loss_min = np.Inf

# epochs = 5
# for e in range(epochs):
#     running_loss = 0.0
#     for images, labels in train_loader:
#         images, labels = images.cuda(), labels.cuda()
#         optimizer_scratch.zero_grad()
#         output = model_scratch(images)
#         loss = criterion_scratch(output, labels)
#         loss.backward()
#         optimizer_scratch.step()
#         running_loss += loss.item()
#     else:
#         print("Training loss: {:.6f}".format(running_loss/len(train_loader)))



Training loss: 9.782534
Training loss: 5.036802
Training loss: 4.944024
Training loss: 4.900999
Training loss: 5.088080


In [13]:

# import matplotlib.pyplot as plt
# %matplotlib inline
# images, labels = next(iter(train_loader))
# images, labels = images.cuda(), labels.cuda()
# with torch.no_grad():
#     logps = model_scratch(images)
# probs = torch.exp(logps)


In [15]:
# print(images[0])

tensor([[[ 0.6000,  0.6000,  0.6078,  ...,  0.5843,  0.5686,  0.5608],
         [ 0.6078,  0.6078,  0.6078,  ...,  0.5765,  0.5608,  0.5608],
         [ 0.6235,  0.6157,  0.6078,  ...,  0.5686,  0.5608,  0.5765],
         ...,
         [-0.0980, -0.2392, -0.4196,  ...,  0.0824, -0.3333, -0.4118],
         [-0.1216, -0.5294, -0.5059,  ...,  0.0667, -0.3020, -0.5608],
         [-0.0588,  0.0275, -0.0275,  ..., -0.0118, -0.2314,  0.2706]],

        [[ 0.6078,  0.6078,  0.6078,  ...,  0.5843,  0.5686,  0.5608],
         [ 0.6157,  0.6157,  0.6078,  ...,  0.5765,  0.5608,  0.5608],
         [ 0.6235,  0.6235,  0.6078,  ...,  0.5686,  0.5608,  0.5765],
         ...,
         [-0.1294, -0.2392, -0.3490,  ...,  0.1216, -0.3098, -0.3961],
         [-0.1451, -0.5294, -0.4588,  ...,  0.0980, -0.2941, -0.5608],
         [-0.0824,  0.0275,  0.0196,  ..., -0.0039, -0.2314,  0.2627]],

        [[ 0.6549,  0.6627,  0.6863,  ...,  0.6784,  0.6627,  0.6549],
         [ 0.6549,  0.6627,  0.6863,  ...,  0

In [19]:
# print(images[0][0])
# print(images[0][1])
# print(images[1][0])

tensor([[ 0.6000,  0.6000,  0.6078,  ...,  0.5843,  0.5686,  0.5608],
        [ 0.6078,  0.6078,  0.6078,  ...,  0.5765,  0.5608,  0.5608],
        [ 0.6235,  0.6157,  0.6078,  ...,  0.5686,  0.5608,  0.5765],
        ...,
        [-0.0980, -0.2392, -0.4196,  ...,  0.0824, -0.3333, -0.4118],
        [-0.1216, -0.5294, -0.5059,  ...,  0.0667, -0.3020, -0.5608],
        [-0.0588,  0.0275, -0.0275,  ..., -0.0118, -0.2314,  0.2706]],
       device='cuda:0')
tensor([[ 0.6078,  0.6078,  0.6078,  ...,  0.5843,  0.5686,  0.5608],
        [ 0.6157,  0.6157,  0.6078,  ...,  0.5765,  0.5608,  0.5608],
        [ 0.6235,  0.6235,  0.6078,  ...,  0.5686,  0.5608,  0.5765],
        ...,
        [-0.1294, -0.2392, -0.3490,  ...,  0.1216, -0.3098, -0.3961],
        [-0.1451, -0.5294, -0.4588,  ...,  0.0980, -0.2941, -0.5608],
        [-0.0824,  0.0275,  0.0196,  ..., -0.0039, -0.2314,  0.2627]],
       device='cuda:0')
tensor([[-0.0353, -0.0196,  0.0039,  ..., -0.7961, -0.7882, -0.7882],
        [-0.03

In [6]:
# the following import is required for training to be robust to truncated images
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


def train(n_epochs, loaders, model, optimizer, criterion, use_cuda, save_path):
    """returns trained model"""
    # initialize tracker for minimum validation loss
    valid_loss_min = np.Inf 
    
    for epoch in range(1, n_epochs+1):

        # initialize variables to monitor training and validation loss
        train_loss = 0.0
        valid_loss = 0.0
        ###################
        # train the model #
        ###################
        model.train()
        for batch_idx, (data, target) in enumerate(loaders['train']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            ## find the loss and update the model parameters accordingly
            optimizer.zero_grad()
            output = model(data)  #output is tensor of shape([batch_size, num_classes]) where largest value is the prediction
            loss = criterion(output, target) # loss is the cross-entropy loss which measures how far the prediction is from the actual target
            loss.backward()  # calculating the gradients for all operations
            optimizer.step() #performing gradient descent step
            train_loss += loss.item()

    
        ######################    
        # validate the model #
        ######################
        model.eval()
        for batch_idx, (data, target) in enumerate(loaders['valid']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            with torch.no_grad():
                output = model(data)
                loss = criterion(output, target)
                valid_loss += loss.item()
            
        train_loss = train_loss / len(train_loader)
        valid_loss = valid_loss / len(valid_loader)
        
            
        # print training/validation statistics 
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            epoch, 
            train_loss,
            valid_loss
            ))
        
        
        ## TODO: save the model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print('Validation Loss Decreased. Saving model')
            torch.save(model.state_dict(), 'model_scratch.pt')
            valid_loss_min = valid_loss
    # return trained model
    return model



In [7]:
# train the model
model_scratch = train(10, loaders_scratch, model_scratch, optimizer_scratch, 
                      criterion_scratch, use_cuda, 'model_scratch.pt')



Epoch: 1 	Training Loss: 5.363417 	Validation Loss: 4.876212
True
Validation Loss Decreased. Saving model
Epoch: 2 	Training Loss: 4.888124 	Validation Loss: 4.870528
True
Validation Loss Decreased. Saving model
Epoch: 3 	Training Loss: 4.909029 	Validation Loss: 4.868946
True
Validation Loss Decreased. Saving model
Epoch: 4 	Training Loss: 4.878525 	Validation Loss: 4.867552
True
Validation Loss Decreased. Saving model
Epoch: 5 	Training Loss: 4.878072 	Validation Loss: 4.867843
True
Epoch: 6 	Training Loss: 4.872267 	Validation Loss: 4.868251
True
Epoch: 7 	Training Loss: 4.866718 	Validation Loss: 4.868318
True
Epoch: 8 	Training Loss: 4.867238 	Validation Loss: 4.867780
True
Epoch: 9 	Training Loss: 4.866567 	Validation Loss: 4.868664
True
Epoch: 10 	Training Loss: 4.866593 	Validation Loss: 4.868259
True


In [None]:
# load the model that got the best validation accuracy
model_scratch.load_state_dict(torch.load('model_scratch.pt'))

In [None]:
def test(model, )
    for batch_idx, (data, target) in enumerate(test_loader):
        if use_cuda():
            data, target = data.cuda(), target.cuda()
        output = model_scratch(data)

In [9]:


# load the model that got the best validation accuracy
model_scratch.load_state_dict(torch.load('model_scratch.pt'))

Epoch: 1 	Training Loss: 8.555840 	Validation Loss: 4.883201
Validation Loss Decreased. Saving model


KeyboardInterrupt: 