<a href="https://colab.research.google.com/github/eisbetterthanpi/transformer/blob/main/gpt_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title hf stream dataset me
!pip install -qU datasets # restart?
from datasets import load_dataset
import torch
from torch.utils.data import Dataset
import tiktoken # https://github.com/openai/tiktoken/tree/main

class StreamDataset(Dataset):
    def __init__(self, dataset, seq_len, buffer_size):
        self.enc = tiktoken.get_encoding("gpt2") # https://github.com/openai/tiktoken/blob/main/tiktoken/core.py
        self.vocab_size = self.enc.n_vocab # gpt2:50257
        self.dataset = dataset
        self.data = iter(dataset)
        self.seq_len = seq_len
        self.buffer_size = buffer_size  # must be ≥ seq_len
        self.buffer = []  # token buffer
        self.fill_buffer()

    def fill_buffer(self):
        while len(self.buffer) < self.buffer_size:
            x = next(self.data)
            tokens = self.enc.encode(x["text"]) # tiktoken
            self.buffer.extend(tokens)

    def __len__(self):
        # /4.5/(4/3)
        return 128000000
        # return self.length

    def __getitem__(self, idx):
        # print('get', idx)
        if idx == 0: self.data = iter(self.dataset)
        if len(self.buffer) < self.seq_len: self.fill_buffer()
        if len(self.buffer) < self.seq_len:
            raise StopIteration
        x = self.buffer[:self.seq_len]
        self.buffer = self.buffer[self.seq_len:]
        # return torch.tensor(x)
        return torch.tensor(x, dtype=torch.int32)

def collate_fn(batch):
    # print(batch)
    return torch.stack(batch)

name = 'Skylion007/openwebtext' if torch.cuda.is_available() else 'stas/openwebtext-10k'

dataset = load_dataset(name, trust_remote_code=True, split="train", streaming=True, cache_dir="/content/hf") # 8.7,3.8
# dataset = load_dataset("Skylion007/openwebtext", trust_remote_code=True, split="train", streaming=True, cache_dir="/content/hf") # 8.7,3.8
# dataset = load_dataset("deepmind/pg19", trust_remote_code=True, split="train", streaming=True, cache_dir="/content/hf") # 8.7,3.8

seq_len = 128*1 # 128
buffer_size = seq_len*1
train_data = StreamDataset(dataset, seq_len, buffer_size) # train_data = StreamDataset(dataset["train"], seq_len, buffer_size)
# del dataset

from torch.utils.data.dataloader import DataLoader
batch_size = 64 if torch.cuda.is_available() else 16 #64 512
train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, pin_memory=True, num_workers=2)
del train_data
def encode(context):
    if type(context) == str: return torch.tensor([train_loader.dataset.enc.encode(context)], device=device)
    elif type(context) == list: return train_loader.dataset.enc.encode_batch(context)
    else: raise Exception
def decode(x): return train_loader.dataset.enc.decode(list(x))
# for x,y in train_loader:
#     break
# print(train_data.vocab_size)


In [None]:
seq_len = 128*2 # 128
buffer_size = seq_len*1
train_data = StreamDataset(dataset, seq_len, buffer_size) # train_data = StreamDataset(dataset["train"], seq_len, buffer_size)
# del dataset

from torch.utils.data.dataloader import DataLoader
batch_size = 64 #64 512
train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, pin_memory=True, num_workers=2)
del train_data
def encode(context):
    if type(context) == str: return torch.tensor([train_loader.dataset.enc.encode(context)], device=device)
    elif type(context) == list: return train_loader.dataset.enc.encode_batch(context)
    else: raise Exception
def decode(x): return train_loader.dataset.enc.decode(list(x))


In [None]:
# @title RoPE
import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'

