In [9]:
import torch
import math
# Transformer
#  Input_id (BxS) -> embedding lookup + position encoding (BxSxd) -> encoder (BxSxd) -> decoder (BxTxd) -> liniear -> softmax (BXT)

# Encoder
#  - self attention
#      - Q=Q*qw, K=K*kw, V=V*vw
#      - attention_weight = softmax ( Q * K.T ) / sqrt(d) 
#      - attention_weight * V
#  - residual + normalization
#  - FF (linear ; Dense)
#  - residual + normalization

# Decoder
#  - self attention (Q,K,V == x)
#  - residual + normalization
#  - cross attention (Q = decoder x, K,V = encoder x)
#  - residual + normalization
#  - FF (linear ; Dense)
#  - residual + normalization


# assume singlehead
# B: batch size
# S: source length
# T: target length
#   (in case of self-attention, S==T)
# d: dimension
# numheads:

# input: K (BxSxd), Q (BxTxd), V (BxSxd), attn_mask (T x S; float -inf for masking)
# internally
#           B x numheads x S or T x d/numheads

# output:  BxTxd
# 
class Attention(torch.nn.Module):
    def __init__(self, dim, numheads, dropout):
        super().__init__()
        self.dim = dim
        self.Wq = torch.nn.Linear(self.dim,self.dim)
        self.Wk = torch.nn.Linear(self.dim,self.dim)
        self.Wv = torch.nn.Linear(self.dim,self.dim)
        self.Wo = torch.nn.Linear(self.dim,self.dim)
        self.numheads = numheads
        self.dropout = dropout
        self._attn_weight_norm = 1 / math.sqrt(self.dim/self.numheads)


    def forward(self, key, value, query, attn_mask=None):
        query = self.Wq(query)
        key = self.Wk(key)
        value = self.Wv(value)

        query = Attention.split_heads(query, self.numheads)
        key = Attention.split_heads(value, self.numheads)
        value = Attention.split_heads(value, self.numheads)

        attn_weight = query @ key.transpose(-2, -1) * self._attn_weight_norm
        if attn_mask is not None:
            attn_weight += attn_mask
        attn_weight = torch.softmax(attn_weight, dim=-1)
        attn_weight = torch.dropout(attn_weight, self.dropout if self.training else 0, train=self.training)

        output = attn_weight @ value        
        output = Attention.combine_heads(output)
        output = self.Wo(output)

        return output

    @staticmethod
    def split_heads(x, heads):
        x = x.reshape(x.shape[0], x.shape[1], heads, -1)
        x = x.transpose(1, 2)
        return x

    @staticmethod
    def combine_heads(x):
        x = x.transpose(1, 2)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        return x

In [181]:
dim=10
numheads=2
dropout=0.1
a=Attention(dim, numheads, dropout)

Batch=7
S=15
T=3
key=torch.randn(Batch,S,dim)
value=torch.randn(Batch,S,dim)
query=torch.randn(Batch,T,dim)

# a.forward(key,value,query).shape

print([p.shape for p in a.parameters()])


ref=torch.randn(Batch,T,dim)


loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(a.parameters(), betas=(0.9, 0.999), lr=0.001)

#training loop
for i in range(1):
    print('train iter', i)
    a.train()
    optim.zero_grad()

    # effective batch
    for input, ref in zip([(key, value, query)], [ref]):
        key, value, query = input
        pred = a(key,value,query)
        loss = loss_fn(pred, ref)

    loss.backward()
    optim.step()

    print("Loss = ", loss.item())


# test
a.eval()
with torch.no_grad():
    for input, ref in zip([(key, value, query)], [ref]):
        key, value, query = input
        pred = a(key,value,query)
        loss = loss_fn(pred, ref)
        print("Test loss = ", loss.item())



[torch.Size([10, 10]), torch.Size([10]), torch.Size([10, 10]), torch.Size([10]), torch.Size([10, 10]), torch.Size([10]), torch.Size([10, 10]), torch.Size([10])]
train iter 0
Loss =  0.19376671314239502
Test loss =  0.2021348476409912
