#Setup

In [None]:
batch_size = 64
epochs = 50
layers = 1
hidden = 32
Dimensions = 5

In [None]:
import math
import random

import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
from torch.nn.parameter import Parameter, UninitializedParameter
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch.nn import init
from tqdm.notebook import trange, tqdm
import torchvision.transforms.functional as TF
from typing import Optional, List, Tuple, Union
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
import itertools
import numpy as np

In [None]:
# Download training data from open datasets.
training_data = datasets.MNIST(root="data", train=True, download=True, 
                                      transform=transforms.Compose([transforms.Resize((16, 16)), 
                                                                    transforms.Pad(8), 
                                                                    transforms.ToTensor()]),)

# Download test data from open datasets.
test_data = datasets.MNIST(root="data", train=False, download=True,
                                  transform=transforms.Compose([transforms.Resize((16, 16)), 
                                                                    transforms.Pad(8), 
                                                                    transforms.ToTensor()]),)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [None]:
# Create data loaders.
train_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

for X, y in test_loader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break



Shape of X [N, C, H, W]:  torch.Size([64, 1, 32, 32])
Shape of y:  torch.Size([64]) torch.int64


# MLP with search in alpha-space



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def count_model_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class HHN_MLPB(nn.Module):
    def __init__(self, hin, dimensions, n_layers, n_units, n_channels, n_classes=10):
        super(HHN_MLPB, self).__init__()
        self.hyper_stack = nn.Sequential(
            nn.Linear(hin, 64),
            nn.ReLU(),
            nn.Linear(64, dimensions),
            nn.Softmax(dim=0)
        )

        self.dimensions = dimensions
        self.n_layers = n_layers
        self.n_units = n_units
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.running_mu = torch.zeros(self.n_units).to(self.device)  # zeros are fine for first training iter
        self.running_std = torch.ones(self.n_units).to(self.device)  # ones are fine for first training iter

        self.weight_list_fc1, self.bias_list_fc1 = \
            self.create_param_combination_linear(in_features=32 * 32 * n_channels, out_features=n_units)
        self.weights = nn.ParameterList()
        self.biases = nn.ParameterList()
        for _ in range(n_layers - 1):
            w, b = self.create_param_combination_linear(in_features=n_units, out_features=n_units)
            self.weights += w
            self.biases += b
        self.weight_list_fc2, self.bias_list_fc2 = self.create_param_combination_linear(in_features=n_units,
                                                                                        out_features=n_classes)

    def create_param_combination_linear(self, in_features, out_features):
        weight_list = nn.ParameterList()
        bias_list = nn.ParameterList()
        for _ in range(self.dimensions):
            weight = Parameter(torch.empty((out_features, in_features)))
            nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
            weight_list.append(weight)

            bias = Parameter(torch.empty(out_features))
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(bias, -bound, bound)
            bias_list.append(bias)
        return weight_list, bias_list

    def calculate_weighted_sum(self, param_list: List, factors: Tensor):
        weighted_list = [a * b for a, b in zip(param_list, factors)]
        return torch.sum(torch.stack(weighted_list), dim=0)

    def forward(self, x, hyper_x):
        hyper_output = self.hyper_stack(hyper_x)

        weight_fc1 = self.calculate_weighted_sum(self.weight_list_fc1, hyper_output)
        weight_fc2 = self.calculate_weighted_sum(self.weight_list_fc2, hyper_output)

        bias_fc1 = self.calculate_weighted_sum(self.bias_list_fc1, hyper_output)
        bias_fc2 = self.calculate_weighted_sum(self.bias_list_fc2, hyper_output)

        logits = torch.flatten(x, start_dim=1)
        logits = F.linear(logits, weight=weight_fc1, bias=bias_fc1)
        logits = torch.relu(logits)

        it_w = iter(self.weights)
        it_b = iter(self.biases)
        for (w, b) in zip(zip(*[it_w] * self.dimensions), zip(*[it_b] * self.dimensions)):
            w = nn.ParameterList(w)
            b = nn.ParameterList(b)
            w = self.calculate_weighted_sum(w.to(self.device), hyper_output)
            b = self.calculate_weighted_sum(b.to(self.device), hyper_output)
            logits = F.linear(logits, weight=w, bias=b)
            logits = F.batch_norm(logits, self.running_mu, self.running_std, training=True, momentum=0.9)
            logits = torch.relu(logits)
        logits = F.linear(logits, weight=weight_fc2, bias=bias_fc2)
        return logits