class RoPE(nn.Module): # Rotary Positional Embeddings
    def __init__(self, dim, seq_len=512, base=10000):
        super().__init__()
        theta = 1.0 / (base ** (torch.arange(0, dim, step=2) / dim))
        pos = torch.arange(seq_len).unsqueeze(1)
        angles = (pos * theta).unsqueeze(-1) # [seq_len, 1] * [dim // 2] -> [seq_len, dim // 2, 1]
        self.rot_emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1).to(device) # [seq_len, dim // 2, 2]

    def forward(self, x):
        batch, seq_len, dim = x.shape
        # if rot_emb.shape[0] < seq_len: self.__init__(dim, seq_len)
        rot_emb = self.rot_emb[:seq_len].unsqueeze(0).expand(batch, -1, -1, -1) # [batch, seq_len, dim//2, 2]
        x = x.reshape(batch, seq_len, dim // 2, 2)
        rot_x = x * rot_emb
        return rot_x.flatten(-2)

dim=16
seq_len=512
rope = RoPE(dim, seq_len, base=10000)
x = torch.rand(4,64,dim, device=device)
out = rope(x)

print(out.shape)


In [None]:
# @title TTLinear
# Tensor Train embedding https://arxiv.org/pdf/1901.10787
import torch
import torch.nn as nn
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def make_einsum(num_tensors):
    a = 97
    R = chr(a+25) # 'z'
    lhs = [chr(a)+R]
    for i in range(1, num_tensors-1):lhs.append(R+chr(a+i)+R)
    lhs.append(R+chr(a+num_tensors-1))
    return ','.join(lhs) + '->' + ''.join([chr(a+i) for i in range(num_tensors)]) # az,zbz,zcz,zd->abcd

class TTLinear(nn.Module):
    def __init__(self, in_features=None, out_features=None, rank=256, std=1):
        super().__init__()
        self.lfeat = len(in_features)
        if self.lfeat==1: lst = in_features + out_features
        elif self.lfeat>=2: lst = [i*j for i, j in zip(in_features, out_features)]
        last = len(lst)
        var = last/rank**(1/(2*(std**.5)*last))
        c=1/last
        self.params = nn.ParameterList([nn.Parameter(torch.randn(lst[0], rank).clamp(-c,c)*var),
            *[nn.Parameter(torch.randn(rank, ij, rank).clamp(-c,c)*var) for ij in lst[1:-1]],
            nn.Parameter(torch.randn(rank, lst[-1]).clamp(-c,c)*var)])
        self.einsum_str = make_einsum(last)
        self.shape = [p for ij in zip(in_features, out_features) for p in ij]
        self.permute = list(range(0, 2*self.lfeat - 1, 2)) + list(range(1, 2*self.lfeat, 2))
    def weight(self): return torch.einsum(self.einsum_str, *self.params).reshape(self.shape).permute(self.permute).flatten(0,self.lfeat-1).flatten(1)

    def forward(self, x):
        weight = self.weight()
        return x.to(weight.dtype) @ weight

def one_hot(x, in_dim):
    return torch.zeros((*x.shape,in_dim), dtype=torch.int8, device=x.device).scatter_(-1, x.unsqueeze(-1).to(int), 1)

def one_hot(x, in_dim):
    b,t = x.shape
    o = torch.zeros((b,t,in_dim), dtype=bool, device=x.device)
    o[torch.arange(b).unsqueeze(-1),torch.arange(t).unsqueeze(0),x] = True
    return o

import math
class TTEmbedding(nn.Module):
    def __init__(self, in_dim, d_model, rank=256, std=1):
        super().__init__()
        self.ttlin = TTLinear(in_dim, d_model, rank, std) # https://docs.pytorch.org/docs/stable/generated/torch.nn.Embedding.html
        self.weight = self.ttlin.weight
        self.num_classes = math.prod(in_dim)

    def forward(self, x):
        # return self.ttlin(F.one_hot(x, self.num_classes))
        return self.ttlin(one_hot(x, self.num_classes))
# self.out = lambda x: x @ self.tok_emb.weight().T  # weight tying

# # in_features=(3,4,5,6); out_features=(2,3,4,5)
# in_features=[120]; out_features=[300]
# in_features=[12]; out_features=[30]
# rank=16
# # std=.5
# lin = TTLinear(in_features, out_features, rank, std).to(device)
# # x = torch.rand(4,math.prod((3,4,5,6)))
# x = torch.rand(4,7,math.prod(in_features), device=device)
# print(lin.params[0].device)
# out = lin(x)
# print(out.shape)
# print(lin.ttlin.params[0].device)

# emb = TTEmbedding(in_features, out_features, rank).to(device)
# x = torch.randint(0, math.prod(in_features), (2, 5), device=device)
# out = emb(x)
# print(out.shape)
# print(out)

# o=lin.weight
# print(o.mean().item(), o.std().item(), o.min().item(), o.max().item())

# import matplotlib.pyplot as plt
# plt.rcParams["figure.figsize"] = (4,4)
# # plt.hist(o.flatten().tolist(), bins=20, alpha=.5, label='context mask')
# # plt.hist(o[:100,:100].flatten().tolist(), bins=20, alpha=.5, label='context mask')
# x = torch.randn(100,100)*std
# # plt.hist(x.flatten().tolist(), bins=20, alpha=.5, label='context mask')
# plt.hist([o[:100,:100].flatten().tolist(), x.flatten().tolist()], bins=20, alpha=.5, label='context mask')
# plt.show()


In [None]:
# @title sliding window Attention as_strided
import torch
import torch.nn as nn
from torch.nn import functional as F

def zero_module(module):
    for p in module.parameters():
        p.detach().zero_()
    return module

class Attention(nn.Module):
    # def __init__(self, d_model, cond_dim=None, n_heads=None, d_head=8, drop=0.): # .1
    def __init__(self, query_dim, cond_dim=None, n_heads=8, d_head=8, drop=0, w=64):
        super().__init__()
        self.d_model = d_model = d_head * n_heads
        self.d_head, self.n_heads = d_head, n_heads
        self.cond_dim = cond_dim
        self.pos_enc = RoPE(d_model, base=100) # 10000
        self.q = nn.Linear(query_dim, d_model, bias=False)
        self.kv = nn.Linear(cond_dim or query_dim, 2*d_model, bias=False)
        self.lin = zero_module(nn.Linear(d_model, d_model))
        self.drop = nn.Dropout(drop) # indp before q,k,v; after linout
        self.scale = self.d_head**-.5
        # torch.nn.init.normal_(self.qkv.weight, std=.02)
        # torch.nn.init.normal_(self.q.weight, std=1/(math.sqrt(query_dim)+math.sqrt(d_model)))
        # torch.nn.init.normal_(self.kv.weight, std=1/(math.sqrt(cond_dim or query_dim)+math.sqrt(d_model)))
        self.w = w

    def forward(self, x, cond=None, mask=None): # [b,t,d], [batch, num_tok, cond_dim], [b,t,t(+p)]
        b,t = x.shape[:2]
        if self.cond_dim==None: cond=x # is self attn
        # q = self.q(x).unflatten(-1, (self.n_heads, self.d_head)).transpose(1, 2) # [batch, T, d_model] -> [batch, n_heads, T, d_head]
        # kv = self.kv(cond).unflatten(-1, (self.n_heads, 2*self.d_head)).transpose(1, 2)#.chunk(2, dim=-1) # [batch, n_heads, T/num_tok, d_head]

        q = self.pos_enc(self.q(x)).unflatten(-1, (self.n_heads, self.d_head)).transpose(1, 2) # [batch, T, d_model] -> [batch, n_heads, T, d_head]
        k, v = self.kv(cond).chunk(2, dim=-1)
        kv = torch.cat([self.pos_enc(k),v], dim=-1).unflatten(-1, (self.n_heads, 2*self.d_head)).transpose(1, 2) # [batch, n_heads, T/num_tok, d_head]

        kv = F.pad(kv, (0,0,self.w-1,0)) # [b, h, t+w-1, d]
        # kv = kv.as_strided((b,self.n_heads,t,self.w,2*self.d_head), kv.stride()[:-1] + kv.stride()[-2:]) # [b,h,t,w,d] # repeat stride at w's dim
        kv = kv.unfold(dimension=-2, size=self.w, step=1).transpose(-2,-1)
        # print('attn fwd kv', kv.shape)

        mk, mv = kv.chunk(2, dim=-1) # [b,h,t,w,d]

        attn = torch.einsum("bhtd,bhtwd->bhtw", q, mk) * self.scale
        # print('attn fwd q mk attn', q.dtype, mk.dtype, attn.dtype)

        # if mask != None:
        #     # attn = attn.masked_fill(~mmask.unsqueeze(1), -torch.finfo(attn.dtype).max) # [b,t,t]->[b,1,t,t]
        #     attn = attn.masked_fill(~mmask[:,None,:,None], -torch.finfo(attn.dtype).max) # [b,t,t]->[b,1,t,1,t]
        attention = torch.softmax(attn, dim=-1) # [b,h,t,1,w]
        # out = (self.drop(attention) @ mv).squeeze(-2) # [b,h,t,1,w]@[b,h,t,w,d]=[b,h,t,1,d]
        out = torch.einsum("bhtw,bhtwd->bhtd", self.drop(attention), mv)

        out = out.transpose(1, 2).flatten(2)
        return self.drop(self.lin(out)) # [batch, T, d_model]

# typically: gen td*td=tt
# gen t1d*twd=tw, patch pd*tpd=p

b,t,d = 64,100,512
# mask = torch.rand(b,t,t, device=device)<.2

# causal_mask = torch.tril(torch.ones((b,t,x.shape[1]), dtype=bool, device=device)) # [1,t,n_cond+n_patch] # got cond, all can attend to cond
# # mask = causal_mask * mask.unsqueeze(1) if mask!=None else causal_mask # *[b,1,t] # mask left side, tril is lower left
# mask = causal_mask * ~torch.tril(torch.ones_like(causal_mask, dtype=bool, device=device), diagonal=-64) # sliding window mask


msk = torch.rand(b,t,t)
w = 3
_, ind = torch.topk(msk, w, dim=-1, sorted=False)

# # print(ind.shape, ind)
mask = torch.zeros_like(msk, dtype=bool)
mask[torch.arange(b)[:,None,None], torch.arange(t)[None,:,None], ind] = True


# model = Attention(d).to(device) # 257 ms
model = Attention(d, w=3).to(device) # 257 ms
x = torch.rand(b,t,d, device=device)

import time
start = time.time()
# out = model(x)
out = model(x, mask=mask)
print(out.shape)
print(time.time()-start)

# midx 0.04
# swa no mask ie bht1d 8e-5
# unfold .0005

# unfold 23 + error? 6.8,7.6
# as strided 24.518250703811646 # 6.8,8.8

# unfold
# this is what they just improve thesan that the characteristic is huge behavior down the artists by our treative origins of delivery real specialist is instant connected to- (which is never symptoms, or limited nationals when established surgery for developers to help for apps clearlyosc ripple shows without the choice of documentary about them by growing opportunities to work out about
# 12300 time: 23.157159328460693 0.0018825428862463565
# strain 6.319660663604736
# this is what is not closer further is in technology should not be about desk, such firmware our comic situation it haseeper moment on that is never for duty to deal with real- weaknesses.

# “The difference among Jitt natural code super departure testing is not seem to be loaded devices to allow information and then such as �
# 12400 time: 23.128992080688477 0.0018650909247258797
# strain 6.408341884613037
# this is what these moreilitating our light and love your benefit.

# What is from all in imaryichn, which are more than your regular factors can see both takes aHTp hydro Beyonees include, intoEngine frequently potential attackers to genBOOK to own up the ha portrait higher and promptingieine itemiques athlete.



In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
b,h,t,d = 2,3,4,5
# q = torch.rand(b,h,t,d, device=device)
# k = torch.rand(b,h,t,d, device=device)
# v = torch.rand(b,h,t,d, device=device)
# mask = torch.rand(b,t,t, device=device)>0.5
print(mask)

out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask.unsqueeze(1) if mask != None else None, dropout_p=0) # mask: [batch,len_q, len_v]
# out = F.scaled_dot_product_attention(q, k, v, dropout_p=0) # mask: [batch,len_q, len_v]
# out = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=0) # mask: [batch,len_q, len_v]
print(out)
scale = d**-0.5
attn = q @ k.transpose(-2,-1) * scale # [batch, n_heads, T] # [batch, n_heads, T, T/num_tok]
# if mask != None: attn = attn.masked_fill(mask[:, None, :, None], -torch.finfo(attn.dtype).max) # [batch,T]->[batch,1,T,1]
# print(attn.shape, mask.shape)
if mask != None: attn = attn.masked_fill(~mask.unsqueeze(1), -torch.finfo(attn.dtype).max) # [b,t,t]->[b,1,t,t]
# print(attn, mask)
print(attn)
# if mask != None: attn = attn.masked_fill(mask.unsqueeze(1), 0) # [b,t,t]->[b,1,t,t]
attention = torch.softmax(attn, dim=-1)
# out = self.drop(attention) @ v # [batch, n_heads, T, d_head]
out = attention @ v # [batch, n_heads, T, d_head]
print(out)

# [0.5788, 0.1438, 0.5807, 0.4964, 0.9362],
# [0.5788, 0.1438, 0.5807, 0.4964, 0.9362],
# [0.7777, 0.0903, 0.5304, 0.6868, 0.6610],
# [0.6023, 0.4530, 0.6115, 0.6318, 0.7990]

# [0.7057, 0.3609, 0.6557, 0.6586, 0.6254],
# [0.6914, 0.3631, 0.6448, 0.6534, 0.6553],
# [0.7140, 0.3456, 0.6367, 0.6716, 0.6285],
# [0.7121, 0.3514, 0.6353, 0.6733, 0.6311]

In [None]:
# print(attn[0])
attention = torch.softmax(attn, dim=-1)
print(attention)

In [None]:
# @title Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def zero_module(module):
    for p in module.parameters():
        p.detach().zero_()
    return module

