In [None]:
pip -q install einops torchvision torch --upgrade

In [None]:
import numpy as np
from copy import deepcopy
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import numpy as np
import itertools
import random
import statistics
import tqdm
import math
from torchsummary import summary
from einops.layers.torch import Rearrange, Reduce

In [None]:
torch.__version__

'2.1.1+cu121'

In [None]:
torch.cuda.is_available()

True

In [None]:
img_shape = 64
batch_size = 2048
num_thread = 2
patch_shape = 8
embd_dim = 256
channel = 3
learning_rate = 1e-4
weight_decay = 5e-4
num_patch = (img_shape//patch_shape)**2
num_epoch = 50
num_class = 100
num_layer = 4

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
transform_train = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale = True),
    v2.Resize((img_shape, img_shape), antialias = True),
    v2.Normalize(mean = (0.45, 0.45, 0.45), std = (0.23, 0.23, 0.23)),
    v2.AutoAugment(transforms.autoaugment.AutoAugmentPolicy.CIFAR10),
    ])

transform_test = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale = True),
    v2.Resize((img_shape, img_shape), antialias = True),
    v2.Normalize(mean = (0.45, 0.45, 0.45), std = (0.23, 0.23, 0.23)),
    ])

transform_mnist = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale = True),
    v2.Resize((img_shape, img_shape), antialias = True),
    ])

## 2.1 utility functions

In [None]:
class Patches(nn.Module):
  def __init__(self, patch = patch_shape, img_shape = img_shape, embd_dim = embd_dim, channel = channel):
    super().__init__()
    self.proj = nn.Sequential(Rearrange('b c (p1 h) (p2 w) -> b (p1 p2) (c h w)', h = patch, w = patch), nn.Linear(channel*(patch**2), embd_dim))

  def forward(self, x):
    return self.proj(x)

In [None]:
class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)

In [None]:
def vec_norm(x, dim = -1, epsilon = 1e-6):
  return torch.div(x, torch.linalg.vector_norm(x, dim = dim, keepdim = True) + epsilon)

def nrs(x, dim = -1):
  return vec_norm(x.relu(), dim = dim)**2

In [None]:
class att(nn.Module):
  def __init__(self, head = 32, patch = num_patch, embd_dim = embd_dim):
    super().__init__()
    self.expand = nn.Linear(embd_dim, int(head) * 3, bias = False)
    self.scale = torch.sqrt(torch.tensor(head, dtype = torch.float32))
    self.proj = nn.Linear(int(head), embd_dim)

  def forward(self, x):
    q, k, v = self.expand(x).chunk(3, -1)
    return F.gelu(self.proj(torch.einsum('...nm,...mf->...nf', nrs(torch.einsum('...nc,...mc->...nm', q, k)/self.scale, dim = -1), v)))

