## Домашнее задание

### Цель задания

1. Обучить простую модель семейства Mamba для обработки любых данных (на выбор: текст, аудио, видео, изображения и ...)
2. Визуализировать веса внимания и интерпретировать их
3. Сделать выводы о том, как модель воспринимает данные и принимает решения

### Решение

#### Импорты

In [1]:
# Подавление предупреждений
import warnings
for warn in [UserWarning, FutureWarning]: warnings.filterwarnings("ignore", category = warn)

# Импорт необходимых библиотек
import math
import torch
import torch.nn as nn
import torch.optim as optim

from torch import Tensor
from einops import rearrange
from typing import Tuple, Callable
from torch.autograd import Function

In [2]:
class PScan(Function):
    @staticmethod
    def forward(ctx, A_inp, X_inp):
        A, X = A_inp.clone(), X_inp.clone()
        A, X = rearrange(A, "l b d s -> b d l s"), rearrange(X, "l b d s -> b d l s")
        PScan._forward(A, X)
        ctx.save_for_backward(A.clone(), X)
        return rearrange(X, "b d l s -> b l d s")

    @staticmethod
    def backward(ctx, grad_inp: Tensor) -> Tuple[Tensor, Tensor]:
        A, X = ctx.saved_tensors
        A = torch.cat((A[:, :, :1], A[:, :, 1:].flip(2)), dim = 2)
        grad_out = rearrange(grad_inp, "b l d s -> b d l s")
        grad_out = grad_out.flip(2)
        PScan._forward(A, grad_out)
        grad_out = grad_out.flip(2)
        Q = torch.zeros_like(X)
        Q[:, :, 1:].add_(X[:, :, :-1] * grad_out[:, :, 1:])
        return rearrange(Q, "b d l s -> b l d s"), rearrange(grad_out, "b d l s -> b l d s")

    @staticmethod
    def _forward(A: Tensor, X: Tensor) -> None:
        b, d, l, s = A.shape
        num_steps = int(math.log2(l))
        Av, Xv = A, X
        for _ in range(num_steps):
            T = Xv.size(2)
            Av, Xv = Av[:, :, :T].reshape(b, d, T // 2, 2, -1), Xv[:, :, :T].reshape(b, d, T // 2, 2, -1)
            Xv[:, :, :, 1].add_(Av[:, :, :, 1].mul(Xv[:, :, :, 0]))
            Av[:, :, :, 1].mul_(Av[:, :, :, 0])
            Av, Xv = Av[:, :, :, 1], Xv[:, :, :, 1]
        for k in range(num_steps - 1, -1, -1):
            Av, Xv = A[:, :, 2**k - 1 : l : 2**k], X[:, :, 2**k - 1 : l : 2**k]
            T = 2 * (Xv.size(2) // 2)
            if T < Xv.size(2):
                Xv[:, :, -1].add_(Av[:, :, -1].mul(Xv[:, :, -2]))
                Av[:, :, -1].mul_(Av[:, :, -2])
            Av, Xv = Av[:, :, :T].reshape(b, d, T // 2, 2, -1), Xv[:, :, :T].reshape(b, d, T // 2, 2, -1)
            Xv[:, :, 1:, 0].add_(Av[:, :, 1:, 0].mul(Xv[:, :, :-1, 1]))
            Av[:, :, 1:, 0].mul_(Av[:, :, :-1, 1])

pscan: Callable[[Tensor, Tensor], Tensor] = PScan.apply

class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-8) -> None:
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x: Tensor) -> Tensor:        
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps) * self.weight

class MambaBlock(nn.Module):
    def __init__(self, d_input, d_model):
        super(MambaBlock, self).__init__()
        self.in_proj = nn.Linear(d_input, d_model)
        self.s_B = nn.Linear(d_model, d_model)
        self.s_C = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_input)

    def forward(self, x):
        x = self.in_proj(x)
        B, C = self.s_B(x), self.s_C(x)
        return self.out_proj(x + B + C)