class MultiHeadAttention(nn.Module):
    # def __init__(self, d_model, n_heads=None, d_head=8, cond_dim=None, drop=0.): # .1
    def __init__(self, query_dim, cond_dim=None, n_heads=8, d_head=64, drop=0):
        super().__init__()
        # self.d_model, self.n_heads, self.d_head = d_model, n_heads, d_model // n_heads
        d_model = d_head * n_heads
        self.d_head, self.n_heads = d_head, n_heads
        self.cond_dim = cond_dim
        self.pos_enc = RoPE(d_model, base=10000) # 10000
        self.q = nn.Linear(query_dim, d_model, bias=False)
        self.kv = nn.Linear(cond_dim or d_model, 2*d_model, bias=False)
        self.lin = zero_module(nn.Linear(d_model, d_model))
        self.drop = nn.Dropout(drop) # indp before q,k,v; after linout
        self.scale = self.d_head**-.5
        # torch.nn.init.normal_(self.q.weight, std=.02)
        # torch.nn.init.normal_(self.kv.weight, std=.02)
        # torch.nn.init.normal_(self.q.weight, std=1/(math.sqrt(query_dim)+math.sqrt(d_model)))
        # torch.nn.init.normal_(self.kv.weight, std=1/(math.sqrt(cond_dim or query_dim)+math.sqrt(d_model)))

    def forward(self, x, cond=None, mask=None): # [batch, T, d_model]=[batch, h*w, c], [batch, num_tok, cond_dim], [batch,T]
        if self.cond_dim==None: cond=x # is self attn
        # q = self.q(x).unflatten(-1, (self.n_heads, self.d_head)).transpose(1, 2) # [batch, T, d_model] -> [batch, n_heads, T, d_head]
        # # K = self.k(x).unflatten(-1, (self.n_heads, self.d_head)).transpose(1, 2)
        # k, v = self.kv(cond).unflatten(-1, (self.n_heads, 2*self.d_head)).transpose(1, 2).chunk(2, dim=-1) # [batch, n_heads, T/num_tok, d_head]

        q = self.pos_enc(self.q(x)).unflatten(-1, (self.n_heads, self.d_head)).transpose(1, 2) # [batch, T, d_model] -> [batch, n_heads, T, d_head]
        k, v = self.kv(cond).chunk(2, dim=-1)
        k = self.pos_enc(k).unflatten(-1, (self.n_heads, self.d_head)).transpose(1, 2)
        v = v.unflatten(-1, (self.n_heads, self.d_head)).transpose(1, 2)

        # # (quadratic) attention # Softmax(q @ k.T) @ v
        # out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask.unsqueeze(1) if mask != None else None, dropout_p=0) # mask: [batch,len_q, len_v]
        # # out = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=0) # mask: [batch,len_q, len_v]
        attn = q @ k.transpose(-2,-1) * self.scale # [batch, n_heads, T] # [batch, n_heads, T, T/num_tok]
        if mask != None: attn = attn.masked_fill(~mask.unsqueeze(1), -torch.finfo(attn.dtype).max) # [b,t,t]->[b,1,t,t]
        attention = torch.softmax(attn, dim=-1)
        out = self.drop(attention) @ v # [batch, n_heads, T, d_head]

        out = out.transpose(1,2).flatten(2)
        return self.drop(self.lin(out)) # [batch, T, d_model]


class SwiGLU(nn.Module): # https://arxiv.org/pdf/2002.05202
    def __init__(self, d_model, ff_dim): # d_model * 3*ff_dim params
        super().__init__()
        self.lin0 = nn.Linear(d_model, 2*ff_dim, bias=False)
        self.lin1 = zero_module(nn.Linear(ff_dim, d_model, bias=False))
        # torch.nn.init.normal_(self.lin0.weight, std=.02)
        torch.nn.init.normal_(self.lin0.weight, std=1/(math.sqrt(d_model)+math.sqrt(ff_dim)))

    def forward(self, x): # [b,t,d]
        x0, x1 = self.lin0(x).chunk(2, dim=-1)
        return self.lin1(x0*F.silu(x1))

# 2048*2
# 2048*7
# ff: d_model*2 *ff_dim params

class AttentionBlock(nn.Module):
    # def __init__(self, d_model, cond_dim=None, d_head, ff_dim=None, drop=0.):
    def __init__(self, d_model, n_heads ,cond_dim=None, ff_dim=None, drop=0):
        super().__init__()
        self.d_model = d_model
        self.cond_dim = cond_dim
        self.norm1 = nn.RMSNorm(d_model) # LayerNorm RMSNorm
        if cond_dim!=None: self.norm2 = nn.RMSNorm(cond_dim)
        # self.drop = nn.Dropout(drop)
        self.attn = MultiHeadAttention(d_model, cond_dim, n_heads=n_heads, d_head=d_model//n_heads, drop=drop)
        # self.attn = Attention(d_model, cond_dim, n_heads=n_heads, d_head=d_model//n_heads, drop=drop, w=64)
        act = nn.ReLU()
        if ff_dim==None: ff_dim=d_model*4
        self.ff = nn.Sequential(
            nn.RMSNorm(d_model), nn.Linear(d_model, ff_dim), act,
            nn.RMSNorm(ff_dim), nn.Dropout(drop), zero_module(nn.Linear(ff_dim, d_model))
            # nn.RMSNorm(d_model), act, nn.Linear(d_model, ff_dim),
            # nn.RMSNorm(ff_dim), act, zero_module(nn.Linear(ff_dim, d_model))
        )
        # torch.nn.init.normal_(self.ff[1].weight, std=.02)
        torch.nn.init.normal_(self.ff[1].weight, std=1/(math.sqrt(d_model)+math.sqrt(ff_dim)))
        # self.ff = SwiGLU(d_model, ff_dim)

    def forward(self, x, cond=None, mask=None): # [b,c,h,w], [batch, num_tok, cond_dim], [batch,T]
        # print('attblk fwd', x.shape, cond.shape if cond!=None else None, mask.shape if mask!=None else None)
        if self.cond_dim==None: x = x + self.attn(self.norm1(x), mask=mask)
        else: x = x + self.attn(self.norm1(x), self.norm2(cond), mask) # maybe no res for decoder
        x = x + self.ff(x) # maybe no ff for decoder?
        return x

import inspect
class Seq(nn.Sequential):
    def __init__(self, *args):
        super().__init__(*args)
        for layer in self:
            params = inspect.signature(layer.forward).parameters.keys()
            layer._fwdparams = ','.join(params)

    def forward(self, x, cond=None, mask=None):
        for layer in self:
            args = [x]
            if 'cond' in layer._fwdparams: args.append(cond)
            if 'mask' in layer._fwdparams: args.append(mask)
            x = layer(*args)
        return x

b,t,d = 2,5,16
x = torch.rand(b,t,d, device=device)
mask = torch.rand(b,t,t, device=device)>0
model = AttentionBlock(d_model=d, n_heads=4, ff_dim=16).to(device)
# model = nn.Sequential(*[AttentionBlock(d_model=d, n_heads=4, ff_dim=16) for _ in range(2)])
out =  model(x, mask)
# out =  model(x)
print(out.shape)


# # model = MultiHeadAttention(d).to(device) # 257 ms
# x = torch.rand(b,t,d, device=device)

# import time
# start = time.time()
# # out = model(x)
# out = model(x, mask=mask)
# print(out.shape)
# print(time.time()-start)



In [None]:
# @title GPT
import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'

class GPT(nn.Module):
    def __init__(self, in_dim, d_model=64, out_dim=None, n_heads=8, n_layers=1, ff_dim=256, dropout=0):
        super().__init__()
        self.d_model = d_model
        self.pos_enc = RoPE(d_model, base=10000)
        # self.tok_emb = nn.Embedding(in_dim, d_model)
        self.tok_emb = TTEmbedding([29, 1733], [64,1], rank=min(d_model,256))
        self.encoder = Seq(*[AttentionBlock(d_model, n_heads=n_heads, cond_dim=d_model) for _ in range(n_layers)])
        # self.out = nn.Linear(d_model, out_dim)
        self.out = lambda x: x @ self.tok_emb.weight().T  # weight tying
        self.w = 64

    def forward(self, x, mask=None): # [b,t], [b,t,t]
        # x = self.pos_enc(self.tok_emb(x))
        x = self.tok_emb(x)
        b,t,d = x.shape

        causal_mask = torch.tril(torch.ones((b,t,x.shape[1]), dtype=bool, device=device)) # [1,t,n_cond+n_patch] # got cond, all can attend to cond
        mask = causal_mask * mask.unsqueeze(1) if mask!=None else causal_mask # *[b,1,t] # mask left side, tril is lower left

        x = self.encoder(x, x, mask=mask) # [b,t,d], [nlyr,b,d]
        x = self.out(x)
        return x

# gpt 2
# Parameters Layers dmodel
# 117M 12 768 gpt1
# 345M 24 1024
# 762M 36 1280
# 1542M 48 1600


try: vocab_size=train_loader.dataset.vocab_size#50
except NameError: vocab_size=50
# model = GPT(input_size, d_model=512, out_dim=num_classes, n_layers=6).to(device)
# model = GPT(vocab_size, d_model=64, out_dim=vocab_size, n_layers=3).to(device)
model = GPT(vocab_size, d_model=64, out_dim=vocab_size, n_layers=1).to(device)
print(sum(p.numel() for p in model.parameters() if p.requires_grad)) # 19683
optim = torch.optim.AdamW(model.parameters(), 1e-3)

# x = torch.randint(0, input_size, (2, 5), device=device)
# x = torch.randint(0, input_size, (64, 128), device=device)
# out = model(x)
# print(out.shape)
# print(out)

# strain 5.660660743713379
# this is what we've done it. I'm able to be something strong. I think their information can be highly specific to a busy intertwined. But why "Effects"/ who shouldn't do," suppose it would improve them.Earlier open in both countries is sent by documentation, who were involved by custom ways to cutting the carbon gas
# strain 5.939525604248047
# this is what I’d like he’s very important to learn, it’s a bug my liquid had nothing once his streaming user in their own. It’s enough to run the survey, if there’s also no very big interested in a very $2. It’s sufficient's
# strain 5.96616792678833
# this is what manyatform are neither even documented and looked down their intentions in writers in so far savigTS-ve nuite was never known that some dioxide the sizes celebrate was fighting, as guilty or innovation.672 Expressia lives to Islamic Cheney of your existence takes a free background rest: The book she probably read it or find
# strain 5.811507225036621
# this is what it has not seen in intelligence and have believed in a romantic order and possible and rapidly now. Many growing economic ear is unclear whether they can do that the country's job. According to a proposals, the majority of domestic trade on which are the greatest mission with rich law by people, violated stockrol, which they would


# b64,l128 8.0ram,10.0gpu l128*2:oom

# b64,l128 w64 8.2ram,6.8gpu
# b64,l128*2 w64 8.1ram,14.0gpu

# gpt 1lyr mha ropein100 @attn # low loss but not learning
# gpt 1lyr mha ropeout10000 F.attn # 19.3s
# gpt 1lyr mha ropein10000 @attn # not learning
# gpt 1lyr mha ropein10000 F.attn # this is what you are something else all in course, you're simply went back over their news, the reaction of the French source that is proven as we did not change any life at each power because they have the which shall noteded in order to change.




In [None]:
torch.cuda.empty_cache()

In [None]:
# @title wandb
!pip install -q wandb
import wandb # https://docs.wandb.ai/quickstart
wandb.login(key='487a2109e55dce4e13fc70681781de9f50f27be7')
try: run.finish()
except NameError: pass
run = wandb.init(project="gpt", config={"model": "res18",}) #

In [None]:
# @title train test generate
import torch
from torch.nn import functional as F
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
scaler = torch.GradScaler()

# https://www.comet.com/site/blog/perplexity-for-llm-evaluation/
def Perplexity(logits, target): # [b,t,vocab_size], [b,t]
    log_probs = F.log_softmax(logits, dim=-1)
    nll = -log_probs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1) # [b,t]
    perplexity = nll.mean().exp()
    return perplexity