In [None]:
class spatial_proj1d(nn.Module):
  def __init__(self, p = num_patch, c = num_patch):
    super().__init__()
    self.proj = nn.Conv1d(p, p, 1, groups = p//c)
    nn.init.constant_(self.proj.weight,0)
    nn.init.constant_(self.proj.bias,1)

  def forward(self, x):
    return self.proj(x)

In [None]:
class spatial_proj2d(nn.Module):
  def __init__(self,p = num_patch, c = 4):
    super().__init__()
    self.proj = nn.Conv2d(p * c, p * c, 3, padding = 'same', groups = c)
    nn.init.constant_(self.proj.weight,0)
    nn.init.constant_(self.proj.bias,1)

  def forward(self, x):
    return self.proj(x)

In [None]:
class spatial_proj3d(nn.Module):
  def __init__(self,p = num_patch, c = num_patch, k = 3):
    super().__init__()
    self.proj = nn.Conv3d(p, p, (1, k, k), padding = 'same', groups = p//c)
    nn.init.constant_(self.proj.weight,0)
    nn.init.constant_(self.proj.bias,1)

  def forward(self, x):
    return self.proj(x)

## 2.2 gMLPs

In [None]:
class gMLP(nn.Module):
  def __init__(self, patch = num_patch, embd_dim = embd_dim, gate = nn.Sequential(nn.LayerNorm(embd_dim), spatial_proj1d()), channel_proj = nn.Linear(embd_dim, embd_dim)):
    super().__init__()
    self.norm = nn.LayerNorm(embd_dim)
    self.expand = nn.Linear(embd_dim, 2 * embd_dim, bias = False)
    self.proj_seq = deepcopy(gate)
    self.proj = deepcopy(channel_proj)
    self.res = att()

  def forward(self, x):
    x_norm = self.norm(x)
    u, v = F.gelu(self.expand(x_norm)).chunk(2, -1)
    return self.proj(u * (self.proj_seq(v) + self.res(x_norm))) + x

class gMLP_blocks(nn.Module):
  def __init__(self, patch = num_patch, embd_dim = embd_dim, num_layer = num_layer, gate = nn.Sequential(nn.LayerNorm(embd_dim), spatial_proj1d()), channel_proj = nn.Linear(embd_dim, embd_dim)):
    super().__init__()
    self.model = nn.Sequential(*[gMLP(gate = gate, channel_proj = channel_proj) for _ in range(num_layer)])

  def forward(self, x):
    return self.model(x)

class gMLP_classifier(gMLP_blocks):
  def __init__(self, label = num_class, ps = patch_shape, imgs = img_shape, patch = num_patch, embd_dim = embd_dim, channel = channel, num_layer = num_layer, gate = nn.Sequential(nn.LayerNorm(embd_dim), spatial_proj1d()), channel_proj = nn.Linear(embd_dim, embd_dim)):
    super().__init__(patch, embd_dim, num_layer, gate, channel_proj)
    self.patch = Patches(ps, imgs, embd_dim, channel)
    self.proj = nn.Linear(embd_dim, label)

  def forward(self, x):
    return self.proj(self.model(self.patch(x)).mean(dim = 1))

##opt

In [None]:
class n_BiNPRAdam(torch.optim.Adam):

    def __init__(self, params, lr = 5e-4, betas=(0.93, 0.999), momentum_decay=1.7e-2):

        super().__init__(params, lr=lr, betas=betas)

        self.md = momentum_decay

        self.pin = 2/(1 - betas[1]) -1
        self.min = 2.85


    def step(self):

        for group in self.param_groups:

            for p in group['params']:

                if p.grad is None:

                    continue

                grad = p.grad.data

                if grad.is_sparse:

                    raise RuntimeError("Adam does not support sparse gradients")



                state = self.state[p]



                # State initialization

                if len(state) == 0:

                    state["step"] = 0


                    # first order momentum

                    state["m"] = torch.zeros_like(p.data)

                    # second order momentum

                    state["v"] = torch.zeros_like(p.data)

                    # nesterov momentum
                    state["nesterov"] = 1


                #m, v = state["m"], state["v"]
                mRate, vRate = group["betas"]

                state["step"] += 1

                mM, vM = mRate ** state["step"], vRate ** state["step"]

                u, u1 = mRate*(1 - (0.96 ** (state["step"] * self.md))/2), mRate*(1 - (0.96 ** ((1+state["step"]) * self.md))/2)
                sig_decay = torch.sigmoid(torch.tensor(5 - state["step"] / 5000))

                c = torch.ones_like(grad) * group["lr"]

                state["nesterov"] *= u

                pt = self.pin - 2 * state["step"] * vM / (1 - vM);
                mt = self.min - 2 * state["step"] * state["nesterov"]/ (1-state["nesterov"])

                # Decay the first and second moment running average coefficient

                state["m"].mul_(mRate).add_(grad, alpha = 1 - mRate)

                state["v"].mul_(vRate).addcmul_(grad, grad, value = 1 - vRate)


                rt, lr = (torch.sqrt(torch.tensor((pt - 4) * (pt - 2) * self.pin / ((self.pin - 4) * (self.pin - 2) * pt))), torch.pow(state["v"] / (1-vM), 0.5 * torch.sigmoid(torch.tensor(-state["step"]/ 5000)) + 0.25).add_(group["eps"])) if pt > 5.0 else (1, 1)
                mrt, mlr = (torch.sqrt(torch.tensor((mt - 2.3) * (mt - 1.03) * self.min/ ((self.min - 2.3) * (self.min - 1.03) * mt))), (u1 * state["m"] / (1.0 - mM) / (1.0 - state["nesterov"] * u1) + (1.0 - u) * grad / (1.0 - state["nesterov"]))) if pt > 2.3 else (1, state["m"])

                if math.isnan(rt):
                  rt = 1
                if math.isnan(mrt):
                  mrt = 1
                t = torch.tensor(state["step"])
                step_size = group["lr"] * math.sqrt(0.5 * (torch.pow(vRate, 0.5 * torch.sqrt(t)) - pow(mRate, 0.5 * torch.sqrt(t))) + 0.5) *(sig_decay * c + (1.0 + (1.0 - sig_decay) * self.md * p.data) *mlr / lr * rt / mrt)
                norm = torch.norm(p.data, dim = -1, keepdim = True) +1e-6
                #print(norm.shape)
                #step_size = torch.div(step_size.transpose(1,0), norm).transpose(1,0)
                p.data.add_(-step_size/norm)

## 4. Train the network

This is when things start to get interesting.
We simply have to loop over our data iterator, and feed the inputs to the
network and optimize.

In [None]:
def train_loop(model, trainloader, valloader, testloader, chan = 3):
  net = model
  summary(net, (chan, img_shape, img_shape));

  criterion = nn.CrossEntropyLoss()
  optimizer = optim.RAdam(net.parameters(), lr = learning_rate, weight_decay = weight_decay)
  scheduler1 = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 512, 16, learning_rate/32)
  scheduler2 = lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.5, patience = 5)

  with tqdm.trange(num_epoch) as pbar:
    total = 1e-15
    correct = 0.
    vallose = 10.
    for _ in pbar:
      train_batch = iter(trainloader)
      loss_epoch = []
      for inputs, labels in train_batch:

        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        torch.nn.utils.clip_grad_norm_(net.parameters(), 0.4)
        loss.backward()
        optimizer.step()

        loss_epoch.append(loss.item())
        scheduler1.step()

        lr = optimizer.param_groups[0]["lr"]
        avg_batch_loss = sum(loss_epoch) / len(loss_epoch)
        pbar.set_postfix(val_loss = "%.3e" % vallose, loss="%.3e" % avg_batch_loss, learning_rate="%.4e" % lr, val_acc = "%.2f"% (100*correct/total))

      with torch.no_grad():
        total = 0
        correct = 0.
        val_loss = list()
        for (images, labels) in valloader:
          images, labels = images.to(device), labels.to(device)

          outputs = net(images)

          _, predicted = torch.max(outputs.data, 1)
          val_loss.append(criterion(outputs, labels))

          total += labels.size(0)
          correct += (predicted == labels).sum().item()

        vallose = torch.mean(torch.tensor(val_loss, dtype = torch.float32, device = device))
        scheduler2.step(vallose)
        pbar.set_postfix(val_loss = "%.3e" % vallose, loss = "%.3e" % avg_batch_loss, learning_rate = "%.4e" % lr, val_acc = "%.2f"% (100*correct/ total))
  print('Finished Training')

  correct = 0
  total = 0
  with torch.no_grad():
    for (images, labels) in testloader:
      images, labels = images.to(device), labels.to(device)

      outputs = net(images)

      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

