# Building model archicture that uses self-attention

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [44]:
"""
simulated input of categorical embeddings
"""
x = torch.rand(32, 10, 3)
x.shape

b, t, k = x.size()

In [106]:
class SimpleSelfAttention(nn.Module):
    """
    Simple self attention will give us the attention score of the input, calculating attention on itself. 
    For this case, it will be for categorical input.
    
    This serves as a sub-layer
    """
    def __init__(self, k):
        super().__init__()
        
        self.keys_weights = nn.Linear(k, k, bias=False)
        self.queries_weights = nn.Linear(k, k, bias=False)
        self.values_weights = nn.Linear(k, k, bias=False)
        
    def forward(self, x):
        """
        ::param x: shape (b, t, k)
            b = batch
            t = categorical var
            k = embedding dimension
        """
        # Grabbing input size
        b, t, k = x.size()
        
        # calculating our keys, queries, and values from our input
        keys = self.keys_weights(x)
        queries = self.queries_weights(x)
        values = self.values_weights(x)
        
        # calculating our normalizer, this will be used to scale down the dot_scores
        norm = math.sqrt(k)
        
        # performing batch matrix-multiply on queries and keys tranpose & further scaling down with our norm
        dot_scores = torch.bmm(queries, keys.transpose(1, 2))
        scaled_dot_scores = dot_scores/norm
        
        # applying softmax to our dot_scores, this will push importance
         
        
        # calculating our attention
        attention = torch.bmm(attention_pre, values)
        
        return attention

In [254]:
class AttentionLayerBlock(nn.Module):
    """
    Performs pooling over the categorical dimensions 
    """
    def __init__(self, k, max_pool=True):
        super().__init__()
        self.attention_layer = SimpleSelfAttention(k)
        self.max_pool = max_pool
        
        
    def forward(self, x):
        x = self.attention_layer(x)
        x = x.max(dim=1)[0] if self.max_pool else x.mean(dim=1)
        return x

In [260]:
attn_block = AttentionLayerBlock(k, max_pool=False)

In [261]:
attention = attn_block(x)

In [262]:
attention

tensor([[ 0.2706,  0.2053, -0.0248],
        [ 0.2774,  0.1599,  0.0089],
        [ 0.3043,  0.2277, -0.0225],
        [ 0.2504,  0.1806, -0.0128],
        [ 0.3104,  0.1672,  0.0205],
        [ 0.2440,  0.1314,  0.0267],
        [ 0.2508,  0.2096, -0.0313],
        [ 0.2823,  0.2493, -0.0513],
        [ 0.2803,  0.1878, -0.0079],
        [ 0.3336,  0.2708, -0.0475],
        [ 0.2890,  0.2231, -0.0283],
        [ 0.3097,  0.2456, -0.0317],
        [ 0.2539,  0.2162, -0.0407],
        [ 0.2522,  0.2006, -0.0335],
        [ 0.2520,  0.1822, -0.0137],
        [ 0.2889,  0.2277, -0.0354],
        [ 0.2202,  0.2258, -0.0583],
        [ 0.3070,  0.2606, -0.0470],
        [ 0.2477,  0.2153, -0.0409],
        [ 0.2158,  0.1679, -0.0172],
        [ 0.2634,  0.2083, -0.0318],
        [ 0.3240,  0.2701, -0.0457],
        [ 0.3022,  0.1993, -0.0041],
        [ 0.2923,  0.2954, -0.0758],
        [ 0.2878,  0.2217, -0.0250],
        [ 0.2495,  0.2085, -0.0328],
        [ 0.2800,  0.1723,  0.0025],
 

In [263]:
attention.shape

torch.Size([32, 3])

In [300]:
attention_pre.shape

torch.Size([32, 10, 10])

In [319]:
output.shape

torch.Size([32, 10, 3])

In [318]:
output[0].sum(1)

tensor([1.3297, 1.7412, 1.6062, 1.4535, 0.6836, 2.5775, 1.4052, 1.8707, 2.3253,
        1.7028], grad_fn=<SumBackward1>)

In [305]:
output[0]

tensor([[0.0759, 0.5773, 0.6765],
        [0.1951, 0.7028, 0.8433],
        [0.6848, 0.6541, 0.2674],
        [0.5399, 0.7300, 0.1835],
        [0.1834, 0.4355, 0.0648],
        [0.5853, 1.1679, 0.8243],
        [0.6805, 0.3567, 0.3680],
        [0.5705, 1.2173, 0.0829],
        [0.4692, 0.8567, 0.9994],
        [0.3824, 1.0378, 0.2826]], grad_fn=<SelectBackward>)

In [316]:
output.mean(dim=1).sum(1)

tensor([1.6696, 1.7049, 1.7893, 1.4567, 1.8641, 1.2247, 1.4176, 1.7577, 1.7115,
        2.1905, 1.7587, 1.8086, 1.5924, 1.6561, 1.4846, 1.8619, 1.3144, 1.8464,
        1.5278, 1.1941, 1.6895, 1.9896, 1.7988, 1.7594, 1.6538, 1.4546, 1.7080,
        1.6242, 1.0824, 1.6065, 1.4762, 1.7519], grad_fn=<SumBackward1>)

tensor([[[9.1694e-02, 3.4119e-01, 6.6736e-01],
         [2.1050e-01, 4.6857e-01, 8.3483e-01],
         [6.9794e-01, 4.3447e-01, 2.6346e-01],
         [5.5351e-01, 5.0826e-01, 1.7902e-01],
         [1.9837e-01, 2.0553e-01, 5.7674e-02],
         [5.9898e-01, 9.4449e-01, 8.1931e-01],
         [6.9399e-01, 1.3484e-01, 3.6328e-01],
         [5.8353e-01, 9.9917e-01, 7.9711e-02],
         [4.8362e-01, 6.2805e-01, 9.9264e-01],
         [3.9649e-01, 8.1342e-01, 2.7736e-01]],

        [[3.9637e-01, 5.1026e-01, 3.6192e-01],
         [5.7675e-01, 9.9568e-01, 2.5280e-01],
         [6.1612e-01, 5.4703e-01, 2.1302e-01],
         [5.9638e-01, 8.9310e-01, 2.0845e-01],
         [8.7639e-01, 8.2103e-01, 3.6958e-01],
         [6.8702e-01, 4.5016e-01, 1.2440e-01],
         [8.4967e-01, 5.7574e-01, 6.2913e-01],
         [2.0731e-01, 7.0952e-01, 7.0851e-02],
         [3.9546e-01, 1.1094e-01, 1.8050e-01],
         [1.6438e-01, 9.2046e-02, 8.7838e-01]],

        [[3.5679e-01, 3.1303e-01, 8.2874e-01],
         