In [1]:
import torch
import os
import shutil
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pickle
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import time

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!pip install timm

Collecting timm
  Downloading timm-0.9.12-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.9.12


In [4]:
!git clone -b AHNO-R2-Transformer https://github.com/jaysulk/AFNO-transformer.git /content/drive/MyDrive/Github/AFNO-transformer

Cloning into '/content/drive/MyDrive/Github/AFNO-transformer'...
remote: Enumerating objects: 422, done.[K
remote: Counting objects: 100% (299/299), done.[K
remote: Compressing objects: 100% (167/167), done.[K
remote: Total 422 (delta 219), reused 180 (delta 132), pack-reused 123[K
Receiving objects: 100% (422/422), 75.66 MiB | 14.42 MiB/s, done.
Resolving deltas: 100% (265/265), done.
Updating files: 100% (48/48), done.


In [5]:
!pip install /content/drive/MyDrive/Github/AFNO-transformer/

Processing ./drive/MyDrive/Github/AFNO-transformer
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: afno
  Building wheel for afno (setup.py) ... [?25l[?25hdone
  Created wheel for afno: filename=afno-0.0.1-py3-none-any.whl size=16426 sha256=bd51cdfb3ebae2212e9612a2cb31c469b7c6b99b45125ed6a90db8f6a987b44f
  Stored in directory: /root/.cache/pip/wheels/77/48/0f/87c770d240c74cb454027fb3d5af919cb6ed10857b9caf1bda
Successfully built afno
Installing collected packages: afno
Successfully installed afno-0.0.1


In [6]:
from afno import AFNO1D,AFNO2D

In [7]:
checkpoints = '/content/drive/MyDrive/colab_files/imagenet64/'
if not os.path.exists(checkpoints):
    os.makedirs(checkpoints)

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Assuming you've already imported or defined AFNO2D

import torch.nn as nn
import torch.nn.functional as F

import torch.nn as nn

class ModelWithAFNO(nn.Module):
    def __init__(self):
        super(ModelWithAFNO, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            # Additional layers for increased depth
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
        )

        self.afno = AFNO2D(hidden_size=512, num_blocks=8, sparsity_threshold=0.01,
                          hard_thresholding_fraction=1, hidden_size_factor=1)

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Dropout(0.7),  # Increased dropout
            nn.Linear(512, 1000),
        )

    def forward(self, x):
        x = self.features(x)
        B, C, H, W = x.size()
        x = x.reshape(B, H*W, C)
        x = self.afno(x)
        x = x.reshape(B, C, H, W)
        x = self.classifier(x)
        return x

# Model instantiation
model = ModelWithAFNO()

In [9]:
def train(net, dataloader, epochs=1, start_epoch=0, lr=0.01, momentum=0.9, decay=0.0005,
          verbose=1, print_every=10, state=None, schedule={}, checkpoint_path=None):
  net.to(device)
  net.train()
  losses = []
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=decay)

  # Load previous training state
  if state:
      net.load_state_dict(state['net'])
      optimizer.load_state_dict(state['optimizer'])
      start_epoch = state['epoch']
      losses = state['losses']

  # Fast forward lr schedule through already trained epochs
  for epoch in range(start_epoch):
    if epoch in schedule:
      print ("Learning rate: %f"% schedule[epoch])
      for g in optimizer.param_groups:
        g['lr'] = schedule[epoch]

  for epoch in range(start_epoch, epochs):
    sum_loss = 0.0

    # Update learning rate when scheduled
    if epoch in schedule:
      print ("Learning rate: %f"% schedule[epoch])
      for g in optimizer.param_groups:
        g['lr'] = schedule[epoch]

    for i, batch in enumerate(dataloader, 0):
        inputs, labels = batch[0].to(device), batch[1].to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()  # autograd magic, computes all the partial derivatives
        optimizer.step() # takes a step in gradient direction

        losses.append(loss.item())
        sum_loss += loss.item()

        if i % print_every == print_every-1:    # print every 10 mini-batches
            if verbose:
              print('[%d, %5d] loss: %.3f' % (epoch, i + 1, sum_loss / print_every))
            sum_loss = 0.0
    if checkpoint_path:
      state = {'epoch': epoch+1, 'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'losses': losses}
      torch.save(state, checkpoint_path + 'checkpoint-%d.pkl'%(epoch+1))
  return losses

def accuracy(net, dataloader):
  net.to(device)
  net.eval()
  correct = 0
  total = 0
  with torch.no_grad():
      for batch in dataloader:
          images, labels = batch[0].to(device), batch[1].to(device)
          outputs = net(images)
          _, predicted = torch.max(outputs.data, 1)
          total += labels.size(0)
          correct += (predicted == labels).sum().item()
  return correct/total

def smooth(x, size):
  return np.convolve(x, np.ones(size)/size, mode='valid')

In [11]:
def get_cifar10_data(augmentation=0):
  # Data augmentation transformations. Not for Testing!
  if augmentation:
    transform_train = transforms.Compose([
      transforms.RandomCrop(32, padding=4, padding_mode='edge'), # Take 32x32 crops from 40x40 padded images
      transforms.RandomHorizontalFlip(),    # 50% of time flip image along y-axis
      transforms.ToTensor(),
    ])
  else:
    transform_train = transforms.ToTensor()

  transform_test = transforms.Compose([
    transforms.ToTensor(),
  ])

  trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,
                                        transform=transform_train)
  trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True,
                                            num_workers=2)

  testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,
                                      transform=transform_test)
  testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False,
                                          num_workers=2)
  classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
  return {'train': trainloader, 'test': testloader, 'classes': classes}

cifar_data = get_cifar10_data(augmentation=1)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
start_time = time.time()
cifar_losses = train(model, cifar_data['train'], epochs=20, schedule={0:.01, 5:.001, 15:.0001}, print_every=1000)
print("Testing accuracy: %f" % accuracy(model, cifar_data['test']))
training_time = time.time() - start_time

Learning rate: 0.010000


In [None]:
training_time

In [None]:
plt.plot(smooth(cifar_losses,50))

In [None]:
with open('/content/drive/My Drive/AHNO_accuracy.pkl', 'wb') as file:
    # Pickle the list and write it to the file
    pickle.dump(cifar_losses, file)