In [1]:
import torch
import torchinfo

import sys; sys.path.append('..')
from language_models import TransformerLM, configure_optimizers
from dual_attention_transformer import DualAttnTransformerLM, configure_optimizers
from dual_attention_transformer_old import DualAttnTransformerLM as DualAttnTransformerLM_old
import time

import os
os.environ["TORCH_LOGS"] = "+dynamo"
os.environ["TORCHDYNAMO_VERBOSE"] = "1"

import torch._dynamo
torch._dynamo.config.suppress_errors = True

In [2]:
compile_model = False
fused_optim = False
use_bfloat16 = False
use_tf32_matmul = False
# use_flash_attention = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_type = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
import torch

# Check if CUDA is available
cuda_available = torch.cuda.is_available()

if cuda_available:
    # Get the number of available GPUs
    num_gpus = torch.cuda.device_count()

    # Get the type of each GPU
    gpu_types = [torch.cuda.get_device_name(i) for i in range(num_gpus)]

    print("CUDA is available")
    print(f"Number of available GPUs: {num_gpus}")
    print("GPU Types:")
    for i, gpu_type in enumerate(gpu_types):
        print(f"GPU {i}: {gpu_type}")
else:
    print("CUDA is not available")

CUDA is available
Number of available GPUs: 1
GPU Types:
GPU 0: NVIDIA H100 80GB HBM3


In [4]:
import tiktoken
data_dir = '../data/shakespeare.txt'


class DataLoaderLite:
    def __init__(self, B, T):
        self.B = B
        self.T = T

        # at init load tokens from disk and store them in memory
        with open(data_dir, 'r') as f:
            text = f.read()
        enc = tiktoken.get_encoding('gpt2')
        tokens = enc.encode(text)
        self.tokens = torch.tensor(tokens)
        print(f"loaded {len(self.tokens)} tokens")
        print(f"1 epoch = {len(self.tokens) // (B * T)} batches")

        # state
        self.current_position = 0

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance the position in the tensor
        self.current_position += B * T
        # if loading the next batch would be out of bounds, reset
        if self.current_position + (B * T + 1) > len(self.tokens):
            self.current_position = 0
        return x, y

# train_loader = DataLoaderLite(B=16, T=1024)
train_loader = DataLoaderLite(B=8, T=256)

x, y = train_loader.next_batch()
x, y = x.to(device), y.to(device)

loaded 338024 tokens
1 epoch = 165 batches


## Parameter Count Scaling & Memory Usage

In [5]:
vocab_size = 50304 # 50257
def create_model():
    model_config = dict(vocab_size=vocab_size, d_model=768, n_layers=12, n_heads=12, dff=None, activation='relu',
        dropout_rate=0., norm_first=True, norm_type='layernorm', max_block_size=1024, bias=False, pos_enc_type='pos_emb',
        block_kwargs=dict(attn_kwargs=dict(n_kv_heads=1))) # MQA
    model = TransformerLM(**model_config)
    return model
model = create_model()
model = model.to(device)
torchinfo.summary(model, input_data=(x,y))

Layer (type:depth-idx)                        Output Shape              Param #
TransformerLM                                 [8, 256, 50304]           --
├─ModuleDict: 1-1                             --                        --
│    └─Embedding: 2-1                         [8, 256, 768]             38,633,472
│    └─Embedding: 2-2                         [256, 768]                786,432
│    └─ModuleList: 2-3                        --                        --
│    │    └─EncoderBlock: 3-1                 [8, 256, 768]             5,999,616
│    │    └─EncoderBlock: 3-2                 [8, 256, 768]             5,999,616
│    │    └─EncoderBlock: 3-3                 [8, 256, 768]             5,999,616
│    │    └─EncoderBlock: 3-4                 [8, 256, 768]             5,999,616
│    │    └─EncoderBlock: 3-5                 [8, 256, 768]             5,999,616
│    │    └─EncoderBlock: 3-6                 [8, 256, 768]             5,999,616
│    │    └─EncoderBlock: 3-7           

In [6]:
vocab_size = 50304 # 50257
def create_model():
    model_config = dict(vocab_size=vocab_size, d_model=768, n_layers=12, n_heads=12, dff=None, activation='relu',
        dropout_rate=0., norm_first=True, norm_type='layernorm', max_block_size=1024, bias=False, pos_enc_type='pos_emb')
    model = TransformerLM(**model_config)
    return model
