# Imports

In [1]:
%load_ext autoreload
%autoreload 2
import sys 

# sys.path.append('..')
from omegaconf import OmegaConf
from pprint import pprint
from dacite import from_dict
from dacite import Config as DaciteConfig
import torch

from xlstm.xlstm_block_stack import xLSTMBlockStack, xLSTMBlockStackConfig

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Init config

In [88]:
# xlstm_cfg = f""" 
# mlstm_block:
#   mlstm:
#     conv1d_kernel_size: 4
#     qkv_proj_blocksize: 1
#     num_heads: 1
#     proj_factor: 1
# slstm_block:
#   slstm:
#     backend: {'cuda' if torch.cuda.is_available() else 'vanilla'} #! only vanilla here works
#     num_heads: 1
#     conv1d_kernel_size: 0
#     bias_init: powerlaw_blockdependent
#   feedforward:
#     proj_factor: 1.2
#     act_fn: gelu
# context_length: 50
# num_blocks: 2
# embedding_dim: 1 # same as `in_features` in Pytorch LSTM
# slstm_at: [1] #[1] # for [] it also works, so if no sLSTM is in the stack
# """

xlstm_cfg = f""" 
mlstm_block:
  mlstm:
    conv1d_kernel_size: 4
    qkv_proj_blocksize: 4
    num_heads: 4
slstm_block:
  slstm:
    backend: {'cuda' if torch.cuda.is_available() else 'vanilla'} #! only vanilla here works
    num_heads: 4
    conv1d_kernel_size: 4
    bias_init: powerlaw_blockdependent
  feedforward:
    proj_factor: 1.3
    act_fn: gelu
context_length: 256
num_blocks: 7
embedding_dim: 128
add_post_blocks_norm: False
slstm_at: [1] #[1] # for [] it also works, so if no sLSTM is in the stack
"""

# Init XLSTM model

In [89]:
cfg = OmegaConf.create(xlstm_cfg)
cfg = from_dict(data_class=xLSTMBlockStackConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))
xlstm_stack = xLSTMBlockStack(cfg)

# Inspect config

In [90]:
pprint(cfg)

