In [59]:
import numpy as np
import pandas as pd
import torch.nn as nn
import torch

In [62]:
inputs_emb = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)
batch = torch.stack((inputs_emb,inputs_emb))
batch.shape

torch.Size([2, 6, 3])

In [100]:
class CasualAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super(CasualAttention, self).__init__()
        self.d_out=d_out

        # Trainable weights
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries@keys.transpose(1,2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens],-torch.inf )

        # keep in note that attn_weights are actualy non trainable weights. the are just scaling the value matrices. 
        attn_weights  = torch.softmax(attn_scores/keys.shape[0]**0.5, dim=1)
        
        #Apply an additional dropout mask
        # (upper right) to zero out additional attention weights to reduce overfitting during training
        attn_weights=self.dropout(attn_weights)
        
        context_vec = attn_weights@values
        return context_vec, attn_weights, self.W_query


In [112]:
torch.manual_seed(234)
context_length = batch.shape[1]
d_in = batch.shape[-1]
d_out=10
ca = CasualAttention(d_in, d_out, context_length,0.0)
context_vectors, attn_w, trainable_w= ca(batch)


In [113]:
context_vectors

tensor([[[-0.0597,  0.0536,  0.0659, -0.0358,  0.0436,  0.0069, -0.0105,
           0.0557,  0.0589, -0.0707],
         [-0.1557,  0.1229,  0.1643, -0.0980,  0.0727,  0.0099, -0.0601,
           0.0760,  0.1096, -0.0888],
         [-0.2828,  0.2121,  0.2956, -0.1768,  0.1122,  0.0156, -0.1239,
           0.1118,  0.1775, -0.1210],
         [-0.5215,  0.4060,  0.5396, -0.3458,  0.2033,  0.0199, -0.2383,
           0.1585,  0.3250, -0.1803],
         [-0.9754,  0.4949,  1.0260, -0.4512,  0.1846,  0.0889, -0.5248,
           0.4315,  0.3808, -0.2649],
         [-1.0276,  0.8671,  1.0486, -0.7517,  0.4156,  0.0121, -0.4834,
           0.1836,  0.6621, -0.2632]],

        [[-0.0597,  0.0536,  0.0659, -0.0358,  0.0436,  0.0069, -0.0105,
           0.0557,  0.0589, -0.0707],
         [-0.1557,  0.1229,  0.1643, -0.0980,  0.0727,  0.0099, -0.0601,
           0.0760,  0.1096, -0.0888],
         [-0.2828,  0.2121,  0.2956, -0.1768,  0.1122,  0.0156, -0.1239,
           0.1118,  0.1775, -0.1210],

In [114]:
attn_w

tensor([[[0.1543, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1374, 0.1487, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1403, 0.1524, 0.1794, 0.0000, 0.0000, 0.0000],
         [0.1795, 0.2198, 0.2585, 0.3259, 0.0000, 0.0000],
         [0.2457, 0.3119, 0.3649, 0.3919, 0.5741, 0.0000],
         [0.1428, 0.1672, 0.1972, 0.2822, 0.4259, 1.0000]],

        [[0.1543, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1374, 0.1487, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1403, 0.1524, 0.1794, 0.0000, 0.0000, 0.0000],
         [0.1795, 0.2198, 0.2585, 0.3259, 0.0000, 0.0000],
         [0.2457, 0.3119, 0.3649, 0.3919, 0.5741, 0.0000],
         [0.1428, 0.1672, 0.1972, 0.2822, 0.4259, 1.0000]]],
       grad_fn=<SoftmaxBackward0>)

In [115]:
trainable_w.weight

Parameter containing:
tensor([[ 0.4938,  0.0943, -0.4663],
        [-0.5758,  0.0860,  0.2456],
        [-0.1564, -0.5161, -0.4219],
        [ 0.4813,  0.4558, -0.3578],
        [ 0.3532,  0.5202,  0.2513],
        [-0.1710,  0.5219,  0.0773],
        [ 0.2211, -0.4895, -0.0674],
        [ 0.4343, -0.0683, -0.3715],
        [ 0.5294, -0.5342, -0.1516],
        [-0.0368,  0.4616,  0.5172]], requires_grad=True)