In [None]:
import auto_compyute as ac
import auto_compyute.nn.functional as F
from auto_compyute import nn

ac.backends.set_random_seed(0)
device = ac.cuda if ac.backends.gpu_available() else ac.cpu

In [None]:
ctx_len = 256
emb_dim = 384
n_heads = 6
n_blocks = 6
batch_size = 32

In [None]:
import requests

# load data
DATA_URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = requests.get(DATA_URL)
data = response.text

# tokenization
chars = sorted(list(set(response.text)))
vocab = {i: c for i, c in enumerate(chars)}
ivocab = {c: i for i, c in vocab.items()}
encode = lambda text: [ivocab[t] for t in text]
decode = lambda token_ids: "".join(vocab[id] for id in token_ids)

vocab_size = len(chars)

In [None]:
# prepare data
data_enc = ac.tensor(encode(data), dtype=ac.int32)
X = ac.stack(*[data_enc[i * ctx_len : i * ctx_len + ctx_len] for i in range(len(data_enc) // ctx_len)])
y = ac.stack(*[data_enc[i * ctx_len + 1 : i * ctx_len + ctx_len + 1] for i in range(len(data_enc) // ctx_len)])
n = int(len(X) * 0.9)
X_train = X.int()[:n]
y_train = y.int()[:n]
X_val = X.int()[n:]
y_val = y.int()[n:]

In [None]:
class Transformer(nn.Module):
    def __init__(self, n_emb, emb_dim, seq_len, n_heads, n_layers, mask, dropout=0) -> None:
        super().__init__()
        self.wte = nn.Embedding(n_emb, emb_dim)
        self.wpe = nn.Embedding(seq_len, emb_dim)
        self.wte.w.data *= emb_dim**-0.5
        self.wpe.w.data *= emb_dim**-0.5

        out_scale = (2 * n_layers)**-0.5
        self.blocks = nn.Modulelist(Block(emb_dim, n_heads, mask, dropout, out_scale) for _ in range(n_layers))

        self.head_ln = nn.Layernorm((emb_dim))
        self.head = nn.Linear(emb_dim, n_emb, bias=False)
        self.head.w = self.wte.w

        self.pos = nn.Buffer(ac.arange(seq_len).view((1, -1)))

    def forward(self, x):
        x = self.wte(x) + self.wpe(self.pos[:, : x.shape[-1]])
        for block in self.blocks:
            x = block(x)
        x = self.head(self.head_ln(x))
        return x


class Block(nn.Module):
    def __init__(self, emb_dim, n_heads, mask, dropout, out_scale) -> None:
        super().__init__()
        
        self.attn_ln = nn.Layernorm((emb_dim,))
        self.attn = nn.MultiHeadSelfAttention(emb_dim, n_heads, mask, dropout)
        self.attn.qkv.w.data *= out_scale
        self.attn_dropout = nn.Dropout(dropout)

        self.mlp_ln = nn.Layernorm((emb_dim,))
        self.mlp = MLP(emb_dim)
        self.mlp.down.w.data *= out_scale
        self.mlp_dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x + self.attn_dropout(self.attn(self.attn_ln(x)))
        x = x + self.mlp_dropout(self.mlp(self.mlp_ln(x)))
        return x


class MLP(nn.Module):
    def __init__(self, n_emb) -> None:
        super().__init__()
        self.up = nn.Linear(n_emb, 4*n_emb)
        self.down = nn.Linear(4*n_emb, n_emb)

    def forward(self, x):
        x = self.up(x)
        x = F.gelu(x)
        x = self.down(x)
        return x

In [None]:
model = Transformer(
    n_emb=vocab_size,
    emb_dim=emb_dim,
    seq_len=ctx_len,
    n_heads=n_heads,
    n_layers=n_blocks,
    mask=ac.full((ctx_len, ctx_len), float("-inf")).triu(1)
)
model.to(ac.cuda)

In [None]:
# loss = F.cross_entropy(model(X_train[:batch_size]), y_train[:batch_size])
# ac.autograd.draw_compute_graph(loss)

In [None]:
# training
train_dl = nn.Dataloader((X_train, y_train), batch_size, device)
val_dl = nn.Dataloader((X_val, y_val), batch_size, device, False)
optim = nn.optimizers.AdamW(model.parameters(), learning_rate=3e-4)

In [None]:
# training parameters
step = 1
max_steps = 2500
val_interval = 250

In [None]:
import time

model.train()
for step, (x, y) in enumerate(train_dl()):
    start = time.perf_counter()
    loss = F.cross_entropy(model(x), y)
    loss.backward()
    optim.step()
    optim.zero_grad()
    dt = time.perf_counter() - start

    tok_per_s = batch_size * ctx_len / dt
    print(f"step {step+1:4} | loss {loss.item():.4f} | dt {dt:.4f} s | {tok_per_s:.1f} tokens/s")