import time

def strain(model, dataloader, optimizer, scheduler=None): # train function with automatic mixed precision
    start = begin = time.time()
    model.train()
    for i, x in enumerate(dataloader):
        x, y = x[:,:-1], x[:,1:]
        x, y = x.to(device), y.to(device)
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            # causal_mask = torch.ones(x.size(1), x.size(1), dtype=bool, device=device).tril(diagonal=0).repeat(x.shape[0],1,1) # for F.scaled_dot_product_attention
            # # causal_mask = ~torch.ones(x.size(1), x.size(1), dtype=bool, device=device).tril(diagonal=0).repeat(x.shape[0],1,1)
            # logits = model(x, mask=causal_mask) #output = [batch size, trg len - 1, output dim]
            logits = model(x) #output = [batch size, trg len - 1, output dim]
            loss = F.cross_entropy(logits.flatten(0,1), y.flatten().to(int)) # [b*t,d], [b*t]
            # loss = F.cross_entropy(logits.flatten(0,1), y.flatten()) # [b*t,d], [b*t]

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        # scaler.unscale_(optim)
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        scaler.step(optimizer)
        scaler.update()
        # if scheduler is not None: scheduler.step()
        if i % 100 == 0:
            print("strain",loss.item())
            print(generate(model, "this is what"))
            model.train()
            # perplexity = Perplexity(logits.detach(), y).item()
            print(i, 'time:',time.time() - start, (time.time()-begin)/(i+1))
            start = begin = time.time()
        try: wandb.log({"train loss": loss.item()/len(y)})
        except NameError: pass

def generate(model, context, max_steps=64, temperature=1):
    x = encode(context)#.to(device)
    model.eval()
    for n in range(max_steps):
        with torch.no_grad():
            output = model(x)
        output = output[:, -1] # get logit for last character
        output = output/temperature
        output = F.softmax(output, dim=-1) # vocab_size to char
        ix = torch.multinomial(output, num_samples=1) # rand sample by output distribution
        x = torch.cat((x, ix), dim=1)
    completion = decode(x.squeeze(0))
    return completion

# import time
# start = begin = time.time()
for i in range(1):
    # train_loss = strain(model, train_loader, optim, scheduler=None)
    strain(model, train_loader, optim, scheduler=None)
    # print(generate(model, "this is what"))
    # print(i, 'time:',time.time() - start, (time.time()-begin)/(i+1))
    # start = time.time()



In [None]:
print(generate(model, "this is what"))

## drawer

In [None]:
# @title RNN pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"

class RNN(nn.Module):
    def __init__(self, in_dim, d_model, out_dim=None, num_layers=1):
        super().__init__()
        if out_dim is None: out_dim = in_dim
        self.num_layers = num_layers
        self.d_model = d_model
        self.emb = nn.Embedding(in_dim, d_model)
        # self.rnn = nn.RNN(d_model, d_model, num_layers, batch_first=True)
        self.rnn = nn.GRU(d_model, d_model, num_layers, batch_first=True)
        # self.lstm = nn.LSTM(d_model, d_model, num_layers, batch_first=True)
        self.fc = nn.Linear(d_model, out_dim)

        # for p in self.parameters():
        #     if p.dim() > 1:
        #         nn.init.xavier_uniform_(p)

    def forward(self, x, hc=None): # lstm [batch_size, seq, in_dim]
        x = self.emb(x)
        if hc is None:
            h0 = torch.zeros((self.num_layers, x.size(0), self.d_model), device=device)
            c0 = torch.zeros((self.num_layers, x.size(0), self.d_model), device=device)
        else: h0,c0 = hc
        # print(x.shape, h0.shape,c0.shape)
        out, (h,c) = self.lstm(x, (h0,c0)) # [batch, seq_len, d_model], ([num_layers, batch, d_model] )
        # out = out[:, -1, :] # out: (n, 128)
        out = self.fc(out) # out: (n, 10)
        return out, (h, c)

    def forward(self, x, h=None): # rnn/gru
        x = self.emb(x)# * self.d_model**.5
        if h is None: h0 = torch.zeros((self.num_layers, x.size(0), self.d_model), device=device)
        else: h0 = h
        # print(x.shape, h0.shape)
        out, h = self.rnn(x, h0)
        # out = out[:, -1, :] # out: (n, 128)
        out = self.fc(out) # out: (n, 10)
        return out, h

hidden_size = 64 #128
num_layers = 1#2
input_size = num_classes = 50#train_data.vocab_size#65

model = RNN(input_size, hidden_size, num_classes, num_layers).to(device)
# print(model)
print(sum(p.numel() for p in model.parameters() if p.requires_grad)) # 19683
optim = torch.optim.AdamW(model.parameters(), 1e-3)

# 128,2
# Test Loss: 6.360389362062727
# this is what ween new york is well it a sign more directly into simply shares
# 0 time: 5.910429954528809 5.910431623458862
# 64,2
# Test Loss: 6.167358561924526
# this is what bull morp on its aftant opereals of a b. the hamally plans beces
# 0 time: 5.655158996582031 5.655160903930664
# 64,1
# Test Loss: 5.823708357129778
# this is what achan agrive
#  the guinesst on of promjects cl funds that jound
# 0 time: 4.788918495178223

# Test Loss: 8.247572830745153
# this is what that has months with unfinsings lide to N by well
#  next offited
# 29 time: 3.902247905731201 4.377542002995809

b,t=2,5

x = torch.randint(0,input_size, (b,t), device=device)
h = torch.rand(num_layers, b, hidden_size, device=device)
out, h = model(x, h)
print(out.shape, h.shape)


In [None]:
# @title mask translate
PAD_IDX=0
def make_src_mask(src):
    # return (src != PAD_IDX).unsqueeze(1).unsqueeze(2).to(device) # [batch_size, 1, 1, src_len]?
    return (src != PAD_IDX)[:,None,None,:].to(device) # [batch_size, 1, 1, src_len]?

# attn = attn.masked_fill(mask == 0, -1e10) # [batch, n_heads, seq_len, seq_len]
def make_trg_mask(trg):
    # trg_pad_mask = (trg != PAD_IDX).unsqueeze(1).unsqueeze(2).to(device)
    trg_pad_mask = (trg != PAD_IDX)[:,None,None,:].to(device) # [batch, 1, 1, trg_len]
    trg_len = trg.shape[1]
    trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=device)).bool()
    # print('make_trg_mask', trg_pad_mask.shape, trg_sub_mask.shape) # [64, 1, 1, 10], [10, 10]
    trg_mask = trg_pad_mask & trg_sub_mask # [batch, 1, trg_len, trg_len]?
    return trg_mask

