In [1]:
from torchvision import models
from torchsummary import summary 
from torchvision import models
import torchvision
import torchvision.transforms as transforms
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Loading CIFAR10 dataset

In [3]:
batch_size = 64
all_transforms = transforms.Compose([transforms.Resize((32,32)),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                                          std=[0.2023, 0.1994, 0.2010])
                                     ])
# Create Training dataset
train_dataset = torchvision.datasets.CIFAR10(root = './data',
                                             train = True,
                                             transform = all_transforms,
                                             download = True)

# Create Testing dataset
test_dataset = torchvision.datasets.CIFAR10(root = './data',
                                            train = False,
                                            transform = all_transforms,
                                            download=True)

# Instantiate loader objects to facilitate processing
train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)


test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)

Files already downloaded and verified
Files already downloaded and verified


### 2D SelfAttention model

In [4]:
## Source of attention model is 
__all__ = ["SelfAttention2d"]


class SelfAttention2d(nn.Module):
    r"""Self Attention Module as proposed in the paper `"Self-Attention Generative Adversarial
    Networks by Han Zhang et. al." <https://arxiv.org/abs/1805.08318>`_
    .. math:: attention = softmax((query(x))^T * key(x))
    .. math:: output = \gamma * value(x) * attention + x
    where
    - :math:`query` : 2D Convolution Operation
    - :math:`key` : 2D Convolution Operation
    - :math:`value` : 2D Convolution Operation
    - :math:`x` : Input
    Args:
        input_dims (int): The input channel dimension in the input ``x``.
        output_dims (int, optional): The output channel dimension. If ``None`` the output
            channel value is computed as ``input_dims // 8``. So if the ``input_dims`` is **less
            than 8** then the layer will give an error.
        return_attn (bool, optional): Set it to ``True`` if you want the attention values to be
            returned.
    """

    def __init__(self, input_dims, output_dims=None, return_attn=False):
        output_dims = input_dims // 8 if output_dims is None else output_dims
        if output_dims == 0:
            raise Exception(
                "The output dims corresponding to the input dims is 0. Increase the input\
                            dims to 8 or more. Else specify output_dims"
            )
        super(SelfAttention2d, self).__init__()
        self.query = nn.Conv2d(input_dims, output_dims, 1)
        self.key = nn.Conv2d(input_dims, output_dims, 1)
        self.value = nn.Conv2d(input_dims, input_dims, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.return_attn = return_attn

    def forward(self, x):
        #print(x.shape)  # [bs,64,5,5]
        r"""Computes the output of the Self Attention Layer
        Args:
            x (torch.Tensor): A 4D Tensor with the channel dimension same as ``input_dims``.
        Returns:
            A tuple of the ``output`` and the ``attention`` if ``return_attn`` is set to ``True``
            else just the ``output`` tensor.
        """
        dims = (x.size(0), -1, x.size(2) * x.size(3)) # [2,-1,5*5]
        out_query = self.query(x).view(dims)
        #print('this is the query',out_query.shape) #  torch.Size([2, 8, 25])
        out_key = self.key(x).view(dims).permute(0, 2, 1) 
        #print('this is the key',out_key.shape) #  torch.Size([2, 25, 8])
        attn = F.softmax(torch.bmm(out_key, out_query), dim=-1) #  torch.Size([2, 25, 25])
        #print('shape of attn var', attn.shape)
        out_value = self.value(x).view(dims)
        #print('shape of out_value var', out_value.shape)

        out_value = torch.bmm(out_value, attn).view(x.size())
        out = self.gamma * out_value + x
        if self.return_attn:
            return out, attn
        return out

In [6]:


# Creating a CNN class
class CNN_At(nn.Module):
	#  Determine what layers and their order in CNN object 
    def __init__(self, num_classes=10):
        super(CNN_At, self).__init__()
        self.conv_layer1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3,padding = 0) 
        self.conv_layer2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3)
        self.max_pool1 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        
        self.conv_layer3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.conv_layer4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3)
        self.max_pool2 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.atten = SelfAttention2d(64)
        self.bn = nn.BatchNorm2d(64)

        
        self.fc1 = nn.Linear(1600, 128)# [64 * 5 * 5]
        self.dropout = nn.Dropout(p=0.1)

       
        self.relu1 = nn.ReLU()
        
        self.fc2 = nn.Linear(128, num_classes)
    
    # Progresses data across layers    
    def forward(self, x):
        out = self.conv_layer1(x)
        out = self.conv_layer2(out)
        out = self.max_pool1(out)
        
        out = self.conv_layer3(out)
        out = self.conv_layer4(out)
        out = self.max_pool2(out)
        out = self.atten(out)
        out = self.bn(out)

                
        out = out.reshape(out.size(0), -1)
        
        out = self.fc1(out)
        out = self.dropout(out)

        out = self.relu1(out)
        
        out = self.fc2(out)
        return out
    
    
