In [None]:
%load_ext autoreload
%autoreload 2

# pLSTM-1D

In [None]:
from plstm.nnx.plstm_1d import pLSTM1D_jax
from plstm.util import log2
import torch
import jax.numpy as jnp
import jax
from functools import partial

In [None]:
DEVICE = "cuda"
B = 64
T = 1024
BT = 256
DHQK = 128
DHHV = 128
JQ = 1
JT = 1
JV = 1
JK = 1
JO = 1
DTYPE = torch.float32

rand_factor = 0.0
Q = 1.0 + rand_factor * torch.randn([B, T, DHQK, JQ], dtype=DTYPE, device=DEVICE) / DHQK
K = 1.0 + rand_factor * torch.randn([B, T, DHQK, JK], dtype=DTYPE, device=DEVICE)
V = torch.randn([B, T, DHHV, JV], dtype=DTYPE, device=DEVICE) + 0.1 * torch.arange(JV * DHHV).reshape(
    1, 1, DHHV, JV
).to(device=DEVICE)
S0 = 0.1 + rand_factor * torch.randn([B, T, JT, JK, JV], dtype=DTYPE, device=DEVICE)
T0 = torch.eye(JT)[None, None, :, :].to(device=DEVICE) * torch.ones([B, T, 1, 1], dtype=DTYPE, device=DEVICE)
T0 = T0 + 0.01 * rand_factor * torch.randn_like(T0)
M0 = 1.0 + rand_factor * torch.randn([B, T, JO, JQ, JT], dtype=DTYPE, device=DEVICE)
D0 = 1.0 + rand_factor * torch.randn([B, T, JO, JQ, JK, JV], dtype=DTYPE, device=DEVICE)
C_initial = 0.0 * torch.randn([B, DHQK, DHHV, JT], dtype=DTYPE, device=DEVICE)

print(Q.shape)

Q.requires_grad_(True)
K.requires_grad_(True)
V.requires_grad_(True)
S0.requires_grad_(True)
T0.requires_grad_(True)
M0.requires_grad_(True)


S0mag = (
    0.01
    # + 0.1* rand_factor * torch.randn([B, T])
    - 0.1 * torch.arange(T)[None, :]
)
T0mag = (
    -0.01 + 0.0 * 0.1 * rand_factor * 0.01 * torch.randn([B, T])
    # + 0.01 * torch.arange(T)[None, :]
)

_ = ""
# Y = pLSTM1D_fwbw(Q, K, V, S0, M0, T0, chunk_size=16)
# Y = Y.reshape(B, T, DHHV * JQ)[0, :, 0]

# Y.sum().backward()

In [None]:
%timeit pLSTM1D_fwbw(Q, K, V, S0, T0, M0, D0, levels=log2(BT))

In [None]:
Q_jnp = jnp.array(Q.detach().cpu().numpy())
K_jnp = jnp.array(K.detach().cpu().numpy())
V_jnp = jnp.array(V.detach().cpu().numpy())
S0_jnp = jnp.array(S0.detach().cpu().numpy())
T0_jnp = jnp.array(T0.detach().cpu().numpy())
M0_jnp = jnp.array(M0.detach().cpu().numpy())
D0_jnp = jnp.array(D0.detach().cpu().numpy())

Q_jnp = jax.device_put(Q_jnp)
K_jnp = jax.device_put(K_jnp)
V_jnp = jax.device_put(V_jnp)
S0_jnp = jax.device_put(S0_jnp)
T0_jnp = jax.device_put(T0_jnp)
M0_jnp = jax.device_put(M0_jnp)
D0_jnp = jax.device_put(D0_jnp)

In [None]:
# pLSTM1D_jax = jax.jit(partial(pLSTM1D_jax, levels=log2(BT)))

In [None]:
pLSTM1D_jax(Q_jnp, K_jnp, V_jnp, S0_jnp, T0_jnp, M0_jnp, D0_jnp, levels=log2(BT))

In [None]:
%timeit pLSTM1D_jax(Q_jnp, K_jnp, V_jnp, S0_jnp, T0_jnp, M0_jnp, D0_jnp, levels=log2(BT))

# pLSTM 2D

In [None]:
from plstm.nnx.plstm_2d import pLSTM2D_jax
from plstm.nnx.plstm_1d import pLSTM1D_jax
from plstm.util import log2
import torch
import jax.numpy as jnp
import jax

In [None]:
DB, MX, MY, DHQK, DHHV, JQ, JK, JV, JT, JO = (32, 32, 32, 128, 128, 1, 1, 1, 1, 1)