model = create_model()
model = model.to(device)
torchinfo.summary(model, input_data=(x,y))

Layer (type:depth-idx)                        Output Shape              Param #
TransformerLM                                 [8, 256, 50304]           --
├─ModuleDict: 1-1                             --                        --
│    └─Embedding: 2-1                         [8, 256, 768]             38,633,472
│    └─Embedding: 2-2                         [256, 768]                786,432
│    └─ModuleList: 2-3                        --                        --
│    │    └─EncoderBlock: 3-1                 [8, 256, 768]             7,080,960
│    │    └─EncoderBlock: 3-2                 [8, 256, 768]             7,080,960
│    │    └─EncoderBlock: 3-3                 [8, 256, 768]             7,080,960
│    │    └─EncoderBlock: 3-4                 [8, 256, 768]             7,080,960
│    │    └─EncoderBlock: 3-5                 [8, 256, 768]             7,080,960
│    │    └─EncoderBlock: 3-6                 [8, 256, 768]             7,080,960
│    │    └─EncoderBlock: 3-7           

In [7]:
vocab_size = 50304 # 50257
def create_transformer_model(use_flash_attention=False):
    model_config = dict(vocab_size=vocab_size, d_model=768, n_layers=12, n_heads=12, dff=None, activation='swiglu', # swiglu activation
        dropout_rate=0., norm_first=True, norm_type='layernorm', max_block_size=1024, bias=False, pos_enc_type='pos_emb')
    model = TransformerLM(**model_config, use_flash_attention=use_flash_attention)
    return model
model = create_transformer_model()
model = model.to(device)
torchinfo.summary(model, input_data=(x,y))

Layer (type:depth-idx)                        Output Shape              Param #
TransformerLM                                 [8, 256, 50304]           --
├─ModuleDict: 1-1                             --                        --
│    └─Embedding: 2-1                         [8, 256, 768]             38,633,472
│    └─Embedding: 2-2                         [256, 768]                786,432
│    └─ModuleList: 2-3                        --                        --
│    │    └─EncoderBlock: 3-1                 [8, 256, 768]             7,080,960
│    │    └─EncoderBlock: 3-2                 [8, 256, 768]             7,080,960
│    │    └─EncoderBlock: 3-3                 [8, 256, 768]             7,080,960
│    │    └─EncoderBlock: 3-4                 [8, 256, 768]             7,080,960
│    │    └─EncoderBlock: 3-5                 [8, 256, 768]             7,080,960
│    │    └─EncoderBlock: 3-6                 [8, 256, 768]             7,080,960
│    │    └─EncoderBlock: 3-7           

In [8]:
vocab_size = 50304 # 50257
def create_dat_model(use_flash_attention=False):
    model_config = dict(vocab_size=vocab_size, d_model=768, n_layers=12, n_heads_sa=6, n_heads_ra=6, dff=None, activation='swiglu',
        symbol_retrieval='symbolic_attention', symbol_retrieval_kwargs=dict(d_model=768, n_heads=4, n_symbols=512),
        dropout_rate=0., norm_first=True, norm_type='layernorm', max_block_size=1024, bias=False, pos_enc_type='pos_emb',
        sa_kwargs=dict(n_kv_heads=None), ra_kwargs=dict(n_kv_heads=1))
    model = DualAttnTransformerLM(**model_config) # DAT
    return model
model = create_dat_model()
model = model.to(device)
torchinfo.summary(model, input_data=(x,y))

Layer (type:depth-idx)                                  Output Shape              Param #
DualAttnTransformerLM                                   [8, 256, 50304]           --
├─ModuleDict: 1-1                                       --                        --
│    └─Embedding: 2-1                                   [8, 256, 768]             38,633,472
│    └─Embedding: 2-2                                   [256, 768]                786,432
│    └─SymbolicAttention: 2-3                           [8, 256, 768]             786,432
│    │    └─Linear: 3-1                                 [8, 256, 768]             590,592
│    └─ModuleList: 2-26                                 --                        (recursive)
│    │    └─DualAttnEncoderBlock: 3-2                   [8, 256, 768]             6,891,520
│    └─SymbolicAttention: 2-5                           [8, 256, 768]             (recursive)
│    │    └─Linear: 3-3                                 [8, 256, 768]             (recursive)
│  

