In [None]:
!pip install torchnet
!pip install utils

!pip install tqdm
!pip install nested_dict
!pip install numpy
!pip install torch
!pip install torchnet
!pip install torchvision

Collecting torchnet
  Downloading https://files.pythonhosted.org/packages/b7/b2/d7f70a85d3f6b0365517782632f150e3bbc2fb8e998cd69e27deba599aae/torchnet-0.0.4.tar.gz
Collecting visdom
[?25l  Downloading https://files.pythonhosted.org/packages/c9/75/e078f5a2e1df7e0d3044749089fc2823e62d029cc027ed8ae5d71fafcbdc/visdom-0.1.8.9.tar.gz (676kB)
[K     |████████████████████████████████| 686kB 5.1MB/s 
Collecting jsonpatch
  Downloading https://files.pythonhosted.org/packages/82/53/73ca86f2a680c705dcd1708be4887c559dfe9ed250486dd3ccd8821b8ccb/jsonpatch-1.25-py2.py3-none-any.whl
Collecting torchfile
  Downloading https://files.pythonhosted.org/packages/91/af/5b305f86f2d218091af657ddb53f984ecbd9518ca9fe8ef4103a007252c9/torchfile-0.1.0.tar.gz
Collecting websocket-client
[?25l  Downloading https://files.pythonhosted.org/packages/4c/5f/f61b420143ed1c8dc69f9eaec5ff1ac36109d52c80de49d66e0c36c3dfdf/websocket_client-0.57.0-py2.py3-none-any.whl (200kB)
[K     |████████████████████████████████| 204kB 41.8

In [None]:
from nested_dict import nested_dict
from functools import partial
import torch
from torch.nn.init import kaiming_normal_
from torch.nn.parallel._functions import Broadcast
from torch.nn.parallel import scatter, parallel_apply, gather
import torch.nn.functional as F


def distillation(y, teacher_scores, labels, T, alpha):
    p = F.log_softmax(y/T, dim=1)
    q = F.softmax(teacher_scores/T, dim=1)
    l_kl = F.kl_div(p, q, size_average=False) * (T**2) / y.shape[0]
    l_ce = F.cross_entropy(y, labels)
    return l_kl * alpha + l_ce * (1. - alpha)


def at(x):
    return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))


def at_loss(x, y):
    print("INSIDE AT_LOSS")
    print(x.shape)
    print(y.shape)
    print(at(x).shape)
    print(at(y).shape)
    return (at(x) - at(y)).pow(4).mean()


def cast(params, dtype='float'):
    if isinstance(params, dict):
        return {k: cast(v, dtype) for k,v in params.items()}
    else:
        return getattr(params.cuda() if torch.cuda.is_available() else params, dtype)()


def conv_params(ni, no, k=1):
    return kaiming_normal_(torch.Tensor(no, ni, k, k))


def linear_params(ni, no):
    return {'weight': kaiming_normal_(torch.Tensor(no, ni)), 'bias': torch.zeros(no)}


def bnparams(n):
    return {'weight': torch.rand(n),
            'bias': torch.zeros(n),
            'running_mean': torch.zeros(n),
            'running_var': torch.ones(n)}


def data_parallel(f, input, params, mode, device_ids, output_device=None):
    device_ids = list(device_ids)
    if output_device is None:
        output_device = device_ids[0]

    if len(device_ids) == 1:
        return f(input, params, mode)

    params_all = Broadcast.apply(device_ids, *params.values())
    params_replicas = [{k: params_all[i + j*len(params)] for i, k in enumerate(params.keys())}
                       for j in range(len(device_ids))]

    replicas = [partial(f, params=p, mode=mode)
                for p in params_replicas]
    inputs = scatter([input], device_ids)
    outputs = parallel_apply(replicas, inputs)
    return gather(outputs, output_device)


def flatten(params):
    return {'.'.join(k): v for k, v in nested_dict(params).items_flat() if v is not None}


def batch_norm(x, params, base, mode):
    return F.batch_norm(x, weight=params[base + '.weight'],
                        bias=params[base + '.bias'],
                        running_mean=params[base + '.running_mean'],
                        running_var=params[base + '.running_var'],
                        training=mode)


def print_tensor_dict(params):
    kmax = max(len(key) for key in params.keys())
    for i, (key, v) in enumerate(params.items()):
        print(str(i).ljust(5), key.ljust(kmax + 3), str(tuple(v.shape)).ljust(23), torch.typename(v), v.requires_grad)


def set_requires_grad_except_bn_(params):
    for k, v in params.items():
        if not k.endswith('running_mean') and not k.endswith('running_var'):
            v.requires_grad = True

In [None]:
import argparse
import os
import json
import numpy as np
from tqdm import tqdm
import torch
from torch.optim import SGD
import torchvision.transforms as T
from torchvision import datasets
import torch.nn.functional as F
import torchnet as tnt
from torchnet.engine import Engine
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import utils
from easydict import EasyDict

cudnn.benchmark = True

In [None]:
def model_parameters(args):
  epoch_step = json.loads(args.epoch_step)
  num_classes = args.classes
  
  def create_dataset(args, train):

    transform = T.Compose([
        T.ToTensor(),
        T.Normalize(np.array([125.3, 123.0, 113.9]) / 255.0,
                    np.array([63.0, 62.1, 66.7]) / 255.0),
    ])
    if args.channel==0:
      transform = T.Compose([
        T.ToTensor(),
        T.Normalize((0.5,), (0.5,))
      ])
    if train:
        transform = T.Compose([
            T.Pad(4, padding_mode='reflect'),
            T.RandomHorizontalFlip(),
            T.RandomCrop(32),
            transform
        ])
    return getattr(datasets, args.dataset)(args.dataroot, train=train, download=True, transform=transform)

  def create_iterator(mode):
          return DataLoader(create_dataset(args, mode), args.batch_size, shuffle=mode,
                            num_workers=args.nthread, pin_memory=torch.cuda.is_available())

  train_loader = create_iterator(True)
  test_loader = create_iterator(False)
  
  c=0
  if args.channel==0:
    c=1
  else:
    c=3
  
  def resnet(depth, width, num_classes):
      assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
      n = (depth - 4) // 6
      widths = [int(v * width) for v in (16, 32, 64)]

      def gen_block_params(ni, no):
          return {
              'conv0': conv_params(ni, no, 3),
              'conv1': conv_params(no, no, 3),
              'bn0': bnparams(ni),
              'bn1': bnparams(no),
              'convdim': conv_params(ni, no, 1) if ni != no else None,
          }

      def gen_group_params(ni, no, count):
          return {'block%d' % i: gen_block_params(ni if i == 0 else no, no)
                  for i in range(count)}
      flat_params = cast(flatten({
          'conv0': conv_params(c, 16, 3),
          'group0': gen_group_params(16, widths[0], n),
          'group1': gen_group_params(widths[0], widths[1], n),
          'group2': gen_group_params(widths[1], widths[2], n),
          'bn': bnparams(widths[2]),
          'fc': linear_params(widths[2], num_classes),
      }))

      set_requires_grad_except_bn_(flat_params)

      def block(x, params, base, mode, stride):
          #print("INSIDE BLOCK")
          #print(x.shape)
          o1 = F.relu(batch_norm(x, params, base + '.bn0', mode), inplace=True)
          #print(o1.shape)
          y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
          #print(y.shape)
          o2 = F.relu(batch_norm(y, params, base + '.bn1', mode), inplace=True)
          #print(o2.shape)
          z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
          #print(z.shape)
          if base + '.convdim' in params:
              #print("HERE in block")
              #print(z.shape)
              #print(F.conv2d(o1, params[base + '.convdim'], stride=stride).shape)
              return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
          else:
              return z + x

      def group(o, params, base, mode, stride):
          for i in range(n):
              o = block(o, params, f'{base}.block{i}', mode, stride if i == 0 else 1)
          return o

      def f(input, params, mode, base=''):
          #print("HERE F")
          #print(input.shape)
          x = F.conv2d(input, params[f'{base}conv0'], padding=1)
          #print("x")
          #print(x.shape)
          #print(x.shape)
          g0 = group(x, params, f'{base}group0', mode, 1)
          #print(g0.shape)
          g1 = group(g0, params, f'{base}group1', mode, 2)
          #print(g1.shape)
          g2 = group(g1, params, f'{base}group2', mode, 2)
          #print(g2.shape)
          o = F.relu(batch_norm(g2, params, f'{base}bn', mode))
          #print(o.shape)
          o = F.avg_pool2d(o, 8, 1, 0)
          #print(o.shape)
          o = o.view(o.size(0), -1)
          #print(o.shape)
          o = F.linear(o, params[f'{base}fc.weight'], params[f'{base}fc.bias'])
          #print(o.shape)
          return o, (g0, g1, g2)

      return f, flat_params




  f_s, params_s = resnet(args.depth, args.width, num_classes)
  if args.teacher_id:
          with open(os.path.join('logs', args.teacher_id, 'log.txt'), 'r') as ff:
              line = ff.readline()
              r = line.find('json_stats')
              info = json.loads(line[r + 12:])
          f_t = resnet(info['depth'], info['width'], num_classes)[0]
          model_data = torch.load(os.path.join('logs', args.teacher_id, 'model.pt7'))
          params_t = model_data['params']

          # merge teacher and student params
          params = {'student.' + k: v for k, v in params_s.items()}
          for k, v in params_t.items():
              params['teacher.' + k] = v.detach().requires_grad_(False)

          def f(inputs, params, mode):
              y_s, g_s = f_s(inputs, params, mode, 'student.')
              with torch.no_grad():
                  y_t, g_t = f_t(inputs, params, False, 'teacher.')
              return y_s, y_t, [at_loss(x, y) for x, y in zip(g_s, g_t)]
  else:
    f, params = f_s, params_s

  def create_optimizer(args, lr):
        print('creating optimizer with lr = ', lr)
        return SGD((v for v in params.values() if v.requires_grad), lr,
                   momentum=0.9, weight_decay=args.weight_decay)
  
  optimizer = create_optimizer(args, args.lr)
  
  
  epoch = 0
  
  if args.resume != '':
    state_dict = torch.load(args.resume)
    epoch = state_dict['epoch']
    params_tensors = state_dict['params']
    for k, v in params.items():
        v.data.copy_(params_tensors[k])
    optimizer.load_state_dict(state_dict['optimizer'])
    print("HERE")

  print('\nParameters:')
  print_tensor_dict(params)
  n_parameters = sum(p.numel() for p in list(params_s.values()))
  print('\nTotal number of parameters:', n_parameters)


  meter_loss = tnt.meter.AverageValueMeter()
  classacc = tnt.meter.ClassErrorMeter(accuracy=True)
  timer_train = tnt.meter.TimeMeter('s')
  timer_test = tnt.meter.TimeMeter('s')
  meters_at = [tnt.meter.AverageValueMeter() for i in range(3)]

  if not os.path.exists(args.save):
    os.mkdir(args.save)

  def h(sample):
        inputs = cast(sample[0], args.dtype).detach()
        targets = cast(sample[1], 'long')
        #print("INSIDE H")
        #print(inputs.shape)
        if inputs.shape[2]!=32:
          inputs = torch.nn.functional.pad(inputs, (3, 1, 3, 1))
        #print(inputs.shape)
        #print(targets.shape)
        if args.teacher_id != '':
            y_s, y_t, loss_groups = data_parallel(f, inputs, params, sample[2], range(1))
            print("LOSS _ GROUPS BEFORE")
            print(loss_groups)
            print(len(loss_groups))
            loss_groups = [v.sum() for v in loss_groups]
            print("LOSS _ GROUPS AFTER")
            print(loss_groups)
            [m.add(v.item()) for m, v in zip(meters_at, loss_groups)]
            return distillation(y_s, y_t, targets, args.temperature, args.alpha) \
                   + args.beta * sum(loss_groups), y_s
        else:
            y = data_parallel(f, inputs, params, sample[2], range(1))[0]
            return F.cross_entropy(y, targets), y

  def log(t, state):
      torch.save(dict(params={k: v.data for k, v in params.items()},
                      optimizer=state['optimizer'].state_dict(),
                      epoch=t['epoch']),
                  os.path.join(args.save, 'model.pt7'))
      z = vars(args).copy(); z.update(t)
      logname = os.path.join(args.save, 'log.txt')
      with open(logname, 'a') as f:
          f.write('json_stats: ' + json.dumps(z) + '\n')
      print(z)

  def on_sample(state):
      state['sample'].append(state['train'])

  def on_forward(state):
      classacc.add(state['output'].data, state['sample'][1])
      meter_loss.add(state['loss'].item())  

  def on_start(state):
        state['epoch'] = epoch

  def on_start_epoch(state):
      classacc.reset()
      meter_loss.reset()
      timer_train.reset()
      [meter.reset() for meter in meters_at]
      state['iterator'] = tqdm(train_loader)

      epoch = state['epoch'] + 1
      if epoch in epoch_step:
          lr = state['optimizer'].param_groups[0]['lr']
          state['optimizer'] = create_optimizer(args, lr * args.lr_decay_ratio)

  def on_end_epoch(state):
      train_loss = meter_loss.mean
      train_acc = classacc.value()
      train_time = timer_train.value()
      meter_loss.reset()
      classacc.reset()
      timer_test.reset()

      #print("IT CAME HERE")
      engine.test(h, test_loader)
      #print("FINISHED TESTIG")
      test_acc = classacc.value()[0]
      print(log({
          "train_loss": train_loss,
          "train_acc": train_acc[0],
          "test_loss": meter_loss.mean,
          "test_acc": test_acc,
          "epoch": state['epoch'],
          "num_classes": num_classes,
          "n_parameters": n_parameters,
          "train_time": train_time,
          "test_time": timer_test.value(),
          "at_losses": [m.value() for m in meters_at],
          }, state))
      print('==> id: %s (%d/%d), test_acc: \33[91m%.2f\033[0m' % \
                      (args.save, state['epoch'], args.epochs, test_acc))


  engine = Engine()
  engine.hooks['on_sample'] = on_sample
  engine.hooks['on_forward'] = on_forward
  engine.hooks['on_start_epoch'] = on_start_epoch
  engine.hooks['on_end_epoch'] = on_end_epoch
  engine.hooks['on_start'] = on_start
  engine.train(h, train_loader, args.epochs, optimizer)

  return engine

In [None]:
teacher_args = EasyDict({
    "channel": 0,
    "classes": 10,
    "depth": 16,
    "width": 2,
    "dataset" : 'FashionMNIST',
    "dataroot" : '.',
    "dtype" : 'float',
    "nthread" : 0,
    "teacher_id" : '',
    "batch_size" : 128,
    "lr" : 0.1,
    "epochs" : 1,
    "weight_decay" : 0.0005,
    "epoch_step" : '[60, 120, 160]',
    "lr_decay_ratio" : 0.2,
    "resume" : '',
    "randomcrop_pad" : 4,
    "temperature" : 4,
    "alpha" : 0,
    "beta" : 0,
    "gpu_id" : '0',
    "save" : '/content/logs/resnet_52_2_teacher'
})

print('parsed options:', teacher_args)
teacher = model_parameters(teacher_args)



parsed options: {'channel': 0, 'classes': 10, 'depth': 16, 'width': 2, 'dataset': 'FashionMNIST', 'dataroot': '.', 'dtype': 'float', 'nthread': 0, 'teacher_id': '', 'batch_size': 128, 'lr': 0.1, 'epochs': 1, 'weight_decay': 0.0005, 'epoch_step': '[60, 120, 160]', 'lr_decay_ratio': 0.2, 'resume': '', 'randomcrop_pad': 4, 'temperature': 4, 'alpha': 0, 'beta': 0, 'gpu_id': '0', 'save': '/content/logs/resnet_52_2_teacher'}
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./FashionMNIST/raw/train-images-idx3-ubyte.gz to ./FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./FashionMNIST/raw/train-labels-idx1-ubyte.gz



HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./FashionMNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./FashionMNIST/raw
Processing...
Done!








  0%|          | 0/469 [00:00<?, ?it/s]

creating optimizer with lr =  0.1

Parameters:
0     conv0                             (16, 1, 3, 3)           torch.cuda.FloatTensor True
1     group0.block0.conv0               (32, 16, 3, 3)          torch.cuda.FloatTensor True
2     group0.block0.conv1               (32, 32, 3, 3)          torch.cuda.FloatTensor True
3     group0.block0.bn0.weight          (16,)                   torch.cuda.FloatTensor True
4     group0.block0.bn0.bias            (16,)                   torch.cuda.FloatTensor True
5     group0.block0.bn0.running_mean    (16,)                   torch.cuda.FloatTensor False
6     group0.block0.bn0.running_var     (16,)                   torch.cuda.FloatTensor False
7     group0.block0.bn1.weight          (32,)                   torch.cuda.FloatTensor True
8     group0.block0.bn1.bias            (32,)                   torch.cuda.FloatTensor True
9     group0.block0.bn1.running_mean    (32,)                   torch.cuda.FloatTensor False
10    group0.block0.bn1.runnin

100%|██████████| 469/469 [00:41<00:00, 11.43it/s]


{'channel': 0, 'classes': 10, 'depth': 16, 'width': 2, 'dataset': 'FashionMNIST', 'dataroot': '.', 'dtype': 'float', 'nthread': 0, 'teacher_id': '', 'batch_size': 128, 'lr': 0.1, 'epochs': 1, 'weight_decay': 0.0005, 'epoch_step': '[60, 120, 160]', 'lr_decay_ratio': 0.2, 'resume': '', 'randomcrop_pad': 4, 'temperature': 4, 'alpha': 0, 'beta': 0, 'gpu_id': '0', 'save': '/content/logs/resnet_52_2_teacher', 'train_loss': 0.5896742646373924, 'train_acc': 78.38000000000001, 'test_loss': 2.275277966185461, 'test_acc': 44.46, 'epoch': 1, 'num_classes': 10, 'n_parameters': 693210, 'train_time': 41.05589175224304, 'test_time': 2.5025694370269775, 'at_losses': [(nan, nan), (nan, nan), (nan, nan)]}
None
==> id: /content/logs/resnet_52_2_teacher (1/1), test_acc: [91m44.46[0m


In [None]:
student_trained_args = EasyDict({
    "channel": 0,
    "classes": 10,
    "depth": 10,
    "width": 1,
    "dataset" : 'FashionMNIST',
    "dataroot" : '.',
    "dtype" : 'float',
    "nthread" : 4,
    "teacher_id" : 'resnet_52_2_teacher',
    "batch_size" : 128,
    "lr" : 0.1,
    "epochs" : 3,
    "weight_decay" : 0.0005,
    "epoch_step" : '[60, 120, 160]',
    "lr_decay_ratio" : 0.2,
    "resume" : '',
    "randomcrop_pad" : 4,
    "temperature" : 4,
    "alpha" : 0,
    "beta" : 1000,
    "gpu_id" : '0',
    "save" : '/content/logs/at_52_2_16_1'
})

print('parsed options:', student_trained_args)
student_trained = model_parameters(student_trained_args)






  0%|          | 0/469 [00:00<?, ?it/s][A[A[A

parsed options: {'channel': 0, 'classes': 10, 'depth': 10, 'width': 1, 'dataset': 'FashionMNIST', 'dataroot': '.', 'dtype': 'float', 'nthread': 4, 'teacher_id': 'resnet_52_2_teacher', 'batch_size': 128, 'lr': 0.1, 'epochs': 3, 'weight_decay': 0.0005, 'epoch_step': '[60, 120, 160]', 'lr_decay_ratio': 0.2, 'resume': '', 'randomcrop_pad': 4, 'temperature': 4, 'alpha': 0, 'beta': 1000, 'gpu_id': '0', 'save': '/content/logs/at_52_2_16_1'}
creating optimizer with lr =  0.1

Parameters:
0     student.conv0                             (16, 1, 3, 3)           torch.cuda.FloatTensor True
1     student.group0.block0.conv0               (16, 16, 3, 3)          torch.cuda.FloatTensor True
2     student.group0.block0.conv1               (16, 16, 3, 3)          torch.cuda.FloatTensor True
3     student.group0.block0.bn0.weight          (16,)                   torch.cuda.FloatTensor True
4     student.group0.block0.bn0.bias            (16,)                   torch.cuda.FloatTensor True
5     student.g




  0%|          | 1/469 [00:00<02:34,  3.03it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(6.1291e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.2115e-05, device='cuda:0', grad_fn=<MeanBackward0>), tensor(0.0001, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(6.1291e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.2115e-05, device='cuda:0', grad_fn=<SumBackward0>), tensor(0.0001, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([1




  1%|          | 4/469 [00:00<01:54,  4.08it/s][A[A[A


  1%|▏         | 7/469 [00:00<01:25,  5.41it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(6.6436e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(9.3570e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.1323e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(6.6436e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(9.3570e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.1323e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




  2%|▏         | 10/469 [00:00<01:05,  7.03it/s][A[A[A


  3%|▎         | 13/469 [00:00<00:51,  8.92it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(7.3571e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.5506e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.6087e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(7.3571e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.5506e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.6087e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




  3%|▎         | 16/469 [00:00<00:41, 10.91it/s][A[A[A


  4%|▍         | 19/469 [00:01<00:35, 12.85it/s]

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(6.9841e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.6992e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.9154e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(6.9841e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.6992e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.9154e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch

[A[A[A


  5%|▍         | 22/469 [00:01<00:29, 15.10it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(6.8383e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.4580e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.4751e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(6.8383e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.4580e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.4751e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




  5%|▌         | 25/469 [00:01<00:26, 16.87it/s][A[A[A


  6%|▌         | 28/469 [00:01<00:23, 18.91it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(6.4828e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(7.1991e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.9696e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(6.4828e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.1991e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.9696e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




  7%|▋         | 31/469 [00:01<00:21, 19.94it/s][A[A[A

torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(5.7319e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.4032e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.1153e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(5.7319e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.4032e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.1153e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([12




  7%|▋         | 34/469 [00:01<00:21, 20.48it/s][A[A[A


  8%|▊         | 37/469 [00:01<00:20, 21.38it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(6.1742e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.2604e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.0896e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(6.1742e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.2604e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.0896e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




  9%|▊         | 40/469 [00:02<00:19, 22.20it/s][A[A[A


  9%|▉         | 43/469 [00:02<00:18, 23.22it/s][A[A[A

torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(5.8965e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.4496e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.1207e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(5.8965e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.4496e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.1207e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([12




 10%|▉         | 46/469 [00:02<00:18, 22.64it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(6.3314e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.6339e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.4249e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(6.3314e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.6339e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.4249e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 10%|█         | 49/469 [00:02<00:18, 22.73it/s][A[A[A


 11%|█         | 52/469 [00:02<00:19, 21.86it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(5.9474e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.9268e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.1605e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(5.9474e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.9268e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.1605e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 12%|█▏        | 55/469 [00:02<00:18, 22.18it/s][A[A[A

LOSS _ GROUPS BEFORE
[tensor(5.6393e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(4.7221e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.1455e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(5.6393e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(4.7221e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.1455e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(5.3420e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.8069e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.0421e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AF




 12%|█▏        | 58/469 [00:02<00:19, 21.57it/s][A[A[A


 13%|█▎        | 61/469 [00:02<00:18, 22.02it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(5.2903e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(4.0826e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.6627e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(5.2903e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(4.0826e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.6627e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 14%|█▎        | 64/469 [00:03<00:17, 22.68it/s][A[A[A


 14%|█▍        | 67/469 [00:03<00:17, 23.30it/s][A[A[A

[tensor(5.7318e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.4197e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.9593e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(5.4816e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.2629e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.4550e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(5.4816e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.2629e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4550e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32




 15%|█▍        | 70/469 [00:03<00:17, 23.46it/s][A[A[A

[tensor(3.7253e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.2061e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.3765e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(4.7976e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.2106e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.2060e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(4.7976e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.2106e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.2060e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32




 16%|█▌        | 73/469 [00:03<00:17, 22.52it/s][A[A[A


 16%|█▌        | 76/469 [00:03<00:16, 23.16it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(5.3549e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(7.4180e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.4894e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(5.3549e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.4180e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4894e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 17%|█▋        | 79/469 [00:03<00:17, 22.20it/s][A[A[A


 17%|█▋        | 82/469 [00:03<00:17, 22.41it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(5.1955e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.5142e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.5993e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(5.1955e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.5142e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.5993e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 18%|█▊        | 85/469 [00:03<00:16, 23.26it/s][A[A[A


 19%|█▉        | 88/469 [00:04<00:16, 23.50it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(4.2286e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.4279e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.2482e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(4.2286e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.4279e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.2482e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 19%|█▉        | 91/469 [00:04<00:15, 24.02it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(4.5218e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.3780e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.2499e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(4.5218e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.3780e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.2499e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 20%|██        | 94/469 [00:04<00:15, 23.97it/s][A[A[A


 21%|██        | 97/469 [00:04<00:15, 24.31it/s][A[A[A

[tensor(4.5327e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.5322e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(9.7595e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(4.5327e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.5322e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(9.7595e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.7862e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.0119e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.0484e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.7862e-0




 21%|██▏       | 100/469 [00:04<00:15, 23.84it/s][A[A[A


 22%|██▏       | 103/469 [00:04<00:15, 23.81it/s][A[A[A

torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.5631e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.6809e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.0148e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.5631e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.6809e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.0148e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.4002e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(4.0212e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(9.8193e-06, device='cuda:0', 




 23%|██▎       | 106/469 [00:04<00:15, 23.59it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.2315e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(4.0293e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.7095e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.2315e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(4.0293e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.7095e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 23%|██▎       | 109/469 [00:04<00:15, 22.93it/s][A[A[A


 24%|██▍       | 112/469 [00:05<00:15, 23.57it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.7422e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.9465e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(9.1810e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.7422e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.9465e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(9.1810e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 25%|██▍       | 115/469 [00:05<00:15, 23.19it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.2875e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.1954e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.1079e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.2875e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.1954e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.1079e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 25%|██▌       | 118/469 [00:05<00:15, 22.65it/s][A[A[A


 26%|██▌       | 121/469 [00:05<00:15, 22.38it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.5300e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(4.1996e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(9.6241e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.5300e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(4.1996e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(9.6241e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 26%|██▋       | 124/469 [00:05<00:15, 22.27it/s][A[A[A


 27%|██▋       | 127/469 [00:05<00:15, 22.67it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.4112e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.7149e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.1109e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.4112e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.7149e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.1109e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 28%|██▊       | 130/469 [00:05<00:14, 23.38it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.3949e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.4673e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(7.9804e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.3949e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.4673e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.9804e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 28%|██▊       | 133/469 [00:06<00:14, 22.94it/s][A[A[A


 29%|██▉       | 136/469 [00:06<00:14, 22.93it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.5390e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.3219e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.8812e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.5390e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.3219e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.8812e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 30%|██▉       | 139/469 [00:06<00:14, 23.18it/s][A[A[A


 30%|███       | 142/469 [00:06<00:14, 22.83it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.9470e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.4121e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.1259e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.9470e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.4121e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.1259e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 31%|███       | 145/469 [00:06<00:14, 22.99it/s][A[A[A


 32%|███▏      | 148/469 [00:06<00:13, 23.57it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(4.4500e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.4778e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.0474e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(4.4500e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.4778e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.0474e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 32%|███▏      | 151/469 [00:06<00:13, 23.29it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.0208e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.5006e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.0031e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.0208e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.5006e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.0031e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 33%|███▎      | 154/469 [00:06<00:13, 22.94it/s][A[A[A


 33%|███▎      | 157/469 [00:07<00:13, 23.70it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(4.0682e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.4574e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.0215e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(4.0682e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.4574e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.0215e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 34%|███▍      | 160/469 [00:07<00:13, 23.66it/s][A[A[A


 35%|███▍      | 163/469 [00:07<00:12, 24.25it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.5213e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.3016e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(9.9643e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.5213e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.3016e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(9.9643e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size(




 35%|███▌      | 166/469 [00:07<00:12, 24.31it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.8870e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.1086e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(9.6264e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.8870e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.1086e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(9.6264e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 36%|███▌      | 169/469 [00:07<00:12, 24.01it/s][A[A[A


 37%|███▋      | 172/469 [00:07<00:12, 24.34it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.6910e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.8095e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(7.9921e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.6910e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.8095e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.9921e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 37%|███▋      | 175/469 [00:07<00:12, 24.37it/s][A[A[A

torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.7559e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.2917e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(7.2574e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.7559e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.2917e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.2574e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([




 38%|███▊      | 178/469 [00:07<00:12, 23.52it/s][A[A[A


 39%|███▊      | 181/469 [00:08<00:12, 23.73it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.7440e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.7007e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(7.7661e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.7440e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.7007e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.7661e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 39%|███▉      | 184/469 [00:08<00:12, 23.60it/s][A[A[A


 40%|███▉      | 187/469 [00:08<00:11, 23.58it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.4421e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.9989e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(7.9981e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.4421e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.9989e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.9981e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 41%|████      | 190/469 [00:08<00:11, 23.26it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.3965e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.6615e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(9.6413e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.3965e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.6615e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(9.6413e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 41%|████      | 193/469 [00:08<00:11, 23.49it/s][A[A[A


 42%|████▏     | 196/469 [00:08<00:12, 22.46it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.6287e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.2498e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(9.0678e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.6287e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.2498e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(9.0678e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 42%|████▏     | 199/469 [00:08<00:11, 23.33it/s][A[A[A


 43%|████▎     | 202/469 [00:08<00:11, 23.13it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.6831e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.6935e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(7.2725e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.6831e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.6935e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.2725e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 44%|████▎     | 205/469 [00:09<00:10, 24.14it/s][A[A[A


 44%|████▍     | 208/469 [00:09<00:11, 23.50it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.0451e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.7672e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(7.8229e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.0451e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.7672e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.8229e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 45%|████▍     | 211/469 [00:09<00:10, 24.13it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.2210e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.4280e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(9.4874e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.2210e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.4280e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(9.4874e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 46%|████▌     | 214/469 [00:09<00:10, 24.60it/s][A[A[A


 46%|████▋     | 217/469 [00:09<00:10, 24.32it/s][A[A[A

[tensor(2.8347e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.8458e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.1952e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.1645e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.2769e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.0144e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.1645e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.2769e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.0144e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32




 47%|████▋     | 220/469 [00:09<00:10, 24.09it/s][A[A[A


 48%|████▊     | 223/469 [00:09<00:10, 23.71it/s]

torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.0565e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.1893e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.2305e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.0565e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.1893e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.2305e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.029

[A[A[A


 48%|████▊     | 226/469 [00:09<00:10, 23.89it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.7908e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.7287e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.8457e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.7908e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.7287e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.8457e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 49%|████▉     | 229/469 [00:10<00:10, 23.79it/s][A[A[A


 49%|████▉     | 232/469 [00:10<00:09, 24.16it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.1318e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.5581e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.6641e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.1318e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.5581e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.6641e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 50%|█████     | 235/469 [00:10<00:09, 23.69it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.2733e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.7084e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.1245e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.2733e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.7084e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.1245e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 51%|█████     | 238/469 [00:10<00:09, 23.33it/s][A[A[A


 51%|█████▏    | 241/469 [00:10<00:09, 24.12it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.2884e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.7241e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.5632e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.2884e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.7241e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.5632e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 52%|█████▏    | 244/469 [00:10<00:09, 23.48it/s][A[A[A


 53%|█████▎    | 247/469 [00:10<00:09, 23.40it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.7706e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.7838e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(9.4068e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.7706e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.7838e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(9.4068e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 53%|█████▎    | 250/469 [00:10<00:09, 23.13it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.7684e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.7117e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(9.3610e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.7684e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.7117e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(9.3610e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 54%|█████▍    | 253/469 [00:11<00:09, 23.24it/s][A[A[A


 55%|█████▍    | 256/469 [00:11<00:08, 23.73it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.9119e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.8437e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.7815e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.9119e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.8437e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.7815e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 55%|█████▌    | 259/469 [00:11<00:08, 23.35it/s][A[A[A


 56%|█████▌    | 262/469 [00:11<00:09, 21.99it/s][A[A[A

torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.0676e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.5357e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.1727e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.0676e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.5357e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.1727e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([12




 57%|█████▋    | 265/469 [00:11<00:09, 22.10it/s][A[A[A


[tensor(2.1892e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.6929e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.0536e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.1892e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.6929e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.0536e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.3633e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.0417e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.6128e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.3633e-




 57%|█████▋    | 268/469 [00:11<00:09, 21.76it/s][A[A[A


 58%|█████▊    | 271/469 [00:11<00:08, 22.06it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.3029e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.6819e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(7.4810e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.3029e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.6819e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.4810e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 58%|█████▊    | 274/469 [00:12<00:08, 21.74it/s][A[A[A

[tensor(2.1347e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.9430e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.1866e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.1347e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.9430e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.1866e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.3889e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.0483e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(9.5104e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.3889e-0




 59%|█████▉    | 277/469 [00:12<00:08, 21.68it/s][A[A[A


 60%|█████▉    | 280/469 [00:12<00:08, 21.96it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.7467e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.2530e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.8513e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.7467e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.2530e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.8513e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 60%|██████    | 283/469 [00:12<00:08, 21.88it/s][A[A[A

[tensor(2.5225e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.5844e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.5280e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.5225e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.5844e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.5280e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.8859e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.1579e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(7.7364e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.8859e-0




 61%|██████    | 286/469 [00:12<00:08, 22.12it/s][A[A[A


 62%|██████▏   | 289/469 [00:12<00:07, 22.70it/s][A[A[A


torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.1085e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.0282e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.9486e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.1085e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.0282e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.9486e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 25




 62%|██████▏   | 292/469 [00:12<00:08, 21.68it/s][A[A[A

[tensor(4.4413e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.1224e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.0928e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.1477e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.3505e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(7.0004e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.1477e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(3.3505e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.0004e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32




 63%|██████▎   | 295/469 [00:13<00:08, 21.71it/s][A[A[A


 64%|██████▎   | 298/469 [00:13<00:07, 22.29it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.9228e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.4835e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.2831e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.9228e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.4835e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.2831e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 64%|██████▍   | 301/469 [00:13<00:07, 22.32it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.3121e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.0974e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.7275e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.3121e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.0974e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.7275e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 65%|██████▍   | 304/469 [00:13<00:07, 22.24it/s][A[A[A


 65%|██████▌   | 307/469 [00:13<00:07, 21.34it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.4890e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.9121e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.4097e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.4890e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.9121e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.4097e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 66%|██████▌   | 310/469 [00:13<00:07, 22.14it/s][A[A[A

torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.6502e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.0651e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.6280e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.6502e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.0651e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.6280e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.0998e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.1815e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.7078e-06, device='cuda:0', 




 67%|██████▋   | 313/469 [00:13<00:07, 21.39it/s][A[A[A


 67%|██████▋   | 316/469 [00:13<00:06, 22.38it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.9783e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.7460e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.0316e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.9783e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.7460e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.0316e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 68%|██████▊   | 319/469 [00:14<00:06, 23.28it/s][A[A[A


 69%|██████▊   | 322/469 [00:14<00:06, 23.05it/s][A[A[A

torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.2051e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.1739e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(7.0593e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.2051e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.1739e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.0593e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.2155e-07, device='cuda:0', grad_f




 69%|██████▉   | 325/469 [00:14<00:06, 22.95it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.9090e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.4154e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(7.0012e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.9090e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.4154e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.0012e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 70%|██████▉   | 328/469 [00:14<00:06, 22.72it/s][A[A[A


 71%|███████   | 331/469 [00:14<00:05, 23.93it/s][A[A[A

[tensor(2.4855e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.9875e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.5449e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.4626e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.0586e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.7885e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.4626e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.0586e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.7885e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32




 71%|███████   | 334/469 [00:14<00:05, 24.37it/s][A[A[A


 72%|███████▏  | 337/469 [00:14<00:05, 23.53it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.6403e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.0813e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.1051e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.6403e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.0813e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.1051e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 72%|███████▏  | 340/469 [00:15<00:05, 22.90it/s][A[A[A


 73%|███████▎  | 343/469 [00:15<00:05, 23.66it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.3918e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.3326e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.6275e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.3918e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.3326e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.6275e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 74%|███████▍  | 346/469 [00:15<00:05, 22.67it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.1698e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.5945e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.0238e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.1698e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.5945e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.0238e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 74%|███████▍  | 349/469 [00:15<00:05, 21.97it/s][A[A[A


 75%|███████▌  | 352/469 [00:15<00:05, 22.41it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.4760e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.8543e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(7.7070e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.4760e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.8543e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.7070e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 76%|███████▌  | 355/469 [00:15<00:05, 22.00it/s][A[A[A

[tensor(3.4998e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.6737e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.0111e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.4998e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.6737e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.0111e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.6754e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(3.1794e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.7835e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.6754e-0




 76%|███████▋  | 358/469 [00:15<00:04, 22.38it/s][A[A[A


 77%|███████▋  | 361/469 [00:15<00:04, 22.13it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.7396e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.4688e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.1585e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.7396e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.4688e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.1585e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 78%|███████▊  | 364/469 [00:16<00:04, 22.39it/s][A[A[A

[tensor(2.8084e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.6280e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(7.6419e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.9711e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.6925e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.9845e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.9711e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.6925e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.9845e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32




 78%|███████▊  | 367/469 [00:16<00:04, 22.65it/s][A[A[A


 79%|███████▉  | 370/469 [00:16<00:04, 23.17it/s][A[A[A

torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.0011e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.8763e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.3754e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.0011e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.8763e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.3754e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.0778e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.8697e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.0262e-06, device='cuda:0', 




 80%|███████▉  | 373/469 [00:16<00:04, 22.92it/s][A[A[A


 80%|████████  | 376/469 [00:16<00:03, 23.91it/s][A[A[A


torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.4527e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.8386e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.3177e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.4527e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.8386e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.3177e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.3838e-07, device='cuda:0', grad_




 81%|████████  | 379/469 [00:16<00:03, 23.99it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.7983e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.7797e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.2251e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.7983e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.7797e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.2251e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 81%|████████▏ | 382/469 [00:16<00:03, 23.53it/s][A[A[A


 82%|████████▏ | 385/469 [00:16<00:03, 23.43it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.4618e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.6079e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.7385e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.4618e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.6079e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.7385e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 83%|████████▎ | 388/469 [00:17<00:03, 22.87it/s][A[A[A


 83%|████████▎ | 391/469 [00:17<00:03, 23.18it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.2977e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.9507e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.0856e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.2977e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.9507e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.0856e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 84%|████████▍ | 394/469 [00:17<00:03, 23.16it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.4991e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.5560e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.2444e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.4991e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.5560e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.2444e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 85%|████████▍ | 397/469 [00:17<00:03, 22.91it/s][A[A[A


 85%|████████▌ | 400/469 [00:17<00:02, 23.77it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.1490e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.2858e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.7750e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.1490e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.2858e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.7750e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 86%|████████▌ | 403/469 [00:17<00:02, 23.96it/s][A[A[A


 87%|████████▋ | 406/469 [00:17<00:02, 24.03it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.0613e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.4271e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.9258e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.0613e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4271e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.9258e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 87%|████████▋ | 409/469 [00:17<00:02, 24.23it/s][A[A[A


torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.4420e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.6097e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.0321e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.4420e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.6097e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.0321e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOS




 88%|████████▊ | 412/469 [00:18<00:02, 23.19it/s][A[A[A


 88%|████████▊ | 415/469 [00:18<00:02, 23.54it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.8354e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.1847e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.3439e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.8354e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.1847e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.3439e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 89%|████████▉ | 418/469 [00:18<00:02, 23.53it/s][A[A[A


 90%|████████▉ | 421/469 [00:18<00:01, 24.64it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.8781e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.3984e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.6122e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.8781e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.3984e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.6122e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 90%|█████████ | 424/469 [00:18<00:01, 24.28it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.4995e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.6828e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(8.3057e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.4995e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.6828e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(8.3057e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 91%|█████████ | 427/469 [00:18<00:01, 23.80it/s][A[A[A


 92%|█████████▏| 430/469 [00:18<00:01, 23.04it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.5145e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.4452e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.1471e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.5145e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4452e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.1471e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 92%|█████████▏| 433/469 [00:19<00:01, 21.57it/s][A[A[A

torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.3347e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.3937e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.4678e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.3347e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.3937e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.4678e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.3076e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.2091e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.9318e-06, device='cuda:0', 




 93%|█████████▎| 436/469 [00:19<00:01, 21.71it/s][A[A[A


 94%|█████████▎| 439/469 [00:19<00:01, 23.19it/s][A[A[A

[tensor(2.1948e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.3708e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.4503e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.1948e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.3708e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.4503e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.1222e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.1771e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.2365e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.1222e-0




 94%|█████████▍| 442/469 [00:19<00:01, 22.47it/s][A[A[A

torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.2857e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.4254e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.7017e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.2857e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4254e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.7017e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.4713e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.6991e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(4.8664e-06, device='cuda:0', grad_fn=<MeanBackward0




 95%|█████████▍| 445/469 [00:19<00:01, 23.07it/s][A[A[A


 96%|█████████▌| 448/469 [00:19<00:00, 23.22it/s][A[A[A


torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.3550e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.5176e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.0188e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.3550e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.5176e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.0188e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.1496e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.7259e-06, device='cuda:0', grad




 96%|█████████▌| 451/469 [00:19<00:00, 22.19it/s][A[A[A

[tensor(2.5245e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.4294e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.8302e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.5245e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4294e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.8302e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.5224e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.2448e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.1149e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.5224e-0




 97%|█████████▋| 454/469 [00:19<00:00, 22.31it/s][A[A[A


 97%|█████████▋| 457/469 [00:20<00:00, 22.33it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.5187e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.1656e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(4.4285e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.5187e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.1656e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(4.4285e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 98%|█████████▊| 460/469 [00:20<00:00, 21.97it/s][A[A[A

LOSS _ GROUPS AFTER
[tensor(2.5673e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.2351e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(4.7813e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.4487e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.3738e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.1225e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.4487e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.3738e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.1225e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.S




 99%|█████████▉| 464/469 [00:20<00:00, 24.35it/s][A[A[A


100%|█████████▉| 468/469 [00:20<00:00, 26.35it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.7263e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.7379e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.9235e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.7263e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.7379e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.9235e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size(

100%|██████████| 469/469 [00:20<00:00, 22.75it/s]


INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.4530e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.3466e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.0654e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.4530e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.3466e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.0654e-05, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




  0%|          | 0/469 [00:00<?, ?it/s][A[A[A

INSIDE AT_LOSS
torch.Size([16, 16, 32, 32])
torch.Size([16, 32, 32, 32])
torch.Size([16, 1024])
torch.Size([16, 1024])
INSIDE AT_LOSS
torch.Size([16, 32, 16, 16])
torch.Size([16, 64, 16, 16])
torch.Size([16, 256])
torch.Size([16, 256])
INSIDE AT_LOSS
torch.Size([16, 64, 8, 8])
torch.Size([16, 128, 8, 8])
torch.Size([16, 64])
torch.Size([16, 64])
LOSS _ GROUPS BEFORE
[tensor(2.4820e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.0063e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.8565e-05, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.4820e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.0063e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.8565e-05, device='cuda:0', grad_fn=<SumBackward0>)]
{'channel': 0, 'classes': 10, 'depth': 10, 'width': 1, 'dataset': 'FashionMNIST', 'dataroot': '.', 'dtype': 'float', 'nthread': 4, 'teacher_id': 'resnet_52_2_teacher', 'batch_size': 128, 'lr': 0.1, 'epochs': 3, 'weight_decay': 0.0005, 'ep




  0%|          | 1/469 [00:00<02:28,  3.15it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.1217e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.7440e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.1471e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.1217e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.7440e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.1471e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




  1%|          | 4/469 [00:00<01:49,  4.23it/s][A[A[A


  1%|▏         | 7/469 [00:00<01:22,  5.62it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.1413e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.2930e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.7054e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.1413e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.2930e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.7054e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




  2%|▏         | 10/469 [00:00<01:02,  7.30it/s][A[A[A


  3%|▎         | 13/469 [00:00<00:49,  9.19it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.0772e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.4142e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(4.7975e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.0772e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4142e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(4.7975e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




  3%|▎         | 16/469 [00:00<00:39, 11.42it/s][A[A[A


  4%|▍         | 19/469 [00:01<00:33, 13.48it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.0628e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.3883e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.2281e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.0628e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.3883e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.2281e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.0377e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.4138e-06, device=




  5%|▍         | 22/469 [00:01<00:29, 15.30it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.1117e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.4183e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.1277e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.1117e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4183e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.1277e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




  5%|▌         | 25/469 [00:01<00:26, 16.85it/s][A[A[A


  6%|▌         | 28/469 [00:01<00:23, 18.54it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.3902e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.0785e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(4.1747e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.3902e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.0785e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(4.1747e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




  7%|▋         | 31/469 [00:01<00:21, 20.08it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.8134e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.5969e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.3396e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.8134e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.5969e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.3396e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




  7%|▋         | 34/469 [00:01<00:21, 20.66it/s][A[A[A


  8%|▊         | 37/469 [00:01<00:20, 21.55it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(3.5072e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.2713e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(4.5846e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(3.5072e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.2713e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(4.5846e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




  9%|▊         | 40/469 [00:01<00:19, 21.91it/s][A[A[A


  9%|▉         | 43/469 [00:02<00:19, 22.09it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.4252e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.6564e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.3695e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.4252e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.6564e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.3695e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 10%|▉         | 46/469 [00:02<00:18, 22.55it/s][A[A[A

LOSS _ GROUPS AFTER
[tensor(2.2593e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4947e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.4601e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.2692e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.3941e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.5461e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.2692e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.3941e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.5461e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.S




 10%|█         | 49/469 [00:02<00:18, 22.62it/s][A[A[A


 11%|█         | 52/469 [00:02<00:18, 23.16it/s][A[A[A

[tensor(2.8833e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.3117e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(4.1224e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.7963e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.4377e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.4176e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.7963e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4377e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.4176e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32




 12%|█▏        | 55/469 [00:02<00:17, 23.50it/s][A[A[A


 12%|█▏        | 58/469 [00:02<00:17, 23.32it/s][A[A[A


3
LOSS _ GROUPS AFTER
[tensor(2.6442e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.5817e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.0758e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.4859e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.4786e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.8653e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.4859e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4786e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.8653e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torc




 13%|█▎        | 61/469 [00:02<00:18, 22.00it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.0407e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.5264e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.4760e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.0407e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.5264e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.4760e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 14%|█▎        | 64/469 [00:03<00:18, 21.38it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.9478e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.4507e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.6273e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.9478e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4507e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.6273e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 14%|█▍        | 67/469 [00:03<00:18, 21.86it/s][A[A[A


 15%|█▍        | 70/469 [00:03<00:18, 21.81it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.4799e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.2332e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.2163e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.4799e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.2332e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.2163e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 16%|█▌        | 73/469 [00:03<00:17, 22.47it/s][A[A[A


 16%|█▌        | 76/469 [00:03<00:17, 22.52it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.5215e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.5880e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.0866e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.5215e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.5880e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.0866e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 17%|█▋        | 79/469 [00:03<00:17, 22.24it/s][A[A[A

[tensor(1.4628e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4422e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(4.8381e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.4191e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.4035e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.3053e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.4191e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4035e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.3053e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32




 17%|█▋        | 82/469 [00:03<00:17, 22.29it/s][A[A[A


 18%|█▊        | 85/469 [00:04<00:17, 21.94it/s][A[A[A

torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.4834e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.6948e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(4.8588e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.4834e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.6948e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(4.8588e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.4379e-07, device='cuda:0', grad_fn=<MeanBackward0>), ten




 19%|█▉        | 88/469 [00:04<00:17, 21.77it/s][A[A[A


 19%|█▉        | 91/469 [00:04<00:16, 22.92it/s]

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.5276e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.7102e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.1663e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.5276e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.7102e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.1663e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch

[A[A[A


 20%|██        | 94/469 [00:04<00:16, 22.65it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.0740e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.2736e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.4796e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.0740e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.2736e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.4796e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 21%|██        | 97/469 [00:04<00:16, 22.31it/s][A[A[A


 21%|██▏       | 100/469 [00:04<00:16, 22.81it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.4186e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.1847e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(4.7289e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.4186e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.1847e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(4.7289e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 22%|██▏       | 103/469 [00:04<00:15, 23.04it/s][A[A[A

[tensor(1.6051e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.6487e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.4291e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.6808e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.4391e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.2233e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.6808e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4391e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.2233e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32




 23%|██▎       | 106/469 [00:04<00:16, 22.14it/s][A[A[A


 23%|██▎       | 109/469 [00:05<00:15, 22.74it/s][A[A[A

torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.1499e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.1485e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.9394e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.1499e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.1485e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.9394e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256




 24%|██▍       | 112/469 [00:05<00:15, 22.58it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.7487e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(2.0366e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(4.8648e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.7487e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(2.0366e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(4.8648e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 25%|██▍       | 115/469 [00:05<00:15, 22.65it/s][A[A[A


 25%|██▌       | 118/469 [00:05<00:15, 22.76it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.7418e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.6038e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(4.6691e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.7418e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.6038e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(4.6691e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 26%|██▌       | 121/469 [00:05<00:14, 23.21it/s][A[A[A


 26%|██▋       | 124/469 [00:05<00:14, 23.56it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.1896e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.4453e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.6044e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.1896e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.4453e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.6044e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 27%|██▋       | 127/469 [00:05<00:14, 23.24it/s][A[A[A


 28%|██▊       | 130/469 [00:05<00:14, 23.66it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.8224e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.1189e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.0360e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.8224e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.1189e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.0360e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 28%|██▊       | 133/469 [00:06<00:14, 23.37it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.9239e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.7695e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.0952e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.9239e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.7695e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.0952e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 29%|██▉       | 136/469 [00:06<00:14, 23.00it/s][A[A[A


 30%|██▉       | 139/469 [00:06<00:13, 23.59it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(2.1172e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.6462e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.1897e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(2.1172e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.6462e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.1897e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 30%|███       | 142/469 [00:06<00:13, 23.93it/s][A[A[A


 31%|███       | 145/469 [00:06<00:13, 24.63it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.8918e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.6108e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(6.1388e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.8918e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.6108e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(6.1388e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch




 32%|███▏      | 148/469 [00:06<00:13, 23.78it/s][A[A[A

[tensor(1.6422e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.2813e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.2561e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.6422e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.2813e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.2561e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.5203e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.1031e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.0023e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.5203e-0




 32%|███▏      | 151/469 [00:06<00:13, 23.43it/s][A[A[A


 33%|███▎      | 154/469 [00:06<00:13, 23.56it/s][A[A[A

INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch.Size([128, 256])
INSIDE AT_LOSS
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 8, 8])
torch.Size([128, 64])
torch.Size([128, 64])
LOSS _ GROUPS BEFORE
[tensor(1.6050e-07, device='cuda:0', grad_fn=<MeanBackward0>), tensor(1.8373e-06, device='cuda:0', grad_fn=<MeanBackward0>), tensor(5.9982e-06, device='cuda:0', grad_fn=<MeanBackward0>)]
3
LOSS _ GROUPS AFTER
[tensor(1.6050e-07, device='cuda:0', grad_fn=<SumBackward0>), tensor(1.8373e-06, device='cuda:0', grad_fn=<SumBackward0>), tensor(5.9982e-06, device='cuda:0', grad_fn=<SumBackward0>)]
INSIDE AT_LOSS
torch.Size([128, 16, 32, 32])
torch.Size([128, 32, 32, 32])
torch.Size([128, 1024])
torch.Size([128, 1024])
INSIDE AT_LOSS
torch.Size([128, 32, 16, 16])
torch.Size([128, 64, 16, 16])
torch.Size([128, 256])
torch

KeyboardInterrupt: ignored

In [None]:
student_args = EasyDict({
    "channel": 0,
    "classes": 10,
    "depth": 10,
    "width": 1,
    "dataset" : 'FashionMNIST',
    "dataroot" : '.',
    "dtype" : 'float',
    "nthread" : 4,
    "teacher_id" : '',
    "batch_size" : 128,
    "lr" : 0.1,
    "epochs" : 10,
    "weight_decay" : 0.0005,
    "epoch_step" : '[60, 120, 160]',
    "lr_decay_ratio" : 0.2,
    "resume" : '',
    "randomcrop_pad" : 4,
    "temperature" : 4,
    "alpha" : 0,
    "beta" : 0,
    "gpu_id" : '0',
    "save" : '/content/logs/at_10_1'
})

print('parsed options:', student_args)
student = model_parameters(student_args)




  0%|          | 0/469 [00:00<?, ?it/s]

parsed options: {'channel': 0, 'classes': 10, 'depth': 10, 'width': 1, 'dataset': 'FashionMNIST', 'dataroot': '.', 'dtype': 'float', 'nthread': 4, 'teacher_id': '', 'batch_size': 128, 'lr': 0.1, 'epochs': 10, 'weight_decay': 0.0005, 'epoch_step': '[60, 120, 160]', 'lr_decay_ratio': 0.2, 'resume': '', 'randomcrop_pad': 4, 'temperature': 4, 'alpha': 0, 'beta': 0, 'gpu_id': '0', 'save': '/content/logs/at_10_1'}
creating optimizer with lr =  0.1

Parameters:
0     conv0                             (16, 1, 3, 3)           torch.cuda.FloatTensor True
1     group0.block0.conv0               (16, 16, 3, 3)          torch.cuda.FloatTensor True
2     group0.block0.conv1               (16, 16, 3, 3)          torch.cuda.FloatTensor True
3     group0.block0.bn0.weight          (16,)                   torch.cuda.FloatTensor True
4     group0.block0.bn0.bias            (16,)                   torch.cuda.FloatTensor True
5     group0.block0.bn0.running_mean    (16,)                   torch.cuda.FloatT

100%|██████████| 469/469 [00:13<00:00, 34.65it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

{'channel': 0, 'classes': 10, 'depth': 10, 'width': 1, 'dataset': 'FashionMNIST', 'dataroot': '.', 'dtype': 'float', 'nthread': 4, 'teacher_id': '', 'batch_size': 128, 'lr': 0.1, 'epochs': 10, 'weight_decay': 0.0005, 'epoch_step': '[60, 120, 160]', 'lr_decay_ratio': 0.2, 'resume': '', 'randomcrop_pad': 4, 'temperature': 4, 'alpha': 0, 'beta': 0, 'gpu_id': '0', 'save': '/content/logs/at_10_1', 'train_loss': 0.7520362459011934, 'train_acc': 72.68166666666667, 'test_loss': 1.4189174643045739, 'test_acc': 60.11, 'epoch': 1, 'num_classes': 10, 'n_parameters': 78042, 'train_time': 13.536037921905518, 'test_time': 1.1514101028442383, 'at_losses': [(nan, nan), (nan, nan), (nan, nan)]}
None
==> id: /content/logs/at_10_1 (1/10), test_acc: [91m60.11[0m


100%|██████████| 469/469 [00:13<00:00, 34.62it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

{'channel': 0, 'classes': 10, 'depth': 10, 'width': 1, 'dataset': 'FashionMNIST', 'dataroot': '.', 'dtype': 'float', 'nthread': 4, 'teacher_id': '', 'batch_size': 128, 'lr': 0.1, 'epochs': 10, 'weight_decay': 0.0005, 'epoch_step': '[60, 120, 160]', 'lr_decay_ratio': 0.2, 'resume': '', 'randomcrop_pad': 4, 'temperature': 4, 'alpha': 0, 'beta': 0, 'gpu_id': '0', 'save': '/content/logs/at_10_1', 'train_loss': 0.4356433651975985, 'train_acc': 84.16833333333334, 'test_loss': 1.672426276569125, 'test_acc': 49.660000000000004, 'epoch': 2, 'num_classes': 10, 'n_parameters': 78042, 'train_time': 13.54922342300415, 'test_time': 1.1172783374786377, 'at_losses': [(nan, nan), (nan, nan), (nan, nan)]}
None
==> id: /content/logs/at_10_1 (2/10), test_acc: [91m49.66[0m


100%|██████████| 469/469 [00:13<00:00, 34.74it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

{'channel': 0, 'classes': 10, 'depth': 10, 'width': 1, 'dataset': 'FashionMNIST', 'dataroot': '.', 'dtype': 'float', 'nthread': 4, 'teacher_id': '', 'batch_size': 128, 'lr': 0.1, 'epochs': 10, 'weight_decay': 0.0005, 'epoch_step': '[60, 120, 160]', 'lr_decay_ratio': 0.2, 'resume': '', 'randomcrop_pad': 4, 'temperature': 4, 'alpha': 0, 'beta': 0, 'gpu_id': '0', 'save': '/content/logs/at_10_1', 'train_loss': 0.3798833021731264, 'train_acc': 86.21166666666666, 'test_loss': 1.7928225239620932, 'test_acc': 65.67, 'epoch': 3, 'num_classes': 10, 'n_parameters': 78042, 'train_time': 13.503987550735474, 'test_time': 1.1379954814910889, 'at_losses': [(nan, nan), (nan, nan), (nan, nan)]}
None
==> id: /content/logs/at_10_1 (3/10), test_acc: [91m65.67[0m


100%|██████████| 469/469 [00:13<00:00, 34.67it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

{'channel': 0, 'classes': 10, 'depth': 10, 'width': 1, 'dataset': 'FashionMNIST', 'dataroot': '.', 'dtype': 'float', 'nthread': 4, 'teacher_id': '', 'batch_size': 128, 'lr': 0.1, 'epochs': 10, 'weight_decay': 0.0005, 'epoch_step': '[60, 120, 160]', 'lr_decay_ratio': 0.2, 'resume': '', 'randomcrop_pad': 4, 'temperature': 4, 'alpha': 0, 'beta': 0, 'gpu_id': '0', 'save': '/content/logs/at_10_1', 'train_loss': 0.3513042730118419, 'train_acc': 87.27333333333334, 'test_loss': 0.6467514196528661, 'test_acc': 76.03, 'epoch': 4, 'num_classes': 10, 'n_parameters': 78042, 'train_time': 13.531157732009888, 'test_time': 1.1443262100219727, 'at_losses': [(nan, nan), (nan, nan), (nan, nan)]}
None
==> id: /content/logs/at_10_1 (4/10), test_acc: [91m76.03[0m


100%|██████████| 469/469 [00:13<00:00, 34.30it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

{'channel': 0, 'classes': 10, 'depth': 10, 'width': 1, 'dataset': 'FashionMNIST', 'dataroot': '.', 'dtype': 'float', 'nthread': 4, 'teacher_id': '', 'batch_size': 128, 'lr': 0.1, 'epochs': 10, 'weight_decay': 0.0005, 'epoch_step': '[60, 120, 160]', 'lr_decay_ratio': 0.2, 'resume': '', 'randomcrop_pad': 4, 'temperature': 4, 'alpha': 0, 'beta': 0, 'gpu_id': '0', 'save': '/content/logs/at_10_1', 'train_loss': 0.33300570771892474, 'train_acc': 88.07666666666667, 'test_loss': 0.7118821566617943, 'test_acc': 74.95, 'epoch': 5, 'num_classes': 10, 'n_parameters': 78042, 'train_time': 13.675455331802368, 'test_time': 1.1399693489074707, 'at_losses': [(nan, nan), (nan, nan), (nan, nan)]}
None
==> id: /content/logs/at_10_1 (5/10), test_acc: [91m74.95[0m


100%|██████████| 469/469 [00:13<00:00, 34.77it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

{'channel': 0, 'classes': 10, 'depth': 10, 'width': 1, 'dataset': 'FashionMNIST', 'dataroot': '.', 'dtype': 'float', 'nthread': 4, 'teacher_id': '', 'batch_size': 128, 'lr': 0.1, 'epochs': 10, 'weight_decay': 0.0005, 'epoch_step': '[60, 120, 160]', 'lr_decay_ratio': 0.2, 'resume': '', 'randomcrop_pad': 4, 'temperature': 4, 'alpha': 0, 'beta': 0, 'gpu_id': '0', 'save': '/content/logs/at_10_1', 'train_loss': 0.3229086582086231, 'train_acc': 88.325, 'test_loss': 0.8934873342514039, 'test_acc': 71.8, 'epoch': 6, 'num_classes': 10, 'n_parameters': 78042, 'train_time': 13.48927927017212, 'test_time': 1.1384053230285645, 'at_losses': [(nan, nan), (nan, nan), (nan, nan)]}
None
==> id: /content/logs/at_10_1 (6/10), test_acc: [91m71.80[0m


100%|██████████| 469/469 [00:13<00:00, 34.77it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

{'channel': 0, 'classes': 10, 'depth': 10, 'width': 1, 'dataset': 'FashionMNIST', 'dataroot': '.', 'dtype': 'float', 'nthread': 4, 'teacher_id': '', 'batch_size': 128, 'lr': 0.1, 'epochs': 10, 'weight_decay': 0.0005, 'epoch_step': '[60, 120, 160]', 'lr_decay_ratio': 0.2, 'resume': '', 'randomcrop_pad': 4, 'temperature': 4, 'alpha': 0, 'beta': 0, 'gpu_id': '0', 'save': '/content/logs/at_10_1', 'train_loss': 0.3109663669297944, 'train_acc': 88.64, 'test_loss': 1.0020114651209187, 'test_acc': 68.48, 'epoch': 7, 'num_classes': 10, 'n_parameters': 78042, 'train_time': 13.490202188491821, 'test_time': 1.1207621097564697, 'at_losses': [(nan, nan), (nan, nan), (nan, nan)]}
None
==> id: /content/logs/at_10_1 (7/10), test_acc: [91m68.48[0m


100%|██████████| 469/469 [00:13<00:00, 34.55it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

{'channel': 0, 'classes': 10, 'depth': 10, 'width': 1, 'dataset': 'FashionMNIST', 'dataroot': '.', 'dtype': 'float', 'nthread': 4, 'teacher_id': '', 'batch_size': 128, 'lr': 0.1, 'epochs': 10, 'weight_decay': 0.0005, 'epoch_step': '[60, 120, 160]', 'lr_decay_ratio': 0.2, 'resume': '', 'randomcrop_pad': 4, 'temperature': 4, 'alpha': 0, 'beta': 0, 'gpu_id': '0', 'save': '/content/logs/at_10_1', 'train_loss': 0.304773916472504, 'train_acc': 89.15333333333334, 'test_loss': 1.5166614870481854, 'test_acc': 60.28, 'epoch': 8, 'num_classes': 10, 'n_parameters': 78042, 'train_time': 13.576308965682983, 'test_time': 1.1312522888183594, 'at_losses': [(nan, nan), (nan, nan), (nan, nan)]}
None
==> id: /content/logs/at_10_1 (8/10), test_acc: [91m60.28[0m


 19%|█▉        | 88/469 [00:02<00:11, 33.11it/s]