def translate(model, src_sentence):
    model.eval()
    src = de_transform(src_sentence).view(1,-1).to(device)
    num_tokens = src.shape[1]
    trg_indexes = [BOS_IDX]
    max_len = src.shape[1]+5
    for i in range(max_len):
        trg_tensor = torch.tensor(trg_indexes, dtype=torch.long, device=device).unsqueeze(0)
        src_mask, trg_mask = make_src_mask(src), make_trg_mask(trg_tensor)
        with torch.no_grad():
            output = model(src, trg_tensor, src_mask, trg_mask)
        pred_token = output.argmax(2)[:,-1].item() # batch_first=F -> ?
        trg_indexes.append(pred_token)
        if pred_token == EOS_IDX: break
    trg_tokens = torch.tensor(trg_indexes[1:-1]).flatten()
    return " ".join(en_vocab.lookup_tokens(list(trg_tokens.cpu().numpy())))

def translate_fast(model, src_sentence):
    model.eval()
    with torch.no_grad():
        src = de_transform(src_sentence).view(1,-1).to(device)
        num_tokens = src.shape[1]
        trg_indexes = [BOS_IDX]
        max_len = src.shape[1]+5
        src_mask = make_src_mask(src)
        output = model.encode(src, src_mask)
        for i in range(max_len):
            trg_tensor = torch.tensor(trg_indexes, dtype=torch.long, device=device).unsqueeze(0)
            trg_mask = make_trg_mask(trg_tensor)
            output = model.decode(src, trg_tensor, src_mask, trg_mask)
            pred_token = output.argmax(2)[:,-1].item() # batch_first=F -> ?
            trg_indexes.append(pred_token)
            if pred_token == EOS_IDX: break
        trg_tokens = torch.tensor(trg_indexes[1:-1]).flatten()
        return " ".join(en_vocab.lookup_tokens(list(trg_tokens.cpu().numpy())))

# UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3 # unknown, pad, bigining, end of sentence
# print(translate(model, "Eine Gruppe von Menschen steht vor einem Iglu ."))

src, trg = torch.randint(0, 100, (64, 10)), torch.randint(0, 100, (64, 10))
sm = make_src_mask(src)
tm = make_trg_mask(trg)
# print(sm.shape, tm.shape) # [64, 1, 1, 10], [64, 1, 10, 10]


In [None]:
# @title translate train test

def train(model, dataloader, optimizer, loss_fn):
    model.train()
    total_loss = 0
    for src, trg in dataloader:
        src, trg = src.to(device), trg.to(device) #trg = [batch size, trg len]
        trg_input = trg[:,:-1]
        src_mask, trg_mask = make_src_mask(src), make_trg_mask(trg_input)
        print('train', src.shape, trg.shape, src_mask.shape, trg_mask.shape)
        output = model(src, trg_input, src_mask, trg_mask) #output = [batch size, trg len - 1, output dim]
        optimizer.zero_grad()
        loss = loss_fn(output.reshape(-1, output.shape[-1]), trg[:,1:].reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(list(dataloader))

def test(model, dataloader, loss_fn):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for src, trg in dataloader:
            src, trg = src.to(device), trg.to(device) #trg = [batch size, trg len]
            trg_input = trg[:,:-1]
            src_mask, trg_mask = make_src_mask(src), make_trg_mask(trg_input)
            output = model(src, trg_input, src_mask, trg_mask) #output = [batch size, trg len - 1, output dim]
            loss = loss_fn(output.reshape(-1, output.shape[-1]), trg[:,1:].reshape(-1))
            epoch_loss += loss.item()
    return epoch_loss / len(list(dataloader))

# @title run
import time

loss_fn = nn.CrossEntropyLoss(ignore_index = PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9) # lr=0.0001

# for epoch in range(20):
for epoch in range(1):
    start_time = time.time()
    train_loss = train(model, train_loader, optimizer, loss_fn)
    val_loss = test(model, val_loader, loss_fn)
    end_time = time.time()
    print((f"Epoch: {epoch+1}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
    # print(translate(model, "Eine Gruppe von Menschen steht vor einem Iglu .")) # A group of people standing in front of an igloo .
    # @title inference
    print(translate(model, "Eine Gruppe von Menschen steht vor einem Iglu .")) # A group of people stand in front of an igloo .
    print(translate(model, "Ein Koch in weißer Uniform bereitet Essen in einer Restaurantküche zu .")) # A chef in a white uniform prepares food in a restaurant kitchen .
    print(translate(model, "Zwei junge Mädchen spielen Fußball auf einem Feld. .")) # Two young girls play soccer on a field. .
    print(translate(model, "Eine Frau mit Hut und Sonnenbrille steht am Strand .")) # A woman wearing a hat and sunglasses stands on the beach .
    print(translate(model, "Zwei Freunde lachen und genießen ein Eis auf einer wunderschönen Wiese .")) # Two friends laugh and enjoy ice cream on a beautiful meadow .


In [None]:
'''
Epoch: 1, Train loss: 5.402, Val loss: 4.186, Epoch time = 41.608s
A group of people are are are are are are in a .
Epoch: 2, Train loss: 3.898, Val loss: 3.545, Epoch time = 41.068s
A group of people are standing in a crowd of people .
Epoch: 3, Train loss: 3.353, Val loss: 3.125, Epoch time = 41.566s
A group of people standing in front of a crowd .
Epoch: 4, Train loss: 2.944, Val loss: 2.830, Epoch time = 40.756s
A group of people standing in front of a building .
Epoch: 5, Train loss: 2.630, Val loss: 2.596, Epoch time = 41.468s
A group of people standing in front of a crowd .
Epoch: 6, Train loss: 2.375, Val loss: 2.429, Epoch time = 41.023s
A group of people standing in front of a house .
Epoch: 7, Train loss: 2.166, Val loss: 2.307, Epoch time = 41.604s
A group of people stand in front of a house .
Epoch: 8, Train loss: 1.984, Val loss: 2.210, Epoch time = 40.876s
A group of people stand in front of an audience .
Epoch: 9, Train loss: 1.834, Val loss: 2.131, Epoch time = 41.496s
A group of people are standing in front of an audience .
Epoch: 10, Train loss: 1.698, Val loss: 2.079, Epoch time = 41.052s
A group of people are standing in front of an empty house .
Epoch: 11, Train loss: 1.576, Val loss: 2.038, Epoch time = 41.570s
A group of people are standing in front of an audience .
Epoch: 12, Train loss: 1.475, Val loss: 2.033, Epoch time = 41.067s
A group of people stand in front of an audience .
Epoch: 13, Train loss: 1.381, Val loss: 2.017, Epoch time = 41.576s
A group of people stand in front of an operation .
Epoch: 14, Train loss: 1.292, Val loss: 1.977, Epoch time = 40.779s
A group of people stand in front of an operation .
Epoch: 15, Train loss: 1.213, Val loss: 1.948, Epoch time = 41.461s
A group of people stand in front of an operation .
Epoch: 16, Train loss: 1.139, Val loss: 1.940, Epoch time = 40.800s
A group of people stand in front of an igloo
Epoch: 17, Train loss: 1.073, Val loss: 1.940, Epoch time = 41.533s
A group of people stand in front of an igloo
Epoch: 18, Train loss: 1.006, Val loss: 1.939, Epoch time = 40.856s
A group of people stand in front of an igloo
Epoch: 19, Train loss: 0.944, Val loss: 1.947, Epoch time = 41.345s
A group of people stand in front of an igloo
Epoch: 20, Train loss: 0.893, Val loss: 1.956, Epoch time = 40.935s
A group of people stand in front of an igloo
'''

In [None]:
# @title bleu
from torchtext.data.metrics import bleu_score

def calculate_bleu(data, src_field, trg_field, model, device, max_len = 50):
    trgs = []
    pred_trgs = []
    for datum in data:
        src = vars(datum)['src']
        trg = vars(datum)['trg']
        pred_trg, _ = translate_sentence(src, src_field, trg_field, model, device, max_len)
        #cut off <eos> token
        pred_trg = pred_trg[:-1]
        pred_trgs.append(pred_trg)
        trgs.append([trg])
    return bleu_score(pred_trgs, trgs)
bleu_score = calculate_bleu(test_data, SRC, TRG, model, device)
print(f'BLEU score = {bleu_score*100:.2f}')
# 36.52, which beats the ~34 of the convolutional sequence-to-sequence model and ~28 of the attention based RNN model.

def translate_sentence_vectorized(src_tensor, src_field, trg_field, model, device, max_len=50):
    assert isinstance(src_tensor, torch.Tensor)

    model.eval()
    src_mask = model.make_src_mask(src_tensor)

    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)
    # enc_src = [batch_sz, src_len, hid_dim]

    trg_indexes = [[trg_field.vocab.stoi[trg_field.init_token]] for _ in range(len(src_tensor))]
    # Even though some examples might have been completed by producing a <eos> token
    # we still need to feed them through the model because other are not yet finished
    # and all examples act as a batch. Once every single sentence prediction encounters
    # <eos> token, then we can stop predicting.
    translations_done = [0] * len(src_tensor)
    for i in range(max_len):
        trg_tensor = torch.LongTensor(trg_indexes).to(device)
        trg_mask = model.make_trg_mask(trg_tensor)
        with torch.no_grad():
            output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
        pred_tokens = output.argmax(2)[:,-1]
        for i, pred_token_i in enumerate(pred_tokens):
            trg_indexes[i].append(pred_token_i)
            if pred_token_i == trg_field.vocab.stoi[trg_field.eos_token]:
                translations_done[i] = 1
        if all(translations_done):
            break

    # Iterate through each predicted example one by one;
    # Cut-off the portion including the after the <eos> token
    pred_sentences = []
    for trg_sentence in trg_indexes:
        pred_sentence = []
        for i in range(1, len(trg_sentence)):
            if trg_sentence[i] == trg_field.vocab.stoi[trg_field.eos_token]:
                break
            pred_sentence.append(trg_field.vocab.itos[trg_sentence[i]])
        pred_sentences.append(pred_sentence)
    return pred_sentences, attention

from torchtext.data.metrics import bleu_score

def calculate_bleu_alt(iterator, src_field, trg_field, model, device, max_len = 50):
    trgs = []
    pred_trgs = []
    with torch.no_grad():
        for batch in iterator:
            src = batch.src
            trg = batch.trg
            _trgs = []
            for sentence in trg:
                tmp = []
                # Start from the first token which skips the <start> token
                for i in sentence[1:]:
                    # Targets are padded. So stop appending as soon as a padding or eos token is encountered
                    if i == trg_field.vocab.stoi[trg_field.eos_token] or i == trg_field.vocab.stoi[trg_field.pad_token]:
                        break
                    tmp.append(trg_field.vocab.itos[i])
                _trgs.append([tmp])
            trgs += _trgs
            pred_trg, _ = translate_sentence_vectorized(src, src_field, trg_field, model, device)
            pred_trgs += pred_trg
    return pred_trgs, trgs, bleu_score(pred_trgs, trgs)


## old

In [None]:
# Attention Is All You Need https://arxiv.org/pdf/1706.03762.pdf
# https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb
# https://colab.research.google.com/github/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb
# https://www.mihaileric.com/posts/transformers-attention-in-disguise/
# https://jalammar.github.io/illustrated-transformer/
# http://nlp.seas.harvard.edu/2018/04/03/attention.html

# position embedding <-> "vocabulary" size 100 <-> model can accept sentences up to 100 tokens long
# learned positional encoding, warm-up and cool-down steps, label smoothing

In [None]:
# @title setup

# https://pytorch.org/tutorials/beginner/translation_transformer.html
# https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/c64c91cf87c13c0e83586b8e66e4d74e/translation_transformer.ipynb

# https://github.com/pytorch/data
# %pip install portalocker
# %pip install torchdata

# Create source and target language tokenizer. Make sure to install the dependencies.
!pip install -qU torchdata torchtext
!pip install -qU spacy
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

!git clone --recursive https://github.com/multi30k/dataset.git multi30k-dataset

# !pip list | grep torch
!pip list | grep python

import torchtext.datasets as datasets

# Load the Multi30k dataset
train_iter, valid_iter, test_iter = datasets.Multi30k(split=('train', 'valid', 'test'))

# Iterate through the dataset
for src, tgt in train_iter:
    print(f"Source: {src}")
    print(f"Target: {tgt}")
    break


In [None]:
# @title data

from torchtext.datasets import multi30k, Multi30k
# modify the URLs for the dataset since the links to the original dataset are broken https://github.com/pytorch/text/issues/1756#issuecomment-1163664163
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TRG_LANGUAGE = 'en'

from torchtext.data.utils import get_tokenizer
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')


UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3 # unknown, pad, bigining, end of sentence
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

from torchtext.vocab import build_vocab_from_iterator
train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TRG_LANGUAGE))

