In [1]:
## Standard libraries
import os
import numpy as np
import random
import math
import json
from functools import partial

## tqdm for loading bars
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

## Torchvision
import torchvision
from torchvision.datasets import CIFAR100
from torchvision import transforms

# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial6"

# Setting the seed
pl.seed_everything(42)


device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Global seed set to 42


Device: cpu


In [15]:
def scaled_dot_product(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        mask=None
):
    # q, k v shapes --> B, T, C
    d_k = q.size()[-1] # hidden dimensionality C
    # B, T, C @ B, C, T ---> B, T, T
    attn_logits = torch.matmul(q, k.transpose(-2, -1)) # output: tensor.shape(T, T)
    attn_logits = attn_logits ** d_k*-0.5 # math.sqrt(d_k) # scaling the dot product
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1) # (T, T)
    values = attention @ v  # (T, T) x (T, dk) -> (T, dk)

    return values, attention



In [13]:
B, T, C = 2, 3, 2
q = torch.randn(B, T, C)
k = torch.randn(B, T, C)
v = torch.randn(B, T, C)
values, attention = scaled_dot_product(q, k, v)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("Values\n", values)
print("Attention\n", attention)

Q
 tensor([[[ 0.6388,  1.4924],
         [-2.0636,  0.7094],
         [ 1.2972,  0.1024]],

        [[-0.3947,  1.7133],
         [-1.0739, -0.2111],
         [ 1.4848,  0.4541]]])
K
 tensor([[[-0.2800,  2.5416],
         [-1.0840,  0.7681],
         [ 0.0133,  0.4032]],

        [[-1.0955,  0.0551],
         [-1.5506,  1.4060],
         [-0.9186,  2.5174]]])
V
 tensor([[[ 0.5408,  1.0435],
         [ 1.2338,  0.1211],
         [ 0.3465, -0.9395]],

        [[-0.1556,  0.7459],
         [-0.0947, -0.0220],
         [-0.5177,  1.4805]]])
Values
 tensor([[[ 0.8084, -0.3860],
         [ 0.3751, -0.8070],
         [ 0.5795,  0.0624]],

        [[-0.1549,  0.7369],
         [-0.3236,  0.9465],
         [-0.3805,  1.0947]]])
Attention
 tensor([[[8.4072e-04, 5.2034e-01, 4.7882e-01],
         [5.6141e-02, 1.9939e-02, 9.2392e-01],
         [4.1320e-01, 1.7211e-01, 4.1469e-01]],

        [[9.8814e-01, 1.1844e-02, 2.0326e-05],
         [2.8173e-01, 2.1768e-01, 5.0059e-01],
         [1.8443e-01, 1

In [27]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim: int, embed_dim: int, n_head: int):
        super(MultiHeadAttention, self).__init__()

        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.n_head = n_head
        self.head_dim = embed_dim // n_head

        self.proj_qkv = nn.Linear(input_dim, 3*embed_dim)
        self.proj_o = nn.Linear(embed_dim, embed_dim)


    def forward(self, x: torch.Tensor, mask = None):

        B, T, C = x.size()
        qkv = self.proj_qkv(x)  # (B, T, 3*embed_dim)
        qkv = qkv.reshape(B, T, self.n_head, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = torch.chunk(input=qkv, chunks=3, dim=-1) # (B, n_head, T, head_dim) * 3
        # Attention
        values, attention = scaled_dot_product(q, k, v, mask)  # (B, n_head, T, head_dim)
        values = values.permute(0, 2, 1, 3)
        values = values.reshape(B, T, self.embed_dim)
        out = self.proj_o(values)
        return out, attention

In [28]:
B, T, C = 2, 16, 32
x = torch.randn(B, T, C)

mh = MultiHeadAttention(input_dim=C, embed_dim=64, n_head=8)
o, attn = mh(x)
o.shape

torch.Size([16, 64])

In [35]:
class EncoderBlock(nn.Module):

    def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0):
        """
        Inputs:
            input_dim - Dimensionality of the input
            num_heads - Number of heads to use in the attention block
            dim_feedforward - Dimensionality of the hidden layer in the MLP
            dropout - Dropout probability to use in the dropout layers
        """
        super().__init__()

        # Attention layer
        self.self_attn = MultiHeadAttention(input_dim, input_dim, num_heads)

        # Two-layer MLP
        self.linear_net = nn.Sequential(
            nn.Linear(input_dim, dim_feedforward),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, input_dim)
        )

        # Layers to apply in between the main layers
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Attention part
        attn_out = self.self_attn(x, mask=mask)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)

        # MLP part
        linear_out = self.linear_net(x)
        x = x + self.dropout(linear_out)
        x = self.norm2(x)

        return x

In [37]:
encoder = EncoderBlock(input_dim=C, num_heads=8, dim_feedforward=64*4)

encoder(x).shape

torch.Size([2, 16, 32])

In [31]:
class MultiHeadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1..h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim, bias=False)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        nn.init.xavier_uniform_(self.o_proj.weight)
        # self.o_proj.bias.data.fill(0)

    def forward(self, x, mask=None, return_attention=False):
        B, T, C = x.size()
        qkv = self.qkv_proj(x) # B, T, 3*EmbedDim

        # Separate Q, K, V from linear output
        qkv = qkv.reshape(B, T, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [B, hn, T, C]
        q, k, v = qkv.chunk(3, dim=-1)

        # Determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0, 2, 1, 3) # [B, T, Head, C]
        values = values.reshape(B, T, self.embed_dim)
        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o




In [33]:
# Create an example input tensor
x = torch.randn(2, 4, 6)  # (B, T, C)
mha = MultiHeadAttention(input_dim=6, embed_dim=12, num_heads=3)
output = mha(x)
output.shape

torch.Size([2, 4, 12])

In [39]:
m = nn.Linear(6, 12*3)
o = m(x)
o.shape

torch.Size([2, 4, 36])

In [42]:
o.reshape(2, 4, 3, 3*4).shape

torch.Size([2, 4, 3, 12])