In [None]:
import torch
import torch.nn as nn

from transformer.model import EncoderLayer
from transformer.visualization import show_heatmaps

In [None]:
import matplotlib.pyplot as plt
# From notebook
class LabelSmoothing(nn.Module):
    "Implement label smoothing."

    def __init__(self, size, padding_idx, smoothing=0.1):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(reduction="sum")
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None

    def forward(self, x, target):
        plt.figure(figsize=(10, 20))
        assert x.size(1) == self.size
        # print(f"X: {x.shape} Target: {target.shape}")
        
        plt.title("True dist X(clone)")
        true_dist = torch.ones_like(x) * (self.smoothing / (self.size - 2))
        plt.subplot(1, 6, 2)
        plt.imshow(true_dist)
        plt.title("True dist (fill)")
        
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        plt.subplot(1, 6, 3)
        plt.imshow(true_dist)
        plt.title("True dist (scatter)")

        true_dist[:, self.padding_idx] = 0
        plt.subplot(1, 6, 4)
        plt.imshow(true_dist)
        plt.title("True dist (padding)")
        mask = torch.nonzero(target.data == self.padding_idx)
        plt.subplot(1, 6, 5)
        plt.imshow(mask)
        # plt.colorbar()
        plt.title("Mask")
        print(f"Mask: {mask}")
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
            plt.subplot(1, 6, 6)
            plt.imshow(true_dist)
            plt.title("True dist (index fill)")
            plt.colorbar()
        self.true_dist = true_dist
        return self.criterion(x, self.true_dist.clone().detach())

In [None]:
criterion = LabelSmoothing(5, padding_idx=0, smoothing=0.4)

predict = torch.FloatTensor(
    [
        [0, 0.2, 0.7, 0.1, 0],
        [0, 0.2, 0.7, 0.1, 0],
        [0, 0.2, 0.7, 0.1, 0],
        [0, 0.2, 0.7, 0.1, 0],
        [0, 0.2, 0.7, 0.1, 0],
    ]
)

v = criterion(predict, torch.LongTensor([2, 1, 0, 3, 3]))
print(v)

In [None]:
print(criterion.true_dist)

In [None]:
target = torch.LongTensor([2,1,0,3,3])
padding_idx = 0
mask = torch.nonzero(target.data == padding_idx)
mask.dim()

In [None]:
import pandas as pd

def loss(x, crit):
    d = x + 3 * 1
    predict = torch.FloatTensor([[0.001, x / d, 1 / d, 1 / d, 1 / d]])
    return crit(predict.log(), torch.LongTensor([1])).data

crit = LabelSmoothing(5, 0, 0.1)
loss_data = pd.DataFrame(
        {
            "Loss": [loss(x, crit) for x in range(1, 100)],
            "Steps": list(range(99)),
        }
    ).astype("float")

loss_data.plot(x="Steps", y="Loss")

In [None]:
crit_ = nn.KLDivLoss(reduction="sum")
crit_(predict.log(), predict.clone().detach())

In [None]:
torch.normal(0, 1, (2, 3))

In [None]:
queries = torch.normal(0, 1, (2, 1, 2))
keys = torch.normal(0, 1, (2, 10, 2))
values = torch.normal(0, 1, (2, 10, 4))

attn = DotProductAttention(dropout_prob=0.0)
attn(queries, keys, values)[0]

In [None]:
queries @ keys.transpose(1,2)

In [None]:
torch.bmm(queries, keys.transpose(1,2)).shape

In [None]:
print(keys.transpose(-2, -1).shape)

In [None]:
attention(queries, keys, values)

In [None]:
show_heatmaps(attn.attention_weights.reshape((1,1,2,10)), xlabel="Keys", ylabel="Queries")

In [None]:
attn.attention_weights.shape

In [None]:
vocab_size = 100

d_model = 64
d_k = 16
n_heads = 4

WQ = nn.Linear(d_model, d_model)
WK = nn.Linear(d_model, d_model)
WV = nn.Linear(d_model, d_model)


emb = nn.Embedding(100, d_model)

In [None]:
 = torch.randint(0, vocab_size, (1, 10))
print(x.shape, x)

In [None]:
emb(x).shape

In [None]:
WQ.weight.shape

In [None]:
Q = WQ(emb(x))

In [None]:
Q.shape

In [None]:
print(q.shape)
n_batches = 1
n_heads = 4
d_k = 16