In [9]:
vocab_size = 50304 # 50257
def create_dat_old_model(use_flash_attention=False):
    model_config = dict(vocab_size=vocab_size, d_model=768, n_layers=12, n_heads_sa=6, n_heads_ra=6, dff=None, activation='swiglu',
        symbol_retrieval='symbolic_attention', symbol_retrieval_kwargs=dict(d_model=768, n_heads=4, n_symbols=512),
        dropout_rate=0., norm_first=True, norm_type='layernorm', max_block_size=1024, bias=False, pos_enc_type='pos_emb',
        sa_kwargs=dict(n_kv_heads=None), ra_kwargs=dict(n_kv_heads=1))
    model = DualAttnTransformerLM_old(**model_config) # DAT
    return model
model = create_dat_old_model()
model = model.to(device)
torchinfo.summary(model, input_data=(x,y))

Layer (type:depth-idx)                                  Output Shape              Param #
DualAttnTransformerLM                                   [8, 256, 50304]           --
├─ModuleDict: 1-1                                       --                        --
│    └─Embedding: 2-1                                   [8, 256, 768]             38,633,472
│    └─Embedding: 2-2                                   [256, 768]                786,432
│    └─SymbolicAttention: 2-3                           [8, 256, 768]             786,432
│    │    └─Linear: 3-1                                 [8, 256, 768]             590,592
│    └─ModuleList: 2-26                                 --                        (recursive)
│    │    └─DualAttnEncoderBlock: 3-2                   [8, 256, 768]             6,891,520
│    └─SymbolicAttention: 2-5                           [8, 256, 768]             (recursive)
│    │    └─Linear: 3-3                                 [8, 256, 768]             (recursive)
│  

NOTE: before new implementation of relational attention, a 162M DAT has "Forward/backward pass size (MB): 22216.70". This is compared to a 163M Transformer with only Forward/backward pass size (MB): 2713.19". This is almost a 10X difference. Here, $T = 256$ only. It would be much larger if $T$ was larger (quadratically).

By comparison, with the new implementation, the 162M DAT has a "Forward/backward pass size" of only "Forward/backward pass size (MB): 2889.35". Only slightly larger than a Transformer!

Total # of Mult-Adds (i.e., # of operations) is roughly the same in comparably sized Transformer (163M; Total mult-adds (Units.GIGABYTES): 1.50) vs DAT (162M; Total mult-adds (Units.GIGABYTES): 1.54). 

## Evaluate Training Speed

In [10]:
def eval_train_speed(model, n_steps=50, fused_optim=fused_optim, use_bfloat16=use_bfloat16, verbose=True):
    optimizer = optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, fused=fused_optim)

    dts = []
    toks_per_sec = []

    for i in range(n_steps):
        t0 = time.time()
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        if use_bfloat16:
            with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
                logits, loss = model(x, y)
        else:
            logits, loss = model(x, y)

        loss.backward()
        optimizer.step()
        torch.cuda.synchronize() # wait for the GPU to finish work
        t1 = time.time()
        dt = t1 - t0 # time difference in seconds
        tokens_processed = train_loader.B * train_loader.T
        tokens_per_sec = tokens_processed / dt
        dts.append(dt)
        toks_per_sec.append(tokens_per_sec)

        if verbose:
            print(f"step {i:4d} | loss: {loss.item():.6f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")

    return dts, toks_per_sec

In [11]:
def eval_setup_speed(
        create_model, use_flash_attention=False, use_tf32_matmul=False, compile_model=False, fused_optim=False, use_bfloat16=False, verbose=False, n_steps=26):
    print('='*100)
    print(f'use_flash_attention: {use_flash_attention}, use_tf32_matmul: {use_tf32_matmul}, compile_model: {compile_model}, fused_optim: {fused_optim}, use_bfloat16: {use_bfloat16}')
    print('='*100)

    model = create_model(use_flash_attention=use_flash_attention) # use_flash_attention is ignored for DAT
    model = model.to(device)

    if use_tf32_matmul:
        torch.set_float32_matmul_precision('high')

    if compile_model:
        model = torch.compile(model)

    dts, tokens_per_sec = eval_train_speed(model, n_steps=n_steps, fused_optim=fused_optim, use_bfloat16=use_bfloat16, verbose=verbose)
    print('*'*100)
    mean_dt = sum(dts[1:])/len(dts[1:])*1000 # exclude first step cuz of compilation overhead
    mean_toks_per_sec = sum(tokens_per_sec[1:])/len(tokens_per_sec[1:])
    print(f"mean dt: {sum(dts[1:])/len(dts[1:])*1000:.2f}ms | mean tok/sec: {sum(tokens_per_sec)/len(tokens_per_sec):.0f}")
    print('*'*100)
    print('='*100)
    return mean_dt, mean_toks_per_sec


