In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
from dataclasses import dataclass
from tqdm import tqdm 

In [5]:
@dataclass
class GPTConfig:
    block_size: int = 256
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_head: int = 2
    batch_size: int = 2
    n_embd: int = 128
    dropout: float = 0.0
    bias: bool = True       # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
    TK_kernel: bool = False # True: use the TK kernel, False: use standard flash attention

config = GPTConfig()

In [6]:
import sys
sys.path.append("../")
from src.model import CausalSelfAttention
from einops import rearrange

def test_attn_causality(
    attn_cls: type,
    config = None,
):
    # set seed
    torch.manual_seed(42)
    layer = attn_cls(config)

    u = torch.randn(config.batch_size, config.n_embd, config.block_size, requires_grad=True)
    y = layer(u.transpose(1,2)).permute(0, 2, 1)  # batch, dim, l
    print(y.shape)

    # gradients from the future must be approx. zero
    for i in tqdm(range(config.block_size - 1)):
        for j in range(config.n_embd):
            g = grad(y[0, j, i], u, retain_graph=True, allow_unused=True)[0]
            check = torch.abs(g[0, :, i + 1 :])
            assert (
                torch.max(check) < 1e-7
            ), "function is not causal at position {}, {}".format(i, check)
    print(f"Passed causality at l={config.block_size}, d_model={config.n_embd}, batch_size={config.batch_size}!\n\n")

test_attn_causality(
    CausalSelfAttention,
    config,
)

Running flash attention
torch.Size([2, 128, 256])


  0%|          | 0/255 [00:00<?, ?it/s]

100%|██████████| 255/255 [01:17<00:00,  3.28it/s]

Passed causality at l=256, d_model=128, batch_size=2!







In [13]:
import sys
sys.path.append("../")
from src.custom_model import CustomAttention
from einops import rearrange

def test_attn_causality(
    attn_cls: type,
    config = None,
):
    # set seed
    torch.manual_seed(42)
    layer = attn_cls(config).cuda().to(dtype=torch.bfloat16)
    u = torch.randn(config.batch_size, config.n_embd, config.block_size, requires_grad=True).cuda().to(dtype=torch.bfloat16)
    y = layer(u.transpose(1,2)).permute(0, 2, 1)  # batch, dim, l
    print(y.shape) 
    print("bs: %d, h: %d, n: %d, d: %d" % (config.batch_size, config.n_head, config.block_size, config.n_embd))

    # gradients from the future must be approx. zero
    for i in tqdm(range(config.block_size - 1)):
        for j in range(config.n_embd):
            g = grad(y[0, j, i], u, retain_graph=True, allow_unused=True)[0]
            check = torch.abs(g[0, :, i + 1 :])
            assert (
                torch.max(check) < 1e-7
            ), "function is not causal at position {}, {}".format(i, check)
    print(f"Passed causality at l={config.block_size}, d_model={config.n_embd}, batch_size={config.batch_size}!\n\n")

test_attn_causality(
    CustomAttention,
    config,
)

torch.Size([2, 128, 256])
bs: 2, h: 2, n: 256, d: 128


  0%|          | 0/255 [00:00<?, ?it/s]

100%|██████████| 255/255 [00:20<00:00, 12.45it/s]

Passed causality at l=256, d_model=128, batch_size=2!