model = HHN_MLPB(2, Dimensions, layers, hidden, 1, n_classes=10).to(device)
print(model)
print(count_model_parameters(model))

HHN_MLPB(
  (hyper_stack): Sequential(
    (0): Linear(in_features=2, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=5, bias=True)
    (3): Softmax(dim=0)
  )
  (weight_list_fc1): ParameterList(
      (0): Parameter containing: [torch.float32 of size 32x1024 (GPU 0)]
      (1): Parameter containing: [torch.float32 of size 32x1024 (GPU 0)]
      (2): Parameter containing: [torch.float32 of size 32x1024 (GPU 0)]
      (3): Parameter containing: [torch.float32 of size 32x1024 (GPU 0)]
      (4): Parameter containing: [torch.float32 of size 32x1024 (GPU 0)]
  )
  (bias_list_fc1): ParameterList(
      (0): Parameter containing: [torch.float32 of size 32 (GPU 0)]
      (1): Parameter containing: [torch.float32 of size 32 (GPU 0)]
      (2): Parameter containing: [torch.float32 of size 32 (GPU 0)]
      (3): Parameter containing: [torch.float32 of size 32 (GPU 0)]
      (4): Parameter containing: [torch.float32 of size 32 (GPU 0)]
  )
  (weights): Par

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0)

cos = nn.CosineSimilarity(dim=0, eps=1e-6)

def train(dataloader, model, loss_fn, optimizer):
    for batch, (X, y) in enumerate(tqdm(dataloader, desc='Training')):
        X, y = X.to(device), y.to(device)
        a, b = random.uniform(-8, 8), random.uniform(-8, 8)
        X = TF.affine(X, scale=1.0, angle=0, translate=(a, b), shear=0.0)

        pred = model(X, hyper_x=Tensor([a, b]).to(device))
        loss = loss_fn(pred, y)

        beta1 = model.hyper_stack(Tensor([a, b]).to(device))
        a2, b2 = random.uniform(-8, 8), random.uniform(-8, 8)
        beta2 = model.hyper_stack(Tensor([a2, b2]).to(device))
        loss += pow(cos(beta1, beta2),2)

        # minimize entropy to the correct degree
        b = (F.softmax(pred, dim=1)) * (-1 * F.log_softmax(pred, dim=1))
        loss += 0.01*b.sum()

        # maximize entropy to the wrong degree
        logits = model(X, hyper_x=Tensor([a2, b2]).to(device))
        b2 = (F.softmax(logits, dim=1)) * (-1 * F.log_softmax(logits, dim=1))
        loss -= 0.01*b2.sum()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # scheduler.step()

def validate(dataloader, model, loss_fn):
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            a, b = random.uniform(-8, 8), random.uniform(-8, 8)
            X = TF.affine(X, scale=1.0, angle=0, translate=(a, b), shear=0.0)

            pred = model(X, hyper_x=Tensor([a, b]).to(device))
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= len(dataloader)
    correct /= len(dataloader.dataset)
    print(f"Test with translation={a,b}: Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")
    return correct, test_loss

for t in range(epochs):
    print(f"=================\n Epoch: {t + 1} \n=================")
    train(train_loader, model, loss_fn, optimizer)
    test_acc, test_loss = validate(test_loader, model, loss_fn)