WQ(x).view(n_batches, -1, n_heads, d_k).transpose(1,2).shape

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder: nn.Module, decoder: nn.Module,
                 encoder_embedding, src_vocab_size: int, tgt_vocab_size: int) -> None:
        super().__init__()
        self.src_embedding = nn.Embedding(d_model, src_vocab_size)
        self.tgt_embedding = nn.Embedding(d_model, tgt_vocab_size)
        self.encoder = encoder
    
    def forward(self, src, tgt, src_mask, tgt_mask):
        return self.encode(src, src_mask)
    
    def encode(self, src, src_mask):
        x = self.src_embedding(src)
        return self.encoder(x, src_mask)
    
def make_model(src_vocab_size, tgt_vocab_size, n_blocks=2):
    model = EncoderDecoder(Encoder(EncoderLayer(), n_blocks), decoder=None)

In [None]:
src = torch.LongTensor([[1,2,3,4,5,6,7,8,9,10]])
src_mask = torch.ones((1,1,10))



In [None]:
import copy 

from transformer.model import EncoderLayer

class Encoder(nn.Module):
    def __init__(self, n_blocks: int, d_model: int, num_heads: int) -> None:
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads) for _ in range(n_blocks)])
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x, src_mask):
        for layer in self.layers:
            x = layer(x, src_mask)
        return self.norm(x)

In [None]:
d_model = 64 
Wq = nn.Linear(d_model, d_model)


In [None]:
src_vocab_size = 11
d_model = 64
emb = nn.Embedding(src_vocab_size, d_model)
src = torch.LongTensor([[1,2,3,4,5,6,7,8,9,10], [1,2,3,4,5,6,7,8,8,10]])
x = emb(src)
enc = Encoder(n_blocks=2, d_model=d_model, num_heads=4)
enc(x, src_mask=None).shape

In [None]:
embedding = nn.Embedding(1000,128)
embedding(torch.LongTensor([3,4])).shape


In [None]:
class PositionwiseFeedForward(nn.Module):
    """ Implements FFN eqn 2 """
    def __init__(self, d_model: int, d_ff:int, dropout_prob=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout_prob)
    
    def forward(self, x):
        return self.w_2(self.dropout(torch.relu(self.w_1(x))))

In [None]:
enc = EncoderLayer(d_model, None, None)

In [None]:
Wq = nn.Linear(d_model, d_model)
Wk = nn.Linear(d_model, d_model)
Wv = nn.Linear(d_model, d_model)
Wo = nn.Linear(d_model, d_model)

query = emb(src)

Q = Wq(query)

key = emb(src)
K = Wk(key)

value = emb(src)
V = Wv(value)

x_out, _ = attention(Q, K, V)


print("after attn: ", x_out.shape)
n_batches = query.shape[0]
x_out = x_out.transpose(1, 2).contiguous().view(n_batches, -1, d_model)
print(x_out.shape)
X_out = Wo(x_out)

print(X_out.shape)

In [None]:
x_out, _ = attention(Q, K, V)

In [None]:
x_out.transpose(1, 2).shape

In [None]:
x_out.transpose(1, 2).contiguous().view(n_batches, -1, d_model).shape


In [None]:
print(query.shape)
n_heads = 4

Q = Wq(query).view(n_batches, -1, n_heads, d_model // n_heads).transpose(1,2)
print(Wq(query).shape)

K = Wk(key).view(n_batches, -1, 4, 16).transpose(1,2)
V = Wv(value).view(n_batches, -1, 4, 16).transpose(1,2)


def attention( query, key, value, mask=None, dropout_prob=None):
        # query: (batch_size, n_queries, d_model)
        # key: (batch_size, n_keys, d_model)
        # value: (batch_size, n_keys, d_model)
        # ?? mask: (batch_size, n_queries, n_keys)
    d_k = query.size(-1)
    print(d_k)        
        #logger.info(f"query shape: {query.shape} d_k: {d_k}")

    scores = torch.matmul(query, key.transpose(-2, -1)) / d_k**0.5
    print(f"scores: {scores.shape}, key: {key.shape}, query: {query.shape}")
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    print(f"Attention shape: {p_attn.shape}" )
    if dropout_prob is not None:
        p_attn = nn.Dropout(dropout_prob)(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [None]:
x, probs = attention(Q, K, V)
print(x.shape)

y_ = x.transpose(1, 2).contiguous().view(n_batches, -1, d_model)
print(y_.shape)