# Fisher Exponential

Now that we made several advances in the FisherFIFO algorithm, maybe we can lend some of these improvements to the FisherExponential as well. 

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import time
import math
import os
import json

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from skopt import gp_minimize

from scipy import stats

In [2]:
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [3]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader

import torchvision
import torchvision.datasets as datasets

In [4]:
def get_device():    
    if torch.cuda.is_available():

        device = torch.device('cuda')
        print( torch.cuda.get_device_name(device) )
        print( torch.cuda.get_device_properties(device) )

    else:
        device = torch.device('cpu')
        print(device)
        
    return device

In [5]:
!pip install torchsummary
import torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
[0m

In [6]:
class cfg:
    # n_features = 28 * 28
    img_size = (32, 32)
    img_channels = 3
    n_classes = 100  ## we have 100 classes in CIFAR100
    
    # device = torch.device('cpu')
    device = get_device()
    
    max_loss = 20.0

Tesla P100-PCIE-16GB
_CudaDeviceProperties(name='Tesla P100-PCIE-16GB', major=6, minor=0, total_memory=16280MB, multi_processor_count=56)


# create the dataset

In [7]:
def generate_dataset_mnist(batch_size):
    print(f'generating MNIST data with {cfg.n_classes} classes')
    
    transf_ = torchvision.transforms.Compose([
        # torchvision.transforms.Resize(size=[14, 14]),
        torchvision.transforms.ToTensor()
    ])
    
    mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transf_)
    mnist_test  = datasets.MNIST(root='./data', train=False, download=True, transform=transf_)
    
    mnist_train_dataloader = DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)
    mnist_test_dataloader  = DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=False)

    return mnist_train_dataloader, mnist_test_dataloader

In [8]:
def generate_dataset_cifar10(batch_size):
    print(f'generating CIFAR10 data with {cfg.n_classes} classes')
    
    transf_ = torchvision.transforms.Compose([
        # torchvision.transforms.Resize(size=[14, 14]),
        torchvision.transforms.ToTensor()
    ])
    
    cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transf_)
    cifar10_test  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transf_)
    
    cifar10_train_dataloader = DataLoader(dataset=cifar10_train, batch_size=batch_size, shuffle=True)
    cifar10_test_dataloader  = DataLoader(dataset=cifar10_test, batch_size=batch_size, shuffle=False)

    return cifar10_train_dataloader, cifar10_test_dataloader

In [9]:
def generate_dataset_cifar100(batch_size):
    print(f'generating CIFAR100 data with {cfg.n_classes} classes')
    
    transf_ = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])
    
    cifar100_train = datasets.CIFAR100(root='./data', train=True, download=True, transform=transf_)
    cifar100_test  = datasets.CIFAR100(root='./data', train=False, download=True, transform=transf_)
    
    cifar100_train_dataloader = DataLoader(dataset=cifar100_train, batch_size=batch_size, shuffle=True)
    cifar100_test_dataloader  = DataLoader(dataset=cifar100_test, batch_size=batch_size, shuffle=False)

    return cifar100_train_dataloader, cifar100_test_dataloader

## declaring network architecture

In [10]:
def get_default_network(c=16, device=cfg.device):
    net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(in_features=cfg.n_features, out_features=c),
        nn.ReLU(),
        nn.Linear(in_features=c, out_features=c),
        nn.ReLU(),
        nn.Linear(in_features=c, out_features=c),
        nn.ReLU(),
        nn.Linear(in_features=c, out_features=cfg.n_classes)
    )
    
    torchsummary.summary(net, input_size=[[cfg.n_features]], device='cpu')
    
    return net