Q, K, V = (
    0.1 * jnp.ones([DB, MX, MY, DHQK, JQ]),
    0.2 * jnp.ones([DB, MX, MY, DHQK, JK]),
    0.3 * jnp.ones([DB, MX, MY, DHHV, JV]),
)

S0r, S0d, T0rl, T0du, T0dl, T0ru, M0l, M0u, D0 = (
    0.1 * jnp.ones([DB, MX, MY, JT, JK, JV]),
    0.2 * jnp.ones([DB, MX, MY, JT, JK, JV]),
    0.2 * jnp.ones([DB, MX, MY, JT, JT]),
    0.1 * jnp.ones([DB, MX, MY, JT, JT]),
    0.2 * jnp.ones([DB, MX, MY, JT, JT]),
    0.3 * jnp.ones([DB, MX, MY, JT, JT]),
    0.3 * jnp.ones([DB, MX, MY, JO, JQ, JT]),
    0.2 * jnp.ones([DB, MX, MY, JO, JQ, JT]),
    jnp.ones([DB, MX, MY, JO, JQ, JK, JV]),
)


levels = 5
res = pLSTM2D_jax(Q, K, V, S0r, S0d, T0rl, T0du, T0dl, T0ru, M0l, M0u, D0, levels=levels)
print(res)

pLSTM2D_jax(Q, K, V, S0r, S0d, T0rl, T0du, None, T0ru, M0l, M0u, D0, levels=levels)

In [None]:
%timeit pLSTM2D_jax(Q, K, V, S0r, S0d, T0rl, T0du, None, T0dl, M0l, M0u, D0, levels=5)

In [None]:
pLSTM2D_jax = jax.jit(partial(pLSTM2D_jax, levels=5))

In [None]:
D0.shape

In [None]:
pLSTM2D_jax(Q, K, V, S0r, S0d, T0rl, T0du, None, T0dl, M0l, M0u, D0)

In [None]:
%timeit pLSTM2D_jax(Q, K, V, S0r, S0d, T0rl, T0du, None, T0dl, M0l, M0u, D0,)

In [None]:
D0.shape

### Torch

In [None]:
import torch


B = 32
X = 32
Y = 32
DHQK = 128
DHHV = 128
JQ = 1
JT = 1
JV = 1
JK = 1
JO = 1
DTYPE = torch.float32
DEVICE = "cuda"

levels = 5

phi0 = 0.5
phi1 = 1 - phi0

Q = 1.0 + 0.0 * torch.randn([B, X, Y, DHQK, JQ], dtype=DTYPE, device=DEVICE) / DHQK
K = 1.0 + 0.0 * torch.randn([B, X, Y, DHQK, JK], dtype=DTYPE, device=DEVICE)
V = 1.0 + 0.0 * torch.randn([B, X, Y, DHHV, JV], dtype=DTYPE, device=DEVICE)
S0_r = phi0 * (1.0 + 0.0 * torch.randn([B, X, Y, JT, JK, JV], dtype=DTYPE, device=DEVICE))
S0_d = phi1 * (1.0 + 0.0 * torch.randn([B, X, Y, JT, JK, JV], dtype=DTYPE, device=DEVICE))
T00 = torch.eye(JT, dtype=DTYPE, device=DEVICE)[None, None, None, :, :] * torch.ones(
    [B, X, Y, 1, 1], dtype=DTYPE, device=DEVICE
)
T0_rl = phi0 * T00 + 0.0 * torch.randn_like(T00, dtype=DTYPE, device=DEVICE)
T0_du = phi1 * T00 + 0.0 * torch.randn_like(T00, dtype=DTYPE, device=DEVICE)
T0_dl = phi1 * T00 + 0.0 * torch.randn_like(T00, dtype=DTYPE, device=DEVICE)
T0_ru = phi0 * T00 + 0.0 * torch.randn_like(T00, dtype=DTYPE, device=DEVICE)
M0_l = 1.0 + 0.0 * torch.randn([B, X, Y, JO, JQ, JT], dtype=DTYPE, device=DEVICE)
M0_u = 1.0 + 0.0 * torch.randn([B, X, Y, JO, JQ, JT], dtype=DTYPE, device=DEVICE)
D0 = 1 + 0.0 * torch.randn([B, X, Y, JO, JQ, JK, JV], dtype=DTYPE, device=DEVICE)

In [None]:
%timeit pLSTM2D_fwbw(Q, K, V, S0_r, S0_d, T0_rl, T0_du, T0_dl, T0_ru, M0_l, M0_u, D0, None, levels=levels)