model_atten = CNN_At()
model_atten = model_atten.to(device=device,dtype=torch.float)
summary(model_atten,(3,32,32))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 30, 30]             896
            Conv2d-2           [-1, 32, 28, 28]           9,248
         MaxPool2d-3           [-1, 32, 14, 14]               0
            Conv2d-4           [-1, 64, 12, 12]          18,496
            Conv2d-5           [-1, 64, 10, 10]          36,928
         MaxPool2d-6             [-1, 64, 5, 5]               0
            Conv2d-7              [-1, 8, 5, 5]             520
            Conv2d-8              [-1, 8, 5, 5]             520
            Conv2d-9             [-1, 64, 5, 5]           4,160
  SelfAttention2d-10             [-1, 64, 5, 5]               0
      BatchNorm2d-11             [-1, 64, 5, 5]             128
           Linear-12                  [-1, 128]         204,928
          Dropout-13                  [-1, 128]               0
             ReLU-14                  [

### Training

In [7]:
batch_size = 64
num_classes = 10
learning_rate = 0.001
num_epochs = 20

model_atten = model_atten.to(device=device,dtype=torch.float)
# Set Loss function with criterion
criterion = nn.CrossEntropyLoss()

# Set optimizer with optimizer
#optimizer = torch.optim.SGD(model_atten.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)  
#optimizer = torch.optim.Adam(model_atten.parameters(), lr=learning_rate, weight_decay = 0.005)  
optimizer = torch.optim.Adam(model_atten.parameters(),lr=0.0001)
total_step = len(train_loader)

# We use the pre-defined number of epochs to determine how many iterations to train the network on
for epoch in range(num_epochs):
	#Load in the data in batches using the train_loader object
    for i, (images, labels) in enumerate(train_loader):  
        # Move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model_atten(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

       


Epoch [1/20], Loss: 1.0352
Epoch [2/20], Loss: 1.1540
Epoch [3/20], Loss: 0.6718
Epoch [4/20], Loss: 0.8419
Epoch [5/20], Loss: 0.5113
Epoch [6/20], Loss: 1.0275
Epoch [7/20], Loss: 1.5056
Epoch [8/20], Loss: 0.6464
Epoch [9/20], Loss: 0.5607
Epoch [10/20], Loss: 0.4222
Epoch [11/20], Loss: 0.3412
Epoch [12/20], Loss: 0.7515
Epoch [13/20], Loss: 0.4577
Epoch [14/20], Loss: 0.7822
Epoch [15/20], Loss: 0.8287
Epoch [16/20], Loss: 0.2697
Epoch [17/20], Loss: 0.3677
Epoch [18/20], Loss: 0.3275
Epoch [19/20], Loss: 0.4244
Epoch [20/20], Loss: 0.4315


### Testing

In [9]:
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model_atten(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    print('Accuracy of the network on the {} train images: {} %'.format(50000, 100 * correct / total))



Accuracy of the network on the 50000 train images: 90.266 %