In [12]:
results = []
setups = [
    dict(use_flash_attention=False, use_tf32_matmul=False, compile_model=False, fused_optim=False, use_bfloat16=False),
    dict(use_flash_attention=False, use_tf32_matmul=True, compile_model=False, fused_optim=False, use_bfloat16=False),
    dict(use_flash_attention=False, use_tf32_matmul=True, compile_model=False, fused_optim=False, use_bfloat16=True),
    dict(use_flash_attention=False, use_tf32_matmul=True, compile_model=True, fused_optim=False, use_bfloat16=True),
    dict(use_flash_attention=False, use_tf32_matmul=True, compile_model=True, fused_optim=True, use_bfloat16=True),
    dict(use_flash_attention=True, use_tf32_matmul=True, compile_model=True, fused_optim=True, use_bfloat16=True),
]

### B = 4, T = 256

In [13]:
train_loader = DataLoaderLite(B=4, T=256)
# train_loader = DataLoaderLite(B=8, T=256)
# train_loader = DataLoaderLite(B=16, T=1024)
# train_loader = DataLoaderLite(B=2, T=1024)

loaded 338024 tokens
1 epoch = 330 batches


In [14]:
# results for Transformer-LM
for setup in setups:
    mean_dt, mean_toks_per_sec = eval_setup_speed(create_model=create_transformer_model, **setup, verbose=False)
    print()
    results.append((setup, mean_dt, mean_toks_per_sec))

use_flash_attention: False, use_tf32_matmul: False, compile_model: False, fused_optim: False, use_bfloat16: False
****************************************************************************************************
mean dt: 30.58ms | mean tok/sec: 32381
****************************************************************************************************

use_flash_attention: False, use_tf32_matmul: True, compile_model: False, fused_optim: False, use_bfloat16: False
****************************************************************************************************
mean dt: 17.23ms | mean tok/sec: 57394
****************************************************************************************************

use_flash_attention: False, use_tf32_matmul: True, compile_model: False, fused_optim: False, use_bfloat16: True
****************************************************************************************************
mean dt: 17.28ms | mean tok/sec: 57217
*************************************

In [15]:
# results for DAT-Old-LM
for setup in setups:
    mean_dt, mean_toks_per_sec = eval_setup_speed(create_dat_old_model, **setup, verbose=False)
    print()
    results.append((setup, mean_dt, mean_toks_per_sec))

use_flash_attention: False, use_tf32_matmul: False, compile_model: False, fused_optim: False, use_bfloat16: False


****************************************************************************************************
mean dt: 52.56ms | mean tok/sec: 18957
****************************************************************************************************

use_flash_attention: False, use_tf32_matmul: True, compile_model: False, fused_optim: False, use_bfloat16: False
****************************************************************************************************
mean dt: 52.51ms | mean tok/sec: 19483
****************************************************************************************************

use_flash_attention: False, use_tf32_matmul: True, compile_model: False, fused_optim: False, use_bfloat16: True
****************************************************************************************************
mean dt: 41.88ms | mean tok/sec: 23708
****************************************************************************************************

use_flash_attention: False, use_tf32_matmul: True

In [16]:
# results for DAT-LM
for setup in setups:
    mean_dt, mean_toks_per_sec = eval_setup_speed(create_dat_model, **setup, verbose=False)
    print()
    results.append((setup, mean_dt, mean_toks_per_sec))

use_flash_attention: False, use_tf32_matmul: False, compile_model: False, fused_optim: False, use_bfloat16: False
****************************************************************************************************
mean dt: 34.10ms | mean tok/sec: 29844
****************************************************************************************************

use_flash_attention: False, use_tf32_matmul: True, compile_model: False, fused_optim: False, use_bfloat16: False
****************************************************************************************************
mean dt: 34.14ms | mean tok/sec: 29833
****************************************************************************************************

