* this notebook contains the main takeaway, multihead-attention implementation 

In [2]:
!pip install torch



DEPRECATION: Loading egg at c:\python311\lib\site-packages\vboxapi-1.0-py3.11.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation.. Discussion can be found at https://github.com/pypa/pip/issues/12330

[notice] A new release of pip is available: 24.0 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [10]:
import tiktoken
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []

        # Tokenize the entire text
        token_ids = tokenizer.encode(txt, allowed_special={'<|endoftext|>'})

        # Use a sliding window to chunk the book into overlapping sequences of max_length
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]


def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):
    # Initialize the tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # Create dataset
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)

    # Create dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

    return dataloader


with open("C:\\Users\Mukund Agarwalla\\Desktop\\NLP\\Attention Mechanism\\text.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()

tokenizer = tiktoken.get_encoding("gpt2")
encoded_text = tokenizer.encode(raw_text)

vocab_size = 50257
output_dim = 256
max_len = 1024
context_length = max_len


token_embedding_layer = nn.Embedding(vocab_size, output_dim)
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)

max_length = 4
dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=max_length)

In [11]:
for batch in dataloader:
    x, y = batch

    token_embeddings = token_embedding_layer(x)
    pos_embeddings = pos_embedding_layer(torch.arange(max_length))

    input_embeddings = token_embeddings + pos_embeddings

    break

In [13]:
import math

class MultiheadAttentio(nn.Module):
    
    def __init__(self, d_model, h, dropout):
        super().__init__()
        self.d_model = d_model
        self.h = h
        self.w_k = nn.Linear(self.d_model, self.d_model, bias = False)
        self.w_q = nn.Linear(self.d_model, self.d_model, bias = False)
        self.w_v = nn.Linear(self.d_model, self.d_model, bias = False)
        self.w_o = nn.Linear(self.d_model, self.d_model, bias = False)
        self.d_k = self.d_model // self.h
        self.dropout = nn.Dropout(dropout)


    def forward(self,x):
        key = self.w_k(x)
        query = self.w_q(x)
        value = self.w_v(x)

        key = key.view(x.shape[0], x.shape[1], self.h, self.d_k).transpose(1,2)
        value = value.view(x.shape[0], x.shape[1], self.h, self.d_k).transpose(1,2)
        query = query.view(x.shape[0], x.shape[1], self.h, self.d_k).transpose(1,2)

        attention_scores = query @ key.transpose(-2, -1) // math.sqrt(self.d_k)

        mask_bool = self.mask.bool()[:x.shape[1], x.shape[1]]

        attention_scores.masked_fill_(mask_bool, -torch.inf)

        attention_scores = attention_scores.softmax(dim = -1)

        attention_scores = self.dropout(attention_scores)

        context_vec = (attention_scores @ value).transpose(1,2)

        context_vec = context_vec.contiguous().view(x.shape[0], -1, self.d_model)
        context_vec = self.w_v(context_vec)
        return context_vec



torch.manual_seed(123)

context_length = max_length
d_in = output_dim
d_out = d_in

mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

batch = input_embeddings
context_vecs = mha(batch)

print("context_vecs.shape:", context_vecs.shape)

In [15]:

torch.manual_seed(123)

context_length = max_length
d_in = output_dim
d_out = d_in

mha = MultiheadAttentio(d_out, 2, 0.0)

batch = input_embeddings
context_vecs = mha(batch)

print("context_vecs.shape:", context_vecs.shape)

AttributeError: 'MultiheadAttentio' object has no attribute 'mask'