In [1]:
## Standard libraries
import os
import numpy as np
import random
import math
import json
from functools import partial

## tqdm for loading bars
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

## Torchvision
import torchvision
from torchvision.datasets import CIFAR100
from torchvision import transforms

# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial6"

# Setting the seed
pl.seed_everything(42)


device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Global seed set to 42


Device: cpu


In [2]:
def scaled_dot_product(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        mask=None
):
    # q, k v shapes --> B, T, C
    d_k = q.size()[-1] # hidden dimensionality C
    # B, T, C @ B, C, T ---> B, T, T
    attn_logits = torch.matmul(q, k.transpose(-2, -1)) # output: tensor.shape(T, T)
    attn_logits = attn_logits ** d_k*-0.5 # math.sqrt(d_k) # scaling the dot product
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1) # (T, T)
    values = attention @ v  # (T, T) x (T, dk) -> (T, dk)

    return values, attention



In [3]:
B, T, C = 2, 3, 2
x = torch.randn(B, T, C)

# Project q, k, v vectors
qkv_proj = nn.Linear(C, 3*C)
qkv_stacked = qkv_proj(x)
print("QSV Stacked", qkv_stacked.shape)
q, k, v = qkv_stacked.chunk(chunks=3, dim=-1)

values, attention = scaled_dot_product(q, k, v)
print("Q\n", q.shape)
print("K\n", k.shape)
print("V\n", v.shape)
print("Values\n", values.shape)
print("Attention\n", attention.shape)

QSV Stacked torch.Size([2, 3, 6])
Q
 torch.Size([2, 3, 2])
K
 torch.Size([2, 3, 2])
V
 torch.Size([2, 3, 2])
Values
 torch.Size([2, 3, 2])
Attention
 torch.Size([2, 3, 3])


In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim: int, embed_dim: int, n_head: int):
        super(MultiHeadAttention, self).__init__()

        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.n_head = n_head
        self.head_dim = embed_dim // n_head

        self.proj_qkv = nn.Linear(input_dim, 3*embed_dim)
        self.proj_o = nn.Linear(embed_dim, embed_dim)

    def forward(self, x: torch.Tensor, mask = None):
        
        B, T, C = x.size()

        qkv = self.proj_qkv(x)  # (B, T, 3*embed_dim)
        qkv = qkv.reshape(B, T, self.n_head, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)  # (B, n_head, T, 3*head_dim)
        q, k, v = torch.chunk(input=qkv, chunks=3, dim=-1) # (B, n_head, T, head_dim)

        # attention
        values, _ = scaled_dot_product(q, k, v, mask)  # (B, n_head, T, head_dim)
        values = values.permute(0, 2, 1, 3) # (B, T, n_head, head_dim)
        values = values.reshape(B, T, self.embed_dim)  # (B, T, embed_dim)  
        out = self.proj_o(values) # (B, T, embed_dim)
        return out
        


In [9]:
B, T, C = 2, 8, 32
x = torch.randn(B, T, C)
print("input tensor:", x.shape)

mh = MultiHeadAttention(input_dim=C, embed_dim=72, n_head=6)
o = mh(x)
print("mh output tensor:", o.shape)

input tensor: torch.Size([2, 8, 32])
mh output tensor: torch.Size([2, 8, 72])