print("Done!")

 Epoch: 1 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(2.3260463607721835, 2.6758056233708736): Accuracy: 40.9%, Avg loss: 1.728002
 Epoch: 2 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(0.9998305250234889, 3.7428690667227436): Accuracy: 53.1%, Avg loss: 1.369338
 Epoch: 3 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-0.2978151914770173, 3.533165817765049): Accuracy: 59.5%, Avg loss: 1.185801
 Epoch: 4 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(4.842590472001154, 6.500853341807126): Accuracy: 64.3%, Avg loss: 1.044691
 Epoch: 5 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(4.421783143195938, 7.304662753133368): Accuracy: 66.0%, Avg loss: 0.999523
 Epoch: 6 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(4.539536671317547, 1.512831817596478): Accuracy: 68.1%, Avg loss: 0.945852
 Epoch: 7 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(0.013288729447397785, 0.12172069103078265): Accuracy: 70.7%, Avg loss: 0.909876
 Epoch: 8 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-7.690624268479752, -6.859744330849802): Accuracy: 73.6%, Avg loss: 0.816797
 Epoch: 9 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-6.158454597330316, -6.532442757332431): Accuracy: 73.0%, Avg loss: 0.812527
 Epoch: 10 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(5.392127595573333, 3.5423643988177105): Accuracy: 73.9%, Avg loss: 0.804419
 Epoch: 11 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(0.14839434049198807, 5.019876810606325): Accuracy: 76.7%, Avg loss: 0.723139
 Epoch: 12 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(0.7607020689897315, -4.218493025487515): Accuracy: 75.3%, Avg loss: 0.748603
 Epoch: 13 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(4.332678307396565, 1.958062841946143): Accuracy: 76.2%, Avg loss: 0.753899
 Epoch: 14 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(4.914982521284795, -5.930619960774672): Accuracy: 76.8%, Avg loss: 0.744317
 Epoch: 15 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(2.4995676385706833, -6.753232655543998): Accuracy: 79.2%, Avg loss: 0.649264
 Epoch: 16 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-6.044234290173629, 4.192205388680302): Accuracy: 78.6%, Avg loss: 0.668623
 Epoch: 17 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(1.4889088343509478, -5.651714507482243): Accuracy: 80.7%, Avg loss: 0.626658
 Epoch: 18 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-6.750145831701424, -6.574751695691196): Accuracy: 81.3%, Avg loss: 0.600439
 Epoch: 19 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-3.396652867843869, 6.027616656358697): Accuracy: 80.8%, Avg loss: 0.623063
 Epoch: 20 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-5.803627798494297, -7.635112824812026): Accuracy: 79.8%, Avg loss: 0.635690
 Epoch: 21 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(3.693136878521461, -6.260335463523008): Accuracy: 80.8%, Avg loss: 0.623958
 Epoch: 22 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(1.471403707688303, -4.4806745856006085): Accuracy: 81.9%, Avg loss: 0.589567
 Epoch: 23 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-7.615502452535997, 4.146624032735337): Accuracy: 82.5%, Avg loss: 0.569379
 Epoch: 24 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(2.722630977746876, -5.817195011282761): Accuracy: 82.4%, Avg loss: 0.573082
 Epoch: 25 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-7.907271815117653, -6.535554369523435): Accuracy: 83.4%, Avg loss: 0.557823
 Epoch: 26 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(1.5567663658104607, -2.8815031308783876): Accuracy: 81.9%, Avg loss: 0.587787
 Epoch: 27 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-5.716578721745373, -5.660561703253244): Accuracy: 83.3%, Avg loss: 0.545266
 Epoch: 28 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-6.626578952286351, 3.948003912601548): Accuracy: 82.8%, Avg loss: 0.565942
 Epoch: 29 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(0.37014032843146616, 6.449340933766175): Accuracy: 84.5%, Avg loss: 0.499839
 Epoch: 30 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(3.815957258998923, -0.5166621178669644): Accuracy: 83.1%, Avg loss: 0.551014
 Epoch: 31 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(0.10862862183601507, 2.9557192141499193): Accuracy: 83.8%, Avg loss: 0.530798
 Epoch: 32 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(6.804953115145748, 3.5004114130849935): Accuracy: 83.9%, Avg loss: 0.528903
 Epoch: 33 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-3.3160175465192783, 3.245393140739928): Accuracy: 83.6%, Avg loss: 0.550808
 Epoch: 34 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(1.478853842945691, 7.703619791742193): Accuracy: 84.3%, Avg loss: 0.517498
 Epoch: 35 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(1.415390000037366, -6.869289082996049): Accuracy: 84.3%, Avg loss: 0.518023
 Epoch: 36 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(2.4309182310714466, -5.983631160686318): Accuracy: 85.1%, Avg loss: 0.484414
 Epoch: 37 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(3.0980465986287484, 7.142471762153479): Accuracy: 85.2%, Avg loss: 0.486447
 Epoch: 38 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-2.7204024246886753, -3.1791958114527787): Accuracy: 85.2%, Avg loss: 0.495190
 Epoch: 39 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-7.1888759965280915, -7.809665283567442): Accuracy: 85.0%, Avg loss: 0.489334
 Epoch: 40 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-4.886289781942347, 6.589106841560918): Accuracy: 85.4%, Avg loss: 0.489493
 Epoch: 41 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(4.8077259283410925, 6.694623549840946): Accuracy: 85.8%, Avg loss: 0.487495
 Epoch: 42 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-4.716291371797251, -6.657879614501658): Accuracy: 85.0%, Avg loss: 0.500132
 Epoch: 43 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-2.7006172833459896, 3.0283489623431255): Accuracy: 85.6%, Avg loss: 0.487584
 Epoch: 44 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-0.2459079772174544, 5.286528427801107): Accuracy: 86.3%, Avg loss: 0.471122
 Epoch: 45 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(5.920140612653531, -1.1549177919876517): Accuracy: 85.6%, Avg loss: 0.487374
 Epoch: 46 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-2.6664024664861827, 0.11757660126706426): Accuracy: 85.4%, Avg loss: 0.482634
 Epoch: 47 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-1.9585720850412098, -1.3767853511407449): Accuracy: 86.1%, Avg loss: 0.465108
 Epoch: 48 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-6.101546567503137, 3.5178696881288545): Accuracy: 85.5%, Avg loss: 0.472860
 Epoch: 49 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-2.9041162097834654, 5.802479612075638): Accuracy: 85.1%, Avg loss: 0.485140
 Epoch: 50 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-7.7357973748913444, 7.989186992718732): Accuracy: 85.4%, Avg loss: 0.497362
