## Model Resnet: modules & parameters

In [1]:
import torch

torch.cuda.is_available()

True

In [2]:
import torch.nn as nn
import numpy

class channel_selection(nn.Module):
    """
    Select channels from the output of BatchNorm2d layer. It should be put directly after BatchNorm2d layer.
    The output shape of this layer is determined by the number of 1 in `self.indexes`.
    """
    def __init__(self, num_channels):
        """
        Initialize the `indexes` with all one vector with the length same as the number of channels.
        During pruning, the places in `indexes` which correpond to the channels to be pruned will be set to 0.
        """
        super(channel_selection, self).__init__()
        self.indexes = nn.Parameter(torch.ones(num_channels))

    def forward(self, input_tensor):
        """
        Parameter
        ---------
        input_tensor: (N,C,H,W). It should be the output of BatchNorm2d layer.
        """
        selected_index = np.squeeze(np.argwhere(self.indexes.data.cpu().numpy()))
        if selected_index.size == 1:
            selected_index = np.resize(selected_index, (1,)) 
        output = input_tensor[:, selected_index, :, :]
        return output

In [3]:
import math
import torch.nn as nn
import numpy as np

"""
preactivation resnet with bottleneck design.
"""


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, cfg, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.select = channel_selection(inplanes)
        self.conv1 = nn.Conv2d(cfg[0], cfg[1], kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(cfg[1])
        self.conv2 = nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(cfg[2])
        self.conv3 = nn.Conv2d(cfg[2], planes * self.expansion, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        # group1
        out = self.bn1(x)
        out = self.select(out)
        out = self.relu(out)
        out = self.conv1(out)

        # group2
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        # group3
        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)

        # down sample
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        return out

    def forward_bn(self, x):
        bn_value = []
        residual = x

        out = self.bn1(x)
        bn_value.append(out.clone())
        out = self.select(out)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        bn_value.append(out.clone())
        out = self.relu(out)
        out = self.conv2(out)

        out = self.bn3(out)
        bn_value.append(out.clone())
        out = self.relu(out)
        out = self.conv3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        return out, bn_value

    def mask_bn(self, index, cfg_mask):
        if index == 0:
            self.bn1.weight.data.mul_(cfg_mask)
            self.bn1.bias.data.mul_(cfg_mask)
        elif index == 1:
            self.bn2.weight.data.mul_(cfg_mask)
            self.bn2.bias.data.mul_(cfg_mask)
        elif index == 2:
            self.bn3.weight.data.mul_(cfg_mask)
            self.bn3.bias.data.mul_(cfg_mask)
        else:
            raise ValueError("Index is not including.")


class resnet(nn.Module):
    def __init__(self, depth=164, dataset='cifar10', cfg=None, conv_cfg=None):
        """
        :param depth:
            164 layers => 1 conv2d + 3 layers × 18 blocks (every layer)  × 3 conv2ds (every block)  + 1 avgPool2d
            param n = (depth - 2) // 9:
                n means how many blocks in every layer
                9 = 3 layers × 3 conv2d (every block)
        :param cfg:
            if depth = 164, then len(cfg) = 164
        :param conv_cfg:
            layer block index examples (index starts at 1 & ≤ 18):
                3 indexes / layer: [4, 9, 14] or [6, 12, 18]
                2 indexes / layer: [9, 18]
                1 index / layer: [18]

        number of BatchNorm2d:
            163 = 162 (3 layers × 18 Bottlenecks × 3 BatchNorm2ds) + 1 BatchNorm2d
        """
        super(resnet, self).__init__()
        assert (depth - 2) % 9 == 0, 'depth should be 9n+2'

        n = (depth - 2) // 9  # depth = 164, n = 18
        block = Bottleneck
        self.block_cfg = conv_cfg
        self.inplanes = 16

        # model config
        if cfg is None:
            cfg = [[16, 16, 16], [64, 16, 16] * (n - 1),
                   [64, 32, 32], [128, 32, 32] * (n - 1),
                   [128, 64, 64], [256, 64, 64] * (n - 1), [256]]
            cfg = [item for sub_list in cfg for item in sub_list]

        # model feature
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False)
        self.layer1 = self._make_layer(block, 16, n, cfg=cfg[0: 3 * n])
        self.layer2 = self._make_layer(block, 32, n, cfg=cfg[3 * n: 6 * n], stride=2)
        self.layer3 = self._make_layer(block, 64, n, cfg=cfg[6 * n:9 * n], stride=2)
        self.bn = nn.BatchNorm2d(64 * block.expansion)
        self.select = channel_selection(64 * block.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(8)

        # model classifier
        if dataset == 'cifar10':
            num_classes = 10
        elif dataset == 'cifar100':
            num_classes = 100
        else:
            raise ValueError('Model `dataset` parameter is Error!')
        self.fc = nn.Linear(cfg[-1], num_classes)

        # model initialize weight
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(0.5)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, cfg, stride=1):
        """
        :param block: Bottleneck item
        :param planes: record the layer's output channel size
                        planes * expansion = output_planes_size,
                        like 16 * 4 = 128 (the first block output)
        :param blocks: number of Bottleneck in layer
        :param cfg:  channel config of all blocks
                        3 cfg items / Bottleneck
        :param stride: default = 1
        """
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False))
        layers = [block(self.inplanes, planes, cfg[0:3], stride, downsample)]
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, cfg[3 * i: 3 * (i + 1)]))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)

        block_value = []
        if self.block_cfg:
            for idx, item in enumerate(self.layer1):
                x = item(x)
                if idx + 1 in self.block_cfg:
                    block_value.append(x)
            for idx, item in enumerate(self.layer2):
                x = item(x)
                if idx + 1 in self.block_cfg:
                    block_value.append(x)
            for idx, item in enumerate(self.layer3):
                x = item(x)
                if idx + 1 in self.block_cfg:
                    block_value.append(x)
        else:
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)

        x = self.bn(x)
        x = self.select(x)
        x = self.relu(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        y = self.fc(x)

        if len(block_value):
            return y, block_value
        return y


In [4]:
import os

def resume_model(resume_file):
    if not os.path.isfile(resume_file):
        raise ValueError("Resume model file is not found at '{}'".format(resume_file))
    print("=> loading checkpoint '{}'".format(resume_file))
    checkpoint = torch.load(resume_file)
    if 'epoch' in checkpoint:
        start_epoch = checkpoint['epoch']
    else:
        start_epoch = None
        
    if 'best_prec1' in checkpoint:
        best_prec1 = checkpoint['best_prec1']
    else:
        best_prec1 = None
        
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = None
    
    if 'optimizer' in checkpoint:
        opti_dict = checkpoint['optimizer']
    else:
        opti_dict = None
        
    if 'cfg' in checkpoint:
        cfg = checkpoint['cfg']
        print("-> model cfg is loading...\n cfg: {}".format(list(cfg)))
    else:
        cfg = None
        print("-> not found model cfg...")
    print("=>  epoch {} Prec1: {}".format(start_epoch, best_prec1))
    return state_dict, opti_dict, start_epoch, best_prec1, cfg


data_path = {
    'root': r'D:\Project\Pycharm\network-slimming\logs',
    'root2': r'D:\Project\Gitee\network-slimming\logs',
    'bn': [
        'bn_prune_resnet164_cifar10_percent_0.4',
        'bn_prune_resnet164_cifar10_percent_0.6',
        'bn_prune_resnet164_cifar100_percent_0.4',
        'bn_prune_resnet164_cifar100_percent_0.6',
    ],
    'at': [
         'at_prune_resnet164_cifar10_percent_0.4',
         'at_prune_resnet164_cifar10_percent_0.6',
         'at_prune_resnet164_cifar100_percent_0.4',
         'at_prune_resnet164_cifar100_percent_0.6',
    ],
    'flie': 'model_best.pt',
    'file2': 'model_best.pth.tar',
    'file3': 'pruned.pt',
    'file4': 'pruned.pth.tar'
}

file_path = os.path.join(data_path['root2'], data_path['at'][2], data_path['file3'])
state_dict, opti_dict, start_epoch, best_prec1, cfg = resume_model(file_path)

=> loading checkpoint 'D:\Project\Gitee\network-slimming\logs\at_prune_resnet164_cifar100_percent_0.4\pruned.pt'
-> model cfg is loading...
 cfg: [5, 13, 15, 13, 14, 15, 20, 15, 15, 13, 12, 13, 16, 13, 14, 16, 15, 14, 18, 13, 13, 21, 14, 16, 14, 12, 15, 23, 15, 14, 14, 13, 13, 16, 14, 16, 2, 8, 13, 7, 14, 14, 19, 13, 14, 15, 15, 16, 3, 8, 10, 10, 14, 16, 17, 26, 32, 43, 32, 32, 39, 31, 31, 46, 30, 32, 46, 31, 32, 42, 32, 32, 49, 28, 32, 45, 32, 32, 50, 30, 30, 55, 32, 32, 60, 30, 32, 44, 29, 31, 52, 29, 32, 46, 32, 31, 57, 32, 32, 57, 32, 32, 50, 31, 32, 52, 31, 31, 78, 60, 64, 107, 63, 64, 117, 64, 63, 122, 63, 64, 118, 63, 64, 110, 63, 64, 113, 63, 64, 121, 64, 64, 121, 63, 64, 127, 64, 64, 119, 64, 64, 131, 64, 64, 130, 64, 64, 130, 64, 64, 120, 64, 64, 119, 64, 64, 119, 64, 64, 109, 64, 64, 157]
=>  epoch None Prec1: None


In [5]:
model = resnet(depth=164, dataset='cifar100', cfg=cfg, conv_cfg=None)
model.load_state_dict(state_dict)
model.eval()

resnet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection()
      (conv1): Conv2d(5, 13, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(13, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(13, 15, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn3): BatchNorm2d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(15, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): Bottleneck(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_select

In [6]:
from torchstat import stat

stat(model, (3, 32, 32))

[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: ch

In [11]:
import torch
from torchvision import datasets, transforms

kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./data.cifar100', train=True, download=True, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=32, shuffle=True, **kwargs)

Files already downloaded and verified


In [12]:
for k, (data, target) in enumerate(test_loader):
    print(k, data.shape, target)

0 torch.Size([32, 3, 32, 32]) tensor([43, 58,  8, 44, 15, 96, 89, 93, 12, 48, 16, 35, 49,  3, 31, 72, 39, 80,
        20, 70, 22, 27, 91, 10, 88, 12, 82, 29, 54, 13, 89, 96])
1 torch.Size([32, 3, 32, 32]) tensor([79, 88, 72, 66, 91, 64,  3, 30, 47, 13,  7, 17,  5, 13, 66, 69, 37,  7,
        25,  6, 52, 28, 17, 85, 82, 58, 78,  5, 34, 60,  8, 40])
2 torch.Size([32, 3, 32, 32]) tensor([20, 92, 92, 41, 82, 90, 93,  1,  2, 71,  4, 68, 95, 27, 81, 63, 45,  1,
        59, 88, 12, 62, 93, 64,  6, 59, 13,  5, 38, 26, 30, 12])
3 torch.Size([32, 3, 32, 32]) tensor([28, 70, 46, 54, 28, 90, 72, 29, 76, 13,  7, 89, 75, 32, 28, 21, 84, 23,
        49, 27, 69, 18, 29, 60, 94, 40, 51, 49, 80, 17, 27, 76])
4 torch.Size([32, 3, 32, 32]) tensor([14, 57, 28, 69,  0, 24, 36, 36, 90, 78, 54, 21, 74, 46, 67, 91, 82, 54,
        75, 38, 20, 35,  8, 75, 91, 33, 66, 54, 50, 26,  2, 33])
5 torch.Size([32, 3, 32, 32]) tensor([64, 10, 91, 96, 87, 55, 68, 73, 19, 36, 16, 42, 28, 97, 17, 35, 31, 10,
        17, 78,

51 torch.Size([32, 3, 32, 32]) tensor([62, 57, 93, 76, 34, 76,  4, 23, 62, 88, 94, 13, 75, 17, 32, 30, 49, 62,
         1, 27, 48, 87, 55, 64, 42, 10, 17, 62, 79,  4, 87, 63])
52 torch.Size([32, 3, 32, 32]) tensor([49, 46, 61, 62, 73, 74,  2, 53, 58, 56, 40, 95,  2, 25, 84, 11, 52, 66,
        89, 64, 82, 21, 17, 15, 89, 43, 12, 50, 64, 88, 94, 59])
53 torch.Size([32, 3, 32, 32]) tensor([79, 48,  2, 13,  7, 35, 39, 82, 19, 24, 67,  9, 64, 31, 50, 47, 58,  6,
         9, 75,  2, 99, 48, 59, 16, 71, 51, 68, 39, 64,  2, 75])
54 torch.Size([32, 3, 32, 32]) tensor([54, 80, 29, 37,  4, 57, 95, 75, 71,  4,  2,  7, 49, 65, 64, 28, 98, 11,
        74, 89, 37,  6, 57, 67, 18, 29, 38, 27, 92, 71,  8, 85])
55 torch.Size([32, 3, 32, 32]) tensor([44, 12, 65,  3, 60, 44, 48, 51, 74,  8, 85, 68, 74, 79, 24, 68,  5, 55,
        43, 27, 20, 38, 48, 28, 59, 14, 73, 31, 27, 16, 10, 54])
56 torch.Size([32, 3, 32, 32]) tensor([15, 47, 95, 12, 29, 32,  7, 20, 79, 63, 38, 90, 34, 60, 51,  4, 78,  1,
        3

105 torch.Size([32, 3, 32, 32]) tensor([62, 10, 40, 92, 87, 28, 90, 28, 64, 46, 31, 87, 45, 43, 35, 93, 33, 97,
        40, 64, 35, 31, 37, 84, 25, 46,  4, 52, 36, 40, 67, 42])
106 torch.Size([32, 3, 32, 32]) tensor([69, 33, 31,  2, 92, 19, 31, 55, 70, 10, 62, 50, 95, 15, 26, 38, 74,  5,
        58, 17, 91, 81, 89, 12,  6, 19,  5, 88, 60, 46, 61, 88])
107 torch.Size([32, 3, 32, 32]) tensor([22, 12, 33, 89, 92, 24, 80, 30, 71, 43, 84, 62, 76, 83,  0, 21, 37, 30,
        39, 53, 24, 53, 34,  7, 30, 83, 62, 11,  9,  2, 26, 22])
108 torch.Size([32, 3, 32, 32]) tensor([75, 10,  2, 94, 50, 67, 99, 55, 15, 53, 15, 62, 70, 43, 23, 40, 97, 64,
        12, 89, 34, 64,  2, 14, 70, 78, 51, 13, 23, 41, 55, 34])
109 torch.Size([32, 3, 32, 32]) tensor([89, 76, 97, 25, 50, 64, 64, 68,  3, 57, 15, 90, 55, 45, 89, 45, 46, 82,
        16, 75, 83, 70, 66, 46,  6, 24, 75, 33, 48, 15, 25, 37])
110 torch.Size([32, 3, 32, 32]) tensor([45, 76, 86, 41, 60,  1, 44, 43, 61, 61, 47, 34, 84, 86, 40, 65, 23, 13,
   

159 torch.Size([32, 3, 32, 32]) tensor([25,  7, 25, 98, 39, 97, 97, 75, 83, 27, 53, 17, 22, 23, 27, 66, 23, 37,
        38, 32,  0, 99, 47, 97, 10,  9, 73, 62, 81, 76, 87, 71])
160 torch.Size([32, 3, 32, 32]) tensor([55, 49, 47, 74, 54, 82, 57, 61, 84, 35, 83, 27, 28,  5,  8, 87, 40, 57,
        34, 15, 66, 34, 11, 17, 92, 44, 73,  1, 74, 45, 45, 35])
161 torch.Size([32, 3, 32, 32]) tensor([67, 10,  3,  9, 16, 90, 55, 79, 55, 59, 61,  0, 76, 14, 76, 64, 43, 22,
        17, 26, 98, 98, 52, 50, 87, 44, 32, 51, 19, 46, 24, 20])
162 torch.Size([32, 3, 32, 32]) tensor([16, 66, 42, 63, 30, 75, 72, 47, 97,  5, 50, 85, 66, 10, 99, 16, 11, 63,
        55,  0, 19, 98, 68, 95,  3, 48, 16, 34, 42, 88, 74, 43])
163 torch.Size([32, 3, 32, 32]) tensor([20,  9, 87, 33, 44, 18, 71, 65, 33, 50, 84, 97,  8, 33, 65,  4,  0, 51,
        81, 48, 70, 94, 20, 19, 37, 84, 71, 66, 31, 87, 43, 67])
164 torch.Size([32, 3, 32, 32]) tensor([88, 18, 59, 60, 57, 60, 56, 78, 95, 18, 77, 47, 93, 89, 57, 37, 74, 75,
   

        11,  0, 27, 47, 57, 44, 86, 35, 82,  0, 21, 26, 93, 14])
217 torch.Size([32, 3, 32, 32]) tensor([ 2, 31, 77, 37, 76, 51, 43, 23, 40, 28, 25, 53, 53, 98, 62,  4, 88,  7,
        72, 40,  7, 21, 54, 27,  5, 33, 61, 39, 46, 46, 23, 82])
218 torch.Size([32, 3, 32, 32]) tensor([40,  8, 11, 26,  3, 47, 69, 68, 92,  2, 70, 35, 51, 20, 19, 72, 92, 75,
        76, 65,  2, 96, 49, 84, 14, 88, 52, 39, 45,  6,  2, 90])
219 torch.Size([32, 3, 32, 32]) tensor([66, 48, 78, 11, 97, 87, 63, 19, 60, 82,  2, 50, 98,  8, 47, 57, 15, 74,
        43, 49, 22, 69, 71, 12, 51, 71, 47, 58, 98, 31, 91, 68])
220 torch.Size([32, 3, 32, 32]) tensor([71, 64, 59,  5, 94, 85, 21, 93, 17, 38, 38, 36, 41, 98, 20,  6, 96, 97,
        80, 24, 68, 33, 26,  5, 86, 32, 55, 64, 81, 82,  8, 22])
221 torch.Size([32, 3, 32, 32]) tensor([71,  3, 70, 88, 41, 54, 14, 45, 53, 89, 38, 73, 40, 76, 66, 88, 12, 86,
         9,  7, 47, 47, 85,  7, 76, 13, 28, 94, 14, 47, 37, 63])
222 torch.Size([32, 3, 32, 32]) tensor([ 0, 15,  1

267 torch.Size([32, 3, 32, 32]) tensor([96, 54, 91, 85, 82, 61, 26, 30, 13, 54, 12, 65, 18, 74, 28, 93, 46, 20,
        33, 28, 30, 84, 66, 83, 67, 51, 63, 75, 35,  4, 30, 24])
268 torch.Size([32, 3, 32, 32]) tensor([65, 15, 43, 90, 81, 45, 53, 56, 40, 46,  8, 15,  2,  8, 74, 14, 10, 44,
        17, 82,  2, 25, 95, 25, 59, 83, 74, 56, 49, 35, 41, 32])
269 torch.Size([32, 3, 32, 32]) tensor([64, 70, 83, 10, 18, 25, 86, 95,  1, 59, 61, 45, 44, 86, 96, 96,  5, 13,
        95, 20, 40, 43, 43, 96, 87, 13, 72, 14, 28, 42, 61, 83])
270 torch.Size([32, 3, 32, 32]) tensor([24,  4, 28, 27, 67, 40, 77, 31, 79, 80,  0, 94, 84, 11,  6, 80, 44,  8,
        71,  5, 57, 63, 24, 32, 22, 51, 22,  9, 22, 51, 66, 63])
271 torch.Size([32, 3, 32, 32]) tensor([ 4, 32, 32, 52, 37, 37, 61, 27,  8, 80, 41, 94, 66, 11, 95, 82, 65, 95,
        30, 49, 52, 96, 59, 48, 99, 66, 92, 46, 11,  8,  8, 41])
272 torch.Size([32, 3, 32, 32]) tensor([67, 27, 17, 91, 89, 82, 40, 55, 91, 60, 12, 41, 69, 80, 14,  0, 58, 60,
   

318 torch.Size([32, 3, 32, 32]) tensor([84, 86, 56, 55, 11, 80, 69, 45, 80, 74,  5, 37, 36, 71, 86, 63, 65, 55,
        54, 77, 50, 20, 81, 18, 72, 77, 47,  8, 49, 79, 38,  7])
319 torch.Size([32, 3, 32, 32]) tensor([23,  4, 39, 85,  0, 88, 43, 48,  6, 36, 13, 88, 67, 70, 15, 65, 86, 39,
        26, 64, 60,  9, 33, 42, 74, 77, 58, 36, 49, 26, 73, 89])
320 torch.Size([32, 3, 32, 32]) tensor([72, 10,  8, 88, 73, 95, 57, 25, 11,  3, 93, 28, 73, 94,  8, 42, 85,  6,
        14, 33, 30, 31, 10,  3, 52, 29, 74, 16, 77, 99, 67, 75])
321 torch.Size([32, 3, 32, 32]) tensor([50, 51, 76, 87, 82, 51,  4, 71, 87, 22, 45, 24,  0, 87, 73, 81, 66, 27,
        18, 28, 67, 22, 38, 20, 71, 58, 15, 28, 82, 21, 83,  5])
322 torch.Size([32, 3, 32, 32]) tensor([85, 16,  8, 37, 55, 93, 30,  9, 25, 34, 50, 28, 17, 18, 84, 32, 24, 44,
        29, 29, 85, 20,  1, 82, 62, 33, 66,  2, 93, 74, 55, 12])
323 torch.Size([32, 3, 32, 32]) tensor([50, 21, 23, 98,  7, 69, 42, 98, 27, 33, 75, 52, 40, 81, 90, 30, 97, 30,
   

370 torch.Size([32, 3, 32, 32]) tensor([46, 99, 49, 92, 31, 41,  6, 34, 87, 54, 69, 62, 67, 49, 14, 84, 60, 13,
        69, 20, 57, 98, 46, 59, 75, 96, 27, 65, 49, 55, 93, 14])
371 torch.Size([32, 3, 32, 32]) tensor([71, 90, 75, 91, 99, 64, 52, 69, 67, 12,  6, 46,  9,  9, 15, 78, 51, 25,
        57, 21, 82, 31, 88,  2, 93, 63, 64, 83, 73, 40, 25, 95])
372 torch.Size([32, 3, 32, 32]) tensor([72, 46, 90, 73, 63, 83, 85, 36, 56, 85,  4, 70, 76, 54, 55, 75, 70, 17,
        58, 42, 87, 55, 26, 70, 39, 87, 85, 31, 28, 45, 12, 63])
373 torch.Size([32, 3, 32, 32]) tensor([74, 83, 53,  0, 35, 33, 97, 47, 89,  8, 45, 71, 38, 77, 37, 13, 35, 63,
        99, 70, 52, 85,  1, 99, 44, 93, 50, 75, 17, 25,  9, 90])
374 torch.Size([32, 3, 32, 32]) tensor([91, 48, 39, 55, 77, 49, 51, 24, 38, 21, 10,  7, 82, 16, 43, 35, 25, 75,
        89, 64, 27, 44, 94, 48, 91,  8, 51, 79, 24, 68, 38, 53])
375 torch.Size([32, 3, 32, 32]) tensor([52, 10, 51, 53, 40, 53, 44, 48, 86, 75, 13, 60, 80, 52, 69, 64, 58, 39,
   

425 torch.Size([32, 3, 32, 32]) tensor([24, 41, 95, 75, 77, 63, 67, 37, 71, 92,  9, 88, 50, 24, 43, 40, 14, 59,
        76, 20,  1, 46, 97, 81, 10, 72,  0, 53, 37, 75, 56, 41])
426 torch.Size([32, 3, 32, 32]) tensor([74, 61, 85, 32,  2, 77, 21, 14, 37, 13, 89, 62, 61, 43, 57, 18, 76, 58,
        56,  1, 25, 68, 32, 53, 97,  8, 82, 72,  4, 37, 22, 92])
427 torch.Size([32, 3, 32, 32]) tensor([68, 24, 22, 40, 80, 13, 11, 75, 21, 72, 59, 60, 39, 94, 33, 27, 30,  6,
        99, 37, 96, 36, 23, 81, 26, 40, 52, 31, 12,  5, 10, 31])
428 torch.Size([32, 3, 32, 32]) tensor([95, 55,  2, 99, 49, 44, 95, 88, 26, 37, 21, 42,  7, 17, 16, 91, 87, 90,
        49,  1, 22, 70,  3,  3, 15, 97, 80, 98, 52, 79, 59, 84])
429 torch.Size([32, 3, 32, 32]) tensor([12, 75,  1, 76, 30, 41, 42, 75, 40, 23, 31,  0, 28, 55, 10, 24, 93, 97,
        52, 54, 68, 31, 63, 21, 79, 15, 46, 28, 49, 87, 45, 42])
430 torch.Size([32, 3, 32, 32]) tensor([45, 75, 50,  4, 45, 23, 19, 93, 94, 53, 56, 92,  9, 47, 91, 10, 55,  2,
   

479 torch.Size([32, 3, 32, 32]) tensor([71, 70, 23, 35, 23,  2, 81, 19, 22, 51, 83, 54, 62, 91, 73, 29, 28, 70,
        13, 30, 45, 14, 56, 22, 71, 38, 65,  2, 30, 81, 62, 79])
480 torch.Size([32, 3, 32, 32]) tensor([98, 46, 44, 47, 90, 54,  2, 21, 16, 70, 95, 13, 75, 62, 83, 64, 28, 50,
        57, 38, 81, 35, 64, 29, 45, 87, 47, 77, 79, 99,  8, 58])
481 torch.Size([32, 3, 32, 32]) tensor([41, 51, 41,  2, 93, 54, 22, 19,  3, 98, 12, 75, 16, 61, 57, 59, 56, 23,
        64, 63, 25, 28, 42, 88, 74, 14, 23,  3, 22,  2,  3, 18])
482 torch.Size([32, 3, 32, 32]) tensor([81, 85, 94, 53, 71, 62, 95, 90, 32, 80, 76, 72, 13, 53,  9, 75, 11, 27,
        62, 26, 53, 49, 54, 29,  9,  3,  0, 67, 25, 45, 31, 95])
483 torch.Size([32, 3, 32, 32]) tensor([45, 96, 20, 95, 40, 90, 11, 13, 69, 81, 43, 97, 38, 70, 97, 97,  6, 68,
        46, 70, 28, 27, 37, 25, 51, 13, 10, 64, 95, 10,  5,  4])
484 torch.Size([32, 3, 32, 32]) tensor([ 6, 62, 69, 94, 54, 49, 80, 23, 98, 76, 13,  0, 15, 15, 11, 25,  2,  3,
   

534 torch.Size([32, 3, 32, 32]) tensor([96, 28, 42, 24, 49, 18, 81, 55, 11, 67, 30, 60, 95, 15, 21, 97, 17, 57,
        39, 25, 60,  2, 96, 43, 63, 12, 78, 24, 90, 75, 40, 67])
535 torch.Size([32, 3, 32, 32]) tensor([63, 84, 80, 91, 49, 31, 68, 32, 47, 21, 53, 80, 87, 13, 92, 58, 74, 87,
        39, 23, 80, 44, 53, 19, 16,  7, 26, 62,  5,  6, 57, 16])
536 torch.Size([32, 3, 32, 32]) tensor([57, 50, 68, 83, 60, 26, 39, 20, 87, 53, 31, 19, 49, 84, 81, 90, 32, 36,
        78, 98, 75, 73, 53, 86, 84, 82, 35, 13, 44, 82, 79, 69])
537 torch.Size([32, 3, 32, 32]) tensor([30, 99, 10, 72,  9, 81, 77, 80, 97, 72, 39, 66, 84,  6, 67,  1, 16, 34,
        18, 98, 42, 52, 74, 78, 56, 68, 72, 74, 19, 37, 44, 55])
538 torch.Size([32, 3, 32, 32]) tensor([16, 63, 40, 10, 53,  1, 27, 68, 77, 35, 61, 51, 65, 13, 78, 78, 90, 85,
        53,  1, 39, 94, 70, 46, 54, 35, 82, 26, 87,  9, 40, 86])
539 torch.Size([32, 3, 32, 32]) tensor([96,  6, 55, 99, 11, 24, 95, 76,  1, 11, 50,  7, 70, 40, 32, 46, 57, 77,
   

588 torch.Size([32, 3, 32, 32]) tensor([ 8, 99, 56, 15, 16, 10, 98, 32,  2, 57, 38, 17, 45, 58, 52, 68, 30, 31,
        58, 28, 24, 67, 24, 23, 34, 88, 39, 35,  2, 53, 92, 43])
589 torch.Size([32, 3, 32, 32]) tensor([ 2, 28, 51, 92, 20, 92, 52, 54, 79, 68, 57, 72, 29, 27, 61, 83, 14, 33,
        37, 64, 27, 17, 95, 65,  6, 78, 41, 11, 89, 47, 90, 53])
590 torch.Size([32, 3, 32, 32]) tensor([77, 24, 93, 43, 18, 19, 88, 70, 50, 54, 52, 74, 63, 31, 23, 50, 65, 55,
        68, 72, 97, 57, 29, 82, 63,  2, 72, 79, 95, 49, 55, 39])
591 torch.Size([32, 3, 32, 32]) tensor([70, 96, 63, 71, 36, 75, 47,  3, 69, 26, 27, 51, 65, 49,  5, 40, 80, 93,
        16, 99, 91, 99,  4, 14, 67, 72,  3, 38, 37, 23, 10,  7])
592 torch.Size([32, 3, 32, 32]) tensor([11, 78, 95, 96, 56, 20, 68, 47, 62, 57, 55, 65, 79, 29, 61, 11, 21, 40,
        64, 88, 24, 67, 88, 59,  9, 64, 76, 79, 84, 50, 77, 35])
593 torch.Size([32, 3, 32, 32]) tensor([93, 97, 75, 47, 89, 22, 84, 61, 22, 62,  5, 40, 29, 54, 65, 37, 59, 53,
   

643 torch.Size([32, 3, 32, 32]) tensor([15, 94, 29, 53, 52, 48, 43, 50, 94, 70, 14, 58, 98, 95, 36, 17, 82, 33,
        81, 22, 62,  2, 88, 25,  4, 14, 44, 28, 35,  4, 65, 14])
644 torch.Size([32, 3, 32, 32]) tensor([79, 82, 14, 62, 46, 10, 97, 56, 36, 40, 34, 60, 66, 10, 49, 88, 11, 16,
        88, 82, 73, 98, 91, 57, 51, 41, 87, 78, 44, 37, 24, 73])
645 torch.Size([32, 3, 32, 32]) tensor([67, 89,  5, 12, 60, 53, 10,  5,  2,  7, 70, 28, 27, 67, 78, 96,  7, 60,
        42, 59, 75, 49,  4, 22, 97, 53, 30, 70, 32, 83, 88, 44])
646 torch.Size([32, 3, 32, 32]) tensor([87, 46, 41, 33, 81, 53, 52, 97, 82, 14, 49, 82, 40, 31, 54, 85,  9, 35,
        25, 55, 26, 92, 41, 76,  9, 50, 21, 36, 81, 61,  7, 76])
647 torch.Size([32, 3, 32, 32]) tensor([70, 65, 26, 95, 22, 43, 56, 64, 23, 94, 27, 16, 43, 16, 73,  3, 76, 80,
        15, 45, 48, 71, 81, 94, 35, 82, 63, 60, 45,  5,  6, 20])
648 torch.Size([32, 3, 32, 32]) tensor([30, 33,  6, 83,  9, 32, 23, 62, 23,  2, 22, 23, 89,  4, 60, 55, 71, 98,
   

        50,  8, 58, 82, 59, 44, 12,  7, 37, 70, 91, 90, 98, 80])
697 torch.Size([32, 3, 32, 32]) tensor([90, 20, 23, 15, 94,  3, 38, 94, 48, 87,  3, 99, 46, 83, 13, 69, 53, 78,
        74, 34, 90, 45,  9, 94, 47,  8, 79, 44, 54, 98, 59, 20])
698 torch.Size([32, 3, 32, 32]) tensor([ 6,  3, 57, 77, 58, 67, 33,  3,  4, 88, 15, 14, 44, 46, 38, 38, 16, 15,
        80,  6, 77, 77, 75, 28, 57, 91, 33, 56,  1, 67, 59, 11])
699 torch.Size([32, 3, 32, 32]) tensor([46, 42,  3, 91, 11, 62, 59, 96, 79, 51, 61, 64, 20,  4, 14, 77, 86, 48,
        13, 54, 67, 36, 52,  3, 66, 54, 49, 17, 41, 26, 87, 15])
700 torch.Size([32, 3, 32, 32]) tensor([87,  6, 69, 85, 32, 50,  3, 36, 50, 55, 97, 48, 99, 46, 52, 32, 63, 60,
        57, 14, 54, 24, 25, 30, 22, 51,  7, 98, 38, 43, 56, 51])
701 torch.Size([32, 3, 32, 32]) tensor([64, 76, 48, 65, 69, 32, 33, 95, 58, 96, 87, 64, 80, 50, 12, 97, 10, 67,
        83, 91, 13, 56,  5, 41, 38, 43,  2, 19, 34, 27, 13,  2])
702 torch.Size([32, 3, 32, 32]) tensor([26, 33, 30

750 torch.Size([32, 3, 32, 32]) tensor([44, 41, 20, 17, 27, 95, 76, 88, 90, 52, 71,  9, 52, 23, 17, 99, 94, 33,
        79, 47, 68, 40, 54, 11, 42, 19, 82,  7,  5, 87,  7,  1])
751 torch.Size([32, 3, 32, 32]) tensor([44, 77, 33, 67, 73, 46, 50, 91, 10, 36, 50, 66, 23, 10, 88, 49,  5,  7,
        59, 18, 84, 56, 55, 67, 93,  6, 83, 35, 76, 78,  5, 88])
752 torch.Size([32, 3, 32, 32]) tensor([32, 68, 48, 34, 96, 75,  2, 61, 56, 83, 32, 33, 44, 34, 10, 38, 39, 92,
         0, 14, 26, 43, 94, 84, 79, 56, 36, 25, 51, 42, 21, 34])
753 torch.Size([32, 3, 32, 32]) tensor([74, 67, 33, 12, 82, 38, 47, 62,  1, 89, 97, 17, 81, 51, 12, 77, 93, 34,
         4, 87, 52, 12, 57, 22, 35, 47, 39, 55, 98, 25, 72,  4])
754 torch.Size([32, 3, 32, 32]) tensor([19, 47, 13, 59, 78, 89, 42,  3, 73, 24, 50, 60, 39, 32, 29,  8, 96, 50,
        56,  4, 27, 95, 25, 86, 59, 27,  9, 47,  8, 73, 32, 43])
755 torch.Size([32, 3, 32, 32]) tensor([91, 26, 83, 59, 98, 29, 85,  7, 64,  1, 12, 79, 57, 57, 79, 52, 74, 59,
   

803 torch.Size([32, 3, 32, 32]) tensor([99, 77, 67, 24, 18, 79, 17, 27,  3, 32, 29, 64, 77, 94, 73, 76, 52,  9,
        47, 26, 13, 57,  1, 55, 61, 35, 38, 66, 95, 61, 35, 62])
804 torch.Size([32, 3, 32, 32]) tensor([79, 27, 66, 70, 64, 70, 46,  3, 77, 75, 36, 60, 10, 91, 45, 74, 17, 89,
        69, 16, 43, 92, 39, 49, 77, 65, 50, 87, 46, 74, 57, 93])
805 torch.Size([32, 3, 32, 32]) tensor([38, 71, 87, 40, 98, 33, 64, 58, 40, 77,  4, 78,  1, 61, 22, 28, 18,  4,
        17,  2, 61, 48, 66, 58, 33, 64, 13, 40, 53, 96, 66, 33])
806 torch.Size([32, 3, 32, 32]) tensor([97, 50, 10, 48, 50, 91, 45, 36, 24, 86, 85,  8, 54, 55, 11,  4, 87, 44,
        78,  5, 51,  2,  5, 68, 47,  0, 85, 89,  8, 30, 80, 38])
807 torch.Size([32, 3, 32, 32]) tensor([87, 12, 48, 82, 25,  7, 55, 30, 72, 15, 16, 99, 90, 96, 62,  3,  8, 67,
        98, 97, 68, 87, 54, 56, 79, 67, 61, 32, 85, 23, 59, 24])
808 torch.Size([32, 3, 32, 32]) tensor([98, 50, 46, 79, 73, 85, 82, 83, 89, 34, 66,  4, 84, 94, 76, 38, 57, 15,
   

857 torch.Size([32, 3, 32, 32]) tensor([15, 96, 54, 21, 40, 66, 81, 41, 30,  3, 56, 37, 79, 80, 93,  0, 32, 33,
        71, 26, 64, 52, 29, 76,  3, 52, 81, 88, 52, 67, 43, 35])
858 torch.Size([32, 3, 32, 32]) tensor([85, 88,  7, 55, 88, 80, 57, 48, 96, 12, 10, 93, 34, 49, 14, 76, 64, 98,
        86, 18, 13, 93, 10, 59, 99, 82, 81, 59, 13, 62, 93, 15])
859 torch.Size([32, 3, 32, 32]) tensor([65, 58, 59, 40, 65, 46, 12, 53, 88, 83, 21, 50, 97, 34, 10, 69, 62, 56,
        43, 25, 61, 39, 56, 14, 45, 55, 60, 86, 66, 60, 50, 76])
860 torch.Size([32, 3, 32, 32]) tensor([48, 92, 42, 90, 28, 46, 61, 26,  7, 86, 96, 14, 35, 63, 41, 54, 83, 25,
        96, 15, 24, 33, 19,  3, 74, 43, 72, 66, 88, 18, 34, 55])
861 torch.Size([32, 3, 32, 32]) tensor([71, 40, 59, 40, 90, 73, 48, 89, 31, 65, 75, 42, 63, 55, 75, 66, 58, 63,
        75, 34, 83, 19, 63, 72, 92, 15, 79, 26, 24, 53, 64, 71])
862 torch.Size([32, 3, 32, 32]) tensor([86, 33, 32, 31, 58, 34, 64, 31, 20, 62, 36, 78, 33, 26, 36, 17, 70, 81,
   

910 torch.Size([32, 3, 32, 32]) tensor([20, 13, 54, 42, 86, 23, 90, 50, 84, 28, 69, 54, 19, 15, 54, 68, 19, 32,
        12, 47, 96, 99,  1, 66, 46, 93,  9, 95, 95, 12, 47, 71])
911 torch.Size([32, 3, 32, 32]) tensor([92, 12, 98, 33, 95, 76, 47, 26, 27, 57, 86, 30, 24, 81, 89, 93, 32, 55,
        19, 12, 41, 94, 43,  9, 42,  6, 56, 79, 19, 73, 77,  5])
912 torch.Size([32, 3, 32, 32]) tensor([57, 91, 94, 70, 73, 97, 33, 69, 48, 18, 28, 36, 47, 58,  4, 52, 12, 51,
        90, 19, 14, 35, 62, 11, 21, 38, 42, 65, 75, 19, 74, 85])
913 torch.Size([32, 3, 32, 32]) tensor([29, 79, 44,  2, 90, 35, 39, 51,  3,  4,  8, 90, 40, 72, 51,  5, 40, 66,
        46, 87, 77, 62, 83, 15, 37, 14, 23, 54, 35, 43, 56, 39])
914 torch.Size([32, 3, 32, 32]) tensor([17, 87, 25, 94, 73, 74, 95, 36, 76, 55, 64, 15, 93, 74, 20, 58, 72, 28,
         2, 96, 33,  8,  0, 16, 42, 73, 24,  2, 60, 75, 64, 21])
915 torch.Size([32, 3, 32, 32]) tensor([ 5, 72, 82, 61, 91, 28, 21, 18,  7, 43, 55, 88, 91, 74, 15, 47, 76, 13,
   

965 torch.Size([32, 3, 32, 32]) tensor([63, 32, 26, 78, 23, 13, 56,  6, 33, 10, 65,  8, 38,  6, 59, 88, 20, 26,
        95, 80, 23, 80, 74, 70, 30,  2, 42, 13, 33,  7, 47,  9])
966 torch.Size([32, 3, 32, 32]) tensor([25, 37, 38, 20, 51, 25, 96, 65, 65, 78, 39, 80, 65, 11, 41, 24, 63, 56,
         2, 71, 27, 55, 61, 47, 48, 91,  8, 14, 53, 95, 61, 44])
967 torch.Size([32, 3, 32, 32]) tensor([36, 92, 53, 90, 45, 76, 48, 78, 59, 95, 25,  9, 28, 37, 31, 33, 85, 44,
        52, 37,  1, 26, 33, 84, 76, 35,  9, 84, 62, 19, 45, 36])
968 torch.Size([32, 3, 32, 32]) tensor([28, 27, 50, 17, 84,  6, 11, 18, 62, 79, 73, 91, 58, 12, 93, 76, 69, 34,
        20, 28, 13, 46, 23, 86, 15, 80, 71, 94, 46, 71, 90, 90])
969 torch.Size([32, 3, 32, 32]) tensor([12, 98, 76, 58, 59, 57, 58, 35, 75, 51,  7, 65, 66, 71, 54, 79, 91, 15,
        66, 18, 70, 84, 31, 92, 13, 37, 56, 17, 44, 64, 50, 10])
970 torch.Size([32, 3, 32, 32]) tensor([62, 30, 96, 17, 53, 22, 92, 75, 94,  4, 68, 58, 44, 17, 29,  7, 15, 29,
   

1019 torch.Size([32, 3, 32, 32]) tensor([ 7, 36, 83, 45, 94, 61, 97,  9, 41, 51, 92, 54, 52, 57, 70, 56,  7, 65,
        86, 97, 45, 95, 52,  1, 57, 11, 21, 45, 74, 40, 38, 85])
1020 torch.Size([32, 3, 32, 32]) tensor([ 3, 85, 99, 96, 95, 72, 82, 91, 64, 80, 54, 72, 90, 88, 41, 39, 57,  4,
        86, 84,  7, 39, 57, 25, 56, 28, 21, 28, 31, 22, 67,  4])
1021 torch.Size([32, 3, 32, 32]) tensor([20, 21, 41, 13, 65,  3, 98, 92, 53, 37, 10, 30, 54,  1, 40, 18, 25, 94,
        42, 82, 10, 31, 37,  6, 16, 64, 67, 53, 42, 36, 95,  0])
1022 torch.Size([32, 3, 32, 32]) tensor([99,  8, 65, 95, 35, 97,  5, 68, 46, 79, 85, 45, 90, 88, 66, 89, 46, 13,
        58, 38, 81, 75, 47, 90, 54, 93, 97, 32, 24, 87, 35, 10])
1023 torch.Size([32, 3, 32, 32]) tensor([99, 80, 12, 76,  7, 15, 62,  7, 83, 49, 52, 66, 39, 27, 87, 88, 95, 65,
        56, 37, 82, 86, 87,  1, 58, 51, 88, 46, 31, 88, 62,  5])
1024 torch.Size([32, 3, 32, 32]) tensor([41, 76, 85, 32, 50, 56, 34,  4, 13, 85, 30, 67, 63, 86, 68, 42, 28,  

1074 torch.Size([32, 3, 32, 32]) tensor([33, 69, 33, 25, 75, 13, 48, 46, 11, 41, 37, 46,  0, 77,  4, 36, 18, 92,
        31, 27, 17, 93, 18, 33, 49, 29, 83, 81, 54, 20, 37, 69])
1075 torch.Size([32, 3, 32, 32]) tensor([62, 10, 14, 60, 29, 66, 58, 81, 29, 37, 23, 98, 77, 10,  7, 99, 57, 45,
        47, 28, 96, 96, 91, 75, 90, 54, 51, 66, 59, 29, 74, 59])
1076 torch.Size([32, 3, 32, 32]) tensor([61, 27, 55, 67, 27, 95, 21, 58, 62, 90, 82, 99, 60, 82, 96, 74,  6, 96,
        26, 59, 69, 74, 51,  5, 53, 70, 79, 46, 36, 72, 65, 48])
1077 torch.Size([32, 3, 32, 32]) tensor([31,  4, 61, 93, 43, 56, 78, 23, 71, 95, 42, 78, 91, 15, 79, 68, 61,  5,
        92, 53, 64, 21, 44, 38, 76,  3, 84,  1, 47,  7, 23, 87])
1078 torch.Size([32, 3, 32, 32]) tensor([12, 90, 16, 97, 55, 79, 42, 63, 65, 27, 13, 96, 53, 80, 21, 43, 80, 80,
        26, 60, 12, 56, 52, 42, 55, 15, 11, 85, 49, 76, 62, 77])
1079 torch.Size([32, 3, 32, 32]) tensor([49, 84, 43,  1, 85, 86, 75, 32, 36, 53, 31, 10, 12, 72, 22, 60, 88, 8

1130 torch.Size([32, 3, 32, 32]) tensor([ 0, 76, 40, 62, 81, 14, 96, 12,  0, 38, 95, 30, 50, 64, 14, 99, 42,  8,
         0, 56, 19, 15, 13, 55, 75, 36, 81, 36, 43, 59, 52, 20])
1131 torch.Size([32, 3, 32, 32]) tensor([ 1, 64, 65, 87, 17, 72, 82, 84, 29, 85, 10, 59, 73,  6, 89, 26, 96, 62,
        67, 92, 28, 36,  2, 54, 51, 45, 93, 41, 68, 98, 86, 31])
1132 torch.Size([32, 3, 32, 32]) tensor([37, 99, 80, 78,  7, 26, 44, 69, 56, 51, 71, 34, 94, 63, 21, 31, 71, 90,
        20, 51,  3, 14, 43,  1, 42, 43, 31, 82,  4, 80, 90, 55])
1133 torch.Size([32, 3, 32, 32]) tensor([35, 48, 69, 47, 63, 43, 35, 87, 84, 61, 69, 95, 72, 35, 40, 49, 89, 82,
        79, 63, 47, 22, 94,  2, 80,  4, 14, 70, 24, 68, 37, 70])
1134 torch.Size([32, 3, 32, 32]) tensor([97, 66, 49,  9, 99, 22, 72, 90, 85, 66, 51, 34, 96, 87, 87, 34, 15, 15,
        66, 79, 50, 76, 84, 35,  8, 23, 73, 30, 32, 10, 21, 45])
1135 torch.Size([32, 3, 32, 32]) tensor([13, 95,  6, 70, 16, 17, 50, 65, 63, 85, 40, 45, 39, 13, 70, 35, 34,  

1186 torch.Size([32, 3, 32, 32]) tensor([76,  8, 51, 55, 31, 46, 70,  6, 85,  5, 80, 39, 90,  3, 55, 61, 32,  3,
         1, 69, 62, 85, 58, 28, 37, 43, 55, 89,  5,  9, 13, 74])
1187 torch.Size([32, 3, 32, 32]) tensor([58, 47, 98, 44, 12, 34, 36, 42, 98, 73, 22, 19, 13, 19, 20, 11, 12,  2,
        72,  1, 53, 68, 24, 64, 13, 51, 35, 96, 49, 73, 10, 62])
1188 torch.Size([32, 3, 32, 32]) tensor([88,  2, 92, 64,  3,  1, 73, 92, 38,  5, 41, 17, 72, 36, 54, 26, 38, 40,
        30, 10, 84, 50, 43, 49, 36,  2, 81, 16, 96, 95,  7, 55])
1189 torch.Size([32, 3, 32, 32]) tensor([ 4, 32, 52, 97, 75,  9, 66, 44, 19, 89,  7, 27,  5, 13, 30, 97, 65, 27,
        28, 50, 58, 30, 45,  0, 11, 97, 84, 23,  4, 37, 34, 83])
1190 torch.Size([32, 3, 32, 32]) tensor([64,  5, 80, 20, 81,  9, 30, 11, 26, 74, 16, 98, 33, 48, 68, 65, 33, 19,
        86, 78, 43,  4, 49, 19, 12, 10, 54, 50, 98, 62, 42, 42])
1191 torch.Size([32, 3, 32, 32]) tensor([57,  0, 22, 52,  8, 41, 16, 10, 29, 66, 58, 85, 75, 14, 21, 93, 69, 3

1241 torch.Size([32, 3, 32, 32]) tensor([ 7, 60, 85, 73, 24, 73, 97, 32, 99,  0, 35, 99, 60, 62, 12, 68, 12, 48,
        17,  6, 84, 71, 89, 26, 75, 42,  0, 54, 46, 33, 39, 62])
1242 torch.Size([32, 3, 32, 32]) tensor([53, 81, 92, 94, 15, 11,  9,  2, 38, 81,  2, 81, 35, 53, 19, 73, 75, 82,
        74, 92, 69, 27, 89, 24, 63, 92, 33, 46, 22, 44, 12, 20])
1243 torch.Size([32, 3, 32, 32]) tensor([40, 90, 89, 90, 48, 74, 51, 66, 64, 61, 75, 28, 32, 29, 19, 16, 33, 13,
        45, 43, 88, 47, 77, 39,  8, 78, 92, 62, 44, 15, 96, 41])
1244 torch.Size([32, 3, 32, 32]) tensor([26, 91,  4, 83, 66, 67, 26, 70, 96, 76, 74, 71, 26, 91, 88, 16,  4, 31,
        68, 14, 17, 85, 82, 68, 70, 39, 60, 16, 39,  2,  4, 83])
1245 torch.Size([32, 3, 32, 32]) tensor([20,  6,  1, 63, 71, 43, 94, 69, 65, 26, 83, 26, 99, 85, 88, 38, 74, 21,
        98, 19, 99,  8, 18,  0,  6, 83, 57, 63, 34, 60, 47, 64])
1246 torch.Size([32, 3, 32, 32]) tensor([22, 89, 98, 72, 50,  9, 96, 72, 71, 37, 96, 78, 78, 70, 37, 69, 84, 8

1297 torch.Size([32, 3, 32, 32]) tensor([27, 18, 46, 47, 68, 87, 75,  5,  0, 81, 81, 10, 71, 78, 99, 32, 28, 76,
        73, 12,  0, 81, 17, 27, 80, 23, 17, 81, 39, 22, 84, 89])
1298 torch.Size([32, 3, 32, 32]) tensor([91, 83, 21, 11, 39, 61, 41, 34, 12, 93, 56, 60, 65, 58, 80, 43,  2, 87,
        72, 36, 67, 89, 69, 34, 57, 54, 77, 49, 49, 41, 39, 33])
1299 torch.Size([32, 3, 32, 32]) tensor([62, 70, 67, 20,  1, 41, 84, 65, 79, 43, 12,  8, 26, 75, 75,  2, 32, 91,
        43,  6, 18, 72,  0,  4, 88, 53, 65, 38,  5, 49, 13, 17])
1300 torch.Size([32, 3, 32, 32]) tensor([31, 75, 70, 65, 57, 32, 79, 23, 83,  3, 49, 40, 66, 55, 29, 45,  5, 46,
        31, 19, 82, 13, 56, 22, 19, 58, 27, 11, 58, 80, 21, 75])
1301 torch.Size([32, 3, 32, 32]) tensor([72, 26, 89, 28, 87, 53, 28, 38, 72, 63,  8, 83, 55, 51,  1, 65, 68,  6,
        45, 62, 39, 92, 97, 31, 42, 66,  3, 16, 98, 79, 10, 25])
1302 torch.Size([32, 3, 32, 32]) tensor([33,  1, 33, 18, 15, 60, 47, 40, 66, 14, 10, 26, 37,  2, 19, 23, 14, 8

1352 torch.Size([32, 3, 32, 32]) tensor([ 7, 66, 47, 37, 30, 80, 51, 35, 65, 83, 71, 80, 53, 47, 98, 67,  7, 10,
        78, 58, 35, 71, 45, 20,  8, 35, 54, 19, 88, 59, 31, 51])
1353 torch.Size([32, 3, 32, 32]) tensor([96, 85, 67, 37, 98, 81,  4, 58, 54, 98, 88, 66, 33, 94, 45, 99, 35, 95,
        74, 17, 42, 39, 22, 95,  2, 37, 95, 20, 97, 56, 24,  9])
1354 torch.Size([32, 3, 32, 32]) tensor([16, 97, 55, 18, 98, 23,  8, 88, 62, 40, 60, 37, 63, 44, 89, 11, 53, 98,
        22, 39, 58, 65, 29, 55, 72, 92,  8, 72, 68, 70, 60, 50])
1355 torch.Size([32, 3, 32, 32]) tensor([73, 90, 11, 14, 32, 21, 22, 39, 85,  7, 37, 98, 46, 27, 16, 59, 50, 58,
        70, 77, 68, 99, 52, 87, 61, 26,  0, 24, 56, 13, 64, 64])
1356 torch.Size([32, 3, 32, 32]) tensor([97,  0, 60, 23, 32, 25,  4, 71,  2, 85, 17, 12, 94, 61, 38, 95, 41, 37,
        82, 95, 58, 36, 42,  8, 63, 30, 66, 24, 66, 58, 87,  0])
1357 torch.Size([32, 3, 32, 32]) tensor([73, 66, 44, 12, 12, 32, 97,  9, 11, 34, 97, 53, 35,  4, 17, 24, 83, 1

1407 torch.Size([32, 3, 32, 32]) tensor([83, 26, 60,  3,  2, 68,  0, 43, 78, 92, 52, 72, 48, 53, 59, 96, 38, 97,
        32, 58, 56, 42, 69, 39, 34, 74, 78, 11,  7, 98, 67, 44])
1408 torch.Size([32, 3, 32, 32]) tensor([77, 85, 54, 74,  2, 69, 42,  5, 61,  3, 99, 27, 61, 11, 48, 97,  0, 56,
        19, 65, 47, 93, 54, 27, 99, 73, 66,  3, 34, 18, 72, 62])
1409 torch.Size([32, 3, 32, 32]) tensor([46, 28, 68, 86,  8, 74, 21, 25, 52, 55, 84, 36, 67, 61, 94, 77, 70, 73,
        74, 31, 81, 21, 55, 28, 18, 78, 51, 66, 26, 24, 67, 66])
1410 torch.Size([32, 3, 32, 32]) tensor([40, 61, 45, 77, 63, 43, 56, 21, 34, 79,  0, 56,  5, 32, 54, 84, 90, 45,
        91, 26, 11, 34, 91,  0, 52, 33, 55, 85, 76, 89, 29, 73])
1411 torch.Size([32, 3, 32, 32]) tensor([36, 30, 75, 66, 22, 43, 17, 81, 17, 82, 69,  3, 56,  0, 85, 56, 39, 41,
        91, 58, 83, 18, 21, 21, 41, 52, 67, 63, 52, 62, 84, 90])
1412 torch.Size([32, 3, 32, 32]) tensor([90, 57, 45, 99, 32, 35, 13, 84, 54, 11, 84, 72,  6, 76, 60, 43,  8, 9

1463 torch.Size([32, 3, 32, 32]) tensor([59, 30, 77, 41, 17, 37, 64, 95,  4,  8, 40, 20, 58, 38, 16, 22, 53, 20,
        36, 28, 65, 75, 92, 76, 73, 45, 48, 84, 73, 58, 91, 24])
1464 torch.Size([32, 3, 32, 32]) tensor([47, 46, 20, 50, 78, 89, 94, 20,  0, 91, 39, 86, 75,  7, 17, 59, 16, 40,
         8, 27,  5, 21, 53, 49, 71, 65,  2, 50, 35, 90, 59, 87])
1465 torch.Size([32, 3, 32, 32]) tensor([66, 32, 94, 47, 80,  2, 79, 92, 80, 92, 35, 37, 17, 81, 76, 49, 89, 17,
        54, 11, 81, 85, 29,  8,  5, 67, 79,  1,  5, 99, 43, 20])
1466 torch.Size([32, 3, 32, 32]) tensor([43, 97, 86, 73, 20, 10, 48, 67, 37, 17, 42, 71, 39, 25, 66, 77, 95, 46,
        98, 34, 49, 25, 33, 38, 43, 96, 16, 22, 64, 95, 18, 85])
1467 torch.Size([32, 3, 32, 32]) tensor([75, 24, 81, 28, 71, 53, 84, 88, 88, 67, 69, 57, 95, 46, 16, 90, 20, 70,
        36, 76, 30, 87, 69, 77, 18, 68, 63, 86, 88, 67, 45,  3])
1468 torch.Size([32, 3, 32, 32]) tensor([11,  6, 51, 33, 76, 85, 31, 69, 73, 14, 11, 26, 55, 41,  0, 65, 47, 7

1518 torch.Size([32, 3, 32, 32]) tensor([31, 88,  0, 71, 83, 42,  7, 10, 38, 60, 83, 21, 55, 22, 24, 19,  1, 74,
        32, 64, 41, 35, 72, 27, 59, 18, 27, 18, 21, 84, 52, 52])
1519 torch.Size([32, 3, 32, 32]) tensor([56, 60, 57, 40, 95,  4, 36, 85, 62, 48,  2, 51, 53, 16, 13, 64, 12, 93,
        88, 36, 36, 98, 13, 74, 68, 76, 92,  0, 98, 48, 64, 52])
1520 torch.Size([32, 3, 32, 32]) tensor([19, 16, 86,  3, 56, 48, 43, 23, 26, 87, 85, 71, 85, 27, 98, 92, 39, 21,
        31, 16, 50, 60, 59, 43, 33, 40, 79, 94, 29, 52, 90, 50])
1521 torch.Size([32, 3, 32, 32]) tensor([81, 75, 35, 32, 83, 81,  4, 34, 29, 56, 77, 59, 56, 59, 12, 56, 24, 93,
        10,  5, 56, 67, 16, 14, 57, 55, 91, 46, 44, 12, 23, 58])
1522 torch.Size([32, 3, 32, 32]) tensor([80, 23, 49, 51, 95, 79, 87, 16, 41, 61, 69, 99,  6, 95, 87, 17, 95, 25,
        31, 29, 64, 16, 93, 47, 23,  3, 51, 89, 25, 45, 11, 24])
1523 torch.Size([32, 3, 32, 32]) tensor([78, 36, 17, 78, 70, 22, 52, 13, 10, 79,  1, 89, 39, 37, 73, 33, 35, 8

In [10]:
record_data = torch.zeros(16)
record = []
record_ones = torch.zeros(16)
for idx, (data, target) in enumerate(test_loader):
    record.append(target)
    record_data += target
    record_ones += torch.ones_like(target)
    print(idx, data.shape, target.shape, target)
    
record_data, record_ones

0 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([4, 2, 2, 0, 1, 7, 6, 1, 4, 7, 5, 8, 1, 3, 1, 4])
1 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([5, 0, 8, 0, 8, 2, 3, 1, 1, 8, 5, 5, 6, 4, 8, 2])
2 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([3, 7, 7, 0, 3, 9, 4, 2, 6, 7, 4, 2, 5, 2, 7, 5])
3 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([0, 6, 7, 8, 0, 6, 1, 8, 7, 2, 7, 6, 2, 0, 7, 6])
4 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([4, 7, 7, 0, 3, 2, 8, 6, 2, 2, 0, 3, 5, 3, 5, 2])
5 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([2, 7, 8, 3, 1, 4, 7, 4, 4, 7, 9, 7, 3, 3, 0, 9])
6 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([4, 7, 9, 3, 2, 3, 7, 2, 4, 2, 3, 8, 0, 1, 9, 8])
7 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([5, 4, 1, 0, 9, 3, 8, 1, 2, 7, 6, 8, 9, 8, 6, 5])
8 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([5, 6, 6, 1, 9, 7, 2, 1, 1, 6, 4, 9, 7, 2, 8, 3])
9 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([5, 2, 3, 

110 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([9, 6, 1, 3, 3, 9, 5, 1, 9, 1, 2, 6, 8, 3, 9, 6])
111 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([0, 3, 8, 4, 3, 2, 1, 2, 2, 2, 8, 1, 5, 3, 3, 9])
112 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([9, 5, 4, 3, 6, 6, 3, 7, 8, 6, 6, 2, 8, 8, 5, 9])
113 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([5, 8, 0, 4, 0, 6, 9, 6, 9, 2, 9, 0, 2, 7, 0, 2])
114 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([8, 7, 4, 2, 8, 7, 9, 7, 5, 9, 6, 9, 9, 5, 0, 2])
115 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([4, 3, 1, 9, 5, 4, 1, 9, 4, 5, 7, 9, 6, 7, 1, 1])
116 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([7, 3, 9, 9, 2, 4, 9, 2, 4, 8, 8, 8, 3, 8, 5, 2])
117 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([1, 7, 1, 0, 1, 3, 2, 5, 9, 7, 4, 2, 2, 2, 4, 9])
118 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([8, 9, 1, 9, 0, 8, 8, 4, 0, 0, 7, 1, 2, 4, 3, 3])
119 torch.Size([16, 3, 32, 32]) torch.Size([16

210 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([3, 0, 2, 2, 7, 5, 1, 0, 8, 8, 6, 9, 1, 7, 1, 6])
211 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([0, 1, 8, 9, 9, 8, 6, 5, 4, 8, 9, 1, 7, 6, 1, 9])
212 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([9, 6, 3, 8, 8, 3, 0, 6, 2, 0, 2, 2, 7, 8, 9, 8])
213 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([8, 1, 1, 5, 2, 4, 7, 7, 3, 4, 2, 8, 5, 3, 1, 6])
214 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([0, 1, 8, 6, 9, 8, 9, 2, 8, 4, 7, 5, 5, 0, 3, 9])
215 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([3, 5, 1, 5, 9, 5, 9, 0, 8, 2, 0, 1, 2, 8, 9, 9])
216 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([6, 7, 3, 5, 7, 3, 2, 0, 3, 4, 9, 6, 2, 1, 9, 0])
217 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([7, 4, 6, 5, 7, 4, 7, 1, 0, 5, 5, 1, 0, 2, 3, 6])
218 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([6, 6, 7, 5, 0, 2, 4, 8, 1, 6, 0, 7, 7, 4, 2, 5])
219 torch.Size([16, 3, 32, 32]) torch.Size([16

311 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([3, 6, 1, 0, 5, 5, 8, 5, 4, 4, 5, 5, 3, 2, 4, 5])
312 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([7, 7, 7, 7, 3, 3, 0, 2, 2, 7, 3, 2, 4, 4, 3, 0])
313 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([9, 3, 3, 4, 8, 8, 5, 7, 0, 7, 2, 9, 4, 1, 9, 8])
314 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([0, 2, 5, 3, 2, 6, 9, 3, 9, 2, 0, 3, 9, 1, 8, 1])
315 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([3, 6, 1, 7, 3, 7, 9, 3, 4, 6, 1, 3, 3, 0, 8, 1])
316 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([5, 0, 4, 9, 1, 6, 1, 8, 1, 3, 8, 7, 1, 5, 2, 1])
317 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([5, 5, 3, 3, 6, 4, 5, 3, 0, 0, 1, 1, 8, 1, 0, 2])
318 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([8, 9, 6, 6, 5, 2, 9, 8, 2, 6, 5, 5, 6, 6, 7, 6])
319 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([2, 8, 3, 3, 1, 7, 6, 3, 9, 8, 5, 1, 5, 6, 5, 6])
320 torch.Size([16, 3, 32, 32]) torch.Size([16

415 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([8, 0, 0, 2, 5, 6, 5, 9, 6, 4, 3, 4, 2, 6, 5, 2])
416 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([6, 2, 5, 3, 6, 8, 9, 8, 7, 4, 7, 5, 2, 2, 7, 8])
417 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([6, 2, 0, 1, 0, 3, 1, 9, 1, 6, 2, 6, 2, 8, 2, 7])
418 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([1, 5, 1, 9, 7, 9, 0, 3, 7, 4, 2, 1, 4, 6, 6, 0])
419 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([2, 8, 0, 9, 8, 4, 1, 5, 6, 6, 4, 5, 6, 6, 8, 4])
420 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([3, 9, 3, 8, 5, 5, 9, 5, 0, 7, 4, 5, 1, 4, 6, 3])
421 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([1, 5, 2, 2, 3, 2, 4, 1, 6, 2, 6, 8, 9, 2, 2, 2])
422 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([7, 8, 1, 9, 3, 5, 7, 6, 7, 7, 5, 8, 2, 7, 2, 9])
423 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([7, 6, 9, 8, 3, 1, 4, 8, 6, 1, 7, 1, 5, 3, 5, 8])
424 torch.Size([16, 3, 32, 32]) torch.Size([16

518 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([5, 6, 8, 1, 4, 8, 5, 5, 1, 3, 4, 0, 5, 4, 8, 3])
519 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([2, 8, 7, 1, 7, 3, 9, 8, 7, 6, 8, 8, 6, 7, 1, 1])
520 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([1, 1, 7, 7, 3, 0, 8, 8, 8, 6, 1, 5, 8, 8, 7, 3])
521 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([1, 1, 2, 3, 7, 8, 8, 3, 9, 0, 6, 1, 1, 6, 0, 2])
522 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([3, 9, 7, 7, 3, 8, 2, 2, 9, 1, 7, 3, 0, 8, 8, 4])
523 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([6, 6, 1, 3, 0, 2, 9, 8, 3, 4, 0, 4, 8, 1, 1, 7])
524 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([5, 1, 6, 3, 1, 1, 5, 9, 8, 2, 5, 3, 3, 9, 3, 2])
525 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([4, 7, 4, 3, 3, 1, 1, 3, 7, 4, 5, 7, 6, 4, 7, 8])
526 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([4, 1, 0, 9, 6, 7, 6, 4, 8, 1, 8, 0, 9, 2, 2, 3])
527 torch.Size([16, 3, 32, 32]) torch.Size([16

620 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([8, 3, 0, 5, 8, 1, 8, 8, 1, 6, 0, 8, 1, 7, 5, 9])
621 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([1, 0, 0, 2, 5, 2, 0, 5, 1, 1, 8, 0, 2, 5, 0, 8])
622 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([4, 9, 2, 9, 6, 4, 0, 4, 0, 3, 5, 0, 1, 8, 5, 8])
623 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([5, 4, 2, 4, 1, 3, 3, 1, 6, 1, 5, 8, 4, 3, 2, 8])
624 torch.Size([16, 3, 32, 32]) torch.Size([16]) tensor([1, 5, 3, 1, 2, 3, 1, 8, 3, 4, 3, 4, 0, 0, 6, 8])


(tensor([2985., 2769., 2730., 2883., 2748., 2889., 2751., 2788., 2787., 2828.,
         2874., 2830., 2797., 2773., 2772., 2796.]),
 tensor([625., 625., 625., 625., 625., 625., 625., 625., 625., 625., 625., 625.,
         625., 625., 625., 625.]))

In [8]:
record_data /= idx + 1
record_data

tensor([4.4864, 4.4848, 4.5344, 4.4976, 4.4832, 4.1568, 4.4112, 4.5248, 4.7024,
        4.5536, 4.5408, 4.3840, 4.6592, 4.5280, 4.3808, 4.6720])