<a href="https://colab.research.google.com/github/supertramp2/Colab/blob/main/SelfAttnCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torchvision
import torchvision.utils as utils
import torchvision.transforms as transforms
import numpy as np
import random
import matplotlib.pyplot as plt

In [2]:
def _worker_init_fn_():
    torch_seed = torch.initial_seed()
    np_seed = torch_seed // 2**32-1
    random.seed(torch_seed)
    np.random.seed(np_seed)

In [3]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

In [4]:
trainset = torchvision.datasets.CIFAR100(root='CIFAR100_data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size= 64, shuffle=True, num_workers=2, worker_init_fn=_worker_init_fn_())
testset = torchvision.datasets.CIFAR100(root='CIFAR100_data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=50, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to CIFAR100_data/cifar-100-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=169001437.0), HTML(value='')))


Extracting CIFAR100_data/cifar-100-python.tar.gz to CIFAR100_data
Files already downloaded and verified


In [5]:
class ConvBlock(nn.Module):
    def __init__(self, in_features, out_features, num_conv, pool=False):
        super(ConvBlock, self).__init__()
        features = [in_features] + [out_features for i in range(num_conv)]
        layers = []
        for i in range(len(features)-1):
            layers.append(nn.Conv2d(in_channels=features[i], out_channels=features[i+1], kernel_size=3, padding=1, bias=True))
            layers.append(nn.BatchNorm2d(num_features=features[i+1], affine=True, track_running_stats=True))
            layers.append(nn.ReLU())
            if pool:
                layers.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
        self.op = nn.Sequential(*layers)
    def forward(self, x):
        return self.op(x)
        

In [6]:
class ProjectorBlock(nn.Module):
    def __init__(self, in_features, out_features):
        super(ProjectorBlock, self).__init__()
        self.op = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=1, padding=0, bias=False)
    def forward(self, inputs):
        return self.op(inputs)
        

In [7]:
class LinearAttentionBlock(nn.Module):
    def __init__(self, in_features, normalize_attn=True):
        super(LinearAttentionBlock, self).__init__()
        self.normalize_attn = normalize_attn
        self.op = nn.Conv2d(in_channels=in_features, out_channels=1, kernel_size=1, padding=0, bias=False)
    def forward(self, l, g):
        N, C, W, H = l.size()
        c = self.op(l+g) # batch_sizex1xWxH
        if self.normalize_attn:
            a = F.softmax(c.view(N,1,-1), dim=2).view(N,1,W,H)
        else:
            a = torch.sigmoid(c)
        g = torch.mul(a.expand_as(l), l)
        if self.normalize_attn:
            g = g.view(N,C,-1).sum(dim=2) # batch_sizexC
        else:
            g = F.adaptive_avg_pool2d(g, (1,1)).view(N,C)
        return c.view(N,1,W,H), g
        

In [8]:
class GridAttentionBlock(nn.Module):
    def __init__(self, in_features_l, in_features_g, attn_features, up_factor, normalize_attn=False):
        super(GridAttentionBlock, self).__init__()
        self.up_factor = up_factor
        self.normalize_attn = normalize_attn
        self.W_l = nn.Conv2d(in_channels=in_features_l, out_channels=attn_features, kernel_size=1, padding=0, bias=False)
        self.W_g = nn.Conv2d(in_channels=in_features_g, out_channels=attn_features, kernel_size=1, padding=0, bias=False)
        self.phi = nn.Conv2d(in_channels=attn_features, out_channels=1, kernel_size=1, padding=0, bias=True)
    
    def forward(self, l, g):
        N, C, W, H = l.size()
        l_ = self.W_l(l)
        g_ = self.W_g(g)
        if self.up_factor > 1:
            g_ = F.interpolate(g_, scale_factor=self.up_factor, mode='bilinear', align_corners=False)
        c = self.phi(F.relu(l_ + g_)) # batch_sizex1xWxH
        # compute attn map
        if self.normalize_attn:
            a = F.softmax(c.view(N,1,-1), dim=2).view(N,1,W,H)
        else:
            a = torch.sigmoid(c)
        # re-weight the local feature
        f = torch.mul(a.expand_as(l), l) # batch_sizexCxWxH
        if self.normalize_attn:
            output = f.view(N,C,-1).sum(dim=2) # weighted sum
        else:
            output = F.adaptive_avg_pool2d(f, (1,1)).view(N,C)
            
        return c.view(N,1,W,H), output

In [9]:
def weights_init_xavierNormal(module):
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_normal_(m.weight, gain=np.sqrt(2))
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, val=0.)
        
        elif isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight, gain=np.sqrt(2))
            if m.bias is not None:
                nn.init.constant_(m.bias, val=0.)

