In [1]:
import os
import argparse

import easydict

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchmetrics.aggregation import MeanMetric
from torchmetrics.functional.classification import accuracy

from src.models import ConvNet
from src.engines import train, evaluate
from src.utils import load_checkpoint, save_checkpoint

# Jupyter 환경
args = easydict.EasyDict({
        "title" : "augmentation",
        "device" : "cuda",
        "root" : "data",
        "batch_size" : 128,
        "num_workers" : 2,
        "epochs" : 100,
        "lr" : 0.001,
        "weight_decay": 0.0001,
        "label_smoothing": 0.05,
        "drop_rate": 0.1,
        "logs": "logs",
        "checkpoints": "checkpoints",
        "resume": False
    })

def main(args):
    # Build dataset
    train_transform = T.Compose([
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    train_data = CIFAR10(args.root, train=True, download=True, transform=train_transform)
    train_loader = DataLoader(train_data, args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True)

    val_transform = T.Compose([
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    val_data = CIFAR10(args.root, train=False, download=True, transform=val_transform)
    val_loader = DataLoader(val_data, batch_size=args.batch_size, num_workers=args.num_workers)

    # Build model
    model = ConvNet()
    model = model.to(args.device)

    # Build optimizer 
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Build scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(train_loader))

    # Build loss function
    loss_fn = nn.CrossEntropyLoss()

    # Build metric function
    metric_fn = accuracy

    # Build logger
    train_logger = SummaryWriter(f'{args.logs}/train/{args.title}')
    val_logger = SummaryWriter(f'{args.logs}/val/{args.title}')

    # Load model
    start_epoch = 0
    if args.resume:
        start_epoch = load_checkpoint(args.checkpoints, args.title, model, optimizer)
    
    # Main loop
    for epoch in range(start_epoch, args.epochs):
        # train one epoch
        train_summary = train(train_loader, model, optimizer, scheduler, loss_fn, metric_fn, args.device)
        
        # evaluate one epoch
        val_summary = evaluate(val_loader, model, loss_fn, metric_fn, args.device)

        # write log
        train_logger.add_scalar('Loss', train_summary['loss'], epoch + 1)
        train_logger.add_scalar('Accuracy', train_summary['metric'], epoch + 1)
        val_logger.add_scalar('Loss', val_summary['loss'], epoch + 1)
        val_logger.add_scalar('Accuracy', val_summary['metric'], epoch + 1)

        # save model
        save_checkpoint(args.checkpoints, args.title, model, optimizer, epoch + 1)

In [2]:
!pip install torchsummary
from torchsummary import summary as summary

model = ConvNet()
model = model.to(args.device)
summary(model, (3,32,32)) # (model, input_size)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 16, 16]           2,688
              ReLU-2           [-1, 96, 16, 16]               0
            Conv2d-3          [-1, 192, 16, 16]         166,080
              ReLU-4          [-1, 192, 16, 16]               0
            Conv2d-5            [-1, 384, 8, 8]         663,936
              ReLU-6            [-1, 384, 8, 8]               0
            Conv2d-7            [-1, 384, 8, 8]       1,327,488
              ReLU-8            [-1, 384, 8, 8]               0
           Dropout-9                  [-1, 384]               0
           Linear-10                   [-1, 10]           3,850
Total params: 2,164,042
Trainable params: 2,164,042
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 1.88
Params size (MB): 8.26
Estimat

In [3]:
if __name__=="__main__":
    main(args)

Files already downloaded and verified
Files already downloaded and verified
tensor([[[[ 0.2848,  0.2073,  0.0328,  ..., -1.5761, -1.6537, -1.7118],
          [ 0.9439,  1.0021,  1.1571,  ..., -1.5180, -1.5567, -1.5761],
          [ 1.4673,  1.0602,  1.2541,  ..., -1.0527, -1.0527, -1.3047],
          ...,
          [-1.7118, -1.6343, -1.4404,  ..., -1.8475, -1.8669, -1.9057],
          [-1.8475, -1.7700, -1.6149,  ..., -1.8669, -1.8087, -1.8669],
          [-2.0608, -2.0026, -2.0414,  ..., -1.9832, -2.0220, -2.0608]],

         [[ 0.3941,  0.3351,  0.2761,  ..., -0.5499, -0.7662, -1.0809],
          [ 0.7874,  0.8071,  0.9841,  ..., -0.6286, -0.7662, -0.9432],
          [ 1.2398,  0.7284,  0.9054,  ..., -0.2942, -0.4712, -0.7466],
          ...,
          [-1.6316, -1.5136, -1.2776,  ..., -1.9856, -2.0053, -2.0643],
          [-1.9069, -1.7692, -1.5726,  ..., -1.9856, -1.9069, -1.9463],
          [-2.1429, -2.0446, -2.0053,  ..., -2.0446, -2.1036, -2.1429]],

         [[ 1.5831,  1.310

tensor([[[[-0.6069, -0.7813, -1.1690,  ..., -1.5180, -1.5180, -1.5180],
          [-0.9170, -0.8589, -1.0915,  ..., -1.5180, -1.5180, -1.5374],
          [-0.8977, -0.7813, -1.0334,  ..., -1.4986, -1.4986, -1.5374],
          ...,
          [-1.0334, -1.5180, -1.3629,  ..., -1.4404, -1.4598, -1.5374],
          [-0.8783, -1.0140, -1.2854,  ..., -1.4017, -1.4211, -1.4986],
          [-0.6844, -0.3161, -0.7232,  ..., -1.4017, -1.4211, -1.4792]],

         [[-0.7072, -0.6876, -0.7859,  ..., -1.1596, -1.1792, -1.1792],
          [-1.0219, -0.7466, -0.7072,  ..., -1.1792, -1.1792, -1.1989],
          [-1.0022, -0.6679, -0.6089,  ..., -1.1596, -1.1596, -1.1989],
          ...,
          [-1.1792, -1.4349, -1.3759,  ..., -1.1202, -1.1596, -1.2382],
          [-1.0809, -1.1006, -1.3956,  ..., -1.0809, -1.1202, -1.1989],
          [-0.8842, -0.5696, -0.9432,  ..., -1.0612, -1.1202, -1.1792]],

         [[-1.4800, -1.8897, -2.1824,  ..., -2.2019, -2.2214, -2.2214],
          [-1.7727, -2.0068, -

tensor([[[[ 6.3375e-01,  4.5929e-01,  5.7560e-01,  ...,  6.1437e-01,
            7.1129e-01,  7.3068e-01],
          [ 7.3068e-01,  6.9191e-01,  7.5006e-01,  ...,  6.7252e-01,
            8.8576e-01,  7.5006e-01],
          [ 7.6945e-01,  9.2452e-01,  6.1437e-01,  ...,  8.2760e-01,
            8.0822e-01,  7.1129e-01],
          ...,
          [ 5.3683e-01,  5.7560e-01,  7.5006e-01,  ...,  1.1184e+00,
            1.2347e+00,  1.1765e+00],
          [ 6.3375e-01,  5.3683e-01,  7.1129e-01,  ...,  1.1184e+00,
            1.4673e+00,  1.0408e+00],
          [ 6.7252e-01,  6.7252e-01,  7.6945e-01,  ...,  1.2153e+00,
            1.0990e+00,  8.4699e-01]],

         [[ 5.7111e-01,  3.9410e-01,  4.9244e-01,  ...,  5.5144e-01,
            6.4977e-01,  6.3011e-01],
          [ 6.3011e-01,  5.9077e-01,  6.6944e-01,  ...,  5.9077e-01,
            8.2677e-01,  6.4977e-01],
          [ 6.8911e-01,  8.4644e-01,  5.3177e-01,  ...,  7.0877e-01,
            6.8911e-01,  5.7111e-01],
          ...,
     

tensor([[[[ 0.7113,  0.5562,  0.6144,  ...,  0.6144,  0.6144,  0.6144],
          [ 0.8470,  0.6919,  0.7694,  ...,  0.6531,  0.6531,  0.6338],
          [ 0.7501,  0.6144,  0.6919,  ...,  0.6919,  0.6919,  0.6531],
          ...,
          [-2.1189, -2.1577, -2.0220,  ..., -1.1497, -1.1109, -1.1884],
          [-2.1189, -2.2158, -2.2352,  ..., -1.2272, -1.0915, -1.1690],
          [-2.1577, -2.3515, -2.3128,  ..., -1.4211, -1.2854, -1.3047]],

         [[ 1.2594,  1.2201,  1.2398,  ...,  1.2201,  1.2594,  1.1808],
          [ 1.3184,  1.2988,  1.3184,  ...,  1.2988,  1.3184,  1.2594],
          [ 1.3578,  1.3381,  1.3971,  ...,  1.3184,  1.3184,  1.2398],
          ...,
          [-1.5726, -1.4742, -1.3956,  ..., -1.0612, -1.0809, -1.1596],
          [-1.6119, -1.5529, -1.6512,  ..., -1.1399, -1.0809, -1.1399],
          [-1.7496, -1.7692, -1.7889,  ..., -1.3169, -1.2186, -1.2382]],

         [[ 2.1099,  2.0904,  1.9928,  ...,  2.0123,  1.9928,  1.9928],
          [ 2.0904,  2.0709,  

tensor([[[[-9.3643e-01, -9.5581e-01, -9.1704e-01,  ..., -1.1109e+00,
           -1.3047e+00, -1.4986e+00],
          [-1.0527e+00, -1.1303e+00, -1.0915e+00,  ..., -1.3241e+00,
           -1.5761e+00, -1.7312e+00],
          [-1.3629e+00, -1.4211e+00, -1.3047e+00,  ..., -1.2272e+00,
           -1.4017e+00, -1.4986e+00],
          ...,
          [-1.2078e+00, -1.2272e+00, -1.1303e+00,  ..., -1.2078e+00,
           -1.3047e+00, -1.3629e+00],
          [-7.0381e-01, -6.6504e-01, -7.0381e-01,  ..., -1.1690e+00,
           -1.1109e+00, -1.1497e+00],
          [-4.3242e-01, -3.5488e-01, -6.4565e-01,  ..., -1.0915e+00,
           -9.3643e-01, -9.3643e-01]],

         [[-1.9590e-01, -2.1557e-01, -1.7623e-01,  ..., -3.7290e-01,
           -5.6957e-01, -7.6624e-01],
          [-3.1390e-01, -4.1224e-01, -3.5324e-01,  ..., -5.8924e-01,
           -8.4491e-01, -1.0022e+00],
          [-6.2857e-01, -6.8757e-01, -5.6957e-01,  ..., -4.9090e-01,
           -6.6791e-01, -7.6624e-01],
          ...,
     

tensor([[[[ 0.9051,  1.0021,  1.1571,  ...,  0.9827,  0.9051,  0.6725],
          [ 0.2654,  0.2848,  0.4399,  ...,  0.2073,  0.2654,  0.6919],
          [-0.7620, -0.8201, -0.6650,  ..., -1.1884, -0.8783,  0.5562],
          ...,
          [ 0.8082,  1.0021,  1.1765,  ...,  0.5756,  0.5368,  0.7694],
          [ 0.7694,  0.9827,  1.0990,  ...,  0.7113,  0.7113,  0.8276],
          [ 0.7888,  1.0408,  1.3316,  ...,  0.9245,  0.8664,  1.0214]],

         [[ 0.9644,  1.0824,  1.2004,  ...,  1.0824,  0.9644,  0.7481],
          [-0.1566, -0.0779,  0.0204,  ...,  0.0204,  0.1778,  0.7284],
          [-1.6316, -1.6316, -1.5136,  ..., -1.7102, -1.1399,  0.5711],
          ...,
          [-0.4516, -0.2156, -0.1566,  ..., -0.7269, -0.7662, -0.5106],
          [-0.5302, -0.2352, -0.1566,  ..., -0.5499, -0.6482, -0.4909],
          [-0.8252, -0.4909, -0.2156,  ..., -0.5892, -0.7466, -0.7466]],

         [[ 1.1149,  1.2905,  1.3685,  ...,  1.2319,  1.2905,  1.1539],
          [-0.0948,  0.0028,  

tensor([[[[-5.0996e-01, -3.5488e-01, -5.4873e-01,  ..., -2.7734e-01,
           -1.8042e-01, -4.4721e-02],
          [-4.7119e-01, -5.4873e-01, -9.3643e-01,  ..., -2.1919e-01,
           -2.1919e-01, -1.4165e-01],
          [-6.0688e-01, -8.7827e-01, -1.1303e+00,  ..., -2.5796e-01,
           -1.0288e-01, -5.9512e-03],
          ...,
          [-9.7520e-01, -9.7520e-01, -9.7520e-01,  ..., -1.1303e+00,
           -1.1303e+00, -1.0527e+00],
          [-1.2078e+00, -1.1497e+00, -1.1690e+00,  ..., -1.0721e+00,
           -1.1690e+00, -1.1109e+00],
          [-1.2854e+00, -1.2272e+00, -1.2660e+00,  ..., -9.7520e-01,
           -1.1303e+00, -1.1497e+00]],

         [[-3.1390e-01, -1.5657e-01, -4.3190e-01,  ..., -7.7900e-02,
            7.6703e-04,  1.5810e-01],
          [-2.5490e-01, -3.7290e-01, -8.4491e-01,  ..., -5.8233e-02,
           -5.8233e-02,  2.0434e-02],
          [-4.7124e-01, -7.8591e-01, -1.1202e+00,  ..., -1.1723e-01,
            2.0434e-02,  1.3844e-01],
          ...,
     

tensor([[[[-1.2466, -0.9946, -1.1884,  ...,  0.4593,  0.3817,  0.2654],
          [-1.4211, -1.1690, -0.8977,  ...,  0.4399,  0.3817,  0.2654],
          [-1.4404, -1.4017, -1.0721,  ...,  0.3430,  0.2848,  0.2073],
          ...,
          [-1.7118, -1.7118, -1.6924,  ..., -0.8783, -0.8977, -0.8589],
          [-1.7312, -1.7118, -1.6924,  ..., -0.9558, -0.8977, -0.9558],
          [-1.7506, -1.7312, -1.6924,  ..., -1.0334, -0.9946, -1.0140]],

         [[-1.5136, -1.2579, -1.4546,  ..., -0.6286, -0.6679, -0.7662],
          [-1.6512, -1.3956, -1.1202,  ..., -0.6089, -0.6679, -0.7269],
          [-1.6316, -1.5922, -1.2776,  ..., -0.6679, -0.7072, -0.7466],
          ...,
          [-1.8086, -1.8086, -1.7692,  ..., -1.4152, -1.3562, -1.2972],
          [-1.8086, -1.7889, -1.7692,  ..., -1.4742, -1.4349, -1.4349],
          [-1.8282, -1.8086, -1.7889,  ..., -1.5332, -1.5136, -1.4939]],

         [[-1.2264, -0.9727, -1.1678,  ..., -0.7776, -0.8167, -0.9142],
          [-1.3434, -1.0898, -

tensor([[[[-1.4792e+00, -1.4211e+00, -1.2272e+00,  ..., -1.0140e+00,
           -1.1497e+00, -1.4211e+00],
          [-1.4404e+00, -1.3823e+00, -1.3241e+00,  ..., -9.1704e-01,
           -1.0140e+00, -1.3435e+00],
          [-1.3823e+00, -1.3241e+00, -1.3047e+00,  ..., -8.0073e-01,
           -8.7827e-01, -1.2272e+00],
          ...,
          [-4.1303e-01,  7.5006e-01, -1.2272e+00,  ..., -1.3629e+00,
           -1.0140e+00, -7.6196e-01],
          [-5.6811e-01,  1.2928e+00,  3.0421e-01,  ..., -1.2660e+00,
           -1.0721e+00, -9.5581e-01],
          [-7.2319e-01,  1.2347e+00,  1.6999e+00,  ..., -1.1690e+00,
           -1.1497e+00, -1.1303e+00]],

         [[-1.4349e+00, -1.3759e+00, -1.2776e+00,  ..., -1.0219e+00,
           -1.1202e+00, -1.3759e+00],
          [-1.3956e+00, -1.3169e+00, -1.2776e+00,  ..., -9.2357e-01,
           -9.8258e-01, -1.2972e+00],
          [-1.3366e+00, -1.2579e+00, -1.2382e+00,  ..., -8.2524e-01,
           -8.4491e-01, -1.1989e+00],
          ...,
     

tensor([[[[-1.8087e+00, -1.7894e+00, -1.7506e+00,  ..., -1.4598e+00,
           -1.3435e+00, -1.2660e+00],
          [-1.8087e+00, -1.7894e+00, -1.7506e+00,  ..., -1.4211e+00,
           -1.3241e+00, -1.2466e+00],
          [-1.8087e+00, -1.7894e+00, -1.7506e+00,  ..., -1.4017e+00,
           -1.3047e+00, -1.2078e+00],
          ...,
          [-1.3629e+00, -1.4211e+00, -1.4404e+00,  ..., -1.7894e+00,
           -1.8281e+00, -1.8087e+00],
          [-1.3435e+00, -1.4017e+00, -1.4404e+00,  ..., -1.8087e+00,
           -1.8281e+00, -1.8475e+00],
          [-1.3047e+00, -1.4017e+00, -1.4598e+00,  ..., -1.8281e+00,
           -1.8475e+00, -1.8863e+00]],

         [[-1.6906e+00, -1.6709e+00, -1.6316e+00,  ..., -1.1596e+00,
           -1.0809e+00, -1.0416e+00],
          [-1.6906e+00, -1.6709e+00, -1.6316e+00,  ..., -1.1202e+00,
           -1.0416e+00, -1.0219e+00],
          [-1.6906e+00, -1.6709e+00, -1.6316e+00,  ..., -1.1006e+00,
           -1.0416e+00, -9.8258e-01],
          ...,
     

tensor([[[[-1.9832e+00, -2.1964e+00, -2.2158e+00,  ..., -2.0801e+00,
           -2.0220e+00, -1.9251e+00],
          [-1.9832e+00, -2.2352e+00, -2.2546e+00,  ..., -2.1577e+00,
           -2.1189e+00, -1.9832e+00],
          [-2.1771e+00, -2.4097e+00, -2.3709e+00,  ..., -1.9057e+00,
           -1.8669e+00, -1.7312e+00],
          ...,
          [ 1.4673e+00,  1.5448e+00,  2.0295e+00,  ..., -1.9832e+00,
           -1.6537e+00, -1.5761e+00],
          [ 1.9519e+00,  1.7193e+00,  1.7387e+00,  ..., -8.5889e-01,
           -9.7520e-01, -1.8087e+00],
          [ 4.5929e-01,  4.2052e-01, -2.5336e-02,  ..., -3.7426e-01,
           -1.0915e+00, -1.4598e+00]],

         [[-1.9659e+00, -2.1823e+00, -2.2019e+00,  ..., -2.0446e+00,
           -2.0053e+00, -1.9069e+00],
          [-1.9659e+00, -2.2216e+00, -2.2413e+00,  ..., -2.1233e+00,
           -2.0839e+00, -1.9659e+00],
          [-2.1429e+00, -2.3986e+00, -2.3593e+00,  ..., -1.8086e+00,
           -1.6709e+00, -1.5332e+00],
          ...,
     

tensor([[[[ 4.2052e-01,  5.2204e-02,  1.1036e-01,  ...,  1.4913e-01,
            1.6851e-01,  2.4605e-01],
          [-3.3549e-01, -8.3950e-01, -7.8135e-01,  ..., -4.3242e-01,
           -4.1303e-01, -3.7426e-01],
          [-2.7734e-01, -6.6504e-01, -7.8135e-01,  ..., -6.8442e-01,
           -6.6504e-01, -6.0688e-01],
          ...,
          [-6.2627e-01, -1.1690e+00, -1.1109e+00,  ..., -1.1303e+00,
           -1.1884e+00, -1.1303e+00],
          [-4.1303e-01, -9.1704e-01, -8.7827e-01,  ..., -8.9766e-01,
           -9.5581e-01, -8.9766e-01],
          [ 3.8175e-01, -2.5336e-02,  1.3434e-02,  ...,  3.2819e-02,
           -2.5336e-02,  3.2819e-02]],

         [[ 5.9077e-01,  3.1544e-01,  2.9577e-01,  ...,  2.5644e-01,
            2.5644e-01,  2.7610e-01],
          [-1.7623e-01, -6.0891e-01, -6.0891e-01,  ..., -2.5490e-01,
           -2.7457e-01, -2.9424e-01],
          [-7.7900e-02, -3.9257e-01, -5.6957e-01,  ..., -4.1224e-01,
           -4.3190e-01, -4.5157e-01],
          ...,
     

tensor([[[[ 2.5141,  2.4753,  2.4365,  ...,  2.4172,  2.3978,  2.4947],
          [ 2.5141,  1.7581,  1.3704,  ...,  1.4867,  1.7968,  2.3590],
          [ 2.4172,  0.9827,  0.4981,  ...,  0.2654,  0.9439,  2.0101],
          ...,
          [ 2.4365,  1.1184, -0.1804,  ...,  0.4205,  1.0602,  2.0682],
          [ 2.4947,  1.5642,  0.8664,  ...,  0.6725,  1.2541,  2.1652],
          [ 2.4947,  2.1652,  2.0682,  ...,  1.7581,  2.0489,  2.3590]],

         [[ 2.5968,  2.4985,  2.5378,  ...,  2.4985,  2.4788,  2.5771],
          [ 2.5575,  1.8101,  1.5938,  ...,  1.7118,  1.9478,  2.4788],
          [ 2.4591,  1.1021,  0.8858,  ...,  0.6891,  1.2594,  2.1641],
          ...,
          [ 2.5378,  1.4168,  0.2761,  ...,  0.8661,  1.3971,  2.2231],
          [ 2.5968,  1.7708,  1.1808,  ...,  0.9448,  1.4758,  2.2821],
          [ 2.5968,  2.2821,  2.2428,  ...,  1.9281,  2.1838,  2.4591]],

         [[ 2.7537,  2.6172,  2.5781,  ...,  2.5976,  2.6172,  2.7537],
          [ 2.7537,  1.8563,  

tensor([[[[-1.9638e+00, -1.8863e+00, -1.8281e+00,  ..., -1.2466e+00,
           -1.2660e+00, -1.3047e+00],
          [-1.9444e+00, -1.8475e+00, -1.7894e+00,  ..., -1.1884e+00,
           -1.1884e+00, -1.2272e+00],
          [-1.9638e+00, -1.8475e+00, -1.7506e+00,  ..., -1.1109e+00,
           -1.1109e+00, -1.1497e+00],
          ...,
          [-1.0721e+00, -9.9458e-01, -1.0334e+00,  ..., -6.4565e-01,
           -7.0381e-01, -6.8442e-01],
          [-7.2319e-01, -4.1303e-01, -5.2934e-01,  ..., -6.8442e-01,
           -7.4258e-01, -7.4258e-01],
          [-9.9458e-01, -7.2319e-01, -7.4258e-01,  ..., -7.6196e-01,
           -8.2012e-01, -8.2012e-01]],

         [[-1.9856e+00, -1.9463e+00, -1.9266e+00,  ..., -1.3562e+00,
           -1.3759e+00, -1.4152e+00],
          [-1.9856e+00, -1.9266e+00, -1.8873e+00,  ..., -1.2972e+00,
           -1.2972e+00, -1.3366e+00],
          [-1.9856e+00, -1.9069e+00, -1.8479e+00,  ..., -1.2776e+00,
           -1.2776e+00, -1.2579e+00],
          ...,
     

tensor([[[[-6.0688e-01, -6.0688e-01, -5.8750e-01,  ..., -1.1497e+00,
           -1.1303e+00, -1.1690e+00],
          [-7.8135e-01, -7.4258e-01, -7.0381e-01,  ..., -1.0334e+00,
           -1.0334e+00, -1.0721e+00],
          [-1.0334e+00, -1.0140e+00, -9.7520e-01,  ..., -1.0140e+00,
           -9.9458e-01, -9.9458e-01],
          ...,
          [-8.7827e-01, -8.7827e-01, -8.7827e-01,  ..., -8.7827e-01,
           -8.9766e-01, -9.1704e-01],
          [-8.5889e-01, -8.0073e-01, -7.8135e-01,  ..., -1.0527e+00,
           -1.0334e+00, -1.0140e+00],
          [-8.7827e-01, -8.2012e-01, -7.8135e-01,  ..., -1.1884e+00,
           -1.2078e+00, -1.1690e+00]],

         [[-4.3190e-01, -4.3190e-01, -4.1224e-01,  ..., -1.0219e+00,
           -1.0022e+00, -1.0022e+00],
          [-6.2857e-01, -5.8924e-01, -5.4990e-01,  ..., -9.0391e-01,
           -9.0391e-01, -9.0391e-01],
          [-8.8424e-01, -8.6457e-01, -8.2524e-01,  ..., -8.8424e-01,
           -8.6457e-01, -8.6457e-01],
          ...,
     

tensor([[[[ 0.1685,  1.4867,  0.9051,  ..., -2.4291, -0.9170,  0.5368],
          [ 0.6338,  1.5836,  1.0408,  ..., -2.4097, -1.2078, -0.0253],
          [ 0.8082,  0.9633,  0.7501,  ..., -2.3515, -1.3047, -0.0447],
          ...,
          [ 2.4753,  2.4172,  2.4172,  ..., -1.2854, -1.2660, -1.2854],
          [ 2.3978,  2.4172,  2.4559,  ..., -1.3629, -1.3047, -1.3629],
          [ 1.7581,  1.7968,  1.8356,  ..., -1.4211, -1.3823, -1.3241]],

         [[ 0.0991,  1.4561,  0.8661,  ..., -2.3593, -0.9236,  0.6104],
          [ 0.6301,  1.6134,  1.0824,  ..., -2.2806, -1.2579, -0.1762],
          [ 0.8268,  0.9251,  0.7284,  ..., -2.2806, -1.2776, -0.0976],
          ...,
          [ 2.4985,  2.4395,  2.4591,  ..., -1.1202, -1.1792, -1.3366],
          [ 2.4001,  2.4395,  2.4985,  ..., -1.1006, -1.0809, -1.1399],
          [ 1.7904,  1.8691,  1.9675,  ..., -1.2186, -1.1792, -1.1202]],

         [[ 0.3149,  1.6221,  1.0563,  ..., -2.2214, -0.9922, -0.0167],
          [ 0.8807,  1.9538,  

tensor([[[[-0.3743, -0.3355, -0.4906,  ..., -1.3435, -1.8087, -2.0026],
          [-0.6844, -0.4906, -0.1029,  ..., -1.5567, -1.9638, -1.8475],
          [-0.6457, -0.3161, -0.1029,  ..., -1.6731, -1.9638, -1.8475],
          ...,
          [ 0.0134,  0.4399, -0.2386,  ..., -0.5487, -1.3047, -1.3047],
          [-0.6069, -0.2386,  0.5950,  ..., -0.4324, -0.8589, -0.5487],
          [-0.8395, -0.9946, -0.3161,  ..., -0.2773, -0.6650, -0.3743]],

         [[-0.3336, -0.3926, -0.4909,  ..., -1.4152, -1.9069, -2.1233],
          [-0.5696, -0.4712, -0.1762,  ..., -1.5332, -1.9856, -1.8282],
          [-0.4516, -0.2549, -0.2156,  ..., -1.6709, -2.0053, -1.7692],
          ...,
          [-0.0779,  0.1581, -0.5499,  ..., -0.9826, -1.5922, -1.5726],
          [-0.6286, -0.3926,  0.3154,  ..., -0.7269, -1.0612, -0.7072],
          [-0.9236, -1.1596, -0.6482,  ..., -0.5302, -0.9039, -0.5302]],

         [[-1.0118, -1.0313, -1.2654,  ..., -1.6556, -1.9873, -2.1238],
          [-1.3825, -1.2459, -

tensor([[[[-1.2078, -1.2078, -1.2272,  ..., -1.0140, -1.0334, -1.0334],
          [-1.2466, -1.2660, -1.2660,  ..., -0.9364, -0.9752, -0.9946],
          [-1.2660, -1.3241, -1.3435,  ..., -0.8201, -0.8395, -0.8589],
          ...,
          [ 0.6919,  0.5756,  0.5562,  ..., -1.0140, -0.4712,  0.0522],
          [ 0.5950,  0.5562,  0.5562,  ..., -0.5487, -0.3161, -0.0253],
          [ 0.6144,  0.5368,  0.5368,  ..., -0.1998, -0.2386, -0.0835]],

         [[-0.9432, -0.9236, -0.8842,  ..., -0.6286, -0.6482, -0.6286],
          [-0.9039, -0.8842, -0.8646,  ..., -0.4909, -0.4909, -0.5106],
          [-0.8646, -0.8646, -0.8646,  ..., -0.2746, -0.2942, -0.3139],
          ...,
          [ 0.5711,  0.4531,  0.3941,  ..., -1.1399, -0.5499, -0.0189],
          [ 0.5514,  0.5121,  0.5318,  ..., -0.7466, -0.5106, -0.1959],
          [ 0.6104,  0.5121,  0.5318,  ..., -0.4122, -0.4516, -0.2746]],

         [[-0.8167, -0.7971, -0.7776,  ..., -0.5435, -0.5630, -0.5435],
          [-0.6996, -0.6996, -

tensor([[[[ 1.2153,  0.7113,  0.7888,  ...,  0.9827,  0.9245,  1.4479],
          [-0.6650, -1.8669, -1.6731,  ..., -1.1884, -1.3047, -0.1610],
          [-0.8589, -2.1383, -1.9638,  ..., -1.4792, -1.5567, -0.3936],
          ...,
          [ 0.1104, -0.8589, -0.6844,  ..., -0.0835, -0.6844,  0.2267],
          [ 2.2039,  2.0489,  2.0876,  ...,  2.1458,  2.0489,  2.2233],
          [ 2.5141,  2.4559,  2.4753,  ...,  2.4559,  2.4559,  2.5141]],

         [[ 1.2988,  0.7874,  0.8661,  ...,  1.0234,  0.9644,  1.5151],
          [-0.6089, -1.8282, -1.6316,  ..., -1.1792, -1.2972, -0.1369],
          [-0.8056, -2.1036, -1.9266,  ..., -1.4546, -1.5332, -0.3729],
          ...,
          [ 0.0794, -0.9236, -0.7466,  ..., -0.1566, -0.7662,  0.1974],
          [ 2.2231,  2.0461,  2.0855,  ...,  2.1248,  2.0265,  2.2231],
          [ 2.5771,  2.4985,  2.5181,  ...,  2.4788,  2.4788,  2.5771]],

         [[ 1.3685,  0.8612,  0.9393,  ...,  1.2905,  1.2124,  1.7392],
          [-0.5240, -1.7141, -

tensor([[[[-1.0140, -1.1690, -1.3629,  ..., -1.6924, -1.5567, -1.5180],
          [ 0.4787, -0.2192, -0.3355,  ..., -1.9832, -1.7894, -1.5567],
          [ 1.4285,  0.7307,  0.1491,  ..., -2.0801, -2.0995, -1.8475],
          ...,
          [-0.1416, -0.8201, -0.6069,  ...,  1.4285, -0.0447, -1.2078],
          [-0.1416, -0.1029, -0.1610,  ...,  1.8744,  0.6338, -0.9752],
          [ 1.3510,  1.4091,  0.4399,  ...,  1.3704,  0.5368, -0.3743]],

         [[-0.9629, -1.1006, -1.2776,  ..., -1.1792, -1.0219, -0.9432],
          [ 0.3744, -0.2549, -0.2746,  ..., -1.4349, -1.1989, -1.0022],
          [ 1.2004,  0.6104,  0.1384,  ..., -1.7299, -1.7102, -1.4546],
          ...,
          [-0.4516, -1.0416, -0.7662,  ...,  1.3381, -0.0976, -1.0416],
          [-0.1959, -0.0976, -0.0779,  ...,  1.8101,  0.6498, -0.8056],
          [ 1.3578,  1.4561,  0.5121,  ...,  1.2988,  0.5711, -0.2352]],

         [[-1.0118, -1.1288, -1.2654,  ..., -1.9092, -1.7727, -1.6751],
          [ 0.3149, -0.3094, -

tensor([[[[ 2.4172,  2.3396,  2.3396,  ...,  2.3590,  2.3590,  2.3590],
          [ 2.4947,  2.3978,  2.3978,  ...,  2.4172,  2.4172,  2.4172],
          [ 2.4947,  2.4172,  2.3978,  ...,  2.4172,  2.4172,  2.3978],
          ...,
          [ 2.0489,  1.9907,  1.9713,  ...,  1.8356,  1.8356,  1.8356],
          [ 2.0489,  1.9907,  1.9907,  ...,  1.8550,  1.8550,  1.8550],
          [ 2.0489,  1.9907,  1.9907,  ...,  1.8744,  1.8744,  1.8744]],

         [[ 2.5968,  2.4985,  2.4985,  ...,  2.4001,  2.4001,  2.4001],
          [ 2.5968,  2.5378,  2.5378,  ...,  2.4591,  2.4591,  2.4591],
          [ 2.5968,  2.5181,  2.4985,  ...,  2.4591,  2.4591,  2.4395],
          ...,
          [ 2.0855,  2.0265,  2.0068,  ...,  1.8691,  1.8691,  1.8691],
          [ 2.0855,  2.0265,  2.0265,  ...,  1.8888,  1.8888,  1.8888],
          [ 2.0855,  2.0265,  2.0265,  ...,  1.9085,  1.9085,  1.9085]],

         [[ 2.7537,  2.6757,  2.6757,  ...,  2.6172,  2.6172,  2.5976],
          [ 2.7537,  2.7342,  

tensor([[[[-0.6844, -0.7232, -0.7232,  ..., -1.7506, -1.8863, -2.1189],
          [-0.5487, -0.6069, -0.5681,  ..., -1.7894, -1.8475, -2.0801],
          [-0.5487, -0.5681, -0.5487,  ..., -1.8087, -1.8475, -2.0414],
          ...,
          [-0.4324,  0.4205,  1.5642,  ..., -2.3515, -2.3515, -2.3709],
          [ 1.4285,  1.9907,  2.2621,  ..., -2.3321, -2.3515, -2.3515],
          [ 2.2427,  2.2427,  2.2039,  ..., -2.3321, -2.3515, -2.3515]],

         [[-1.2579, -1.3366, -1.2776,  ..., -1.6512, -1.8282, -2.0446],
          [-1.2972, -1.3169, -1.1792,  ..., -1.7102, -1.7889, -2.0446],
          [-1.3169, -1.2382, -1.1399,  ..., -1.7102, -1.7889, -2.0053],
          ...,
          [-0.7662,  0.1581,  1.4168,  ..., -2.3396, -2.3396, -2.3593],
          [ 1.3184,  1.9675,  2.2035,  ..., -2.3199, -2.3396, -2.3396],
          [ 2.2428,  2.2625,  2.1641,  ..., -2.3199, -2.3396, -2.3396]],

         [[-1.3044, -1.3629, -1.3629,  ..., -1.4410, -1.6166, -1.8702],
          [-1.3044, -1.3629, -

tensor([[[[ 7.6945e-01,  7.6945e-01,  8.0822e-01,  ...,  7.1129e-01,
            5.9498e-01,  5.5621e-01],
          [ 6.7252e-01,  7.5006e-01,  8.4699e-01,  ...,  6.5314e-01,
            6.9191e-01,  6.5314e-01],
          [ 7.6945e-01,  7.8883e-01,  8.0822e-01,  ...,  7.5006e-01,
            6.3375e-01,  6.9191e-01],
          ...,
          [ 3.0421e-01,  2.6544e-01,  9.8268e-01,  ...,  5.9498e-01,
            5.7560e-01,  6.5314e-01],
          [ 9.2452e-01,  4.7867e-01,  8.6637e-01,  ...,  6.7252e-01,
            4.9806e-01,  5.1744e-01],
          [ 1.0214e+00,  8.4699e-01,  8.6637e-01,  ...,  7.3068e-01,
            6.1437e-01,  4.7867e-01]],

         [[ 5.9768e-02,  7.6703e-04,  7.6703e-04,  ..., -9.7567e-02,
           -1.9590e-01, -2.3524e-01],
          [-3.8567e-02, -1.8900e-02,  5.9768e-02,  ..., -1.3690e-01,
           -9.7567e-02, -1.3690e-01],
          [ 5.9768e-02,  2.0434e-02,  7.6703e-04,  ..., -3.8567e-02,
           -1.5657e-01, -9.7567e-02],
          ...,
     

tensor([[[[-2.4097e+00, -2.4097e+00, -2.3903e+00,  ..., -2.3321e+00,
           -2.3515e+00, -2.3321e+00],
          [-2.4097e+00, -2.3903e+00, -2.4097e+00,  ..., -2.3321e+00,
           -2.3515e+00, -2.3515e+00],
          [-2.4097e+00, -2.4097e+00, -2.4097e+00,  ..., -2.3515e+00,
           -2.3321e+00, -2.3515e+00],
          ...,
          [-2.3709e+00, -2.3709e+00, -2.3903e+00,  ..., -2.3321e+00,
           -2.3321e+00, -2.3709e+00],
          [-2.3903e+00, -2.3903e+00, -2.3709e+00,  ..., -2.3709e+00,
           -2.3515e+00, -2.3321e+00],
          [-2.3903e+00, -2.3903e+00, -2.3709e+00,  ..., -2.3903e+00,
           -2.3709e+00, -2.3321e+00]],

         [[-1.1006e+00, -1.0809e+00, -1.0219e+00,  ..., -1.6316e+00,
           -1.6709e+00, -1.7299e+00],
          [-1.0809e+00, -1.0219e+00, -1.0022e+00,  ..., -1.5922e+00,
           -1.6512e+00, -1.7102e+00],
          [-1.0416e+00, -1.0022e+00, -9.6291e-01,  ..., -1.5922e+00,
           -1.6119e+00, -1.6709e+00],
          ...,
     

tensor([[[[ 1.3316,  0.8858,  0.8664,  ...,  2.4365,  2.4365,  2.5141],
          [-0.6069, -0.2967,  0.4399,  ...,  2.4947,  2.4947,  2.5141],
          [ 0.0522, -0.3743,  0.1297,  ...,  2.4559,  2.4559,  2.5141],
          ...,
          [ 2.5141,  2.4753,  2.4947,  ...,  2.1845,  2.4172,  2.5141],
          [ 2.5141,  2.5141,  2.5141,  ...,  2.4753,  2.4559,  2.5141],
          [ 2.5141,  2.4559,  2.4753,  ...,  2.4365,  2.4172,  2.4947]],

         [[ 1.4168,  0.9644,  0.9841,  ...,  2.4985,  2.5378,  2.5968],
          [-0.5499, -0.1566,  0.6694,  ...,  2.5575,  2.5968,  2.5968],
          [ 0.1188, -0.1959,  0.4334,  ...,  2.5181,  2.5575,  2.5968],
          ...,
          [ 2.5968,  2.5575,  2.5771,  ...,  2.2625,  2.4985,  2.5968],
          [ 2.5968,  2.5968,  2.5968,  ...,  2.5575,  2.5575,  2.5968],
          [ 2.5968,  2.5378,  2.5575,  ...,  2.5181,  2.4985,  2.5771]],

         [[ 1.5246,  1.0954,  1.0758,  ...,  2.6367,  2.6952,  2.7537],
          [-0.4264,  0.0418,  

tensor([[[[-4.5180e-01, -4.7119e-01, -4.9057e-01,  ..., -3.3549e-01,
           -3.7426e-01, -4.1303e-01],
          [-4.1303e-01, -4.3242e-01, -4.3242e-01,  ..., -3.3549e-01,
           -3.5488e-01, -3.9365e-01],
          [-3.5488e-01, -3.5488e-01, -3.3549e-01,  ..., -2.9672e-01,
           -3.1611e-01, -3.3549e-01],
          ...,
          [-1.4792e+00, -1.4598e+00, -1.1497e+00,  ..., -1.1303e+00,
           -1.3047e+00, -1.9832e+00],
          [-1.5955e+00, -1.4986e+00, -1.1884e+00,  ..., -1.3241e+00,
           -1.5955e+00, -2.2740e+00],
          [-1.6924e+00, -1.5567e+00, -1.2272e+00,  ..., -1.5374e+00,
           -1.7118e+00, -2.0801e+00]],

         [[ 2.5644e-01,  2.9577e-01,  3.7444e-01,  ...,  5.7111e-01,
            5.5144e-01,  5.1211e-01],
          [ 2.9577e-01,  3.1544e-01,  4.1377e-01,  ...,  5.9077e-01,
            5.7111e-01,  5.5144e-01],
          [ 3.7444e-01,  3.7444e-01,  3.9410e-01,  ...,  6.6944e-01,
            6.4977e-01,  6.1044e-01],
          ...,
     

tensor([[[[-0.1029, -0.1998, -0.2773,  ..., -0.1804, -0.1610, -0.0447],
          [-0.4130, -0.8783, -1.0721,  ..., -0.4518, -0.4324, -0.3549],
          [-1.0915, -1.2854, -1.3435,  ..., -1.1884, -1.2078, -1.1884],
          ...,
          [-2.3128, -2.2934, -2.2740,  ..., -2.2352, -2.2546, -2.2740],
          [-2.3128, -2.2934, -2.3128,  ..., -2.2158, -2.2158, -2.2546],
          [-2.2934, -2.2934, -2.2740,  ..., -2.2546, -2.2546, -2.2546]],

         [[ 0.4924,  0.3744,  0.2761,  ...,  0.3744,  0.3941,  0.5121],
          [ 0.2958, -0.0976, -0.2549,  ...,  0.2171,  0.2171,  0.3154],
          [-0.2156, -0.3336, -0.2942,  ..., -0.4516, -0.4712, -0.4319],
          ...,
          [-2.2019, -2.1823, -2.1626,  ..., -2.1429, -2.1626, -2.2019],
          [-2.2216, -2.2019, -2.2216,  ..., -2.1626, -2.1626, -2.2019],
          [-2.2413, -2.2413, -2.2413,  ..., -2.2019, -2.2216, -2.2216]],

         [[ 0.9783,  0.8612,  0.7637,  ...,  0.8807,  0.9198,  1.0173],
          [ 0.8807,  0.4515,  

tensor([[[[ 0.7501,  1.6612,  2.2621,  ...,  2.4365,  2.4365,  2.4365],
          [ 0.7501,  1.2153,  1.1765,  ...,  2.5141,  2.5141,  2.5141],
          [ 0.2073,  0.2654, -0.1804,  ...,  2.4753,  2.4753,  2.4753],
          ...,
          [-0.8007, -0.9752, -0.9364,  ..., -0.4324, -0.8201, -1.0334],
          [-0.7232, -0.9558, -0.7813,  ..., -0.7038, -0.8589, -0.9170],
          [-0.8201, -0.8783, -1.0334,  ..., -1.1109, -1.1690, -1.0334]],

         [[ 0.8071,  1.6921,  2.2428,  ...,  2.5181,  2.5181,  2.5181],
          [ 0.7481,  1.0431,  0.7481,  ...,  2.5968,  2.5968,  2.5968],
          [-0.1762, -0.2549, -0.8449,  ...,  2.5575,  2.5575,  2.5575],
          ...,
          [-0.7662, -0.9432, -0.9039,  ..., -0.3926, -0.7859, -1.0022],
          [-0.6876, -0.9236, -0.7466,  ..., -0.6679, -0.8252, -0.8842],
          [-0.7859, -0.8449, -1.0022,  ..., -1.0809, -1.1399, -1.0022]],

         [[ 0.9783,  1.8758,  2.4611,  ...,  2.6757,  2.6757,  2.6757],
          [ 0.9393,  1.2905,  

tensor([[[[-2.2546e+00, -2.1964e+00, -2.1771e+00,  ..., -2.0995e+00,
           -2.0995e+00, -2.0995e+00],
          [-2.1771e+00, -2.1771e+00, -2.1771e+00,  ..., -2.0801e+00,
           -2.0801e+00, -2.0801e+00],
          [-2.1189e+00, -2.1771e+00, -2.1771e+00,  ..., -2.0801e+00,
           -2.0801e+00, -2.0801e+00],
          ...,
          [-1.0334e+00, -8.9766e-01, -7.6196e-01,  ..., -9.5581e-01,
           -7.2319e-01, -6.0688e-01],
          [-1.3823e+00, -1.2272e+00, -1.0334e+00,  ..., -5.6811e-01,
           -5.2934e-01, -5.4873e-01],
          [-1.6537e+00, -1.4986e+00, -1.3241e+00,  ..., -4.5180e-01,
           -4.9057e-01, -6.0688e-01]],

         [[-1.2382e+00, -1.2186e+00, -1.1989e+00,  ..., -9.8258e-01,
           -9.8258e-01, -1.0022e+00],
          [-1.1989e+00, -1.1596e+00, -1.1399e+00,  ..., -9.4324e-01,
           -9.4324e-01, -9.4324e-01],
          [-1.1989e+00, -1.1399e+00, -1.1006e+00,  ..., -9.2357e-01,
           -9.0391e-01, -9.2357e-01],
          ...,
     

tensor([[[[-0.8395, -0.5681, -1.0334,  ..., -0.3743, -0.4324, -0.0060],
          [-1.0334, -0.9558, -0.9170,  ..., -0.4712, -0.4324, -0.6457],
          [-0.9946, -0.8395, -1.1497,  ..., -0.5875, -1.0915, -1.0721],
          ...,
          [-1.9251, -1.9251, -1.8863,  ..., -1.2466, -1.1884, -1.1884],
          [-1.8669, -1.8087, -1.8087,  ..., -1.1497, -1.2660, -1.2466],
          [-1.9057, -1.9251, -1.8863,  ..., -1.2854, -1.0915, -0.9558]],

         [[-0.5106, -0.1959, -0.6679,  ...,  0.0204, -0.1369,  0.2368],
          [-0.7072, -0.6089, -0.5499,  ..., -0.1172, -0.2156, -0.4712],
          [-0.6679, -0.4712, -0.7859,  ..., -0.2549, -0.8842, -0.8842],
          ...,
          [-1.9069, -1.9069, -1.8676,  ..., -1.2579, -1.2776, -1.2972],
          [-1.9069, -1.8479, -1.8479,  ..., -1.2382, -1.3366, -1.3169],
          [-1.9659, -1.9856, -1.9463,  ..., -1.3956, -1.1399, -1.0022]],

         [[-1.4215, -1.3629, -1.7922,  ..., -1.4215, -1.1288, -0.4069],
          [-1.5971, -1.7336, -

tensor([[[[ 2.0876e+00,  2.1264e+00,  2.1845e+00,  ...,  3.2819e-02,
           -1.6924e+00, -2.0608e+00],
          [ 2.0876e+00,  2.1458e+00,  2.2039e+00,  ...,  1.2974e-01,
           -1.7118e+00, -2.0220e+00],
          [ 2.1070e+00,  2.1458e+00,  2.1652e+00,  ...,  5.2204e-02,
           -1.7700e+00, -2.0026e+00],
          ...,
          [-8.7827e-01, -1.3047e+00, -1.5761e+00,  ..., -4.5180e-01,
           -4.1303e-01, -2.9672e-01],
          [-8.0073e-01, -1.2272e+00, -1.4986e+00,  ..., -3.5488e-01,
           -3.5488e-01, -2.3857e-01],
          [-7.6196e-01, -1.1884e+00, -1.4017e+00,  ..., -3.3549e-01,
           -2.7734e-01, -1.8042e-01]],

         [[ 2.1641e+00,  2.2035e+00,  2.2625e+00,  ...,  2.0434e-02,
           -1.6316e+00, -2.0249e+00],
          [ 2.1641e+00,  2.2231e+00,  2.2821e+00,  ...,  1.3844e-01,
           -1.6316e+00, -1.9856e+00],
          [ 2.1838e+00,  2.2231e+00,  2.2428e+00,  ...,  4.0101e-02,
           -1.7102e+00, -1.9659e+00],
          ...,
     

tensor([[[[-1.9251e+00, -1.9251e+00, -1.8863e+00,  ..., -1.5567e+00,
           -1.8281e+00, -1.9638e+00],
          [-1.8669e+00, -1.8087e+00, -1.6924e+00,  ..., -1.6149e+00,
           -1.8475e+00, -1.8475e+00],
          [-1.7700e+00, -1.4986e+00, -1.1690e+00,  ..., -1.5955e+00,
           -1.8863e+00, -1.7118e+00],
          ...,
          [-1.4404e+00, -1.3435e+00, -1.3047e+00,  ..., -1.5955e+00,
           -1.1884e+00, -6.4565e-01],
          [-1.6149e+00, -1.4404e+00, -1.3047e+00,  ..., -1.6537e+00,
           -9.7520e-01, -3.9365e-01],
          [-1.7700e+00, -1.5374e+00, -1.3241e+00,  ..., -1.3241e+00,
           -7.0381e-01, -5.4873e-01]],

         [[-2.0053e+00, -2.0053e+00, -1.9659e+00,  ..., -1.6512e+00,
           -1.9069e+00, -1.9659e+00],
          [-1.9659e+00, -1.8873e+00, -1.7496e+00,  ..., -1.6709e+00,
           -1.8873e+00, -1.8282e+00],
          [-1.8479e+00, -1.5726e+00, -1.2382e+00,  ..., -1.5922e+00,
           -1.8873e+00, -1.7496e+00],
          ...,
     

tensor([[[[ 1.2347,  1.1765,  1.2153,  ...,  1.1571,  1.1184,  1.0990],
          [ 1.2735,  1.2347,  1.2541,  ...,  1.2153,  1.1765,  1.1571],
          [ 1.2541,  1.1959,  1.2347,  ...,  1.1571,  1.1184,  1.0990],
          ...,
          [-1.3047, -1.2854, -1.1497,  ...,  1.8744,  1.6418,  1.4673],
          [-1.2466, -1.4211, -1.5180,  ...,  0.2267, -0.1610, -0.2967],
          [-1.3241, -1.3629, -1.2854,  ..., -1.3435, -1.2272, -1.1497]],

         [[ 2.3608,  2.3018,  2.3215,  ...,  2.2625,  2.2428,  2.2231],
          [ 2.4198,  2.3608,  2.3805,  ...,  2.3411,  2.3018,  2.2821],
          [ 2.3805,  2.3018,  2.3215,  ...,  2.2821,  2.2428,  2.2231],
          ...,
          [-0.2549, -0.1959, -0.0779,  ...,  2.3411,  2.2231,  2.1248],
          [-0.1369, -0.2942, -0.4516,  ...,  0.9644,  0.6498,  0.4531],
          [-0.1959, -0.2352, -0.2352,  ..., -0.4319, -0.2352, -0.2156]],

         [[ 2.5586,  2.4611,  2.4806,  ...,  2.4221,  2.4025,  2.3830],
          [ 2.6172,  2.5196,  

tensor([[[[-8.3491e-02, -4.4721e-02, -5.9512e-03,  ...,  2.2667e-01,
            5.2204e-02, -2.3857e-01],
          [ 1.3434e-02,  1.8790e-01, -1.0288e-01,  ...,  3.4298e-01,
            3.6236e-01,  3.2819e-02],
          [ 2.8482e-01,  3.2359e-01,  1.3434e-02,  ..., -1.9980e-01,
           -1.4165e-01,  2.0728e-01],
          ...,
          [ 1.7581e+00,  1.8938e+00,  1.9132e+00,  ...,  1.9132e+00,
            1.9325e+00,  1.8744e+00],
          [ 1.8938e+00,  1.7775e+00,  1.8550e+00,  ...,  1.8938e+00,
            1.8356e+00,  1.6999e+00],
          [ 1.9519e+00,  2.0876e+00,  2.0101e+00,  ...,  1.9519e+00,
            1.9519e+00,  1.7387e+00]],

         [[-5.8233e-02,  2.0434e-02,  9.9101e-02,  ...,  2.7610e-01,
            9.9101e-02, -1.3690e-01],
          [ 5.9768e-02,  2.5644e-01,  7.6703e-04,  ...,  4.1377e-01,
            4.9244e-01,  1.3844e-01],
          [ 3.5477e-01,  3.9410e-01,  1.1877e-01,  ..., -9.7567e-02,
            4.0101e-02,  3.3510e-01],
          ...,
     

tensor([[[[-2.2158e+00, -2.2352e+00, -2.2546e+00,  ..., -2.2546e+00,
           -2.2352e+00, -2.2158e+00],
          [-2.2158e+00, -2.2352e+00, -2.2546e+00,  ..., -2.2546e+00,
           -2.2546e+00, -2.2158e+00],
          [-2.2158e+00, -2.2352e+00, -2.2546e+00,  ..., -2.2546e+00,
           -2.2546e+00, -2.2158e+00],
          ...,
          [-2.2546e+00, -2.1383e+00, -1.8087e+00,  ..., -2.2158e+00,
           -2.2158e+00, -2.2158e+00],
          [-1.4211e+00, -1.0527e+00, -7.8135e-01,  ..., -2.2158e+00,
           -2.2158e+00, -2.2158e+00],
          [-7.2319e-01, -6.6504e-01, -5.4873e-01,  ..., -2.2158e+00,
           -2.2158e+00, -2.2158e+00]],

         [[-2.2019e+00, -2.2216e+00, -2.2413e+00,  ..., -2.2413e+00,
           -2.2413e+00, -2.2019e+00],
          [-2.2019e+00, -2.2216e+00, -2.2413e+00,  ..., -2.2413e+00,
           -2.2413e+00, -2.2019e+00],
          [-2.2019e+00, -2.2216e+00, -2.2413e+00,  ..., -2.2413e+00,
           -2.2413e+00, -2.2019e+00],
          ...,
     

tensor([[[[ 3.4298e-01,  3.4298e-01,  3.6236e-01,  ...,  1.9132e+00,
            1.9132e+00,  1.9132e+00],
          [ 3.0421e-01,  3.2359e-01,  3.6236e-01,  ...,  1.8356e+00,
            1.8744e+00,  1.9325e+00],
          [ 3.0421e-01,  4.2052e-01,  4.3990e-01,  ...,  1.6999e+00,
            1.7387e+00,  1.8356e+00],
          ...,
          [ 2.2233e+00,  2.2815e+00,  2.2039e+00,  ...,  2.2039e+00,
            2.2039e+00,  2.2427e+00],
          [ 2.1264e+00,  2.1070e+00,  2.1070e+00,  ...,  2.2233e+00,
            2.2233e+00,  2.2427e+00],
          [ 2.0682e+00,  2.0295e+00,  2.0876e+00,  ...,  2.1845e+00,
            2.1845e+00,  2.2039e+00]],

         [[ 1.5810e-01,  1.5810e-01,  1.7777e-01,  ...,  1.9478e+00,
            1.9675e+00,  1.9675e+00],
          [ 1.1877e-01,  1.3844e-01,  1.7777e-01,  ...,  1.9085e+00,
            1.9478e+00,  2.0068e+00],
          [ 1.3844e-01,  2.3677e-01,  2.5644e-01,  ...,  1.8298e+00,
            1.8691e+00,  1.9478e+00],
          ...,
     

tensor([[[[ 2.5141e+00,  2.4947e+00,  2.4947e+00,  ...,  2.4947e+00,
            2.4947e+00,  2.4947e+00],
          [ 2.5141e+00,  2.5141e+00,  2.5141e+00,  ...,  2.5141e+00,
            2.5141e+00,  2.5141e+00],
          [ 2.5141e+00,  2.4947e+00,  2.5141e+00,  ...,  2.5141e+00,
            2.5141e+00,  2.5141e+00],
          ...,
          [ 2.5141e+00,  2.4947e+00,  2.5141e+00,  ...,  2.4365e+00,
            2.4947e+00,  2.5141e+00],
          [ 2.5141e+00,  2.4947e+00,  2.5141e+00,  ...,  2.4947e+00,
            2.5141e+00,  2.5141e+00],
          [ 2.5141e+00,  2.4947e+00,  2.5141e+00,  ...,  2.5141e+00,
            2.5141e+00,  2.5141e+00]],

         [[ 2.5968e+00,  2.5771e+00,  2.5771e+00,  ...,  2.5771e+00,
            2.5771e+00,  2.5771e+00],
          [ 2.5968e+00,  2.5968e+00,  2.5968e+00,  ...,  2.5968e+00,
            2.5968e+00,  2.5968e+00],
          [ 2.5968e+00,  2.5771e+00,  2.5968e+00,  ...,  2.5968e+00,
            2.5968e+00,  2.5968e+00],
          ...,
     

tensor([[[[-2.1383, -1.9638, -1.4792,  ..., -2.2158, -2.3515, -2.3128],
          [-1.7118, -1.5955, -1.3629,  ..., -2.2158, -2.2934, -2.1771],
          [-1.6924, -1.5761, -1.7894,  ..., -1.5955, -1.9444, -1.6924],
          ...,
          [-0.0253,  0.0328,  0.0328,  ...,  0.3430,  0.2073,  0.0328],
          [-0.1029, -0.1223, -0.1610,  ..., -0.0253, -0.1610, -0.1998],
          [-0.3549, -0.1416, -0.3936,  ..., -0.0835, -0.1610, -0.4712]],

         [[-2.0053, -1.9463, -1.3956,  ..., -2.1036, -2.3199, -2.2806],
          [-1.4742, -1.3562, -1.0416,  ..., -2.0643, -2.2019, -2.1036],
          [-1.3759, -1.2186, -1.4349,  ..., -1.4939, -1.7889, -1.5332],
          ...,
          [ 0.5318,  0.5711,  0.5908,  ...,  0.9054,  0.7678,  0.6498],
          [ 0.4138,  0.3941,  0.3548,  ...,  0.5318,  0.3941,  0.3941],
          [ 0.1581,  0.3744,  0.1384,  ...,  0.4728,  0.3941,  0.0598]],

         [[-2.1238, -2.0263, -1.5971,  ..., -2.1238, -2.2214, -2.1824],
          [-1.8117, -1.7141, -

tensor([[[[-0.6263, -0.6263, -0.5100,  ...,  2.4559,  2.4559,  2.4559],
          [-0.9558, -0.9558, -0.7038,  ...,  2.4753,  2.4753,  2.4753],
          [-0.8395, -0.9558, -0.5875,  ...,  2.4753,  2.4559,  2.4559],
          ...,
          [ 1.8550,  1.0796, -1.4017,  ..., -1.5374, -1.5761, -1.5761],
          [ 1.6612,  1.5836,  0.3817,  ..., -1.0140, -1.0915, -1.3047],
          [ 2.0295,  2.0295,  1.9132,  ..., -1.2272, -1.0334, -0.9558]],

         [[-0.3532, -0.3532, -0.2352,  ...,  2.5378,  2.5378,  2.5378],
          [-0.7662, -0.7269, -0.4516,  ...,  2.5575,  2.5378,  2.5575],
          [-0.7072, -0.7859, -0.3729,  ...,  2.5575,  2.5378,  2.5378],
          ...,
          [ 1.9478,  1.0824, -1.4546,  ..., -1.3562, -1.3169, -1.2186],
          [ 1.6921,  1.4758,  0.2564,  ..., -0.7859, -0.7466, -0.8056],
          [ 2.0068,  1.9085,  1.8298,  ..., -0.9629, -0.7072, -0.4319]],

         [[-0.5630, -0.5825, -0.4655,  ...,  2.6952,  2.6952,  2.6952],
          [-0.7971, -0.9142, -

tensor([[[[ 3.6236e-01,  3.2359e-01,  3.2359e-01,  ..., -5.2934e-01,
           -5.6811e-01, -6.0688e-01],
          [ 3.4298e-01,  2.8482e-01,  2.6544e-01,  ..., -4.9057e-01,
           -5.2934e-01, -5.8750e-01],
          [ 3.0421e-01,  2.2667e-01,  2.2667e-01,  ..., -4.9057e-01,
           -5.0996e-01, -5.6811e-01],
          ...,
          [-1.0288e-01, -2.1919e-01, -1.4165e-01,  ..., -5.8750e-01,
           -6.2627e-01, -6.4565e-01],
          [-2.5796e-01, -3.3549e-01, -1.8042e-01,  ..., -7.2319e-01,
           -7.4258e-01, -7.4258e-01],
          [-4.4721e-02, -4.4721e-02,  3.2819e-02,  ..., -7.0381e-01,
           -7.2319e-01, -7.2319e-01]],

         [[ 4.7277e-01,  4.3344e-01,  4.3344e-01,  ..., -3.5324e-01,
           -3.9257e-01, -4.3190e-01],
          [ 4.5310e-01,  3.9410e-01,  3.7444e-01,  ..., -2.9424e-01,
           -3.3357e-01, -3.7290e-01],
          [ 4.1377e-01,  3.3510e-01,  3.3510e-01,  ..., -2.5490e-01,
           -2.7457e-01, -3.3357e-01],
          ...,
     

tensor([[[[-1.9444e+00, -1.3629e+00,  1.6851e-01,  ..., -2.1771e+00,
           -2.0608e+00, -2.0414e+00],
          [-2.1189e+00, -1.4211e+00,  9.2452e-01,  ..., -1.8863e+00,
           -1.9444e+00, -1.9444e+00],
          [-2.1189e+00, -1.6343e+00,  7.5006e-01,  ..., -1.9444e+00,
           -2.1383e+00, -2.1189e+00],
          ...,
          [-2.3515e+00, -2.2546e+00, -2.0801e+00,  ..., -1.6149e+00,
           -1.5955e+00, -1.4986e+00],
          [-2.3709e+00, -2.2934e+00, -2.1771e+00,  ..., -1.6731e+00,
           -1.6924e+00, -1.5180e+00],
          [-2.3709e+00, -2.3321e+00, -2.2934e+00,  ..., -1.7118e+00,
           -1.7312e+00, -1.6537e+00]],

         [[-1.9659e+00, -1.3759e+00, -2.5490e-01,  ..., -2.1036e+00,
           -1.9463e+00, -1.9069e+00],
          [-2.1036e+00, -1.4546e+00,  3.3510e-01,  ..., -1.9463e+00,
           -1.9266e+00, -1.8873e+00],
          [-2.0643e+00, -1.6709e+00, -1.8900e-02,  ..., -1.9659e+00,
           -2.1036e+00, -2.0839e+00],
          ...,
     

KeyboardInterrupt: 