In [None]:
#MNIST

trainset = torchvision.datasets.MNIST(root='./data', train = True,
                                        download = True, transform = transform_mnist
                                         )
train_set, val_set = torch.utils.data.random_split(trainset, [54000, 6000])

trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle = True, num_workers = num_thread, pin_memory = True)

valloader = torch.utils.data.DataLoader(val_set, batch_size = batch_size,
                                          shuffle = False, num_workers = num_thread, pin_memory = True)

testset = torchvision.datasets.MNIST(root='./data', train = False,
                                       download = True, transform = transform_mnist
                                         )
testloader = torch.utils.data.DataLoader(testset, batch_size = batch_size,
                                         shuffle = False, num_workers = num_thread, pin_memory = True)
model_list = list()

gate = nn.Sequential(
       nn.LayerNorm(embd_dim),
       Rearrange('b s (c h w) -> b s c h w', c = 4, h = patch_shape, w = patch_shape),
       spatial_proj3d(),
       Rearrange('b s c h w -> b s (c h w)'),
       )

channel_proj = nn.Sequential(
       Rearrange('b s (c p) -> b s c p', c = 4),
       nn.Linear(patch_shape**2, patch_shape**2),
       Rearrange('b s c p -> b s p c'),
       nn.Linear(4, 4),
       Rearrange('b s p c -> b s (c p)')
       )

