In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
import os 

# os.chdir('/content/drive/My Drive/data_augmentation_techniques')

### Load libraries and packages

In [None]:
import os
import PIL 
import time 
import torch 
import torchvision 
import random

import numpy as np 
import torch.nn as nn 
import matplotlib.pyplot as plt 
import torch.nn.functional as F 

#data transforms
from torch.autograd import Variable 
from torchvision import datasets, transforms 
from PIL import Image, ImageEnhance

#data aug
from augmentation.autoaugment   import CIFAR10Policy 
from augmentation.AugMix.AugMix import AugMixDataset 
from augmentation.cutout        import Cutout 
from augmentation.RandAugment   import RandAugment

#optim and activation
from optim.deepmemory import DeepMemory
from optim.lookahead  import Lookahead 
from optim.radam      import RAdam 
from adamod           import AdaMod
from activations      import Mish

from loss_func.cross_entropy import CrossEntropyLoss #https://github.com/eladhoffer/utils.pytorch
from metrics import AverageMeter, accuracy

### Set Seed for Reproducibility

In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED']=str(seed)
    
set_seed(72)

### Preprocess and Load Data

In [None]:
preprocess = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (.2023, .1994, .2010)),  #CIFAR10
#      transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))      #CIFAR100
#      Cutout(n_holes=1, length=16),               # CutOut
])

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4, fill=128),
#     CIFAR10Policy(),                    # AutoAugment
    preprocess

])

# train_transform.transforms.insert(0, RandAugment(1, 5))                    #RandAugment

test_transform = preprocess

# test_transform = transforms.Compose([
#     transforms.ToTensor(),
#      transforms.Normalize((0.4914, 0.4822, 0.4465), (.2023, .1994, .2010)),  #CIFAR10
#      transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))      #CIFAR100
# ])

In [None]:
batch_size = 4

train_data = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)

test_data= datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)

# load training data in batches
train_loader = torch.utils.data.DataLoader(
#                       AugMixDataset(train_data, preprocess, no_jsd=True),      # Augmix
                      train_data,
                      batch_size=batch_size,
                      num_workers=8,
                      shuffle=True, 
                      pin_memory=True
                      )

# load test data in batches
test_loader = torch.utils.data.DataLoader(test_data,
                      batch_size=batch_size,
                      num_workers=8,
                      shuffle=False,
                      pin_memory=True
                      )

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
print(f'Length of train loader is {len(train_loader)}')
print(f'Length of test loader is {len(test_loader)}')

### Build Model

In [None]:
class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()

        # Depthwise Convolutions
        self.layers = nn.Sequential(
                            nn.Conv2d(in_channels=1*3, out_channels=16*3, kernel_size=3, groups=1, stride=1, padding=1, bias=False),
                            Mish(),
                            nn.BatchNorm2d(num_features=16*3, eps=1e-3, momentum=0.99),

                            nn.Conv2d(in_channels=16*3, out_channels=96, kernel_size=1, groups=8, stride=1, padding=0, bias=False),
                            Mish(),
                            nn.BatchNorm2d(num_features=96, eps=1e-3, momentum=0.99),

                            nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, groups=8, stride=1, padding=1, bias=False),
                            Mish(),
                            nn.BatchNorm2d(num_features=128, eps=1e-3, momentum=0.99),
                            nn.MaxPool2d(2, 2),

                            nn.Conv2d(in_channels=128, out_channels=192, kernel_size=1, groups=16, stride=1, padding=0, bias=False),
                            Mish(),
                            nn.BatchNorm2d(num_features=192, eps=1e-3, momentum=0.99),

                            nn.Conv2d(in_channels=192, out_channels=256, kernel_size=3, groups=16, stride=1, padding=1, bias=False),
                            Mish(),
                            nn.BatchNorm2d(num_features=256, eps=1e-3, momentum=0.99),

                            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1, groups=32, stride=1, padding=0, bias=False),
                            Mish(),
                            nn.BatchNorm2d(num_features=512, eps=1e-3, momentum=0.99),
                            nn.MaxPool2d(2, 2),

                            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, groups=64, stride=1, padding=1, bias=False),
                            Mish(),
                            nn.BatchNorm2d(num_features=512, eps=1e-3, momentum=0.99),

                            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, groups=16, stride=1, padding=0, bias=False),
                            Mish(),
                            nn.BatchNorm2d(num_features=256, eps=1e-3, momentum=0.99),

                            nn.Conv2d(in_channels=256, out_channels=192, kernel_size=3, groups=16, stride=1, padding=1, bias=False),
                            Mish(),
                            nn.BatchNorm2d(num_features=192, eps=1e-3, momentum=0.99),
                            nn.MaxPool2d(2, 2),
                            )


        #squeeze and excitation
        self.se_reduce = nn.Conv2d(in_channels=192, out_channels=128, kernel_size=1)
        self.se_expand = nn.Conv2d(in_channels=128, out_channels=192, kernel_size=1)

        # fully connected layer
        self.fc = nn.Linear(in_features=192*4*4, out_features=10)


    def forward(self, x):
        x = self.layers(x)
        x_squeezed = F.adaptive_avg_pool2d(x, x.size(2))
        x_squeezed = self.se_expand(Mish()(self.se_reduce(x_squeezed)))
        x = torch.sigmoid(x_squeezed) * x
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x
    