In [11]:
def get_cnn_network(in_channels=cfg.img_channels, c=16, p_drop=0.1, device=cfg.device):
    
    img_flat_size = (4 * c * (cfg.img_size[0] // 8) * (cfg.img_size[1] // 8) )
    print(img_flat_size)
    net = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=c, kernel_size=5, stride=2, padding=2),
        nn.ReLU(),

        nn.Conv2d(in_channels=c, out_channels=(2 * c), kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Dropout2d(p=p_drop),
        
        nn.Conv2d(in_channels=(2 * c), out_channels=(4 * c), kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Dropout2d(p=p_drop),
        
        nn.Flatten(),
        
        nn.Linear(in_features=img_flat_size, out_features=(8 * c) ),
        nn.ReLU(),
        nn.Dropout(p=p_drop),
        
        nn.Linear(in_features=(8 * c), out_features=(4 * c) ),
        nn.ReLU(),
        nn.Dropout(p=p_drop),
        
        nn.Linear(in_features=(4 * c), out_features=cfg.n_classes)
    )
    
    torchsummary.summary(net, input_size=[[cfg.img_channels, *cfg.img_size]], device='cpu')
    
    return net

In [12]:
def get_cnn_network_v2(in_channels=cfg.img_channels, p_drop=0.1, device=cfg.device):
    
    net = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=96, kernel_size=5, padding=2),
        nn.MaxPool2d(kernel_size=2),
        nn.ReLU(),

        nn.Conv2d(in_channels=96, out_channels=80, kernel_size=5, padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Dropout2d(p=p_drop),
        
        nn.Conv2d(in_channels=80, out_channels=96, kernel_size=5, padding=2),
        nn.ReLU(),
        nn.Dropout2d(p=p_drop),
        
        nn.Conv2d(in_channels=96, out_channels=64, kernel_size=5, padding=2),
        nn.ReLU(),
        nn.Dropout2d(p=p_drop),
        
        nn.Flatten(),
        
        # nn.Linear(in_features=4096, out_features=256 ),
        nn.Linear(in_features=(cfg.img_size[0] // 4) * (cfg.img_size[1] // 4) * 64, out_features=256 ),
        nn.ReLU(),
        nn.Dropout(p=p_drop),
        
        nn.Linear(in_features=256, out_features=cfg.n_classes)
    )
    
    torchsummary.summary(net, input_size=[[cfg.img_channels, *cfg.img_size]], device='cpu')
    
    return net

In [13]:
def get_resnet18(device=cfg.device):
    
    net = torchvision.models.resnet18(num_classes=cfg.n_classes)
    torchsummary.summary(net, input_size=[[cfg.img_channels, *cfg.img_size]], device='cpu')
    
    return net

# object for calculation of the metrics

In [14]:
class Metrics():
    def __init__(self, value_round=None, time_round=None):
        self.metrics_dict = {}
        self.set_initial_time()
        self.val_round = value_round
        self.time_round = time_round
        
    def set_initial_time(self):
        self.init_time = time.time()
        
    def get_time(self):
        return time.time() - self.init_time
    
    def add(self, key, value, step=None):
        
        if step is None:
            step = np.nan
        
        if key not in self.metrics_dict:
            self.metrics_dict[key] = []
        
        t = self.get_time()
        if self.time_round is not None:
            t = round(t, ndigits=self.time_round)
        
        if self.val_round is not None:
            value = round(value, ndigits=self.val_round)
        
        self.metrics_dict[key].append( (value, step, t) )
    
    def add_(self, dict_, step=None):
        for key, value in dict_.items():
            self.add(key, value, step)
    
    def get(self, key, get_step=False, get_time=False):
        y, x, t = zip(*self.metrics_dict[key])
        y, x, t = list(y), list(x), list(t)
        
        return x, y, t

# Fisher Information calculation objects

In [15]:
class FisherExp():
    def __init__(self,
                 named_params,
                 beta,
                 partition_size,
                 block_updates):
        
        self.beta = beta
        self.partition_size = partition_size
        self.block_updates = block_updates
        
        self.max_inv_norm = 10 * math.sqrt(partition_size)
        
        named_params = list(named_params)
        
        self.partition_fisher_list = []
        total_partitions, total_block_upd = 0, 0
        for pi, (n, p) in enumerate( named_params ):
            part_fisher = PartitionerFisher(param = p,
                                            name = n,
                                            partition_size = partition_size,
                                            block_updates = block_updates)
            
            self.partition_fisher_list.append( (p, part_fisher, total_partitions, total_block_upd) )
            
            total_partitions += part_fisher.num_part
            total_block_upd += part_fisher.block_updates
            
        self.num_part = total_partitions
        self.total_block_updates = total_block_upd
        
        print(f'total partitions: {self.num_part} - effective block updates: {self.total_block_updates}')
        
        ## pre-alocate the memory for the tensor that stores the selected gradients (changes every iteration)
        self.g = torch.zeros(size=[self.total_block_updates, partition_size, 1], dtype=torch.float, device=cfg.device)
        
        ## pre-alocate the memory for the tensor that stores the inverses
        self.fisher_inv = torch.zeros(size=[self.num_part, partition_size, partition_size], dtype=torch.float, device=cfg.device)        
    
        print('initializing inverses...')
        ## now we initialize the inverse for all partitions. In the case of the FisherExponential, we can just
        ## use the identity matrix
        i = 0
        for _, part_fisher, _, _ in self.partition_fisher_list:
            for _, start, end in part_fisher.ind_fisher_list:

                if i == 0 or ( (i + 1) % 10000 ) == 0 or i == (self.num_part - 1):
                    print(f'partition {i+1}/{self.num_part}')

                n = end - start
                fisher_inv = torch.eye(n = n, dtype=torch.float, device=cfg.device)

                self.fisher_inv[i, :n, :n] = fisher_inv

                i += 1
                

    def get_idx_lists(self):
        run_enc_list = []
        default_idx_list = []
        for p, part_fisher, num_part, block_upd in self.partition_fisher_list:
            init_block, end_block, g_init_idx, g_end_idx = part_fisher.get_random_blocks()

            run_enc_list.append( (num_part + init_block, num_part + end_block, g_init_idx, g_end_idx) )
            default_idx_list.append( np.arange(start=num_part + init_block, stop=num_part + end_block + 1) )    
            
        return run_enc_list, np.concatenate(default_idx_list)
    
    
    def read_gradients(self, idx):
        self_g_start = 0
        for i, (_, _, g_start, g_end) in enumerate(idx):
            n_grad = g_end - g_start
            # self_g_end = min( self_g_start + n_grad, torch.numel(self.g) )
            self_g_end = self_g_start + n_grad
            
            p, _, _, _ = self.partition_fisher_list[i]
            
            self.g.view(-1)[self_g_start:self_g_end] = p.grad.view(-1)[g_start:g_end]
            
            if (n_grad % self.partition_size) > 0:
                extra_zeros = self.partition_size - (n_grad % self.partition_size)
                self.g.view(-1)[self_g_end:(self_g_end + extra_zeros)] = 0.0
            else:
                extra_zeros = 0

            self_g_start = self_g_end + extra_zeros
            

    def write_gradients(self, idx):
        self_g_start = 0
        for i, (_, _, g_start, g_end) in enumerate(idx):
            n_grad = g_end - g_start
            self_g_end = self_g_start + n_grad
            
            p, _, _, _ = self.partition_fisher_list[i]
            p.grad.view(-1)[g_start:g_end] = self.g.view(-1)[self_g_start:self_g_end]

            if (n_grad % self.partition_size) > 0:
                extra_zeros = self.partition_size - (n_grad % self.partition_size)
            else:
                extra_zeros = 0
            
            self_g_start = self_g_end + extra_zeros

    
    def step(self):
        ## selects the blocks to be updated
        run_enc_idx, default_idx = self.get_idx_lists()
        
        ## read the selected blocks gradients and stores them in self.g
        self.read_gradients(run_enc_idx)
        
        ## set apart the inverses for the selected blocks
        inv = self.fisher_inv[default_idx, ...]

        ## update the inverses and modify current gradients        
        _, new_inv = self.upd_inverse(g = math.sqrt(1 - self.beta) * self.g, inverse = inv / self.beta)
        
        ## calculate inverses Frob. norm, and clip the matrices
        new_inv_norm = torch.sqrt( torch.sum(new_inv**2, dim=[1, 2], keepdim=True) )
        norm_coefs = torch.where( new_inv_norm > self.max_inv_norm, self.max_inv_norm / new_inv_norm, torch.ones_like(new_inv_norm) )
        new_inv = new_inv * norm_coefs

        ## get the modified gradient using "de facto" the new inverses and the gradients
        # self.g = self.modify_grad(self.g, inv)
        self.g = self.modify_grad(self.g, new_inv)
        
        ## return the inverses and buffers to the main tensor
        self.fisher_inv[default_idx, ...] = new_inv
        # print(new_inv.shape)
        
        ## return the modified gradients to the parameters
        self.write_gradients(run_enc_idx)


    def upd_inverse(self, g, inverse, type_='sum'):
        ## update the inverse based on the woodbury inversion
        f_inv_g = torch.bmm(inverse, g)

        if type_ == 'sum':
            d = 1 + torch.sum(g * f_inv_g, dim=[1, 2], keepdim=True)
            inverse[:] = inverse - (f_inv_g * torch.transpose(f_inv_g, 1, 2) / d)

        elif type_ == 'sub':
            d = 1 - torch.sum(g * f_inv_g, dim=[1, 2], keepdim=True)
            inverse[:] = inverse + (f_inv_g * torch.transpose(f_inv_g, 1, 2) / d)

        else:
            ## incorrect type
            print('incorrect rank-1 update type: ' + type_)
        
        return f_inv_g, inverse


    def modify_grad(self, g, inverse):
        return torch.bmm(inverse, g)

In [16]:
class PartitionerFisher():
    def __init__(self,
                 param,
                 name,
                 partition_size,
                 block_updates):
        
        self.param = param
        self.name = name 
        
        if partition_size is None:
            self.partition_size = param.numel()
        else:
            self.partition_size = partition_size
        
        ## calculates the number of partitions required. It is calculated using the param size and
        ## our partition maximum size. The gradient (the same size as param) is going to be partitioned in
        ## equal pieces (except possibly the last one) to be processed individually by our "IndividualFisherXXX"
        self.param_size = param.numel()
        self.num_part = math.ceil(self.param_size / self.partition_size)
        
        ## the number of blocks (partitions) to update at each iteration. This can be < num_part to make
        ## the algorithm more efficient. (we dont update every partition at every iteration)
        if block_updates is None:
            self.block_updates = self.num_part
        else:
            self.block_updates = min(block_updates, self.num_part)
        
        print(f'FisherPartitioner: param: {self.param_size} - partition: {self.partition_size} - nº part: {self.num_part} - block updates: {self.block_updates}')
                
        ## the list stores the indexes used to partition the gradient
        self.ind_fisher_list = []
        for i in range(self.num_part):
            start = i * self.partition_size
            end = min(start + self.partition_size, self.param_size)
            
            self.ind_fisher_list.append( (i, start, end) )
        
    
    def get_random_blocks(self, num_part=None, block_upd=None):
        
        if num_part is None:
            num_part = self.num_part
        
        if block_upd is None:
            block_upd = self.block_updates
        
        ## choose the initial block randomly
        init_block = np.random.choice(num_part - block_upd + 1)
        
        ## the final block will be necessarily `block_upd` blocks further. This means we select
        ## a contiguous sequence of blocks. This is going to be used for performance reasons
        end_block = init_block + block_upd - 1
        
        ## therefore, the starting and ending index to be used to fetch the gradient positions for the
        ## blocks will be the starting index for the first block and the ending positions for the last block
        _, g_init_idx, _ = self.ind_fisher_list[init_block]
        _, _, g_end_idx = self.ind_fisher_list[end_block]
        
        return init_block, end_block, g_init_idx, g_end_idx

---

# utils function for training

In [17]:
def accuracy_score_tns(y_true, y_pred):
    return torch.mean( (y_true == y_pred).to(dtype=torch.float) ).cpu().item()

In [18]:
def train_iteration(x, y, net, optim, loss, fisher=None):
    net.train()
    net.zero_grad()
    
    y_pred = net(x)
    l = loss(y_pred, y)
    
    l.backward()
    
    if fisher is not None:
        fisher.step()
    
    optim.step()
    
    return l.item(), accuracy_score_tns( y.view(-1), y_pred.argmax(dim=1).view(-1) )

In [19]:
def evaluate(net, dataloader, loss):
    net.eval()
    
    with torch.no_grad():

        loss_list = []
        y_pred_list = []
        y_label_list = []
        for x, y in dataloader:
            
            x = x.to(cfg.device)
            y = y.to(cfg.device)

            y_pred = net(x)
            l = loss(y_pred, y)

            loss_list.append( l.cpu().item() )
            y_pred_list.append( y_pred.argmax(dim=1).view(-1) )
            y_label_list.append( y.view(-1) )

        y_pred_list = torch.cat(y_pred_list).view(-1)
        y_label_list = torch.cat(y_label_list).view(-1)

    return np.mean(loss_list), accuracy_score_tns(y_label_list, y_pred_list)

# training

In [20]:
def train_network_fisher_optimization(batch_size = 32,
                                      lr = 1e-3,
                                      momentum = 0.9,
                                      epochs = 30,
                                      beta = 0.9,
                                      partition_size = 256,
                                      block_updates = 4,
                                      net_params = {'c':16, 'p':0.1},
                                      apply_fisher = True,
                                      # gpu_memory_check = 20,
                                      time_limit_secs = 600,
                                      interval_print = 100):

    ## declare (instantiate) the dataset
    # train_dataloader, test_dataloader = generate_dataset_cifar10(batch_size = batch_size)
    # train_dataloader, test_dataloader = generate_dataset_mnist(batch_size = batch_size)
    train_dataloader, test_dataloader = generate_dataset_cifar100(batch_size = batch_size)

    ## instantiate the network
    # net = get_cnn_network_v2(p_drop = net_params['p']).to(device=cfg.device)
    net = get_resnet18().to(device=cfg.device)
    
    if apply_fisher:
        fisher_obj = FisherExp(named_params = net.named_parameters(),
                                beta = beta,
                                partition_size = partition_size,
                                block_updates = block_updates)
    else:
        fisher_obj = None

    ## create loss object: we multiply by our constant to stabilize norms
    # cross_entropy = nn.CrossEntropyLoss(reduction='mean') # standard version
    cross_entropy_standard = nn.CrossEntropyLoss(reduction='mean')
    cross_entropy = lambda y_pred, y: math.sqrt(batch_size) * cross_entropy_standard(y_pred, y)
    
    ## create optimize objects
    optim = torch.optim.SGD(params=net.parameters(), lr=lr, momentum=momentum)

    default_metrics = Metrics(value_round=3, time_round=2)

    ini_time = time.time()

    step = 0
    training_finished = False
    for epc in range(1, epochs + 1):
        
        if training_finished:
            break
        
        print(f'starting epoch: {epc}/{epochs}')

        for nbt, (x, y) in enumerate(train_dataloader):

            if training_finished:
                break

            x = x.to(cfg.device)
            y = y.to(cfg.device)

            train_loss, train_acc = train_iteration(x, y, net, optim, cross_entropy, fisher_obj)
            default_metrics.add_({'train-loss': train_loss, 'train-acc': train_acc}, step=step)
            
            ## check time limit
            t = int(time.time() - ini_time)
            if t > time_limit_secs:
                print('time is up! finishing training')
                training_finished = True

            if ( (nbt + 1) % interval_print ) == 0 or (nbt + 1) == len(train_dataloader) or training_finished:
                avg_train_loss = np.mean( default_metrics.get('train-loss')[1][-interval_print:] )
                avg_train_acc = np.mean( default_metrics.get('train-acc')[1][-interval_print:] )
                
                test_loss, test_acc = evaluate(net, test_dataloader, cross_entropy)
                default_metrics.add_({'test-loss': test_loss, 'test-acc': test_acc}, step=step)

                m, s = t // 60, t % 60

                print(f'batch: {nbt + 1}/{len(train_dataloader)}', end='')
                print(f' - train loss: {avg_train_loss:.4f} - test loss: {test_loss:.4f}', end='')
                print(f' - train acc: {avg_train_acc:.4f} - test acc: {test_acc:.4f}', end='')
                print(f' - {m}m {s}s')
                
            step += 1

        ## check for GPU memory consumption
        if torch.cuda.is_available():
            mem_alloc_gb = torch.cuda.memory_allocated(cfg.device) / 1024**3
            mem_res_gb = torch.cuda.memory_reserved(cfg.device) / 1024**3
            max_mem_alloc_gb = torch.cuda.max_memory_allocated(cfg.device) / 1024**3
            max_mem_res_gb = torch.cuda.max_memory_reserved(cfg.device) / 1024**3

            print(f'GPU memory used: {mem_alloc_gb:.2f} GB - max: {max_mem_alloc_gb:.2f} GB - memory reserved: {mem_res_gb:.2f} GB - max: {max_mem_res_gb:.2f} GB')

            # torch.cuda.empty_cache()

    return default_metrics, fisher_obj

In [21]:
# net = torchvision.models.resnet18()
# torchsummary.summary(net, input_size=[[3, 64, 64]], device='cpu')

In [22]:
last_step_saved = None

def results_list_to_json(results_list, out_dir='/kaggle/working', step=0):
    global last_step_saved

    json_results = []

    for metrics, bt, ps, bu in results_list:
        json_results.append({
            'beta': bt,
            'partition-size': ps,
            'blocks-updates': bu,
            'metrics': metrics.metrics_dict
        })

    with open( os.path.join(out_dir, f'results_step_{step}.json'), 'w' ) as fp:
        json.dump(json_results, fp)
    
    if last_step_saved is not None:
        old_file = os.path.join(out_dir, f'results_step_{last_step_saved}.json')
        if os.path.exists(old_file):
            os.remove(old_file)
    
    last_step_saved = step

In [23]:
def get_min_test_loss(metrics):
    _, test_loss, _ = metrics.get('test-loss')
    return min(test_loss)

## running FisherExp in CIFAR100

In [24]:
results_list = []
ci = 0

beta = 0.9
partition_size = 8
block_updates = 64

n_runs = 10

for _ in range(n_runs):

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    print(f'testing - beta: {beta} - part size: {partition_size} - block upd: {block_updates} - combination nº: {ci + 1}')
    ci += 1

    default_metrics, _ = train_network_fisher_optimization(apply_fisher = True,
                                                           beta = beta,
                                                           partition_size = partition_size,
                                                           block_updates = block_updates,
                                                           epochs = 100,
                                                           time_limit_secs = 20 * 60)

    results_list.append( (default_metrics, beta, partition_size, block_updates) )
    results_list_to_json(results_list, step=ci)

testing - beta: 0.9 - part size: 8 - block upd: 64 - combination nº: 1
generating CIFAR100 data with 100 classes
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 16, 16]           9,408
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
         MaxPool2d-4             [-1, 64, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]          36,864
       BatchNorm2d-6             [-1, 64, 8, 8]             128
              ReLU-7             [-1, 64, 8, 8]               0
            Conv2d-8             [-1, 64, 8, 8]          36,864
       BatchNorm2d-9             [-1, 64, 8, 8]             128
             ReLU-10             [-1, 64, 8, 8]               0
       BasicBlock-11             [-1, 64, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          36,864
      BatchNo