<center><h1>Основы глубокого обучение</h1></center>

In [3]:
!python -V # Версия Python

Python 3.12.8


In [4]:
# Подавление предупреждений
import warnings
for warn in [UserWarning, FutureWarning]: warnings.filterwarnings("ignore", category = warn)

# Импорт необходимых библиотек
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import polars as pl
import pandas as pd
import yfinance as yf
import sklearn
import networkx as nx
import jupyterlab as jlab
import ipywidgets
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

from torch import Tensor
from einops import rearrange
from typing import Tuple, Callable
from torch.autograd import Function

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_iris
from sklearn.preprocessing import MinMaxScaler
from mpl_toolkits.mplot3d import Axes3D

In [5]:
# Версии необходимых библиотек
packages = [
    "Torch", "NumPy", "Polars", "Pandas", "Matplotlib", "Yfinance", "Scikit-learn", "Ipywidgets", "JupyterLab"
]

package_objects = [
    torch, np, pl, pd, mpl, yf, sklearn, ipywidgets, jlab
]

versions = list(map(lambda obj: obj.__version__, package_objects))

columns_order = ["№", "Библиотека", "Версия"]
df_pkgs = (
    pl.DataFrame({
        columns_order[1]: packages,
        columns_order[2]: versions
    })
    .with_columns(pl.arange(1, pl.lit(len(packages)) + 1).alias(columns_order[0]))
    .select(columns_order)
)

display(df_pkgs)

path2reqs = "."
reqs_name = "requirements.txt"

def get_packages_and_versions():
    """Генерация строк с библиотеками и их версиями в формате: библиотека==версия"""
    
    for package, version in zip(packages, versions):
        yield f"{package.lower()}=={version}\n"

with open(os.path.join(path2reqs, reqs_name), "w", encoding = "utf-8") as f:
    f.writelines(get_packages_and_versions())

№,Библиотека,Версия
i64,str,str
1,"""Torch""","""2.2.2"""
2,"""NumPy""","""1.26.4"""
3,"""Polars""","""1.19.0"""
4,"""Pandas""","""2.2.3"""
5,"""Matplotlib""","""3.10.0"""
6,"""Yfinance""","""0.2.51"""
7,"""Scikit-learn""","""1.6.1"""
8,"""Ipywidgets""","""8.1.5"""
9,"""JupyterLab""","""4.3.4"""


# Лекция 6

4. **Современные архитектуры глубокого обучения**
    - Обзор семейства xLSTM
    - Применение данных архитектур для анализа последовательных данных, изображений и мультимодальной интеграции

## Вспоминаем LSTM

<div style="text-align: center;">
    <img src="https://raw.githubusercontent.com/DmitryRyumin/HSE_Fundamentals_of_DL_2025/refs/heads/main/tutorials/imgs/6_1.png" alt="" width="800px">
</div>

### Основные моменты

<div style="text-align: center;">
    <img src="https://raw.githubusercontent.com/DmitryRyumin/HSE_Fundamentals_of_DL_2025/refs/heads/main/tutorials/imgs/6_22.png" alt="" width="500px">
</div>

### Шаг 1

<div style="text-align: center;">
    <img src="https://raw.githubusercontent.com/DmitryRyumin/HSE_Fundamentals_of_DL_2025/refs/heads/main/tutorials/imgs/6_3.png" alt="" width="500px">
</div>

$$
f_t=\sigma\left(W_f \times\left[h_{t-1}, x_t\right]+b_f\right)
$$

### Шаг 2

<div style="text-align: center;">
    <img src="https://raw.githubusercontent.com/DmitryRyumin/HSE_Fundamentals_of_DL_2025/refs/heads/main/tutorials/imgs/6_4.png" alt="" width="500px">
</div>

$$
\begin{aligned}
i_t & =\sigma\left(W_i \times\left[h_{t-1}, x_t\right]+b_i\right) \\
\tilde{C}_t & =\tanh \left(W_C \times\left[h_{t-1}, x_t\right]+b_C\right)
\end{aligned}
$$

### Шаг 3

<div style="text-align: center;">
    <img src="https://raw.githubusercontent.com/DmitryRyumin/HSE_Fundamentals_of_DL_2025/refs/heads/main/tutorials/imgs/6_5.png" alt="" width="500px">