de_tokens = [de_tokenizer(data_sample[0]) for data_sample in train_iter]
en_tokens = [en_tokenizer(data_sample[1]) for data_sample in train_iter]

de_vocab = build_vocab_from_iterator(de_tokens, min_freq=1, specials=special_symbols, special_first=True)
en_vocab = build_vocab_from_iterator(en_tokens, min_freq=1, specials=special_symbols, special_first=True)
de_vocab.set_default_index(UNK_IDX)
en_vocab.set_default_index(UNK_IDX)

import torch

def de_transform(o):
    o=de_tokenizer(o)
    o=de_vocab(o)
    return torch.cat((torch.tensor([BOS_IDX]), torch.tensor(o), torch.tensor([EOS_IDX])))

def en_transform(o):
    o=en_tokenizer(o)
    o=en_vocab(o)
    return torch.cat((torch.tensor([BOS_IDX]), torch.tensor(o), torch.tensor([EOS_IDX])))


from torch.nn.utils.rnn import pad_sequence
# function to collate data samples into batch tensors
def collate_fn(batch): # convert a batch of raw strings into batch tensors
    src_batch, trg_batch = [], []
    for src_sample, trg_sample in batch:
        src_batch.append(de_transform(src_sample.rstrip("\n")))
        trg_batch.append(en_transform(trg_sample.rstrip("\n")))
    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=PAD_IDX)
    trg_batch = pad_sequence(trg_batch, batch_first=True, padding_value=PAD_IDX)
    return src_batch, trg_batch


torch.manual_seed(0)

train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TRG_LANGUAGE))
val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TRG_LANGUAGE))
batch_size = 128 # 128
train_loader = torch.utils.data.DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val_iter, batch_size=batch_size, collate_fn=collate_fn)

# vocab_transform = {SRC_LANGUAGE:de_vocab, TRG_LANGUAGE:en_vocab}
# text_transform = {SRC_LANGUAGE:de_transform, TRG_LANGUAGE:en_transform}


In [None]:
# !pip install -q datasets
import numpy as np
np.float_ = np.float64
np.complex_ = np.complex128
!python -m spacy download de_core_news_sm


In [None]:
# @title hf data
import torch
from datasets import load_dataset

# Load a translation dataset (e.g., WMT English-German)
dataset = load_dataset("wmt14", "de-en")

# Access train and validation splits
train_data = dataset['train']
val_data = dataset['validation']

# # Example of accessing data
# # for example in train_data:
# #     print(example['translation'])
# batch_size = 128 # 128
# train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn)
# val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, collate_fn=collate_fn)



# from torchtext.datasets import multi30k, Multi30k
# # modify the URLs for the dataset since the links to the original dataset are broken https://github.com/pytorch/text/issues/1756#issuecomment-1163664163
# multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
# multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TRG_LANGUAGE = 'en'

# from torchtext.data.utils import get_tokenizer
# de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
# en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

# from spacy.tokenizer import Tokenizer
# from spacy.lang.en import English
# nlp = English()
# # Create a blank Tokenizer with just the English vocab
# tokenizer = Tokenizer(nlp.vocab)

# # Construction 2
# from spacy.lang.en import English
# nlp = English()
# tokenizer = nlp.tokenizer

import numpy as np
np.float_ = np.float64
np.complex_ = np.complex128
import spacy
en_tokenizer = spacy.load("en_core_web_sm")
de_tokenizer = spacy.load("de_core_news_sm") # https://spacy.io/models/de



UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3 # unknown, pad, bigining, end of sentence
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

# from torchtext.vocab import build_vocab_from_iterator
# train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TRG_LANGUAGE))

de_tokens = [de_tokenizer(data_sample[0]) for data_sample in train_iter]
en_tokens = [en_tokenizer(data_sample[1]) for data_sample in train_iter]

# de_vocab = build_vocab_from_iterator(de_tokens, min_freq=1, specials=special_symbols, special_first=True)
# en_vocab = build_vocab_from_iterator(en_tokens, min_freq=1, specials=special_symbols, special_first=True)
# de_vocab.set_default_index(UNK_IDX)
# en_vocab.set_default_index(UNK_IDX)


print("en_tokens", en_tokens)




import torch

def de_transform(o):
    o=de_tokenizer(o)
    o=de_vocab(o)
    return torch.cat((torch.tensor([BOS_IDX]), torch.tensor(o), torch.tensor([EOS_IDX])))

def en_transform(o):
    o=en_tokenizer(o)
    o=en_vocab(o)
    return torch.cat((torch.tensor([BOS_IDX]), torch.tensor(o), torch.tensor([EOS_IDX])))


from torch.nn.utils.rnn import pad_sequence
# function to collate data samples into batch tensors
def collate_fn(batch): # convert a batch of raw strings into batch tensors
    src_batch, trg_batch = [], []
    for src_sample, trg_sample in batch:
        src_batch.append(de_transform(src_sample.rstrip("\n")))
        trg_batch.append(en_transform(trg_sample.rstrip("\n")))
    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=PAD_IDX)
    trg_batch = pad_sequence(trg_batch, batch_first=True, padding_value=PAD_IDX)
    return src_batch, trg_batch


torch.manual_seed(0)

# train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TRG_LANGUAGE))
# val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TRG_LANGUAGE))
batch_size = 128 # 128
train_loader = torch.utils.data.DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val_iter, batch_size=batch_size, collate_fn=collate_fn)