In [10]:
class AttnVGG(nn.Module):
  def __init__(self, im_size, num_classes, attention=True, normalize_attn=True):
    super(AttnVGG, self).__init__()
    self.attention = attention

    self.cv1 = ConvBlock(3,64,2)
    self.cv2 = ConvBlock(64,128, 2)
    self.cv3 = ConvBlock(128,256,3)
    self.cv4 = ConvBlock(256,512,3)
    self.cv5 = ConvBlock(512,512,3)
    self.cv6 = ConvBlock(512,512,2, pool=True)
    self.dense = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = int(im_size/32), padding = 0, bias = True)
    #Attention = True
    self.projector = ProjectorBlock(256, 512)
    self.attn1 = LinearAttentionBlock(in_features=512, normalize_attn= normalize_attn)
    self.attn2 = LinearAttentionBlock(in_features=512, normalize_attn= normalize_attn)
    self.attn3 = LinearAttentionBlock(in_features=512, normalize_attn= normalize_attn)      
    #Final Classification Layer
    self.classify = nn.Linear(in_features = 512 * 3, out_features = num_classes, bias = True)
    #weight = U [-(1/sqrt(n)), 1/sqrt(n)]
    weights_init_xavierNormal(self)
  
  def forward(self, x):
    x = self.cv1(x)
    x = self.cv2(x)

    l1 = self.cv3(x)
    x = F.max_pool2d(l1, kernel_size = 2, stride = 2, padding = 0)

    l2 = self.cv4(x)
    x = F.max_pool2d(l2, kernel_size = 2, stride = 2, padding = 0)

    l3 = self.cv5(x)
    x = F.max_pool2d(l3, kernel_size = 2, stride = 2, padding = 0)

    x = self.cv6(x)
    g = self.dense(x)

    #Attention part
    c1, g1 = self.attn1(self.projector(l1), g)
    c2, g2 = self.attn2(l2, g)
    c3, g3 = self.attn3(l3, g)
    g = torch.cat((g1,g2,g3), dim=1) # batch_sizexC
    
    # classification layer
    x = self.classify(g) # batch_sizexnum_classes

    return [x, c1, c2, c3]


In [11]:
%mkdir logs

In [12]:
%matplotlib inline
def show(img):
    npimg = img.cpu().numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

In [17]:
def train():
  net = AttnVGG(im_size= 32, num_classes=100)
  criterion = nn.CrossEntropyLoss()
  epochs = 300
  device = torch.device("cuda")
  device_ids = [0,]
  model = nn.DataParallel(net, device_ids=device_ids).to(device)
  criterion.to(device)
  optimizer = optim.SGD(model.parameters(), lr= 0.1, momentum=0.9, weight_decay=5e-4)
  lr_lambda = lambda epoch : np.power(0.5, int(epoch/25))
  scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
  
  step = 0
  running_avg_accuracy = 0

  for epoch in range(epochs):
    images_disp = []
    print("\nepoch %d learning rate %f\n" % (epoch, optimizer.param_groups[0]['lr']))
    for aug in range(3):
      for i, data in enumerate(trainloader, 0):
        model.train()
        model.zero_grad()
        optimizer.zero_grad()
        
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        if aug == 0 and i == 0:
          images_disp.append(inputs[0:36, :,:,:])
          # forward
          pred, __, __, __ = model(inputs)
          # backward
          loss = criterion(pred, labels)
          loss.backward()
          optimizer.step()
          # display results
          if i % 10 == 0:
              model.eval()
              pred, __, __, __ = model(inputs)
              predict = torch.argmax(pred, 1)
              total = labels.size(0)
              correct = torch.eq(predict, labels).sum().double().item()
              accuracy = correct / total
              running_avg_accuracy = 0.9*running_avg_accuracy + 0.1*accuracy
              
              print("[epoch %d][aug %d/%d][%d/%d] loss %.4f accuracy %.2f%% running avg accuracy %.2f%%"
                  % (epoch, aug, 2, i, len(trainloader)-1, loss.item(), (100*accuracy), (100*running_avg_accuracy)))
          step += 1        

    torch.save(model.state_dict(), os.path.join("logs", 'net.pth'))

    if epoch == 150:
      torch.save(model.state_dict(), os.path.join("logs", 'net%d.pth' % epoch))
    
    model.eval()

    total = 0
    correct = 0
    with torch.no_grad():
        # log scalars
        for i, data in enumerate(testloader, 0):
            images_test, labels_test = data
            images_test, labels_test = images_test.to(device), labels_test.to(device)
            if i == 0: # archive images in order to save to logs
                images_disp.append(inputs[0:36,:,:,:])
            pred_test, __, __, __ = model(images_test)
            predict = torch.argmax(pred_test, 1)
            total += labels_test.size(0)
            correct += torch.eq(predict, labels_test).sum().double().item()
        
        print("\n[epoch %d] accuracy on test data: %.2f%%\n" % (epoch, 100*correct/total))
      
          #I_train = utils.make_grid(images_disp[0], nrow=6, normalize=True, scale_each=True)
          #show(I_train)
          #if epoch == 0:
                    #I_test = utils.make_grid(images_disp[1], nrow=6, normalize=True, scale_each=True)
                    #show(I_test)

In [None]:
train()


epoch 0 learning rate 0.100000

[epoch 0][aug 0/2][0/781] loss 4.6051 accuracy 6.25% running avg accuracy 0.62%

[epoch 0] accuracy on test data: 1.00%


epoch 1 learning rate 0.100000

[epoch 1][aug 0/2][0/781] loss 4.6045 accuracy 3.12% running avg accuracy 0.88%

[epoch 1] accuracy on test data: 1.00%


epoch 2 learning rate 0.100000

[epoch 2][aug 0/2][0/781] loss 4.6080 accuracy 0.00% running avg accuracy 0.79%

[epoch 2] accuracy on test data: 1.00%


epoch 3 learning rate 0.100000

[epoch 3][aug 0/2][0/781] loss 4.6051 accuracy 0.00% running avg accuracy 0.71%

[epoch 3] accuracy on test data: 1.00%


epoch 4 learning rate 0.100000

[epoch 4][aug 0/2][0/781] loss 4.6120 accuracy 0.00% running avg accuracy 0.64%

[epoch 4] accuracy on test data: 1.00%


epoch 5 learning rate 0.100000

[epoch 5][aug 0/2][0/781] loss 4.6020 accuracy 1.56% running avg accuracy 0.73%

[epoch 5] accuracy on test data: 1.00%


epoch 6 learning rate 0.100000

[epoch 6][aug 0/2][0/781] loss 4.5988 accur