</div>

$$
C_t=f_t * C_{t-1}+i_t * \tilde{C}_t
$$

### Шаг 4

<div style="text-align: center;">
    <img src="https://raw.githubusercontent.com/DmitryRyumin/HSE_Fundamentals_of_DL_2025/refs/heads/main/tutorials/imgs/6_6.png" alt="" width="500px">
</div>

$$
\begin{aligned}
& o_t=\sigma\left(W_o\left[h_{t-1}, x_t\right]+b_o\right) \\
& h_t=o_t * \tanh \left(C_t\right)
\end{aligned}
$$

### Проблемы LSTM

1. Ограниченная способность пересматривать решения о хранении информации
2. Ограниченная способность хранить информацию
3. Невозможно распараллеливать вычисления

## sLSTM

<div style="text-align: center;">
    <img src="https://raw.githubusercontent.com/DmitryRyumin/HSE_Fundamentals_of_DL_2025/refs/heads/main/tutorials/imgs/6_7.png" alt="" width="900px">
</div>

$$
n_t=\mathrm{f}_t n_{t-1}+\mathrm{i}_t
$$

### Способ вычисления выходов гейтов забывания и входного состояния

#### LSTM

$$
\begin{array}{llr}
\mathrm{i}_t=\sigma\left(\tilde{\mathrm{i}}_t\right), & \tilde{\mathrm{i}}_t=\boldsymbol{w}_{\mathrm{i}}^{\top} \boldsymbol{x}_t+r_{\mathrm{i}} h_{t-1}+b_{\mathrm{i}} & \text { input gate } \\
\mathrm{f}_t=\sigma\left(\tilde{\mathrm{f}}_t\right), & \tilde{\mathrm{f}}_t=\boldsymbol{w}_{\mathrm{f}}^{\top} \boldsymbol{x}_t+r_{\mathrm{f}} h_{t-1}+b_{\mathrm{f}} & \text { forget gate }
\end{array}
$$

#### sLSTM

$$
\begin{array}{rlr}
m_t & =\max \left(\log \left(\mathrm{f}_t\right)+m_{t-1}, \log \left(\mathrm{i}_t\right)\right) & \text { stabilizer state } \\
\mathrm{i}_t^{\prime} & =\exp \left(\log \left(\mathrm{i}_t\right)-m_t\right)=\exp \left(\tilde{\mathrm{i}}_t-m_t\right) & \text { stabil. input gate } \\
\mathrm{f}_t^{\prime} & =\exp \left(\log \left(\mathrm{f}_t\right)+m_{t-1}-m_t\right) & \text { stabil. forget gate }
\end{array}
$$

#### mLSTM

##### Правило обновления ковариаций

$$
\boldsymbol{C}_t=\boldsymbol{C}_{t-1}+\boldsymbol{v}_t \boldsymbol{k}_t^{\top}
$$

##### Гейты вычисляются также с помощью экспоненты, но без использования предыдущих скрытых состояний

###### LSTM

$$
\begin{array}{llr}
\mathrm{i}_t=\sigma\left(\tilde{\mathrm{i}}_t\right), & \tilde{\mathrm{i}}_t=\boldsymbol{w}_{\mathrm{i}}^{\top} \boldsymbol{x}_t+\underline{r_{\mathrm{i}} h_{t-1}}+b_{\mathrm{i}} & \text { input gate } \\
\mathrm{f}_t=\sigma\left(\tilde{\mathrm{f}}_t\right), & \tilde{\mathrm{f}}_t=\boldsymbol{w}_{\mathrm{f}}^{\top} \boldsymbol{x}_t+\underline{r_{\mathrm{f}} h_{t-1}}+b_{\mathrm{f}} & \text { forget gate } \\
\mathrm{o}_t=\sigma\left(\tilde{\mathrm{o}}_t\right), & \tilde{\mathrm{o}}_t=\boldsymbol{w}_{\mathrm{o}}^{\top} \boldsymbol{x}_t+\underline{r_{\mathrm{o}} h_{t-1}}+b_{\mathrm{o}} & \text { output gate }
\end{array}
$$