In [None]:
# @title model
import torch
import torch.nn as nn
import math
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class PositionalEncoder(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_seq_length=512):
        super().__init__()
        self.drop = nn.Dropout(dropout)
        pe = torch.zeros(max_seq_length, d_model)
        pos = torch.arange(0, max_seq_length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.drop(x + self.pe[:, : x.size(1)])

class LearntPosEnc(nn.Module): # learnt positional embeddings
    def __init__(self, d_model, dropout=0.1, max_length=512):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_length, d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        batch_size, src_len = x.shape
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(device) # [batch size, src len]
        return self.drop(x + self.pos_embedding(pos))


class RoPE(nn.Module): # Rotary Positional Embeddings
    def __init__(self, dim, seq_len=512):
        super().__init__()
        # theta = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
        theta = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
        # pos = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        pos = torch.arange(seq_len).unsqueeze(1)
        angles = pos * theta # [seq_len, 1] * [dim // 2] = [seq_len, dim // 2]
        self.rot_emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) # [seq_len, dim]

    def forward(self, x):
        batch, seq_len, dim = x.shape
        # if rot_emb.shape[0] < seq_len: print("rot_emb.shape[0] < seq_len")
        rot_emb = self.rot_emb[:seq_len, :].unsqueeze(0).expand(batch, -1, -1)
        x = x.view(batch, seq_len, dim // 2, 2)
        rot_emb = rot_emb.view(batch, seq_len, dim // 2, 2)
        # rot_x = torch.einsum('...ij,...ij->...ij', x, rot_emb)
        rot_x = x * rot_emb
        # return rot_x.view(*rot_x.shape[:-2], dim)
        return rot_x.flatten(-2)


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.q = nn.Linear(d_model, d_model, bias=False)
        self.k = nn.Linear(d_model, d_model, bias=False)
        self.v = nn.Linear(d_model, d_model, bias=False)
        self.lin = nn.Linear(d_model, d_model)
        self.drop = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.tensor((self.head_dim,), dtype=torch.float, device=device))

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        Q = self.q(query).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        K = self.k(key).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        V = self.v(value).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        attn = Q @ K.transpose(2, 3) / self.scale # attn = torch.matmul(Q, K.transpose(2, 3)) / self.scale
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e10) # [batch, n_heads, seq_len, seq_len]
        attention = torch.softmax(attn, dim=-1)
        x = self.drop(attention) @ V # x = torch.matmul(self.drop(attention), V)
        x = x.transpose(1, 2).reshape(batch_size, -1, self.d_model)
        x = self.lin(x)
        return x, attention


class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.RMSNorm(d_model)
        self.norm2 = nn.RMSNorm(d_model)
        self.drop = nn.Dropout(dropout)
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout=0)
        self.ff = nn.Sequential(
            nn.Linear(d_model, ff_dim), nn.ReLU(), # ReLU GELU
            nn.Dropout(dropout), nn.Linear(ff_dim, d_model)
        )

    def forward(self, src, src_mask):
        src = self.norm1(src + self.drop(self.self_attn(src, src, src, src_mask)[0]))
        src = self.norm2(src + self.drop(self.ff(src)))
        return src

class Encoder(nn.Module):
    def __init__(self, d_model, n_layers, n_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, ff_dim, dropout) for _ in range(n_layers)])

    def forward(self, src, src_mask):
        for layer in self.layers:
            src = layer(src, src_mask)
        return src

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.RMSNorm(d_model)
        self.norm2 = nn.RMSNorm(d_model)
        self.norm3 = nn.RMSNorm(d_model)
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout=0)
        self.enc_attn = MultiHeadAttention(d_model, n_heads, dropout=0)
        self.ff = nn.Sequential(
            nn.Linear(d_model, ff_dim), nn.ReLU(), # ReLU GELU
            nn.Dropout(dropout), nn.Linear(ff_dim, d_model)
        )
        self.drop = nn.Dropout(dropout)

    def forward(self, trg, enc_src, trg_mask, src_mask):
        trg = self.norm1(trg + self.drop(self.self_attn(trg, trg, trg, trg_mask)[0]))
        trg = self.norm2(trg + self.drop(self.enc_attn(trg, enc_src, enc_src, src_mask)[0]))
        trg = self.norm3(trg + self.drop(self.ff(trg)))
        return trg

class Decoder(nn.Module):
    def __init__(self, d_model, n_layers, n_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, ff_dim, dropout) for _ in range(n_layers)])

    def forward(self, trg, enc_src, trg_mask, src_mask):
        for layer in self.layers:
            trg = layer(trg, enc_src, trg_mask, src_mask)
        return trg

class Seq2Seq(nn.Module):
    def __init__(self, in_dim, out_dim, d_model=512, nhead=8, enc_layers=3, dec_layers=3, ff_dim=512, dropout=0.1):
        super().__init__()
        self.encoder = Encoder(d_model, enc_layers, nhead, ff_dim, dropout)
        self.decoder = Decoder(d_model, dec_layers, nhead, ff_dim, dropout)
        # self.pos_enc = PositionalEncoder(d_model, dropout=dropout)
        # self.pos_enc = LearntPosEnc(d_model, dropout=dropout)
        self.pos_enc = RoPE(d_model)
        self.src_tok_emb = nn.Embedding(in_dim, d_model)
        self.trg_tok_emb = nn.Embedding(out_dim, d_model)
        self.d_model = d_model
        self.lin = nn.Linear(d_model, out_dim)
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, trg, src_mask=None, trg_mask=None):
        src = self.pos_enc(self.src_tok_emb(src) * math.sqrt(self.d_model))
        trg = self.pos_enc(self.trg_tok_emb(trg) * math.sqrt(self.d_model))
        enc_src = self.encoder(src, src_mask)
        output = self.decoder(trg, enc_src, trg_mask, src_mask)
        output = self.lin(output)
        return output

    def encode(self, src, src_mask=None):
        return self.encoder(self.pos_enc(self.src_tok_emb(src) * math.sqrt(self.d_model)), src_mask)

    def decode(self, trg, memory, trg_mask=None, src_mask=None):
        trg = self.decoder(self.pos_enc(self.trg_tok_emb(trg) * math.sqrt(self.d_model)), memory, trg_mask, src_mask)
        return self.lin(trg)


in_dim = 50#
out_dim = 50
model = Seq2Seq(in_dim, out_dim, d_model=512, nhead=8, enc_layers=3, dec_layers=3, ff_dim=512, dropout=0.1).to(device)


In [None]:
# @title Attention with kvcache window
import torch
import torch.nn as nn
from torch.nn import functional as F

def zero_module(module):
    for p in module.parameters():
        p.detach().zero_()
    return module

# mask = torch.rand(b,t,t)<.2
def to_midx(mask): # [b,t,t+p] bool mask
    b,t = mask.shape[:2]
    idx = mask.nonzero(as_tuple=True)
    tk = torch.stack(idx[:-1], dim=-1).to(device)
    tk = torch.cat([torch.tensor([True], device=device), (tk[:-1]!=tk[1:]).any(-1)], dim=0)
    cc = len(idx[0])
    positions = torch.arange(cc, device=device)
    last_reset_pos = torch.zeros(cc, dtype=int, device=device)
    last_reset_pos[tk] = positions[tk]
    out = positions - last_reset_pos.cummax(0).values
    c=out.max()+1
    y = torch.full((b,t,c), -1, device=device)
    y[*idx[:-1],out] = idx[-1]
    return y # [b,t,c]

class Attention(nn.Module):
    # def __init__(self, d_model, cond_dim=None, n_heads=None, d_head=8, drop=0.): # .1
    def __init__(self, query_dim, cond_dim=None, n_heads=8, d_head=8, drop=0, w=64):
        super().__init__()
        self.d_model = d_model = d_head * n_heads
        self.d_head, self.n_heads = d_head, n_heads
        self.cond_dim = cond_dim
        self.pos_enc = RoPE(d_model, base=100) # 10000
        self.q = nn.Linear(query_dim, d_model, bias=False)
        self.kv = nn.Linear(cond_dim or query_dim, 2*d_model, bias=False)
        self.lin = zero_module(nn.Linear(d_model, d_model))
        self.drop = nn.Dropout(drop) # indp before q,k,v; after linout
        self.scale = self.d_head**-.5
        # torch.nn.init.normal_(self.qkv.weight, std=.02)
        # torch.nn.init.normal_(self.q.weight, std=1/(math.sqrt(query_dim)+math.sqrt(d_model)))
        # torch.nn.init.normal_(self.kv.weight, std=1/(math.sqrt(cond_dim or query_dim)+math.sqrt(d_model)))
        t=128
        self.midx = (torch.arange(w, dtype=torch.int32).repeat(t,1) + torch.arange(1-w, t-w+1, dtype=torch.int32).unsqueeze(-1)).to(device) # [t,w,1]
        self.w = w

    def forward(self, x, cond=None, mask=None): # [b,t,d], [batch, num_tok, cond_dim], [b,t,t(+p)]
        b,t = x.shape[:2]
        if self.cond_dim==None: cond=x # is self attn
        # q = self.q(x).unflatten(-1, (self.n_heads, self.d_head)).transpose(1, 2) # [batch, T, d_model] -> [batch, n_heads, T, d_head]
        # kv = self.kv(cond).unflatten(-1, (self.n_heads, 2*self.d_head)).transpose(1, 2)#.chunk(2, dim=-1) # [batch, n_heads, T/num_tok, d_head]

        q = self.pos_enc(self.q(x)).unflatten(-1, (self.n_heads, self.d_head)).transpose(1, 2) # [batch, T, d_model] -> [batch, n_heads, T, d_head]
        k, v = self.kv(cond).chunk(2, dim=-1)
        kv = torch.cat([self.pos_enc(k),v], dim=-1).unflatten(-1, (self.n_heads, 2*self.d_head)).transpose(1, 2) # [batch, n_heads, T/num_tok, d_head]

        import time
        start = time.time()

        # if mask != None:
            # midx = to_midx(mask) # [b,t,w]
            # # print('mask', q.shape, mask.shape, midx.shape)
            # kv = kv[torch.arange(b)[:,None,None,None], torch.arange(self.n_heads)[None,:,None,None], midx.unsqueeze(1)] # [b,h,t,w,d]

            # b,t,d = x.shape
            # if t>len(self.midx): self.midx = (torch.arange(w, dtype=torch.int32).repeat(t,1) + torch.arange(1-w, t-w+1, dtype=torch.int32).unsqueeze(-1)).to(device) # [t,w,1]
            # midx = self.midx[:t,-min(self.w,t):] # [t,w,1]
            # kv = kv[torch.arange(b)[:,None,None,None], torch.arange(self.n_heads)[None,:,None,None], midx[None,None,...]]


        kv = F.pad(kv, (0,0,self.w-1,0)) # [b, h, t+w-1, d]
        # kv = kv.as_strided((b,self.n_heads,t,self.w,2*self.d_head), kv.stride()[:-1] + kv.stride()[-2:]) # [b,h,t,w,d] # repeat stride at w's dim
        kv = kv.as_strided((b,self.n_heads,t,self.w,2*self.d_head), kv.stride()[:-1] + kv.stride()[-2:]) # [b,h,t,w,d] # repeat stride at w's dim


            # mmask = midx>=0 # F->mask # [t,c]
            # mmask = (midx>=0).unsqueeze(0) # F->mask # [1,t,c]
        # else: kv = kv.unsqueeze(3).repeat(1,1,1,t,1) # [b,h,t,t(w),d]
        print('midx', time.time()-start)