Done!


Optimize in the alpha space

In [None]:
# model to eval mode and move to cpu
model.eval()
model.cpu()

# freeze Ws
for param in model.parameters():
  param.requires_grad = False

In [None]:
# execute only if you wish to test with a different batch_size (for example, batch_size=1 ... takes long!)
batch_size=64
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)



In [None]:
from scipy import optimize

# function to minimize by the basin hopping algorithm
def f(z, *args):
    a, b = ((z + 1)*8)%16 - 8
    X = args[0]
    logits = model(Tensor(X), hyper_x=Tensor([a, b]))
    b = (F.softmax(logits, dim=1)) * (-1 * F.log_softmax(logits, dim=1))  # entropy
    return b.sum().numpy()

# given a batch of images find the rotation angle alpha
def findalpha(X):
    # Basin hopping algorithm
    # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.basinhopping.html
    minimizer_kwargs = {"method": "BFGS", "args":X}
    res = optimize.basinhopping(f, [0.0, 0.0], minimizer_kwargs=minimizer_kwargs, niter=100, T=50)

    alpha = ((res.x + 1)*8)%16 - 8
    # print("alpha estimate = ", alpha)     # obtained minimum
    # print("fun = ", res.fun)              # function value at minimum
    return alpha

result = 0.0
for _, (X, y) in enumerate(tqdm(test_loader, desc='Testing alpha search')):
    a, b = random.uniform(-8, 8), random.uniform(-8, 8)
    # print("=============")
    # print("alpha true = ", [a,b])
    X = TF.affine(X, scale=1.0, angle=0, translate=(a, b), shear=0.0)

    alpha = findalpha(X)

    # compute model prediction with the estimated alpha
    logits = model(X, hyper_x=Tensor(alpha))
    # y is the true label --> calculate accuracy
    correct = (logits.argmax(1) == y).type(torch.float).sum().item() / batch_size
    # print(f"accuracy = {(100*correct):>0.1f}")
    result += correct

result /= len(test_loader.dataset) / batch_size
print(f"Test accuracy: {(100*result):>0.1f}%")

print("Done!")

Testing alpha search:   0%|          | 0/157 [00:00<?, ?it/s]

Test accuracy: 84.4%
Done!


# One4All

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class MLPB(nn.Module):
    def __init__(self, n_layers, n_units, n_channels, n_classes=10):
        super(MLPB, self).__init__()
        mid_layers = []
        mid_layers.extend([nn.Flatten(), nn.Linear(32 * 32 * n_channels, n_units), nn.ReLU()])
        for _ in range(n_layers-1):
            mid_layers.extend([
                nn.Linear(n_units, n_units),
                nn.BatchNorm1d(n_units, momentum=0.9),
                nn.ReLU(),
            ])
        mid_layers.extend([nn.Linear(n_units, n_classes)])
        self.linear_relu_stack = nn.Sequential(*mid_layers)

    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits

