In [1]:
import torch
from fm.scan.fm_scan import fm_scan, fm_memory # c++ cuda kernel

Using /home/lexion/.cache/torch_extensions/py38_cu121 as PyTorch extensions root...
Creating extension directory /home/lexion/.cache/torch_extensions/py38_cu121/fm_scan...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/lexion/.cache/torch_extensions/py38_cu121/fm_scan/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module fm_scan...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module fm_scan...


## Usage

In [None]:
batch_size, dim, seqlen, mem = 1, 2048, 8192, 32

# testing just the scan kernel
gates = 0.9 + 0.1 * torch.rand(batch_size, mem, seqlen, device="cuda", dtype=torch.bfloat16)
tokens = torch.rand(batch_size, dim, seqlen, device="cuda", dtype=torch.bfloat16)
memory_states = fm_scan(gates, tokens, initial_state=None)
print(memory_states.shape)

# testing the memory layer
alpha = torch.nn.functional.softmax(torch.rand(batch_size, seqlen, mem, device="cuda", dtype=torch.bfloat16, requires_grad=True), dim=-1)
update_scale = torch.rand(batch_size, seqlen, 1, device="cuda", dtype=torch.bfloat16, requires_grad=True) + 0.001
output_scale = torch.rand(batch_size, seqlen, 1, device="cuda", dtype=torch.bfloat16, requires_grad=True) + 0.001
inputs = torch.rand(batch_size, seqlen, dim, device="cuda", dtype=torch.bfloat16, requires_grad=True)

memory_output, memory_states = fm_memory(alpha, update_scale, output_scale, inputs, initial_state=None, mem_norm=True, norm_eps=1e-6)
print(memory_states.shape)
print(memory_output.shape)

## Correctness test

In [None]:
# check with pytorch
from fm.scan.fm_pytorch import fm_scan_pytorch

batch_size, dim, seqlen, mem = 1, 5, 1024, 1
gates = 0.5 + 0.5 * torch.rand(batch_size, mem, seqlen, device="cuda", dtype=torch.bfloat16)
tokens = torch.rand(batch_size, dim, seqlen, device="cuda", dtype=torch.bfloat16)

tokens_cuda = tokens.clone().detach().requires_grad_()
gates_cuda = gates.clone().detach().requires_grad_()
tokens_pytorch = tokens.clone().detach().requires_grad_()
gates_pytorch = gates.clone().detach().requires_grad_()

# cuda val
memory_states = fm_scan(gates_cuda, tokens_cuda, initial_state=None)
loss_cuda = memory_states.sum()
loss_cuda.backward()

gates_cuda32 = gates_cuda.clone().to(torch.float32)
tokens_cuda32 = tokens_cuda.clone().to(torch.float32)

memory_states32 = fm_scan(gates_cuda, tokens_cuda, initial_state=None)
loss_cuda32 = memory_states32.sum()
loss_cuda32.backward()

# pytorch val
memory_states_pytorch = fm_scan_pytorch(gates_pytorch, tokens_pytorch, initial_state=None)
loss_pytorch = memory_states_pytorch.sum()
loss_pytorch.backward()

print(torch.mean(torch.abs(memory_states - memory_states_pytorch)))
print(torch.mean(torch.abs(memory_states32 - memory_states)))
print(torch.mean(torch.abs(tokens_cuda.grad - tokens_pytorch.grad)))
print(torch.mean(torch.abs(gates_cuda.grad - gates_pytorch.grad)))

tensor(0.0012, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
tensor(0., device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
tensor(0.9961, device='cuda:0', dtype=torch.bfloat16)
tensor(2.3125, device='cuda:0', dtype=torch.bfloat16)