#config for gMLP

model_list.append(gMLP_classifier(label = 10, channel = 1).to(device))


#config for reforming the gate by Conv3D

model_list.append(gMLP_classifier(label = 10, channel = 1, gate = gate).to(device))


#config for grouped connect channel proj

model_list.append(gMLP_classifier(label = 10, channel = 1, channel_proj = channel_proj ).to(device))


#config for reforming the gate by Conv3D + grouped connect channel proj

model_list.append(gMLP_classifier(label = 10, channel = 1, gate = gate, channel_proj = channel_proj ).to(device))

for m in model_list:
  train_loop(m, trainloader, valloader, testloader, 1)
  del m
del trainloader
del valloader
del testloader

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1               [-1, 64, 64]               0
            Linear-2              [-1, 64, 256]          16,640
           Patches-3              [-1, 64, 256]               0
         LayerNorm-4              [-1, 64, 256]             512
            Linear-5              [-1, 64, 512]         131,072
         LayerNorm-6              [-1, 64, 256]             512
            Conv1d-7              [-1, 64, 256]           4,160
    spatial_proj1d-8              [-1, 64, 256]               0
            Linear-9               [-1, 64, 96]          24,576
           Linear-10              [-1, 64, 256]           8,448
              att-11              [-1, 64, 256]               0
           Linear-12              [-1, 64, 256]          65,792
             gMLP-13              [-1, 64, 256]               0
        LayerNorm-14              [-1, 

100%|█████| 50/50 [24:25<00:00, 29.31s/it, learning_rate=9.7520e-05, loss=8.348e-02, val_acc=96.82, val_loss=1.050e-01]


Finished Training
Accuracy of the network on the 10000 test images: 96 %
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1               [-1, 64, 64]               0
            Linear-2              [-1, 64, 256]          16,640
           Patches-3              [-1, 64, 256]               0
         LayerNorm-4              [-1, 64, 256]             512
            Linear-5              [-1, 64, 512]         131,072
         LayerNorm-6              [-1, 64, 256]             512
         Rearrange-7          [-1, 64, 4, 8, 8]               0
            Conv3d-8          [-1, 64, 4, 8, 8]          36,928
    spatial_proj3d-9          [-1, 64, 4, 8, 8]               0
        Rearrange-10              [-1, 64, 256]               0
           Linear-11               [-1, 64, 96]          24,576
           Linear-12              [-1, 64, 256]           8,448
              att-13          

100%|█████| 50/50 [24:45<00:00, 29.72s/it, learning_rate=9.7520e-05, loss=5.275e-02, val_acc=96.90, val_loss=1.089e-01]


Finished Training
Accuracy of the network on the 10000 test images: 96 %
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1               [-1, 64, 64]               0
            Linear-2              [-1, 64, 256]          16,640
           Patches-3              [-1, 64, 256]               0
         LayerNorm-4              [-1, 64, 256]             512
            Linear-5              [-1, 64, 512]         131,072
         LayerNorm-6              [-1, 64, 256]             512
            Conv1d-7              [-1, 64, 256]           4,160
    spatial_proj1d-8              [-1, 64, 256]               0
            Linear-9               [-1, 64, 96]          24,576
           Linear-10              [-1, 64, 256]           8,448
              att-11              [-1, 64, 256]               0
        Rearrange-12            [-1, 64, 4, 64]               0
           Linear-13          

100%|█████| 50/50 [24:39<00:00, 29.60s/it, learning_rate=9.7520e-05, loss=1.802e-01, val_acc=94.37, val_loss=1.943e-01]


Finished Training
Accuracy of the network on the 10000 test images: 94 %
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1               [-1, 64, 64]               0
            Linear-2              [-1, 64, 256]          16,640
           Patches-3              [-1, 64, 256]               0
         LayerNorm-4              [-1, 64, 256]             512
            Linear-5              [-1, 64, 512]         131,072
         LayerNorm-6              [-1, 64, 256]             512
         Rearrange-7          [-1, 64, 4, 8, 8]               0
            Conv3d-8          [-1, 64, 4, 8, 8]          36,928
    spatial_proj3d-9          [-1, 64, 4, 8, 8]               0
        Rearrange-10              [-1, 64, 256]               0
           Linear-11               [-1, 64, 96]          24,576
           Linear-12              [-1, 64, 256]           8,448
              att-13          

100%|█████| 50/50 [24:46<00:00, 29.74s/it, learning_rate=9.7520e-05, loss=9.133e-02, val_acc=96.60, val_loss=1.068e-01]


Finished Training
Accuracy of the network on the 10000 test images: 96 %


NameError: name 'train' is not defined

In [None]:
#CIFAR10

trainset = torchvision.datasets.CIFAR10(root='./data', train = True,
                                        download = True, transform = transform_train
                                         )
train_set, val_set = torch.utils.data.random_split(trainset, [45000, 5000])

trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle = True, num_workers = num_thread)