###### mLSTM

$$
\begin{array}{lrl}
\mathrm{i}_t=\exp \left(\tilde{\mathrm{i}}_t\right), & \tilde{\mathrm{i}}_t=\boldsymbol{w}_{\mathrm{i}}^{\top} \boldsymbol{x}_t+b_{\mathrm{i}} & \text { input gate } \\
\mathrm{f}_t=\sigma\left(\tilde{\mathrm{f}}_t\right) \text { OR } \exp \left(\tilde{\mathrm{f}}_t\right), & \tilde{\mathrm{f}}_t=\boldsymbol{w}_{\mathrm{f}}^{\top} \boldsymbol{x}_t+b_{\mathrm{f}} & \text { forget gate } \\
\mathbf{o}_t=\sigma\left(\tilde{\mathbf{o}}_t\right), & \tilde{\mathbf{o}}_t=\boldsymbol{W}_{\mathbf{o}} \boldsymbol{x}_t+\boldsymbol{b}_{\mathbf{o}} & \text { output gate }
\end{array}
$$

##### Нормализация скрытого состояния

$$
\boldsymbol{h}_t=\mathbf{o}_t \odot \tilde{\boldsymbol{h}}_t, \quad \quad \tilde{\boldsymbol{h}}_t=\boldsymbol{C}_t \boldsymbol{q}_t / \max \left\{\left|\boldsymbol{n}_t^{\top} \boldsymbol{q}_t\right|, 1\right\} \quad \text { hidden state }
$$

$$
\begin{aligned}
\boldsymbol{q}_t & =\boldsymbol{W}_q \boldsymbol{x}_t+\boldsymbol{b}_q \\
\boldsymbol{k}_t & =\frac{1}{\sqrt{d}} \boldsymbol{W}_k \boldsymbol{x}_t+\boldsymbol{b}_k \\
\boldsymbol{v}_t & =\boldsymbol{W}_v \boldsymbol{x}_t+\boldsymbol{b}_v
\end{aligned}
$$

## Vision-LSTM

<div style="text-align: center;">
    <img src="https://raw.githubusercontent.com/DmitryRyumin/HSE_Fundamentals_of_DL_2025/refs/heads/main/tutorials/imgs/6_8.png" alt="" width="900px">
</div>

# Семинар 6

In [6]:
class CausalConv1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs):
        super(CausalConv1D, self).__init__()
        self.padding = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=self.padding, dilation=dilation, **kwargs)

    def forward(self, x):
        x = self.conv(x)
        return x[:, :, :-self.padding]