model = MLPB(layers, hidden, 1, n_classes=10).to(device)
print(model)
print(count_model_parameters(model))

MLPB(
  (linear_relu_stack): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=1024, out_features=32, bias=True)
    (2): ReLU()
    (3): Linear(in_features=32, out_features=10, bias=True)
  )
)
33130


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0)

def train(dataloader, model, loss_fn, optimizer):
    for batch, (X, y) in enumerate(tqdm(dataloader, desc='Training')):
        X, y = X.to(device), y.to(device)
        a, b = random.uniform(-8, 8), random.uniform(-8, 8)
        X = TF.affine(X, scale=1.0, angle=0, translate=(a, b), shear=0.0)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # scheduler.step()

def validate(dataloader, model, loss_fn):
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            a, b = random.uniform(-8, 8), random.uniform(-8, 8)
            X = TF.affine(X, scale=1.0, angle=0, translate=(a, b), shear=0.0)

            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= len(dataloader)
    correct /= len(dataloader.dataset)
    print(f"Test with translation={a,b}: Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")
    return correct, test_loss

for t in range(epochs):
    print(f"=================\n Epoch: {t + 1} \n=================")
    train(train_loader, model, loss_fn, optimizer)
    test_acc, test_loss = validate(test_loader, model, loss_fn)
print("Done!")

 Epoch: 1 


Training:   0%|          | 0/938 [00:00<?, ?it/s]