valloader = torch.utils.data.DataLoader(val_set, batch_size = batch_size,
                                          shuffle = False, num_workers = num_thread)

testset = torchvision.datasets.CIFAR10(root='./data', train = False,
                                       download = True, transform = transform_test
                                         )
testloader = torch.utils.data.DataLoader(testset, batch_size = batch_size,
                                         shuffle = False, num_workers = num_thread)
model_list = list()

#config for gMLP

model_list.append(gMLP_classifier(label = 10).to(device))


#config for reforming the gate by Conv3D

model_list.append(gMLP_classifier(label = 10, gate = gate).to(device))

#config for reforming grouped connect channel proj

model_list.append(gMLP_classifier(label = 10, channel_proj = channel_proj).to(device))


#config for reforming the gate by Conv3D + grouped connect channel proj

model_list.append(gMLP_classifier(label = 10, gate = gate, channel_proj = channel_proj).to(device))


for m in model_list:
  train_loop(m, trainloader, valloader, testloader)
  del m
del trainloader
del valloader
del testloader

Files already downloaded and verified
Files already downloaded and verified
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1              [-1, 64, 192]               0
            Linear-2              [-1, 64, 256]          49,408
           Patches-3              [-1, 64, 256]               0
         LayerNorm-4              [-1, 64, 256]             512
            Linear-5              [-1, 64, 512]         131,072
         LayerNorm-6              [-1, 64, 256]             512
            Conv1d-7              [-1, 64, 256]           4,160
    spatial_proj1d-8              [-1, 64, 256]               0
            Linear-9               [-1, 64, 96]          24,576
           Linear-10              [-1, 64, 256]           8,448
              att-11              [-1, 64, 256]               0
           Linear-12              [-1, 64, 256]          65,792
             gMLP-13       

100%|█████| 50/50 [43:48<00:00, 52.57s/it, learning_rate=9.8774e-05, loss=1.737e+00, val_acc=36.22, val_loss=1.748e+00]


Finished Training
Accuracy of the network on the 10000 test images: 47 %
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1              [-1, 64, 192]               0
            Linear-2              [-1, 64, 256]          49,408
           Patches-3              [-1, 64, 256]               0
         LayerNorm-4              [-1, 64, 256]             512
            Linear-5              [-1, 64, 512]         131,072
         LayerNorm-6              [-1, 64, 256]             512
         Rearrange-7          [-1, 64, 4, 8, 8]               0
            Conv3d-8          [-1, 64, 4, 8, 8]          36,928
    spatial_proj3d-9          [-1, 64, 4, 8, 8]               0
        Rearrange-10              [-1, 64, 256]               0
           Linear-11               [-1, 64, 96]          24,576
           Linear-12              [-1, 64, 256]           8,448
              att-13          

100%|█████| 50/50 [43:21<00:00, 52.03s/it, learning_rate=9.8774e-05, loss=1.532e+00, val_acc=44.84, val_loss=1.542e+00]


