In [None]:
import auto_compyute as ac
import auto_compyute.nn.functional as F
from auto_compyute import nn

import numpy as np
import torch

ac.backends.set_random_seed(0)

In [None]:
ctx_len = 256
emb_dim = 384
n_heads = 6
n_blocks = 6
batch_size = 64

In [None]:
class Transformer(nn.Module):
    def __init__(self, n_emb, emb_dim, seq_len, n_heads, n_layers, mask, dropout=0) -> None:
        super().__init__()
        self.wte = nn.Embedding(n_emb, emb_dim)
        self.wpe = nn.Embedding(seq_len, emb_dim)
        self.wte.w.data *= emb_dim**-0.5
        self.wpe.w.data *= emb_dim**-0.5

        out_scale = (2 * n_layers)**-0.5
        self.blocks = nn.Modulelist(Block(emb_dim, n_heads, mask, dropout, out_scale) for _ in range(n_layers))

        self.head_ln = nn.Layernorm((emb_dim))
        self.head = nn.Linear(emb_dim, n_emb, bias=False)
        self.head.w = self.wte.w

        self.pos = nn.Buffer(ac.arange(seq_len).view((1, -1)))

    def forward(self, x):
        x = self.wte(x) + self.wpe(self.pos[:, : x.shape[-1]])
        for block in self.blocks:
            x = block(x)
        x = self.head(self.head_ln(x))
        return x


class Block(nn.Module):
    def __init__(self, emb_dim, n_heads, mask, dropout, out_scale) -> None:
        super().__init__()
        
        self.attn_ln = nn.Layernorm((emb_dim,))
        self.attn = nn.MultiHeadSelfAttention(emb_dim, n_heads, mask, dropout)
        self.attn.qkv.w.data *= out_scale
        self.attn_dropout = nn.Dropout(dropout)

        self.mlp_ln = nn.Layernorm((emb_dim,))
        self.mlp = MLP(emb_dim)
        self.mlp.down.w.data *= out_scale
        self.mlp_dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x + self.attn_dropout(self.attn(self.attn_ln(x)))
        x = x + self.mlp_dropout(self.mlp(self.mlp_ln(x)))
        return x


class MLP(nn.Module):
    def __init__(self, n_emb) -> None:
        super().__init__()
        self.up = nn.Linear(n_emb, 4*n_emb)
        self.down = nn.Linear(4*n_emb, n_emb)

    def forward(self, x):
        x = self.up(x)
        x = F.gelu(x)
        x = self.down(x)
        return x

In [None]:
model = Transformer(
    n_emb=256,
    emb_dim=emb_dim,
    seq_len=ctx_len,
    n_heads=n_heads,
    n_layers=n_blocks,
    mask=ac.full((ctx_len, ctx_len), float("-inf")).triu(1)
)
model.to(ac.cuda)

In [None]:
x = ac.randi((batch_size, ctx_len), 0, 256, device=ac.cuda, dtype=ac.int32)
y = ac.randi((batch_size, ctx_len), 0, 256, device=ac.cuda, dtype=ac.int32)
loss = F.cross_entropy(model(x), y)

In [None]:
queue = ac.autograd.build_backward_queue(loss, [], set())

In [None]:
import sys

total_bytes = 0

for node in queue:
    print(node.ctx.name)
    vals = node.ctx.cache.vals if node.ctx.cache.vals is not None else []
    for v in vals:
        v_id = id(v)
        v_dtype = v.dtype if isinstance(v, ac.backends.Array) else type(v)
        v_shape = v.shape if isinstance(v, ac.backends.Array) else 1
        v_nbytes = v.nbytes if isinstance(v, ac.backends.Array) else sys.getsizeof(v)
        total_bytes += v_nbytes
        print("    ", v_id, v_dtype, v_shape, f"{v_nbytes:_}")

print(f"total {total_bytes:_}")

In [None]:
# for p in model.parameters():
#     print(p.grad)