In [1]:
import torch
from models.sashimi.sashimi_standalone import Sashimi
from tqdm import tqdm

CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) not found. Install by going to extensions/kernels/ and running `python setup.py install`, for improved speed and memory efficiency. Note that the kernel changed for state-spaces 4.0 and must be recompiled.
Falling back on slow Cauchy and Vandermonde kernel. Install at least one of pykeops or the CUDA extension for better speed and memory efficiency.


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

d_model = 64
n_layers = 8
expand = 2
ff = 2
dropout = 0.0
pool = [4, 4]

batch_size = 32
data_dim = 1
seq_len = 2048

Using device: cpu


In [4]:
model = Sashimi(
    d_model=d_model,
    n_layers=n_layers,
    pool=pool,
    expand=expand,
    ff=ff,
    bidirectional=False,
    unet=False,
    dropout=dropout,
).to(device)

In [5]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Total number of parameters: {pytorch_total_params}')
print(f'Trainable parameters: {trainable_params}')

Total number of parameters: 5414656
Trainable parameters: 5414656


In [6]:
x = torch.randn(batch_size, seq_len, d_model).to(device)
print(f'Input shape: {x.shape}')

Input shape: torch.Size([32, 2048, 64])


In [9]:
state = model.default_state()
out, state = model(x, state)

In [10]:
print(f'Output shape: {out.shape}')

torch.Size([32, 2048, 64])