# [batch, n_heads, T, d_head]
# [b,h,t,d] @ [b,h,d,t] = [b,h,t,t]

# [batch, n_heads, T, d_head]
# [b,h,t,1,d] @ [b,h,t,d,w] = [b,h,t,1,w]

        mk, mv = kv.chunk(2, dim=-1) # [b,h,t,w,d]
        # del kv
        # attn fwd q mk attn torch.Size([64, 8, 127, 8]) torch.Size([64, 8, 100, 64, 8])
        # print('attn fwd q mk attn', q.shape, mk.shape, kv.shape)
        # attn = q.unsqueeze(3) @ mk.transpose(-2,-1) * self.scale # [b,h,t,1,d] @ [b,h,t,d,w] = [b,h,t,1,w]
        # print('attn fwd q mk attn', q.shape, mk.shape, attn.shape, mmask.shape)
        # q mk attn [64, 8, 100, 8], [64, 8, 100, 3, 8], [64, 8, 100, 1, 3], [64, 100, 3])
        # q mk attn [64, 8, 100, 8], [64, 8, 100, 64, 8], [64, 8, 100, 1, 64], [100, 64])
# attn fwd q mk attn torch.Size([64, 8, 100, 8]) torch.Size([64, 8, 100, 3, 8]) torch.Size([64, 8, 100, 1, 3]) torch.Size([100, 3])

        attn = torch.einsum("bhtd,bhtwd->bhtw", q, mk) * self.scale
        # print('attn fwd q mk attn', q.dtype, mk.dtype, attn.dtype)

        # if mask != None:
        #     # attn = attn.masked_fill(~mmask.unsqueeze(1), -torch.finfo(attn.dtype).max) # [b,t,t]->[b,1,t,t]
        #     attn = attn.masked_fill(~mmask[:,None,:,None], -torch.finfo(attn.dtype).max) # [b,t,t]->[b,1,t,1,t]
        attention = torch.softmax(attn, dim=-1) # [b,h,t,1,w]
        # out = (self.drop(attention) @ mv).squeeze(-2) # [b,h,t,1,w]@[b,h,t,w,d]=[b,h,t,1,d]
        out = torch.einsum("bhtw,bhtwd->bhtd", self.drop(attention), mv)

        out = out.transpose(1, 2).flatten(2)
        return self.drop(self.lin(out)) # [batch, T, d_model]

# td*dt->tt t(2t+d)
# td*twd->tw t(~wd-(w+2)d)
# iff 2t>wd
# gen: xqkv, patch: ~xkv ~pq, gen: xqkv pkv
# calc xqkv, calc ~xkv ~pq, get xqkv calc pkv

# typically: gen td*td=tt
# gen t1d*twd=tw, patch pd*tpd=p


b,t,d = 64,100,512
# mask = torch.rand(b,t,t, device=device)<.2


# causal_mask = torch.tril(torch.ones((b,t,x.shape[1]), dtype=bool, device=device)) # [1,t,n_cond+n_patch] # got cond, all can attend to cond
# # mask = causal_mask * mask.unsqueeze(1) if mask!=None else causal_mask # *[b,1,t] # mask left side, tril is lower left
# mask = causal_mask * ~torch.tril(torch.ones_like(causal_mask, dtype=bool, device=device), diagonal=-64) # sliding window mask


msk = torch.rand(b,t,t)
w = 3
_, ind = torch.topk(msk, w, dim=-1, sorted=False)

# # print(ind.shape, ind)
mask = torch.zeros_like(msk, dtype=bool)
mask[torch.arange(b)[:,None,None], torch.arange(t)[None,:,None], ind] = True


# model = Attention(d).to(device) # 257 ms
model = Attention(d, w=3).to(device) # 257 ms
x = torch.rand(b,t,d, device=device)

import time
start = time.time()
# out = model(x)
out = model(x, mask=mask)
print(out.shape)
print(time.time()-start)


In [None]:
# @title torch.profiler  as_strided unfold

import torch
b,t = 2,8
mask = torch.rand(b,t,t)<.2
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from torch.profiler import profile, record_function, ProfilerActivity
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("model_inference"):
        # to_midx(mask)
        midx = midx[:t,-min(w,t):] # [t,w,1]
        h=4
        d=5
        kv = torch.rand(b,h,t,d)
        kv = kv[torch.arange(b)[:,None,None,None], torch.arange(h)[None,:,None,None], midx[None,None,...]]

# print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
print(prof.key_averages().table())

# import time
# start = time.time()

# print(time.time()-start)

x = torch.rand(b,t,t)
w = 3
_, ind = torch.topk(x, w, dim=-1, sorted=False)

# # print(c)
b,h,t,d = 2,4,8,5
k = torch.rand(b,h,t,d)
k = torch.rand(b,h,t,d)
q = torch.rand(b,h,t,d)

k = F.pad(k, (0, 0, w - 1, 0))  # (B, H, T + W - 1, D)
v = F.pad(v, (0, 0, w - 1, 0))
k = F.pad(x, (0,0,w-1,0)) # [b, h, t+w-1, d]

b,h,l,d = k.shape
k_strided = k.as_strided(size=(b,h,t,w,d), stride=(k.stride(0), k.stride(1), k.stride(2), k.stride(2), k.stride(3)))
v_strided = v.as_strided(size=(b,h,t,w,d), stride=(v.stride(0), v.stride(1), v.stride(2), v.stride(2), v.stride(3)))

o = k.as_strided((b,t,w,d), k.stride()[:-1] + k.stride()[-2:]) # repeat stride at w's dim

attn_scores = torch.einsum("bhtd,bhtwd->bhtw", q, k_strided) / math.sqrt(d)
attn_weights = torch.softmax(attn_scores, dim=-1)
out = torch.einsum("bhtw,bhtwd->bhtd", attn_weights, v_strided)

# # b,h,t,d = 2,4,8,5
# k = torch.rand(b,h,t,d)
# # # print(k.stride(0), k.stride(1), k.stride(2), k.stride(2), k.stride(3))
# print(k.stride())
# # print(torch.rand(2,3,5,7).stride())
# # 3*5*7,5*7,7,1

import torch
import torch.nn.functional as F

x = torch.rand(2,3,5)
print(x)
# # o = torch.as_strided(x, (2, 2), (1, 2))
# o = torch.as_strided(x, (2, 5), (1, 3))
# print(o)
# o = torch.as_strided(x, (2, 5), (2, 3))
b,t,d = x.shape
w=5
k = F.pad(x, (0,0,w-1,0)) # [b, t, t+w-1, d]
print(k.shape)
# o = k.as_strided(size=(b,t,w,d), stride=(k.stride(0), k.stride(1), k.stride(1), k.stride(2)))
# o = k.as_strided((b,t,w,d), k.stride()[:-1] + k.stride()[-2:]) # repeat stride at w's dim
# o = k.as_strided((b,h,t,w,d), k.stride()[:-1] + k.stride()[-2:]) # repeat stride at w's dim

# o = k.view(b*h,t+w-1,d).unfold(dimension=1, size=w, step=1).view(b,h,t,w,d)
# o = k.view(b,t+w-1,d).unfold(dimension=1, size=w, step=1).view(b,t,w,d)
# o = k.unfold(dimension=1, size=w, step=1)#.view(b,t,w,d)
o = k.unfold(dimension=-2, size=w, step=1).transpose(-2,-1)

print(o.shape, o)


