# Model construct Notebook (Colab)

**Contexte, principe et architecture proposé :**

* Emergent World Representations: Exploring a Sequence Model Trained on a Synthetic Task (https://arxiv.org/abs/2210.13382) --> finalement, image-captionning !
* Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model (https://arxiv.org/abs/2401.09417)
* Bi-Mamba+: Bidirectional Mamba for Time Series Forecasting (https://arxiv.org/abs/2404.15772)
* MambaByte : MambaByte: Token-free Selective State Space Model (https://arxiv.org/abs/2401.13660)

L'idée est que le modele genere du texte et des pixels sous formes de sequences, lors de la génération d'image, il y aura toujours des sauts de lignes (ASCII OA) lors du démarrage de l'image, mais egalement qu'on atteint la limite de l'image. Si les sequences généré à la suite n'ont pas la meme tailles, ca génére 2 images differentes. La complexité ici est que mamba doit assimiler la tache de "copie" et que l'interpreteur construise logiquement aussi bien les images que le texte. (un peu comme l'art ASCII, mais en ++) --> utiliser des couches d'attention ? à voir si la bidirectionnalité permet d'améliorer la copie

**Code inspiration :**

* https://github.com/hustvl/Vim
* https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
* https://huggingface.co/JunxiongWang/MambaByte_Arxiv


In [None]:
!pip install -q mamba-ssm causal-conv1d

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/85.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.4/85.4 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m307.2/307.2 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for mamba-ssm (setup.py) ... [?25l[?25hdone
  Building wheel for causal-conv1d (setup.py) ... [?25l[?25hdone


In [None]:
import torch, math
import numpy as np
from torch import nn
from mamba_ssm.modules.mamba_simple import Mamba

In [None]:
from dataclasses import dataclass

@dataclass
class MambaConfig:
    dim: int # The input dimension of the input tensor.
    d_state: int = 16 #16 # The dimension of the state space model.
    d_conv : int = 4 # The convolutionnal windows
    expand: int = 2 # E in paper/comments
    depth : int = 8 # The number of residual S6 layers

In [None]:
d_model, d_state, n_layers = 64, 16, 8
config = MambaConfig(dim=d_model, d_state=d_state, depth=n_layers)

In [None]:
batch, length = 2, 64
x = torch.randn(batch, length, d_model).to("cuda")
x.shape

torch.Size([2, 64, 64])

In [None]:
class BysMamba(nn.Module):
    def __init__(self, config: MambaConfig):
        super().__init__()
        self.config = config
        # text & image(t) embedding
        self.vocab_size = 256+512 # ASCII bytes + RGB 8*8*8 pixel
        self.linear_embedding = nn.Embedding(self.vocab_size, config.dim)
        self.patch_embedding = nn.Conv2d(1, self.vocab_size, kernel_size=patch_size, stride=stride) # 3D in future
        # mamba part
        self.in_mamba = Mamba(d_model=config.dim, d_state=config.d_state, d_conv=config.d_conv, expand=config.expand,)
        self.layers = nn.ModuleList([Mamba(d_model=config.dim, d_state=config.d_state, d_conv=config.d_conv, expand=config.expand,) for _ in range(config.depth)])
        self.out_mamba = Mamba(d_model=config.dim, d_state=config.d_state, d_conv=config.d_conv, expand=config.expand,)
        # output
        self.lm_head = nn.Linear(config.dim, self.vocab_size, bias=False)

    def forward(self, x):
        # shape : x : (B, M, N, L)
        _,M,N,_ = x.shape
        # embedding
        xl = x[:, M//2, N//2, :] # img center
        xl = self.linear_embedding(xl) # (B,L,D)
        xp = self.patch_embedding(x).flatten(2).transpose(1, 2) # (B,L,D)
        x = xl + xp
        # bidirectional mamba input
        x += self.in_mamba(x) + self.in_mamba(torch.flip(x, dims=[1])).flip([1])
        # mamba intermediate layers
        for layer in self.layers:
            x += layer(x)
        # bidirectional mamba output
        x += self.out_mamba(x) + self.out_mamba(torch.flip(x, dims=[1])).flip([1])
        # prediction output
        x = self.lm_head(x) # probability
        return x

In [None]:
# test
# bytes([i for i in range(16**3)].decode('utf-8') # doesn't work, max 256
s = [chr(i) for i in range(16**3)]
text_byte = np.frombuffer("".join(s).encode('utf-8'), dtype=np.uint8)
# validation
embedding = nn.Embedding(256+512, d_model).to("cuda") # 256 for ASCII text and 512 for image

In [None]:
#text = "\documentclass[12pt]{article}"
#text_byte = np.frombuffer(text.encode('utf-8'), dtype=np.uint8)
input_ids = torch.from_numpy(text_byte[None, :]).long().cuda()
input_ids.shape

torch.Size([1, 10112])

In [None]:

x_ = embedding(input_ids)
x_.shape

torch.Size([1, 10112, 64])

In [None]:
model = BysMamba(config).to("cuda")
y = model(x_)
y.shape

torch.Size([1, 10112, 768])