use_flash_attention: False, use_tf32_matmul: True, compile_model: False, fused_optim: False, use_bfloat16: True
****************************************************************************************************
mean dt: 38.68ms | mean tok/sec: 26336
*************************************

### B = 16, T = 1024

In [14]:
# train_loader = DataLoaderLite(B=4, T=256)
# train_loader = DataLoaderLite(B=8, T=256)
train_loader = DataLoaderLite(B=16, T=1024)
# train_loader = DataLoaderLite(B=2, T=1024)

loaded 338024 tokens
1 epoch = 20 batches


In [18]:
torch.cuda.empty_cache()
# results for Transformer-LM
for setup in setups:
    mean_dt, mean_toks_per_sec = eval_setup_speed(create_model=create_transformer_model, **setup, verbose=False, n_steps=26)
    print()
    results.append((setup, mean_dt, mean_toks_per_sec))

use_flash_attention: False, use_tf32_matmul: False, compile_model: False, fused_optim: False, use_bfloat16: False


****************************************************************************************************
mean dt: 163.40ms | mean tok/sec: 100176
****************************************************************************************************

use_flash_attention: False, use_tf32_matmul: True, compile_model: False, fused_optim: False, use_bfloat16: False
****************************************************************************************************
mean dt: 163.39ms | mean tok/sec: 100258
****************************************************************************************************

use_flash_attention: False, use_tf32_matmul: True, compile_model: False, fused_optim: False, use_bfloat16: True
****************************************************************************************************
mean dt: 144.45ms | mean tok/sec: 113393
****************************************************************************************************

use_flash_attention: False, use_tf32_matmul

In [19]:
# runs out of memory because old implementation is too memory hungry
torch.cuda.empty_cache()
# results for DAT-Old-LM
for setup in setups:
    mean_dt, mean_toks_per_sec = eval_setup_speed(create_dat_old_model, **setup, verbose=False, n_steps=26)
    print()
    results.append((setup, mean_dt, mean_toks_per_sec))

use_flash_attention: False, use_tf32_matmul: False, compile_model: False, fused_optim: False, use_bfloat16: False


OutOfMemoryError: CUDA out of memory. Tried to allocate 24.00 GiB. GPU 0 has a total capacity of 79.10 GiB of which 10.30 GiB is free. Including non-PyTorch memory, this process has 68.79 GiB memory in use. Of the allocated memory 66.59 GiB is allocated by PyTorch, and 1.48 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [15]:

# results for DAT-LM
for setup in setups:
    mean_dt, mean_toks_per_sec = eval_setup_speed(create_dat_model, **setup, verbose=False, n_steps=26)
    print()
    results.append((setup, mean_dt, mean_toks_per_sec))

use_flash_attention: False, use_tf32_matmul: False, compile_model: False, fused_optim: False, use_bfloat16: False
****************************************************************************************************
mean dt: 280.18ms | mean tok/sec: 58425
****************************************************************************************************

use_flash_attention: False, use_tf32_matmul: True, compile_model: False, fused_optim: False, use_bfloat16: False
****************************************************************************************************
mean dt: 280.32ms | mean tok/sec: 58441
****************************************************************************************************

use_flash_attention: False, use_tf32_matmul: True, compile_model: False, fused_optim: False, use_bfloat16: True
****************************************************************************************************
mean dt: 277.55ms | mean tok/sec: 59029
**********************************

Seems compile without flash attention is 145ms while flash attention without compile is 124ms. not too different. (with all other optimizations turned on)

using tf32 or bfloat16 makes big difference. with both turned on and no flash attention or compilation, avg time per step is 257ms. This is down from 1033ms without tf32 or bfloat16.

with compile + flash attention but no tf32 or bfloat16, the avg time per step is 246ms.

With both, all optimizations, we are down to 105ms.

With all optimizations, RoPE is the same as positional embeddings (105 ms)

With all optimizations, RMSNorm is the marginally faster than LayerNorm (103.47ms)

With all optimizations, SwiGLU activation is 121ms.

norm_first = False or is marginally faster (for some reason) than pre-LN (101.80ms)

In [None]:
# more to check speed
# effect of rmsnorm vs layernorm
# effect of bias vs no bias
# effect of pos emb vs RoPE
# effect of activation
# effect of MQA vs MHA
# clean custom loop vs pytorch lightning

# how much slower is DAT for comparable configs that work for both (e.g., no compile but tf32 matmul and bfloat16, etc.)