class BlockDiagonal(nn.Module):
    def __init__(self, in_features, out_features, num_blocks):
        super(BlockDiagonal, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_blocks = num_blocks

        assert out_features % num_blocks == 0
        
        block_out_features = out_features // num_blocks
        
        self.blocks = nn.ModuleList([
            nn.Linear(in_features, block_out_features)
            for _ in range(num_blocks)
        ])
        
    def forward(self, x):
        x = [block(x) for block in self.blocks]
        x = torch.cat(x, dim=-1)
        return x

class sLSTMBlock(nn.Module):
    def __init__(self, input_size, head_size, num_heads, proj_factor=4/3):
        super(sLSTMBlock, self).__init__()
        self.input_size = input_size
        self.head_size = head_size
        self.hidden_size = head_size * num_heads
        self.num_heads = num_heads
        self.proj_factor = proj_factor

        assert proj_factor > 0

        self.layer_norm = nn.LayerNorm(input_size)
        self.causal_conv = CausalConv1D(1, 1, 4)

        self.Wz = BlockDiagonal(input_size, self.hidden_size, num_heads)
        self.Wi = BlockDiagonal(input_size, self.hidden_size, num_heads)
        self.Wf = BlockDiagonal(input_size, self.hidden_size, num_heads)
        self.Wo = BlockDiagonal(input_size, self.hidden_size, num_heads)

        self.Rz = BlockDiagonal(self.hidden_size, self.hidden_size, num_heads)
        self.Ri = BlockDiagonal(self.hidden_size, self.hidden_size, num_heads)
        self.Rf = BlockDiagonal(self.hidden_size, self.hidden_size, num_heads)
        self.Ro = BlockDiagonal(self.hidden_size, self.hidden_size, num_heads)

        self.group_norm = nn.GroupNorm(num_heads, self.hidden_size)

        self.up_proj_left = nn.Linear(self.hidden_size, int(self.hidden_size * proj_factor))
        self.up_proj_right = nn.Linear(self.hidden_size, int(self.hidden_size * proj_factor))
        self.down_proj = nn.Linear(int(self.hidden_size * proj_factor), input_size)

    def forward(self, x, prev_state):
        assert x.size(-1) == self.input_size
        h_prev, c_prev, n_prev, m_prev = prev_state

        h_prev = h_prev.to(x.device)
        c_prev = c_prev.to(x.device)
        n_prev = n_prev.to(x.device)
        m_prev = m_prev.to(x.device)
        
        x_norm = self.layer_norm(x)
        x_conv = F.silu(self.causal_conv(x_norm.unsqueeze(1)).squeeze(1))

        z = torch.tanh(self.Wz(x_norm) + self.Rz(h_prev))
        o = torch.sigmoid(self.Wo(x_norm) + self.Ro(h_prev))
        i_tilde = self.Wi(x_conv) + self.Ri(h_prev)
        f_tilde = self.Wf(x_conv) + self.Rf(h_prev)

        # TODO !!!!
        m_t = torch.max(f_tilde + m_prev, i_tilde)
        i = torch.exp(i_tilde - m_t) # Всегда <= 1.0
        f = torch.exp(f_tilde + m_prev - m_t) # Всегда <= 1.0

        c_t = f * c_prev + i * z
        n_t = f * n_prev + i
        h_t = o * c_t / n_t

        output = h_t
        output_norm = self.group_norm(output)
        output_left = self.up_proj_left(output_norm)
        output_right = self.up_proj_right(output_norm)
        output_gated = F.gelu(output_right)
        output = output_left * output_gated
        output = self.down_proj(output)
        final_output = output + x

        return final_output, (h_t, c_t, n_t, m_t)
    
class sLSTM(nn.Module):
    # Add bias, dropout, bidirectional
    def __init__(self, input_size, head_size, num_heads, num_layers=1, batch_first=False, proj_factor=4/3):
        super(sLSTM, self).__init__()
        self.input_size = input_size
        self.head_size = head_size
        self.hidden_size = head_size * num_heads
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.proj_factor_slstm = proj_factor

        self.layers = nn.ModuleList([sLSTMBlock(input_size, head_size, num_heads, proj_factor) for _ in range(num_layers)])

    def forward(self, x, state=None):
        assert x.ndim == 3
        if self.batch_first: x = x.transpose(0, 1)
        seq_len, batch_size, _ = x.size()
        
        if state is not None:
            state = torch.stack(list(state)).to(x.device)
            assert state.ndim == 4
            num_hidden, state_num_layers, state_batch_size, state_input_size = state.size()
            assert num_hidden == 4
            assert state_num_layers == self.num_layers
            assert state_batch_size == batch_size
            assert state_input_size == self.input_size
            state = state.transpose(0, 1)
        else:
            state = torch.zeros(self.num_layers, 4, batch_size, self.hidden_size, device=x.device)

        output = []
        for t in range(seq_len):
            x_t = x[t]
            for layer in range(self.num_layers):
                x_t, state_tuple = self.layers[layer](x_t, tuple(state[layer].clone()))
                state[layer] = torch.stack(list(state_tuple))
            output.append(x_t)
        
        output = torch.stack(output)
        if self.batch_first:
            output = output.transpose(0, 1)
        state = tuple(state.transpose(0, 1))
        return output, state

class mLSTMBlock(nn.Module):
    def __init__(self, input_size, head_size, num_heads, proj_factor=2):
        super(mLSTMBlock, self).__init__()
        self.input_size = input_size
        self.head_size = head_size
        self.hidden_size = head_size * num_heads
        self.num_heads = num_heads
        self.proj_factor = proj_factor

        assert proj_factor > 0

        self.layer_norm = nn.LayerNorm(input_size)
        self.up_proj_left = nn.Linear(input_size, int(input_size * proj_factor))
        self.up_proj_right = nn.Linear(input_size, self.hidden_size)
        self.down_proj = nn.Linear(self.hidden_size, input_size)

        self.causal_conv = CausalConv1D(1, 1, 4)
        self.skip_connection = nn.Linear(int(input_size * proj_factor), self.hidden_size)

        self.Wq = BlockDiagonal(int(input_size * proj_factor), self.hidden_size, num_heads)
        self.Wk = BlockDiagonal(int(input_size * proj_factor), self.hidden_size, num_heads)
        self.Wv = BlockDiagonal(int(input_size * proj_factor), self.hidden_size, num_heads)
        self.Wi = nn.Linear(int(input_size * proj_factor), self.hidden_size)
        self.Wf = nn.Linear(int(input_size * proj_factor), self.hidden_size)
        self.Wo = nn.Linear(int(input_size * proj_factor), self.hidden_size)

        self.group_norm = nn.GroupNorm(num_heads, self.hidden_size)

    def forward(self, x, prev_state):
        h_prev, c_prev, n_prev, m_prev = prev_state

        h_prev = h_prev.to(x.device)
        c_prev = c_prev.to(x.device)
        n_prev = n_prev.to(x.device)
        m_prev = m_prev.to(x.device)
        
        assert x.size(-1) == self.input_size
        x_norm = self.layer_norm(x)
        x_up_left = self.up_proj_left(x_norm)
        x_up_right = self.up_proj_right(x_norm)

        x_conv = F.silu(self.causal_conv(x_up_left.unsqueeze(1)).squeeze(1))
        x_skip = self.skip_connection(x_conv)

        q = self.Wq(x_conv)
        k = self.Wk(x_conv) / (self.head_size ** 0.5)
        v = self.Wv(x_up_left)

        i_tilde = self.Wi(x_conv)
        f_tilde = self.Wf(x_conv)
        o = torch.sigmoid(self.Wo(x_up_left))

        m_t = torch.max(f_tilde + m_prev, i_tilde)
        i = torch.exp(i_tilde - m_t)
        f = torch.exp(f_tilde + m_prev - m_t)

        c_t = f * c_prev + i * (v * k) # v @ k.T
        n_t = f * n_prev + i * k
        h_t = o * (c_t * q) / torch.max(torch.abs(n_t.T @ q), 1)[0] # o * (c @ q) / max{|n.T @ q|, 1}

        output = h_t
        output_norm = self.group_norm(output)
        output = output_norm + x_skip
        output = output * F.silu(x_up_right)
        output = self.down_proj(output)
        final_output = output + x

        return final_output, (h_t, c_t, n_t, m_t)
    
class mLSTM(nn.Module):
    # Add bias, dropout, bidirectional
    def __init__(self, input_size, head_size, num_heads, num_layers=1, batch_first=False, proj_factor=2):
        super(mLSTM, self).__init__()
        self.input_size = input_size
        self.head_size = head_size
        self.hidden_size = head_size * num_heads
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.proj_factor_slstm = proj_factor

        self.layers = nn.ModuleList([mLSTMBlock(input_size, head_size, num_heads, proj_factor) for _ in range(num_layers)])

    def forward(self, x, state=None):
        assert x.ndim == 3
        if self.batch_first: x = x.transpose(0, 1)
        seq_len, batch_size, _ = x.size()
        
        if state is not None:
            state = torch.stack(list(state)).to(x.device)
            assert state.ndim == 4
            num_hidden, state_num_layers, state_batch_size, state_input_size = state.size()
            assert num_hidden == 4
            assert state_num_layers == self.num_layers
            assert state_batch_size == batch_size
            assert state_input_size == self.input_size
            state = state.transpose(0, 1)
        else:
            state = torch.zeros(self.num_layers, 4, batch_size, self.hidden_size, device=x.device)

        output = []
        for t in range(seq_len):
            x_t = x[t]
            for layer in range(self.num_layers):
                x_t, state_tuple = self.layers[layer](x_t, tuple(state[layer].clone()))
                state[layer] = torch.stack(list(state_tuple))
            output.append(x_t)
        
        output = torch.stack(output)
        if self.batch_first:
            output = output.transpose(0, 1)
        state = tuple(state.transpose(0, 1))
        return output, state

class xLSTM(nn.Module):
    # Add bias, dropout, bidirectional
    def __init__(self, input_size, head_size, num_heads, layers, batch_first=False, proj_factor_slstm=4/3, proj_factor_mlstm=2):
        super(xLSTM, self).__init__()
        self.input_size = input_size
        self.head_size = head_size
        self.hidden_size = head_size * num_heads
        self.num_heads = num_heads
        self.layers = layers
        self.num_layers = len(layers)
        self.batch_first = batch_first
        self.proj_factor_slstm = proj_factor_slstm
        self.proj_factor_mlstm = proj_factor_mlstm

        self.layers = nn.ModuleList()
        for layer_type in layers:
            if layer_type == 's':
                layer = sLSTMBlock(input_size, head_size, num_heads, proj_factor_slstm)
            elif layer_type == 'm':
                layer = mLSTMBlock(input_size, head_size, num_heads, proj_factor_mlstm)
            else:
                raise ValueError(f"Invalid layer type")
            self.layers.append(layer)

    def forward(self, x, state=None):
        assert x.ndim == 3
        if self.batch_first: x = x.transpose(0, 1)
        seq_len, batch_size, _ = x.size()
        
        if state is not None:
            state = torch.stack(list(state)).to(x.device)
            assert state.ndim == 4
            num_hidden, state_num_layers, state_batch_size, state_input_size = state.size()
            assert num_hidden == 4
            assert state_num_layers == self.num_layers
            assert state_batch_size == batch_size
            assert state_input_size == self.input_size
            state = state.transpose(0, 1)
        else:
            state = torch.zeros(self.num_layers, 4, batch_size, self.hidden_size, device=x.device)

        output = []
        for t in range(seq_len):
            x_t = x[t]
            for layer in range(self.num_layers):
                x_t, state_tuple = self.layers[layer](x_t, tuple(state[layer].clone()))
                state[layer] = torch.stack(list(state_tuple))
            output.append(x_t)
        
        output = torch.stack(output)
        if self.batch_first:
            output = output.transpose(0, 1)
        state = tuple(state.transpose(0, 1))
        return output, state

In [18]:
# Инициализация модели
model = xLSTM(input_size=512, head_size=512, num_heads=2, layers="msm")
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Генерация случайных входных данных
input_tensor = torch.randn(32, 128, 512)
target_tensor = torch.randn(32, 128, 512)

# Обучение модели
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    output = model(input_tensor)
    if isinstance(output, tuple):
        output = output[0]
    loss = criterion(output, target_tensor)
    
    loss.backward()
    optimizer.step()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.6f}")