Test with translation=(2.220755577066466, 4.929664021916047): Accuracy: 39.5%, Avg loss: 1.752685
 Epoch: 2 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(3.77104921820108, 6.216031554722267): Accuracy: 53.0%, Avg loss: 1.404256
 Epoch: 3 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(2.0534735491795733, 6.906741441593727): Accuracy: 61.0%, Avg loss: 1.214778
 Epoch: 4 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-4.310417265760412, -3.6290544446692703): Accuracy: 65.0%, Avg loss: 1.089294
 Epoch: 5 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-1.6355997517232694, 3.1374492762143316): Accuracy: 67.4%, Avg loss: 1.009096
 Epoch: 6 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-5.875902783594112, 1.276188149039843): Accuracy: 69.5%, Avg loss: 0.950495
 Epoch: 7 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(4.368058044371887, -0.6478460772792403): Accuracy: 69.7%, Avg loss: 0.924104
 Epoch: 8 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-1.55105884250934, 6.1130812454292): Accuracy: 71.7%, Avg loss: 0.879989
 Epoch: 9 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(7.9419965428551755, 4.209176025687944): Accuracy: 72.6%, Avg loss: 0.862061
 Epoch: 10 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(5.6459761437516125, -7.602328816088098): Accuracy: 72.4%, Avg loss: 0.843561
 Epoch: 11 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-5.345489089304145, -3.2381783070935395): Accuracy: 73.3%, Avg loss: 0.827280
 Epoch: 12 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(7.15418855555987, -3.3141880574882183): Accuracy: 73.7%, Avg loss: 0.808063
 Epoch: 13 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(5.85440563919083, -6.99454197454885): Accuracy: 74.5%, Avg loss: 0.792623
 Epoch: 14 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-4.29355188250689, 1.1346185580111818): Accuracy: 74.0%, Avg loss: 0.814210
 Epoch: 15 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(5.183066934299431, 3.8106260496836875): Accuracy: 74.8%, Avg loss: 0.778214
 Epoch: 16 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(3.6073026623759556, -3.0738574133067136): Accuracy: 74.8%, Avg loss: 0.786383
 Epoch: 17 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-6.619996347160315, 6.374096388808301): Accuracy: 75.5%, Avg loss: 0.759170
 Epoch: 18 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-0.9080987077877367, -4.770667793889318): Accuracy: 76.2%, Avg loss: 0.750546
 Epoch: 19 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-5.564312424340205, 0.016923332692122983): Accuracy: 76.0%, Avg loss: 0.747756
 Epoch: 20 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-6.2924197020888695, 5.757802428993594): Accuracy: 75.7%, Avg loss: 0.741011
 Epoch: 21 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(1.9163904074037728, -6.7356952298028805): Accuracy: 77.1%, Avg loss: 0.721701
 Epoch: 22 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(0.2690986779159079, 3.2246899149807415): Accuracy: 76.1%, Avg loss: 0.740591
 Epoch: 23 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-3.446937636299989, -3.4402019693564316): Accuracy: 77.1%, Avg loss: 0.719941
 Epoch: 24 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(2.751895469070279, 1.3168121279544778): Accuracy: 77.0%, Avg loss: 0.720743
 Epoch: 25 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(3.534926531411992, 4.218080955788988): Accuracy: 76.9%, Avg loss: 0.721394
 Epoch: 26 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-6.6172093931553295, -2.001374585422795): Accuracy: 76.8%, Avg loss: 0.733867
 Epoch: 27 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-1.449642814940237, -7.348270530650881): Accuracy: 76.9%, Avg loss: 0.719269
 Epoch: 28 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(6.351320087912518, 2.0788917258657023): Accuracy: 76.9%, Avg loss: 0.710210
 Epoch: 29 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(4.772767937058598, -3.083893641841419): Accuracy: 77.6%, Avg loss: 0.707075
 Epoch: 30 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-2.0435619044888345, -0.8212100373835973): Accuracy: 77.5%, Avg loss: 0.705401
 Epoch: 31 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(3.013503715361569, -4.460159096243972): Accuracy: 77.4%, Avg loss: 0.694146
 Epoch: 32 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-7.787404267599687, 1.4919580421710315): Accuracy: 77.0%, Avg loss: 0.710196
 Epoch: 33 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-0.01729307055525986, 5.476194895151085): Accuracy: 77.8%, Avg loss: 0.694541
 Epoch: 34 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(2.8446565192576774, -7.7431598252686005): Accuracy: 77.9%, Avg loss: 0.682537
 Epoch: 35 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-2.3038677253019415, -4.599022260067564): Accuracy: 77.8%, Avg loss: 0.702560
 Epoch: 36 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(1.1136262009729787, -0.8196358152426644): Accuracy: 77.7%, Avg loss: 0.695461
 Epoch: 37 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(7.055339685762581, -7.693976345953349): Accuracy: 78.9%, Avg loss: 0.675993
 Epoch: 38 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(0.4358016731632155, -1.4500566086441342): Accuracy: 78.1%, Avg loss: 0.687268
 Epoch: 39 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(7.7813197602842035, 7.041740400393861): Accuracy: 79.2%, Avg loss: 0.664133
 Epoch: 40 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-2.3653098424766164, -7.048407633676602): Accuracy: 78.9%, Avg loss: 0.686194
 Epoch: 41 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-3.6372045324145006, -5.974668065245837): Accuracy: 78.9%, Avg loss: 0.664072
 Epoch: 42 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(2.476809204488019, -5.831384074769341): Accuracy: 79.0%, Avg loss: 0.660442
 Epoch: 43 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-7.330950747905582, 4.813470042555641): Accuracy: 78.7%, Avg loss: 0.669048
 Epoch: 44 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(4.543602934743303, -4.3626865370803145): Accuracy: 78.8%, Avg loss: 0.670517
 Epoch: 45 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(0.13008969110752844, -3.404411745882843): Accuracy: 78.4%, Avg loss: 0.676563
 Epoch: 46 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(3.1392735165757575, 6.940003620826962): Accuracy: 79.0%, Avg loss: 0.657696
 Epoch: 47 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-4.4654524252348455, -6.608269285520434): Accuracy: 78.5%, Avg loss: 0.669905
 Epoch: 48 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(0.4947995121703048, -3.204277199042691): Accuracy: 79.8%, Avg loss: 0.652544
 Epoch: 49 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-5.0801060276455825, -3.3254813299776345): Accuracy: 80.2%, Avg loss: 0.651436
 Epoch: 50 


Training:   0%|          | 0/938 [00:00<?, ?it/s]

Test with translation=(-4.570527886771533, -4.845838056985331): Accuracy: 79.4%, Avg loss: 0.649829
Done!