class Mamba(nn.Module):
    def __init__(self, num_layers, d_input, d_model):
        super(Mamba, self).__init__()
        self.layers = nn.ModuleList([MambaBlock(d_input, d_model) for _ in range(num_layers)])

    def forward(self, seq):
        for mamba in self.layers:
            seq = mamba(seq)
        return seq

model = Mamba(num_layers = 6, d_input = 512, d_model = 256)

input_tensor = torch.randn(32, 128, 512)
output = model(input_tensor)

print("Output shape:", output.shape)

target = torch.randn(32, 128, 512)
criterion = nn.MSELoss()
loss = criterion(output, target)
loss.backward()

for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"Gradients calculated for {name}")
    else:
        print(f"No gradients for {name}")

Output shape: torch.Size([32, 128, 512])
Gradients calculated for layers.0.in_proj.weight
Gradients calculated for layers.0.in_proj.bias
Gradients calculated for layers.0.s_B.weight
Gradients calculated for layers.0.s_B.bias
Gradients calculated for layers.0.s_C.weight
Gradients calculated for layers.0.s_C.bias
Gradients calculated for layers.0.out_proj.weight
Gradients calculated for layers.0.out_proj.bias
Gradients calculated for layers.1.in_proj.weight
Gradients calculated for layers.1.in_proj.bias
Gradients calculated for layers.1.s_B.weight
Gradients calculated for layers.1.s_B.bias
Gradients calculated for layers.1.s_C.weight
Gradients calculated for layers.1.s_C.bias
Gradients calculated for layers.1.out_proj.weight
Gradients calculated for layers.1.out_proj.bias
Gradients calculated for layers.2.in_proj.weight
Gradients calculated for layers.2.in_proj.bias
Gradients calculated for layers.2.s_B.weight
Gradients calculated for layers.2.s_B.bias
Gradients calculated for layers.2.s

In [3]:
model = Mamba(num_layers = 6, d_input = 512, d_model = 256)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = 1e-3)

input_tensor = torch.randn(32, 128, 512)
target_tensor = torch.randn(32, 128, 512)

num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    
    optimizer.zero_grad()
    
    # Прямой проход
    output = model(input_tensor)
    
    # Вычисление потерь
    loss = criterion(output, target_tensor)
    
    # Обратный проход
    loss.backward()
    
    # Обновление параметров модели
    optimizer.step()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

# Пример использования обученной модели
model.eval()
with torch.no_grad():
    output = model(input_tensor)
    print("Output after training:", output.shape)

Epoch 1/100, Loss: 1.0038691759109497
Epoch 2/100, Loss: 1.0039128065109253
Epoch 3/100, Loss: 1.0013190507888794
Epoch 4/100, Loss: 1.00005042552948
Epoch 5/100, Loss: 0.9989610314369202
Epoch 6/100, Loss: 0.9978038668632507
Epoch 7/100, Loss: 0.9966843128204346
Epoch 8/100, Loss: 0.9954712390899658
Epoch 9/100, Loss: 0.9943345189094543
Epoch 10/100, Loss: 0.9932175874710083
Epoch 11/100, Loss: 0.9920615553855896
Epoch 12/100, Loss: 0.9909035563468933
Epoch 13/100, Loss: 0.9897960424423218
Epoch 14/100, Loss: 0.9886513948440552
Epoch 15/100, Loss: 0.9875175356864929
Epoch 16/100, Loss: 0.986386775970459
Epoch 17/100, Loss: 0.9852705597877502
Epoch 18/100, Loss: 0.9841755032539368
Epoch 19/100, Loss: 0.9831169247627258
Epoch 20/100, Loss: 0.9820706844329834
Epoch 21/100, Loss: 0.981052815914154
Epoch 22/100, Loss: 0.9800412654876709
Epoch 23/100, Loss: 0.9790753126144409
Epoch 24/100, Loss: 0.9781175255775452
Epoch 25/100, Loss: 0.9771913886070251
Epoch 26/100, Loss: 0.9762938618659973