# Проверка модели
with torch.no_grad():
    output = model(input_tensor)
    if isinstance(output, tuple):  
        output = output[0]
    
    print("Output shape after training:", output.shape)

Epoch 1/20, Loss: 2.029963
Epoch 2/20, Loss: 1.965565
Epoch 3/20, Loss: 1.911696
Epoch 4/20, Loss: 1.851583
Epoch 5/20, Loss: 1.782440
Epoch 6/20, Loss: 1.706850
Epoch 7/20, Loss: 1.632636
Epoch 8/20, Loss: 1.551116
Epoch 9/20, Loss: 1.476998
Epoch 10/20, Loss: 1.411536
Epoch 11/20, Loss: 1.358058
Epoch 12/20, Loss: 1.298240
Epoch 13/20, Loss: 1.248318
Epoch 14/20, Loss: 1.230115
Epoch 15/20, Loss: 1.182001
Epoch 16/20, Loss: 1.143725
Epoch 17/20, Loss: 1.117498
Epoch 18/20, Loss: 1.085795
Epoch 19/20, Loss: 1.056862
Epoch 20/20, Loss: 1.045573
Output shape after training: torch.Size([32, 128, 512])


## Домашнее задание: Обучение и визуализация xLSTM

### Цель задания

1. Обучить простую модель семейства xLSTM для обработки данных, которые были выбраны для обучения Mamba из предыдущего семинара
2. Визуализировать веса внимания и интерпретировать их
3. Сделать выводы о том, как модель воспринимает данные и принимает решения и лучше она или хуже Mamba и в чем?