xLSTMBlockStackConfig(mlstm_block=mLSTMBlockConfig(mlstm=mLSTMLayerConfig(proj_factor=2.0,
                                                                          round_proj_up_dim_up=True,
                                                                          round_proj_up_to_multiple_of=64,
                                                                          _proj_up_dim=256,
                                                                          conv1d_kernel_size=4,
                                                                          qkv_proj_blocksize=4,
                                                                          num_heads=4,
                                                                          embedding_dim=128,
                                                                          bias=False,
                                                                          dropout=0.0,
                                                                

# Inspect layers

In [91]:
xlstm_stack

xLSTMBlockStack(
  (blocks): ModuleList(
    (0): mLSTMBlock(
      (xlstm_norm): LayerNorm()
      (xlstm): mLSTMLayer(
        (proj_up): Linear(in_features=128, out_features=512, bias=False)
        (q_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (k_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (v_proj): LinearHeadwiseExpand(in_features=256, num_heads=64, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
        (conv1d): CausalConv1d(
          (conv): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
        )
        (conv_act_fn): SiLU()
        (mlstm_cell): mLSTMCell(
          (igate): Linear(in_features=768, out_features=4, bias=True)
          (fgate): Linear(in_features=768, out_features=4, bias=True)
          (outnorm): Mult

# Generate synthetic examlpe

In [122]:
x = torch.randn(2, 256, 128).to(device=device)

In [123]:
x.shape

torch.Size([2, 256, 128])

In [97]:
x[:, 1, :].view(2,1,128).shape

torch.Size([2, 1, 128])

# Check model's output

In [98]:
xlstm_stack = xlstm_stack.to(device=device)

In [84]:
for i in range(2):
    y, states_dict = xlstm_stack.step(x[:, i, :].view(2,1,128))

In [87]:
states_dict['block_6']

{'mlstm_state': (tensor([[[[ 5.4874e-06, -1.2764e-06, -4.2204e-06,  ..., -1.7712e-05,
             -2.4697e-06, -7.9438e-06],
            [ 4.2978e-05, -9.9968e-06, -3.3054e-05,  ..., -1.3872e-04,
             -1.9342e-05, -6.2216e-05],
            [-2.4128e-05,  5.6123e-06,  1.8557e-05,  ...,  7.7881e-05,
              1.0859e-05,  3.4929e-05],
            ...,
            [-1.0465e-05,  2.4343e-06,  8.0490e-06,  ...,  3.3780e-05,
              4.7100e-06,  1.5150e-05],
            [-3.4788e-05,  8.0919e-06,  2.6756e-05,  ...,  1.1229e-04,
              1.5657e-05,  5.0361e-05],
            [-1.0198e-05,  2.3721e-06,  7.8432e-06,  ...,  3.2916e-05,
              4.5896e-06,  1.4763e-05]],
  
           [[ 1.1881e-04, -4.0994e-04,  2.5100e-04,  ...,  1.2454e-06,
             -2.5543e-04,  3.8301e-04],
            [ 1.7321e-04, -5.9762e-04,  3.6592e-04,  ...,  1.8156e-06,
             -3.7237e-04,  5.5836e-04],
            [ 7.2518e-05, -2.5021e-04,  1.5320e-04,  ...,  7.6015e-07,
     

In [81]:
y.shape

torch.Size([2, 1, 128])

In [82]:
y

tensor([[[-0.4978,  1.5445,  1.2798, -1.4677, -2.1897,  0.7250,  0.6384,
          -0.3543, -0.6553, -0.5294,  0.9895,  0.8925,  0.0369,  0.3600,
           1.0671, -0.8967,  0.7769,  1.7208, -0.9602, -2.3994, -0.3734,
          -0.0866,  1.4061,  0.0087, -0.7891, -0.0173, -1.9538, -0.3937,
          -0.5488, -0.5184,  1.9376, -0.7643, -0.7506,  1.1501, -0.1561,
           0.4513, -1.2406, -0.9217,  0.3434,  0.6700, -0.2392,  0.1344,
          -0.0117, -0.9852, -1.2051,  0.3428, -0.1593, -1.8795, -0.0729,
           1.1581,  0.8664,  0.4295, -0.0295,  0.2280,  1.7564,  0.3873,
          -0.2153,  1.3113,  0.0999,  0.7173, -1.0597, -2.2170,  1.2809,
          -1.7322,  0.5820,  1.3519,  0.9275, -0.4028,  0.1005,  0.5075,
          -0.2481,  0.8091,  0.7399, -0.1949,  0.4818, -0.6697,  0.0183,
           0.0428, -0.2187, -1.7713,  1.3467, -1.7573,  0.5467,  1.1787,
           1.0028, -0.7336,  1.9220,  0.9596, -0.5030, -1.3327,  1.2682,
           0.0318,  0.2944, -0.4946, -1.2384, -0.16

In [83]:
x[:, 1, :].view(2,1,128)

tensor([[[-8.8171e-01,  7.5984e-01,  1.3308e+00, -8.0911e-01, -7.1669e-01,
           1.2324e-01,  6.4953e-01,  1.9917e-01, -2.8007e-02,  2.0094e-01,
           8.0642e-01,  1.5618e+00, -3.5212e-01,  3.7325e-02,  9.7363e-01,
          -6.5723e-01, -6.3216e-01, -2.7189e-01, -6.8714e-01, -2.0393e+00,
          -3.4308e-02,  9.7309e-01, -2.3616e-01, -3.9586e-01, -6.4092e-02,
           9.2531e-01, -2.7790e+00, -6.8198e-01, -1.6832e-02, -2.0715e-01,
           2.5983e+00, -7.3946e-01, -1.0449e+00,  7.0546e-01,  1.0123e+00,
           7.0937e-01, -5.6492e-01, -1.9830e+00, -2.7441e-01,  8.7571e-01,
           1.7398e-02,  9.0983e-01,  1.3253e-01,  4.8852e-02, -1.4840e+00,
          -2.9523e-02,  1.3925e+00,  1.0114e-01,  3.2748e-01,  1.0978e+00,
          -8.5351e-01, -3.0628e-01, -4.6161e-01, -3.8436e-01,  1.0335e+00,
           6.7830e-01, -3.3133e-01,  6.6896e-01, -4.6368e-01,  7.4131e-01,
          -1.6848e+00, -1.2495e+00,  1.8919e+00, -6.0618e-01,  2.2224e+00,
           9.6635e-01,  4

In [124]:
y = xlstm_stack(x)

In [125]:
x[:, -1, :]

tensor([[-0.8839,  0.6531, -0.8390, -0.1645,  1.6704,  0.5127, -0.0529,  0.2271,
         -2.5904,  1.3150,  0.3769, -0.1745,  0.2888,  0.9407, -0.1206,  0.6557,
          0.1330,  0.1118,  0.4273,  0.5373, -0.3927, -0.0368, -0.3819,  1.1103,
          0.2225,  0.5210, -0.3059, -0.2255, -0.9224, -0.4568, -2.2560,  1.0137,
         -0.1084, -0.6037,  0.5965, -0.5597, -0.5377, -0.5736,  0.8941, -2.7295,
          1.3067,  0.7518,  0.9739,  0.5398, -0.4596,  0.3005, -0.3543, -0.9072,
          0.3558,  0.1623,  0.5433, -0.9490,  0.1389, -1.0022, -1.1491, -0.3044,
         -0.1395, -0.2475, -0.1755,  0.2795,  1.3190,  0.0260, -0.0259, -0.6402,
         -0.3443, -1.1712,  0.2273, -0.3507,  1.4633, -0.2927,  1.0462,  1.0782,
          0.0841, -0.4016,  0.4079,  0.1693,  0.1225,  0.6326,  0.0118, -1.4028,
         -1.1295,  0.6233,  0.1547,  0.0442, -0.7434, -0.5714, -0.8025,  0.6202,
          0.8696,  1.3488,  0.3080,  0.0314, -0.9079, -0.9560, -0.4306,  0.1162,
         -0.6938,  0.3198, -

In [126]:
y[:, -1, :]

tensor([[-8.8777e-01,  7.0733e-01, -1.1354e+00,  6.8297e-02,  1.2867e+00,
          8.6632e-01,  8.6892e-01,  9.2484e-01, -1.2396e+00,  1.7039e+00,
         -1.0790e+00, -1.0183e+00, -6.6037e-01,  3.6460e-01,  9.4218e-01,
         -9.8400e-02, -1.5814e+00,  4.7198e-01,  1.7564e+00,  1.6153e+00,
          2.3493e-01,  2.9874e-01, -6.3176e-01,  7.8221e-01,  2.1035e+00,
          3.4908e-01, -2.4726e+00,  9.4857e-01, -6.0690e-01,  6.2530e-01,
         -4.8214e+00,  1.9170e+00, -3.0467e-01, -5.0227e-01,  8.7043e-01,
         -4.5736e-01,  4.8147e-03, -5.5708e-01,  2.0257e-01, -2.0661e+00,
          9.1403e-01,  1.8090e+00,  2.6168e-01,  1.3294e+00, -2.0494e+00,
         -1.7052e+00, -5.7928e-01, -7.9424e-01,  9.7212e-01,  2.6963e-01,
          2.0599e+00,  3.2527e-02,  1.6426e+00, -1.7454e+00,  1.4678e+00,
         -6.0697e-01,  5.2550e-01,  4.7229e-01, -5.6340e-01, -1.4185e+00,
          1.2872e+00, -1.8675e+00, -5.2358e-01, -6.4815e-01, -3.6819e-01,
         -1.1029e+00,  3.5903e-01, -6.

In [100]:
y.shape

torch.Size([2, 256, 128])

In [54]:
pprint(y)

tensor([[[ 0.8577, -1.1099, -0.3923,  ...,  0.0940,  0.0937,  0.7333],
         [ 0.8142,  0.0359,  0.3861,  ..., -0.1010, -0.7600, -0.0066],
         [ 2.4055, -1.4971,  0.7610,  ..., -0.5647,  1.4988, -1.2182],
         ...,
         [-1.9958,  0.2881,  0.2207,  ...,  0.1564, -1.6645, -0.3780],
         [-0.4549,  0.3158, -0.3409,  ...,  0.4510, -0.3832,  1.3363],
         [-0.8845, -0.6907,  0.5125,  ...,  1.1808, -1.4142,  1.1701]],

        [[ 0.8740, -2.0196,  0.3545,  ..., -0.7747, -0.1946, -1.2063],
         [-0.3995,  0.4758,  0.2324,  ...,  1.6014, -0.7852,  0.3335],
         [ 0.2124, -0.1924, -0.0830,  ...,  0.5713,  1.6000,  1.5129],
         ...,
         [-0.1742,  1.3780, -0.5989,  ...,  0.0386, -0.5372,  1.5216],
         [ 0.4534,  0.9240, -0.4873,  ...,  0.6691, -0.8149,  1.2498],
         [-0.2121, -0.4231, -1.2146,  ...,  0.4834,  1.7986, -0.7707]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>)


In [103]:
pprint(y[:, -1, :].view(1, -1, 128))

tensor([[[ 1.1480, -1.0595,  2.0452, -0.0863, -0.1370,  2.0175, -0.4361,
          -0.1611, -1.5872,  1.1023, -2.6816,  0.6252,  1.0867, -1.2995,
          -0.2884,  1.1804,  1.3891, -1.3155,  0.1274, -3.0621, -4.0834,
           1.0835,  0.5426,  0.4090,  1.8845,  0.7048, -2.5606, -0.6336,
          -0.5072, -0.1934, -0.1107,  4.0092, -0.6715, -0.6509, -1.7275,
           0.2459,  0.1660,  1.4571, -1.2948, -2.2396, -1.9346, -0.5495,
           1.1377, -0.1067,  0.4261, -0.1087, -0.5682,  5.4285,  0.8706,
           1.8102, -1.4242,  0.7603, -0.2171, -1.3306, -0.8624,  1.6319,
          -0.9979, -0.4409,  0.2527, -0.4009, -2.9918, -2.7180,  0.6599,
          -0.0263, -0.8976,  1.0564,  0.7353, -0.1279, -0.2253, -0.8580,
           1.5953, -1.0426,  2.8536, -2.9724, -2.6748,  0.6674,  0.2355,
           0.4488,  3.3318, -1.6286,  0.3944, -1.5674,  0.0701,  0.0965,
           2.1487, -0.0439, -0.0311, -1.9621,  2.6495,  1.4265,  1.3333,
           0.7164, -3.4484, -1.8838, -0.6050,  2.54

In [92]:
t = ()
for i in range(5):
    t += (i,)

In [93]:
t[0]

0

In [128]:
from torch import nn

rnn = nn.LSTM(10, 20, 2, batch_first=True)
inpt = torch.randn(5, 3, 10)
h0 = torch.randn(2, 5, 20)
c0 = torch.randn(2, 5, 20)
output, (hn, cn) = rnn(inpt, (h0, c0))

In [129]:
output.shape

torch.Size([5, 3, 20])

In [130]:
hn.shape

torch.Size([2, 5, 20])

In [121]:
hn[0]

tensor([[-0.2291,  0.0706, -0.2873, -0.0193,  0.0982,  0.1942,  0.0228, -0.1276,
          0.0790,  0.0730, -0.1216,  0.0562,  0.1600,  0.1633, -0.0206, -0.0394,
         -0.0367, -0.0733,  0.1518, -0.1802],
        [-0.1929,  0.1946,  0.0548, -0.1062,  0.0587,  0.0891, -0.0085, -0.0888,
          0.0230,  0.2114,  0.1167,  0.1953,  0.0377, -0.0480, -0.0815,  0.1854,
          0.0207, -0.0434, -0.0672, -0.1421],
        [-0.2902,  0.2361, -0.2112, -0.0509, -0.0608,  0.1652, -0.0098, -0.1355,
          0.1445,  0.1078,  0.0399,  0.1191,  0.0123, -0.0428,  0.0851, -0.0684,
         -0.2345, -0.0469,  0.2690,  0.1379],
        [-0.2975,  0.0864,  0.1136, -0.1358,  0.0094,  0.1324,  0.0084, -0.1017,
         -0.1106,  0.1565, -0.1292,  0.1546, -0.0853, -0.1034,  0.0358,  0.1535,
         -0.0611, -0.0737, -0.1640, -0.1349],
        [-0.0851,  0.1141,  0.2472, -0.2493, -0.1072,  0.0791, -0.1482, -0.1104,
          0.1480,  0.1121,  0.0041,  0.1225,  0.1576,  0.0570, -0.0797, -0.0238,
      

In [119]:
output[:,-1,:]

tensor([[-0.2291,  0.0706, -0.2873, -0.0193,  0.0982,  0.1942,  0.0228, -0.1276,
          0.0790,  0.0730, -0.1216,  0.0562,  0.1600,  0.1633, -0.0206, -0.0394,
         -0.0367, -0.0733,  0.1518, -0.1802],
        [-0.1929,  0.1946,  0.0548, -0.1062,  0.0587,  0.0891, -0.0085, -0.0888,
          0.0230,  0.2114,  0.1167,  0.1953,  0.0377, -0.0480, -0.0815,  0.1854,
          0.0207, -0.0434, -0.0672, -0.1421],
        [-0.2902,  0.2361, -0.2112, -0.0509, -0.0608,  0.1652, -0.0098, -0.1355,
          0.1445,  0.1078,  0.0399,  0.1191,  0.0123, -0.0428,  0.0851, -0.0684,
         -0.2345, -0.0469,  0.2690,  0.1379],
        [-0.2975,  0.0864,  0.1136, -0.1358,  0.0094,  0.1324,  0.0084, -0.1017,
         -0.1106,  0.1565, -0.1292,  0.1546, -0.0853, -0.1034,  0.0358,  0.1535,
         -0.0611, -0.0737, -0.1640, -0.1349],
        [-0.0851,  0.1141,  0.2472, -0.2493, -0.1072,  0.0791, -0.1482, -0.1104,
          0.1480,  0.1121,  0.0041,  0.1225,  0.1576,  0.0570, -0.0797, -0.0238,
      