<a href="https://colab.research.google.com/github/maxmatical/pytorch-projects/blob/master/Attention_Augmented_Wide_Resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

import numpy as np

from attention_augmented_convnets import augmented_conv2d


torch.Size([16, 20, 16, 16])
16 16 256


In [0]:
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif isinstance(m, nn.Linear):
        n = m.weight.size(1)
        m.weight.data.normal_(0, 0.01)
        m.bias.data.zero_()

In [0]:
class wide_basic(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate, stride = 1, v = 0.2, k = 2, n_heads = 4):
        super(wide_basic, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = augmented_conv2d(in_channels, out_channels, kernel_size=3, dk=k*out_channels, dv = int(v*out_channels), Nh= n_heads,\
                                     relative = True, padding = 1)
        self.dropout = nn.Dropout(dropout_rate)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv2 = augmented_conv2d(out_channels, out_channels, kernel_size=3, dk=k*out_channels, dv = int(v*out_channels), Nh= n_heads,\
                                    stride = stride, relative = True, padding = 1)
        
        self.shortcut = nn.Sequential()
        if stride !=1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(augmented_conv2d(in_channels, out_channels, kernel_size = 3, dk=k*out_channels, dv = int(v*out_channels),\
                                         Nh = n_heads, relative = True, stride = stride, padding = 1),)
            
    def forward(self,x):
        out = self.dropout(self.conv1(F.relu(self.bn1(x))))
        out = self.conv2(F.relu(self.bn2(out)))
        out += self.shortcut(x)

        return out

In [0]:
tmp = torch.randn((16, 3, 32, 32)).to(device)
a = wide_basic(3, 20, dropout_rate = 0.1).to(device)
print(a(tmp).shape)
bs, n_channels, H, W = a(tmp).size()
print(H, W, H*W)

a2 = wide_basic(3, 20, dropout_rate = 0.1, stride = 2).to(device)
print(a2(tmp).shape)
bs, n_channels, H, W = a2(tmp).size()
print(H, W, H*W)

torch.Size([16, 20, 32, 32])
32 32 1024
torch.Size([16, 20, 16, 16])
16 16 256


In [0]:
class WideResnet(nn.Module):
    def __init__(self, depth, widen_factor, dropout_rate, n_classes, layers):
        """
        
        layers should be a list of length 4
        eg [20, 20, 40, 80] NEEDS TO BE >20 FOR split_heads_2d
        layers will be multiplied by widen_factor to get out_channels for each block 
        shape = dimension of the image (shape x shape)
        
        """
        super(WideResnet, self).__init__()
        self.in_channels = 20
        
        assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4'
        n = int((depth-4)/6) # each wide_basic block will have n conv layers
        k = widen_factor
        
        dv_v = 0.2
        dk_k = 2
        Nh = 4
        
        self.conv1 = augmented_conv2d(3, out_channels = layers[0], kernel_size = 3, dk = dk_k*layers[0], dv = int(dv_v * layers[0]),\
                                     Nh = Nh, relative = True)
        self.block1 = nn.Sequential(self.make_layer(wide_basic, layers[1]*k, n, dropout_rate, stride = 1),) # 1st block keeps same dimensions
        self.block2 = nn.Sequential(self.make_layer(wide_basic, layers[2]*k, n, dropout_rate, stride = 2),) # 2nd block reduces dimensions by 1/2
        self.block3 = nn.Sequential(self.make_layer(wide_basic, layers[3]*k, n, dropout_rate, stride = 2),)
        self.bn1 = nn.BatchNorm2d(layers[3]*k, momentum=0.9)
        self.linear = nn.Linear(layers[3]*k*2, n_classes) #*2 because we use both avg pool and max pool
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.apply(init_weights)
        
    def make_layer(self, block, out_channels, n_blocks, dropout_rate, stride):
        strides = [stride] + [1]*(n_blocks-1)
        layers = []
        
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, dropout_rate, stride = stride))
            self.in_channels = out_channels
            
        return nn.Sequential(*layers)


    
    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = F.relu(self.bn1(self.block3(out)))
        a = self.avg_pool(out)
        b = self.max_pool(out)
        a = a.view(a.size(0),-1) # flatten
        b = b.view(b.size(0),-1) # flatten
        out = a.view(out.size(0), -1)
        out = torch.cat([a,b],1)
        out = self.linear(out)
        
        return out
      

In [0]:
tmp = torch.randn((4, 3, 32, 32)).to(device)
layers=[20, 20, 40, 40]

model = WideResnet(28, 10, 0.3, 10, layers).to(device)
print(model(tmp).shape)

torch.Size([4, 10])


In [0]:
net = WideResnet(28, 4, 0.3, 10, [20, 20, 40, 80]).to(device)
# res_block = 

In [0]:
from torchsummary import summary


In [0]:
summary(net, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             448
            Conv2d-2           [-1, 84, 32, 32]           2,352
            Conv2d-3            [-1, 4, 32, 32]              20
  augmented_conv2d-4           [-1, 20, 32, 32]               0
       BatchNorm2d-5           [-1, 20, 32, 32]              40
            Conv2d-6           [-1, 64, 32, 32]          11,584
            Conv2d-7          [-1, 336, 32, 32]          60,816
            Conv2d-8           [-1, 16, 32, 32]             272
  augmented_conv2d-9           [-1, 80, 32, 32]               0
          Dropout-10           [-1, 80, 32, 32]               0
      BatchNorm2d-11           [-1, 80, 32, 32]             160
           Conv2d-12           [-1, 64, 32, 32]          46,144
           Conv2d-13          [-1, 336, 32, 32]         242,256
           Conv2d-14           [-1, 16,