Finished Training
Accuracy of the network on the 10000 test images: 54 %
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1              [-1, 64, 192]               0
            Linear-2              [-1, 64, 256]          49,408
           Patches-3              [-1, 64, 256]               0
         LayerNorm-4              [-1, 64, 256]             512
            Linear-5              [-1, 64, 512]         131,072
         LayerNorm-6              [-1, 64, 256]             512
            Conv1d-7              [-1, 64, 256]           4,160
    spatial_proj1d-8              [-1, 64, 256]               0
            Linear-9               [-1, 64, 96]          24,576
           Linear-10              [-1, 64, 256]           8,448
              att-11              [-1, 64, 256]               0
        Rearrange-12            [-1, 64, 4, 64]               0
           Linear-13          

100%|█████| 50/50 [44:51<00:00, 53.84s/it, learning_rate=9.8774e-05, loss=1.821e+00, val_acc=35.06, val_loss=1.817e+00]


Finished Training
Accuracy of the network on the 10000 test images: 42 %
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1              [-1, 64, 192]               0
            Linear-2              [-1, 64, 256]          49,408
           Patches-3              [-1, 64, 256]               0
         LayerNorm-4              [-1, 64, 256]             512
            Linear-5              [-1, 64, 512]         131,072
         LayerNorm-6              [-1, 64, 256]             512
         Rearrange-7          [-1, 64, 4, 8, 8]               0
            Conv3d-8          [-1, 64, 4, 8, 8]          36,928
    spatial_proj3d-9          [-1, 64, 4, 8, 8]               0
        Rearrange-10              [-1, 64, 256]               0
           Linear-11               [-1, 64, 96]          24,576
           Linear-12              [-1, 64, 256]           8,448
              att-13          

100%|█████| 50/50 [45:32<00:00, 54.64s/it, learning_rate=9.8774e-05, loss=1.664e+00, val_acc=38.28, val_loss=1.666e+00]


Finished Training
Accuracy of the network on the 10000 test images: 50 %


In [None]:
#CIFAR100

trainset = torchvision.datasets.CIFAR100(root='./data', train = True,
                                        download = True, transform = transform_train
                                         )
train_set, val_set = torch.utils.data.random_split(trainset, [45000, 5000])

trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle = True, num_workers = num_thread)

valloader = torch.utils.data.DataLoader(val_set, batch_size = batch_size,
                                          shuffle = False, num_workers = num_thread)

testset = torchvision.datasets.CIFAR100(root='./data', train = False,
                                       download = True, transform = transform_test
                                         )
testloader = torch.utils.data.DataLoader(testset, batch_size = batch_size, shuffle = False, num_workers = num_thread)

model_list = list()

gate = nn.Sequential(
       nn.LayerNorm(embd_dim),
       Rearrange('b s (c h w) -> b s c h w', c = 4, h = patch_shape, w = patch_shape),
       spatial_proj3d(),
       Rearrange('b s c h w -> b s (c h w)'),
       )

channel_proj = nn.Sequential(
       Rearrange('b s (c p) -> b s c p', c = 4),
       nn.Linear(patch_shape**2, patch_shape**2),
       Rearrange('b s c p -> b s p c'),
       nn.Linear(4, 4),
       Rearrange('b s p c -> b s (c p)')
       )

#config for gMLP

model_list.append(gMLP_classifier().to(device))


#config for reforming the gate by Conv3D

model_list.append(gMLP_classifier(gate = gate).to(device))


#config for grouped connect channel proj

model_list.append(gMLP_classifier(channel_proj = channel_proj).to(device))


#config for reforming the gate by Conv3D + grouped connect channel proj
'''
gate = nn.Sequential(
       nn.LayerNorm(embd_dim),
       Rearrange('b s (c h w) -> b s c h w', c = 4, h = patch_shape, w = patch_shape),
       spatial_proj3d(),
       Rearrange('b s c h w -> b s c (h w)'),
       nn.Linear(patch_shape**2, patch_shape**2),
       Rearrange('b s c p -> b s p c'),
       nn.Linear(4, 4),
       Rearrange('b s p c -> b s (c p)')
       )
'''
model_list.append(gMLP_classifier(gate = gate, channel_proj = channel_proj).to(device))

