In [None]:
from decoder import TransformerConfig, RLTransformer
import torch

In [None]:
config = TransformerConfig(max_seq_len=100,
                  attention='causal',
                  attention_impl='scaled_dot',  # original
                  num_layers=3,
                  num_heads=4,
                  embed_dim=128,
                  embed_pdrop=0.0, 
                  resid_pdrop=0.0, 
                  attn_pdrop=0.0,)

In [None]:
transformer = RLTransformer(config)
transformer

In [None]:
output, new_kv = transformer(torch.rand(3, 73, 128))

In [None]:
# ----- test no randomness in the forward -----
transformer = RLTransformer(config)
x = torch.rand(3, 73, 128)
output1, kv1 = transformer(x)
output2, kv2 = transformer(x)
assert torch.allclose(output1, output2)
assert torch.allclose(kv1['k'], kv2['k'])
assert torch.allclose(kv1['v'], kv2['v'])

In [None]:
output1

In [None]:
# ----- test the kv cache vs no kv cache, should be the same -----
# no kv cache
transformer = RLTransformer(config)
B = 17 # batch size
T = 33 # sequence length
D = 128 # embedding dimension
x = torch.rand(B, T, D)
output1, kv1 = transformer(x) # output1: (B, T, D)

# with kv cache
kv_cache = transformer.generate_empty_keys_values(n=B, max_tokens=T)
# feed the x one by one
output2s = []
for t in range(x.shape[1]):
    past_ks = torch.stack([layer_cache.get()[0] for layer_cache in kv_cache._keys_values], dim=0)
    past_vs = torch.stack([layer_cache.get()[1] for layer_cache in kv_cache._keys_values], dim=0)
    output2, this_kvs = transformer(x[:, t:t+1], past_keys_values=(past_ks, past_vs)) # output2: (B, 1, D)
    assert this_kvs['k'].shape == (len(transformer.blocks), B, config.num_heads, 1, config.embed_dim // config.num_heads)
    assert this_kvs['v'].shape == (len(transformer.blocks), B, config.num_heads, 1, config.embed_dim // config.num_heads)

    # update the kv_caches
    for j, layer_cache in enumerate(kv_cache._keys_values):  # looping over caches for layers
        assert layer_cache._k_cache.shape == (B, config.num_heads, t, config.embed_dim // config.num_heads)
        assert layer_cache._v_cache.shape == (B, config.num_heads, t, config.embed_dim // config.num_heads)
        this_layer_k = this_kvs['k'][j]
        this_layer_v = this_kvs['v'][j]
        layer_cache.update(this_layer_k, this_layer_v)
    
    output2s.append(output2)

output2 = torch.cat(output2s, dim=1)
# check the equivalent of the outputs
print(f"mean error: {torch.mean(torch.abs(output1 - output2))}, shapes: {output1.shape}, {output2.shape}")
assert torch.allclose(output1, output2, atol=1e-6)
# check the equivalent of the kv
past_ks = torch.stack([layer_cache.get()[0] for layer_cache in kv_cache._keys_values], dim=0)
past_vs = torch.stack([layer_cache.get()[1] for layer_cache in kv_cache._keys_values], dim=0)
print(f"mean error: {torch.mean(torch.abs(kv1['k'] - past_ks))}, shapes: {kv1['k'].shape}, {past_ks.shape}")
print(f"mean error: {torch.mean(torch.abs(kv1['v'] - past_vs))}, shapes: {kv1['v'].shape}, {past_vs.shape}")
assert torch.allclose(kv1['k'], past_ks, atol=1e-6)
assert torch.allclose(kv1['v'], past_vs, atol=1e-6)

In [None]:
# ----- test the vmapped function forward with kv cache and no kv cache -----
import copy
# test the vmapped forward method
from torch.func import functional_call, stack_module_state
def call_models_forward(models, xs, past_ks_vs=None):
    # Stack all agent parameters
    params, buffers = stack_module_state(models)
    base_model = copy.deepcopy(models[0])
    base_model = base_model.to('meta')

    def fmodel(params, buffers, x, past_ks_vs=None):
        batched = False
        return functional_call(base_model, (params, buffers), (x, past_ks_vs, batched))

    if past_ks_vs is not None:
        vmap_fmodel = torch.vmap(fmodel, in_dims=(0, 0, 0, 0), randomness='different')(
            params, buffers, xs, past_ks_vs)
    else:
        vmap_fmodel = torch.vmap(fmodel, in_dims=(0, 0, 0), randomness='different')(
            params, buffers, xs)

    return vmap_fmodel
# no kv cache and no vmapped forward
print("-"*50)
print("no kv cache and no vmapped forward")
transformer = RLTransformer(config)
B = 17 # batch size
T = 33 # sequence length
D = 128 # embedding dimension
xs = torch.rand(B, T, D)
output1, kv1 = transformer(xs)

# no kv cache but with vmapped forward
print("-"*50)
print("no kv cache but with vmapped forward")
formers = [transformer for _ in range(B)]
output2, kv2 = call_models_forward(formers, xs.unsqueeze(1)) # unsqueeze to add a fake batch dimension
output2 = output2.squeeze(1) # remove the fake batch dimension, later refactor so that the fake batch dimension is removed already
print(f"shapes of kvs: {kv1['k'].shape}, {kv2['k'].shape}")
kv2['k'] = kv2['k'].transpose(0, 1) # (B, num_layers, num_heads, T, D // num_heads) -> (num_layers, B, num_heads, T, D // num_heads)
kv2['v'] = kv2['v'].transpose(0, 1) # (B, num_layers, num_heads, T, D // num_heads) -> (num_layers, B, num_heads, T, D // num_heads)
print(f"mean error: {torch.mean(torch.abs(output1 - output2))}, shapes: {output1.shape}, {output2.shape}")
assert torch.allclose(output1, output2, atol=1e-6)
print(f"mean error: {torch.mean(torch.abs(kv1['k'] - kv2['k']))}, shapes: {kv1['k'].shape}, {kv2['k'].shape}")
assert torch.allclose(kv1['k'], kv2['k'], atol=1e-6)

# with kv cache and vmapped forward 
print("-"*50)
print("with kv cache and vmapped forward")
kv_cache = formers[0].generate_empty_keys_values(n=B, max_tokens=T)
output3_s = []
for t in range(xs.shape[1]):
    past_ks = torch.stack([layer_cache.get()[0] for layer_cache in kv_cache._keys_values], dim=1) # stack in dimension 1 so vmap works on dimension 0 (batch dimension)
    past_vs = torch.stack([layer_cache.get()[1] for layer_cache in kv_cache._keys_values], dim=1)
    assert past_ks.shape == (B, len(formers[0].blocks), config.num_heads, t, config.embed_dim // config.num_heads)
    assert past_vs.shape == (B, len(formers[0].blocks), config.num_heads, t, config.embed_dim // config.num_heads)
    
    output3, this_kvs = call_models_forward(formers, xs[:, t:t+1].unsqueeze(1), (past_ks, past_vs))
    output3 = output3.squeeze(1) # remove the fake batch dimension so (B, 1, 1, D) -> (B, 1, D); 1 is the sequence length in the second shape
    output3_s.append(output3) 
    assert this_kvs['k'].shape == (B, len(formers[0].blocks), config.num_heads, 1, config.embed_dim // config.num_heads)
    assert this_kvs['v'].shape == (B, len(formers[0].blocks), config.num_heads, 1, config.embed_dim // config.num_heads)
    
    # update the kv_caches
    for j, layer_cache in enumerate(kv_cache._keys_values):  # looping over caches for layers
        this_layer_k = this_kvs['k'][:, j]
        this_layer_v = this_kvs['v'][:, j]
        layer_cache.update(this_layer_k, this_layer_v)
    
output3 = torch.cat(output3_s, dim=1)
# check the equivalent of the outputs
print(f"mean error: {torch.mean(torch.abs(output1 - output3))}, shapes: {output1.shape}, {output3.shape}")
assert torch.allclose(output1, output3, atol=1e-6)
# check the equivalent of the kv
past_ks = torch.stack([layer_cache.get()[0] for layer_cache in kv_cache._keys_values], dim=1)
past_vs = torch.stack([layer_cache.get()[1] for layer_cache in kv_cache._keys_values], dim=1)
past_ks = past_ks.transpose(0, 1) # (B, num_layers, num_heads, T, D // num_heads) -> (num_layers, B, num_heads, T, D // num_heads)
past_vs = past_vs.transpose(0, 1) # (B, num_layers, num_heads, T, D // num_heads) -> (num_layers, B, num_heads, T, D // num_heads)
print(f"mean error: {torch.mean(torch.abs(kv1['k'] - past_ks))}, shapes: {kv1['k'].shape}, {past_ks.shape}")
print(f"mean error: {torch.mean(torch.abs(kv1['v'] - past_vs))}, shapes: {kv1['v'].shape}, {past_vs.shape}")
assert torch.allclose(kv1['k'], past_ks, atol=1e-6)
assert torch.allclose(kv1['v'], past_vs, atol=1e-6)

In [None]:
# test the scaled_dot_product_attention function and compare it with the manual implementation

config_original = TransformerConfig(max_seq_len=100,
                    attention='causal',
                    attention_impl='original',
                    num_layers=3,
                    num_heads=4,
                    embed_dim=128,
                    embed_pdrop=0.0, 
                    resid_pdrop=0.0, 
                    attn_pdrop=0.0,)

transformer_original = RLTransformer(config_original)

config_scaled_dot = TransformerConfig(max_seq_len=100,
                  attention='causal',
                  attention_impl='scaled_dot',
                  num_layers=3,
                  num_heads=4,
                  embed_dim=128,
                  embed_pdrop=0.0, 
                  resid_pdrop=0.0, 
                  attn_pdrop=0.0,)

transformer_scaled_dot = RLTransformer(config_scaled_dot)
# use state_dict of original
transformer_scaled_dot.load_state_dict(transformer_original.state_dict())

# original 
x = torch.rand(3, 73, 128)
output_original, kv_original = transformer_original(x)

# scaled dot
output_scaled_dot, kv_scaled_dot = transformer_scaled_dot(x)

print(f"mean error: {torch.mean(torch.abs(output_original - output_scaled_dot))}, shapes: {output_original.shape}, {output_scaled_dot.shape}")
assert torch.allclose(output_original, output_scaled_dot, atol=1e-6)
print(f"mean error: {torch.mean(torch.abs(kv_original['k'] - kv_scaled_dot['k']))}, shapes: {kv_original['k'].shape}, {kv_scaled_dot['k'].shape}")
assert torch.allclose(kv_original['k'], kv_scaled_dot['k'], atol=1e-6)
print(f"mean error: {torch.mean(torch.abs(kv_original['v'] - kv_scaled_dot['v']))}, shapes: {kv_original['v'].shape}, {kv_scaled_dot['v'].shape}")
assert torch.allclose(kv_original['v'], kv_scaled_dot['v'], atol=1e-6)