model = CustomModel()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = torch.nn.DataParallel(model)

In [None]:
print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])
        ))

### Fresh Training

In [None]:
best_top1 = 0 # train from start
start_epoch = 0

In [None]:
num_epochs = 20
criterion = CrossEntropyLoss(smooth_eps=0.1).to(device)
params = [p for p in model.parameters()]
optimizer = AdaMod(params, lr=0.1, betas=(0.999, 0.9999), weight_decay=1e-5)
# optimizer = DeepMemory(params, betas=(0.999, 0.9999), len_memory=len(train_data.data)//batch_size, weight_decay=1e-4)
optimizer = Lookahead(DeepMemory(params, len_memory=len(train_data.data)//batch_size))
# optimizer = Lookahead(AdaMod(params, betas=(0.999, 0.9999), weight_decay=1e-5))
# optimizer = Lookahead(RAdam(params, lr=0.0015, weight_decay=0.0))          # rectified adam wtih lookahead
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_data.data)//batch_size) 

### Load Checkpoints to resume training

In [None]:
# checkpoint = torch.load('./checkpoint/CustomModel_standard_dmla_ckpt.pth')

# model.module.load_state_dict(checkpoint['model_state_dict'], strict=False)
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# start_epoch = checkpoint['epoch']
# best_top1 = checkpoint['top1']   # resume training
# best_top5 = checkpoint['top5']  

In [None]:
# print(f'Loaded checkpoint with \n {best_top1}% Top-1 Accuracy, {best_top5}% Top-5 Accuracy, after training for {start_epoch} epochs.')

### Train Model

In [None]:
def train(train_loader, model, criterion, optimizer, epoch):
    print('Training model...\n')
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    
    for i, (input, target) in enumerate(train_loader):
                
        # measure data loading time
        data_time.update(time.time() - end)
        
        target = target.to(device)
        input_var = Variable(input)
        target_var = Variable(target)
        
        optimizer.zero_grad()
        
        # compute output
        output = model(input_var) 

        def closure():
            output = model(input_var) 
            loss = criterion(output, target_var)
          
            return loss
            
        
        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(closure().item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))
       
        closure().backward()
        optimizer.step(closure)
        
        lr_scheduler.step(epoch)
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 1500 == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                epoch, i, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5))
    
    print(' * Acc@1 {top1.avg:.3f} Acc@1 Error {top1_err:.3f}\n'
              ' * Acc@5 {top5.avg:.3f} Acc@5 Error {top5_err:.3f}'
              .format(top1=top1, top1_err=100-top1.avg, top5=top5, top5_err=100-top5.avg))

### Accuracy on test data

In [None]:
def validate(test_loader, model, criterion, epoch):
   
    # switch to evaluate mode
    model.eval()
    
    print('Evaluating model on test data...\n')
    
    with torch.no_grad():
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        end = time.time()
        for i, (input, target) in enumerate(test_loader):
            target = target.to(device)
            input = input.to(device)

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target.data, topk=(1, 5))
            losses.update(loss.data.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 250 == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                       i, len(test_loader), batch_time=batch_time, loss=losses,
                       top1=top1, top5=top5))
                
    
        
        print(' * Acc@1 {top1.avg:.3f} Acc@1 Error {top1_err:.3f}\n'
              ' * Acc@5 {top5.avg:.3f} Acc@5 Error {top5_err:.3f}'
              .format(top1=top1, top1_err=100-top1.avg, top5=top5, top5_err=100-top5.avg))

        return top1, top5, losses

### Train, Evaluate and Checkpoint Model

In [None]:
%%time
for epoch in range(start_epoch, num_epochs):
    train(train_loader, model, criterion, optimizer, epoch)
    top1, top5, losses = validate(test_loader, model, criterion, epoch)

    if top1.avg > best_top1:
        print('Saving checkpoint')
        state = {
          'model_state_dict': model.module.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'epoch': epoch,
          'loss': losses.avg, 
          'top1': top1.avg,
          'top5': top5.avg}
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/seed0_ckpt.pth')

        best_top1 = top1.avg

### Accuracy for each class

In [None]:
# class_correct = list(0. for i in range(100))
# class_total = list(0. for i in range(100))
# with torch.no_grad(): 
#     for data in test_loader:
#         images, labels = data
#         outputs = model(images)
#         _, predicted = torch.max(outputs, 1)
#         c = (predicted == labels).squeeze()
#         for i in range(4):
#             label = labels[i]
#             class_correct[label] += c[i].item()
#             class_total[label] += 1


# for i in range(10):
#     print('Accuracy of %5s : %2d %%' % (
#         classes[i], 100 * class_correct[i] / class_total[i]))