In [None]:
import torch
from typing import Optional
import torch.functional as F

def scaled_attentioon(
        q:  torch.Tensor, 
        k : torch.Tensor, 
        v: torch.Tensor, 
        mask: Optional[torch.Tensor] = None, 
        dropout: Optional[torch.Tensor] = None
):
    # Matrixes are in shape [batch, seq_len, d_k]
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2,-1)) / torch.sqrt(d_k)

    if mask:
        scores = scores.masked_fill(mask.bool(), float("-inf"))

    attn = F.softmax(scores, dim=1)
    if dropout is not None:
        attn = dropout(attn)
    output = torch.matmul(attn, v)
    
    return output, attn


In [3]:
import torch.nn as nn
class MultiHeadAttention(nn.Module):
    def __init__(self,embed_dim:int, num_head: int, dropout: float = 0.0):
        super().__init__()
        assert embed_dim % num_head == 0
        self.embed_dim = embed_dim
        self.num_heads = num_head
        self.head_dim =  embed_dim // self.num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout) if dropout > 0 else None

    def _shape(self, x: torch.Tensor):
        batch_size, seq_len, embed_dim = x.size()
        x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
        x = x.transpose(1,2)
        return x

    def forward(self, X: torch.Tensor):
        Q = self.q_proj(X)
        K = self.k_proj(X)
        V = self.v_proj(X)

        Q = self._shape(Q)
        K = self._shape(K)
        V = self._shape(V)

        scores = torch.matmul(Q, K.transpose(-2,-1)) / (self.head_dim ** 0.5)

        attention = torch.softmax(scores)

        if self.dropout is not None:
            attention = self.dropout(attention)
        
        context = torch.matmul(attention, V)
        context = torch.transpose(1, 2).contiguous().view(X.size(0), X.size(1), self.embed_dim)
        return self.out_proj(context)

        



In [None]:
from torch import TensorType


class SingleHeadAttention(nn.Module):
    
    def __init__(self, embedding_dim: int, attention_dim: int):
        super().__init__()
        torch.manual_seed(0)
        self.q_proj = nn.Linear(embedding_dim, attention_dim, bias=False)
        self.k_proj = nn.Linear(embedding_dim, attention_dim, bias=False)
        self.v_proj = nn.Linear(embedding_dim, attention_dim, bias=False)

    def scaled_dot_attention(self, Q, K, V):
        d_k = Q.size(-1)
        scores  = torch.matmul(Q, (K.transpose(-2, -1))) / (d_k ** 0.5)

        seq_len = scores.size(-1)
        mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
        mask = mask.unsqueeze(0)
        print(mask)
        scores = scores.masked_fill(mask == 0.0, float("-inf"))

        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        return output
        
    
    def forward(self, embedded: TensorType[float]) -> TensorType[float]:
        Q = self.q_proj(embedded)
        K = self.k_proj(embedded)
        V = self.v_proj(embedded)

        output = self.scaled_dot_attention(Q,K,V)
        return torch.round(output, decimals=4)