for m in model_list:
  train_loop(m, trainloader, valloader, testloader)
  del m
del trainloader
del valloader
del testloader

Files already downloaded and verified
Files already downloaded and verified
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1              [-1, 64, 192]               0
            Linear-2              [-1, 64, 256]          49,408
           Patches-3              [-1, 64, 256]               0
         LayerNorm-4              [-1, 64, 256]             512
            Linear-5              [-1, 64, 512]         131,072
         LayerNorm-6              [-1, 64, 256]             512
            Conv1d-7              [-1, 64, 256]           4,160
    spatial_proj1d-8              [-1, 64, 256]               0
            Linear-9               [-1, 64, 96]          24,576
           Linear-10              [-1, 64, 256]           8,448
              att-11              [-1, 64, 256]               0
           Linear-12              [-1, 64, 256]          65,792
             gMLP-13       

100%|██████| 50/50 [44:35<00:00, 53.51s/it, learning_rate=9.8774e-05, loss=3.972e+00, val_acc=9.80, val_loss=3.956e+00]


Finished Training
Accuracy of the network on the 10000 test images: 14 %
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1              [-1, 64, 192]               0
            Linear-2              [-1, 64, 256]          49,408
           Patches-3              [-1, 64, 256]               0
         LayerNorm-4              [-1, 64, 256]             512
            Linear-5              [-1, 64, 512]         131,072
         LayerNorm-6              [-1, 64, 256]             512
         Rearrange-7          [-1, 64, 4, 8, 8]               0
            Conv3d-8          [-1, 64, 4, 8, 8]          36,928
    spatial_proj3d-9          [-1, 64, 4, 8, 8]               0
        Rearrange-10              [-1, 64, 256]               0
           Linear-11               [-1, 64, 96]          24,576
           Linear-12              [-1, 64, 256]           8,448
              att-13          

100%|█████| 50/50 [43:12<00:00, 51.85s/it, learning_rate=9.8774e-05, loss=3.713e+00, val_acc=14.42, val_loss=3.727e+00]


Finished Training
Accuracy of the network on the 10000 test images: 22 %
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1              [-1, 64, 192]               0
            Linear-2              [-1, 64, 256]          49,408
           Patches-3              [-1, 64, 256]               0
         LayerNorm-4              [-1, 64, 256]             512
            Linear-5              [-1, 64, 512]         131,072
         LayerNorm-6              [-1, 64, 256]             512
            Conv1d-7              [-1, 64, 256]           4,160
    spatial_proj1d-8              [-1, 64, 256]               0
            Linear-9               [-1, 64, 96]          24,576
           Linear-10              [-1, 64, 256]           8,448
              att-11              [-1, 64, 256]               0
        Rearrange-12            [-1, 64, 4, 64]               0
           Linear-13          

100%|██████| 50/50 [43:07<00:00, 51.75s/it, learning_rate=9.8774e-05, loss=3.993e+00, val_acc=8.58, val_loss=4.004e+00]


Finished Training
Accuracy of the network on the 10000 test images: 14 %
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1              [-1, 64, 192]               0
            Linear-2              [-1, 64, 256]          49,408
           Patches-3              [-1, 64, 256]               0
         LayerNorm-4              [-1, 64, 256]             512
            Linear-5              [-1, 64, 512]         131,072
         LayerNorm-6              [-1, 64, 256]             512
         Rearrange-7          [-1, 64, 4, 8, 8]               0
            Conv3d-8          [-1, 64, 4, 8, 8]          36,928
    spatial_proj3d-9          [-1, 64, 4, 8, 8]               0
        Rearrange-10              [-1, 64, 256]               0
           Linear-11               [-1, 64, 96]          24,576
           Linear-12              [-1, 64, 256]           8,448
              att-13          

100%|█████| 50/50 [43:37<00:00, 52.35s/it, learning_rate=9.8774e-05, loss=3.904e+00, val_acc=11.06, val_loss=3.893e+00]


Finished Training
Accuracy of the network on the 10000 test images: 16 %
