# PixelByte: Catching Insights in Unified Multimodal Sequences

Ce notebook présente **PixelByte**, un modèle innovant conçu pour générer simultanément du texte et des images pixel par pixel sous forme de séquences. L'objectif est d'explorer un embedding unifié qui permet une génération multimodale cohérente.

## Contexte et Architecture Proposée

### Fondements Théoriques
- **Image Transformer** : [Génération d'images pixel par pixel](https://arxiv.org/abs/1802.05751)
- **Bi-Mamba+** : [Modèle bidirectionnel pour la prévision de séries temporelles](https://arxiv.org/abs/2404.15772)
- **MambaByte** : [Modèle d'état d'espace sélectif sans token](https://arxiv.org/abs/2401.13660)

### Concept Clé
Le modèle PixelByte génère des séquences mixtes de texte et d'images. Il doit :
- Gérer les transitions entre texte et image avec des sauts de ligne (ASCII 0A).
- Maintenir la cohérence des dimensions des images générées.
- Assimiler la tâche de "copie" pour reproduire des motifs complexes.

Ce notebook utilise la puissance des GPU T4 x2 de Kaggle pour expérimenter avec des architectures avancées et des jeux de données volumineux, afin de relever les défis de la génération multimodale unifiée.

## Ressources du Projet

### Dataset
Pour ce projet, nous utiliserons le dataset **PixelBytes-Pokemon**, spécialement conçu pour cette tâche de génération multimodale. Ce dataset, créé par l'auteur de ce notebook, est disponible sur Hugging Face : [PixelBytes-Pokemon](https://huggingface.co/datasets/ffurfaro/PixelBytes-Pokemon). Il contient des séquences de texte et d'images de Pokémon, encodées de manière à permettre l'entraînement de notre modèle PixelByte sur des données multimodales.

### Implémentation
L'implémentation du modèle et les scripts d'entraînement sont disponibles dans le dépôt GitHub **Mamba-Bys** : [Mamba-Bys](https://github.com/fabienfrfr/Mamba-Bys). Ce dépôt contient le code source nécessaire pour reproduire les expériences, ainsi que des instructions détaillées sur la configuration et l'utilisation du modèle PixelByte.

In [1]:
!pip install -q git+https://github.com/fabienfrfr/PixelBytes.git@main

## à tester 

- le calcul de la fonction de cout et de la precision (bizarre qu'elle soit à 40 % des le debut)
- l'embedding PxByEmbed creer pour notre modele ! il y a des incoherences et trop de mémoire alloué !


In [2]:
from pixelbytes.mambabys import *
from pixelbytes.train import *
from pixelbytes.dataset import *
from pixelbytes.tokenizer import *

from datasets import load_dataset

  def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  def backward(ctx, dout):
  def forward(
  def backward(ctx, dout, *args):
  def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
  def backward(ctx, grad_output):
  def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
  def backward(ctx, dout, *args):


In [3]:
# init
hf_dataset = load_dataset("ffurfaro/PixelBytes-Pokemon")
ds = hf_dataset["train"].train_test_split(test_size=0.1)

train_dataset = PxByDataset(ds["train"]["pixelbyte"], seq_length=256, stride=32)
test_dataset = PxByDataset(ds["test"]["pixelbyte"], seq_length=256, stride=32)

pixelbyte = PixelBytesTokenizer()

Downloading readme:   0%|          | 0.00/426 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.59M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/964 [00:00<?, ? examples/s]



In [7]:
vocab_size = len(pixelbyte.vocab)
embedding_dim = 16
hidden_dim = 64
# choose model (simple, attention, mamba)
#model = SimpleSeqModel(vocab_size, embedding_dim, hidden_dim)
model = SimpleAttentionModel(vocab_size, embedding_dim, hidden_dim)
# Mamba
d_model = 64 # 256 is to hudge !
d_state = 16
n_layers = 8
config = MambaConfig(dim=d_model, d_state=d_state, depth=n_layers, vocab_size=vocab_size)
# import model
#model = BysMamba(config)

In [8]:
"""
trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    batch_size=32,
    learning_rate=0.001,
    num_epochs=200,
    save_dir='model_checkpoints',
    compile_model=False, #True,
    model_name="SimpleSeqModel",
    dataset_name="PixelBytes-Pokemon",
    eval_every=5  # Évaluer tous les 5 epochs
)
"""

trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    batch_size=32,
    learning_rate=0.001,
    num_epochs=200,
    save_dir='model_checkpoints',
    compile_model=False, #True,
    model_name="SimpleAttentionModel",
    dataset_name="PixelBytes-Pokemon",
    eval_every=5  # Évaluer tous les 5 epochs
)

"""
trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    batch_size=2,
    learning_rate=0.001,
    num_epochs=200,
    save_dir='model_checkpoints',
    compile_model=False, #True,
    model_name="Mambabys",
    dataset_name="PixelBytes-Pokemon",
    eval_every=5  # Évaluer tous les 5 epochs
)
"""

trainer.train()

Epoch 1/200: 100%|██████████| 728/728 [00:31<00:00, 23.23it/s]


Epoch 1/200, Average Training Loss: 2.3358


Epoch 2/200: 100%|██████████| 728/728 [00:31<00:00, 23.16it/s]


Epoch 2/200, Average Training Loss: 1.9939


Epoch 3/200: 100%|██████████| 728/728 [00:31<00:00, 22.93it/s]


Epoch 3/200, Average Training Loss: 1.8812


Epoch 4/200: 100%|██████████| 728/728 [00:31<00:00, 22.80it/s]


Epoch 4/200, Average Training Loss: 1.8000


Epoch 5/200: 100%|██████████| 728/728 [00:31<00:00, 22.94it/s]


Epoch 5/200, Average Training Loss: 1.7308
Test Accuracy: 45.62%, Test Loss: 1.8565
Checkpoint saved for epoch 5 (Best Model)


Epoch 6/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 6/200, Average Training Loss: 1.6697


Epoch 7/200: 100%|██████████| 728/728 [00:31<00:00, 22.85it/s]


Epoch 7/200, Average Training Loss: 1.6194


Epoch 8/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 8/200, Average Training Loss: 1.5748


Epoch 9/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 9/200, Average Training Loss: 1.5276


Epoch 10/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 10/200, Average Training Loss: 1.4798
Test Accuracy: 45.36%, Test Loss: 1.9054
Checkpoint saved for epoch 10


Epoch 11/200: 100%|██████████| 728/728 [00:31<00:00, 22.89it/s]


Epoch 11/200, Average Training Loss: 1.4411


Epoch 12/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 12/200, Average Training Loss: 1.4007


Epoch 13/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 13/200, Average Training Loss: 1.3648


Epoch 14/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 14/200, Average Training Loss: 1.3315


Epoch 15/200: 100%|██████████| 728/728 [00:31<00:00, 22.90it/s]


Epoch 15/200, Average Training Loss: 1.2965
Test Accuracy: 42.83%, Test Loss: 2.0594
Checkpoint saved for epoch 15


Epoch 16/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 16/200, Average Training Loss: 1.2700


Epoch 17/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 17/200, Average Training Loss: 1.2367


Epoch 18/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 18/200, Average Training Loss: 1.2103


Epoch 19/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 19/200, Average Training Loss: 1.1929


Epoch 20/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 20/200, Average Training Loss: 1.1719
Test Accuracy: 43.21%, Test Loss: 2.2347
Checkpoint saved for epoch 20


Epoch 21/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 21/200, Average Training Loss: 1.1512


Epoch 22/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 22/200, Average Training Loss: 1.1340


Epoch 23/200: 100%|██████████| 728/728 [00:31<00:00, 22.84it/s]


Epoch 23/200, Average Training Loss: 1.1133


Epoch 24/200: 100%|██████████| 728/728 [00:31<00:00, 22.85it/s]


Epoch 24/200, Average Training Loss: 1.1008


Epoch 25/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 25/200, Average Training Loss: 1.0798
Test Accuracy: 41.32%, Test Loss: 2.4690
Checkpoint saved for epoch 25


Epoch 26/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 26/200, Average Training Loss: 1.0684


Epoch 27/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 27/200, Average Training Loss: 1.0593


Epoch 28/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 28/200, Average Training Loss: 1.0440


Epoch 29/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 29/200, Average Training Loss: 1.0281


Epoch 30/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 30/200, Average Training Loss: 1.0189
Test Accuracy: 41.43%, Test Loss: 2.5646
Checkpoint saved for epoch 30


Epoch 31/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 31/200, Average Training Loss: 1.0151


Epoch 32/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 32/200, Average Training Loss: 0.9983


Epoch 33/200: 100%|██████████| 728/728 [00:31<00:00, 22.89it/s]


Epoch 33/200, Average Training Loss: 0.9815


Epoch 34/200: 100%|██████████| 728/728 [00:31<00:00, 22.89it/s]


Epoch 34/200, Average Training Loss: 0.9847


Epoch 35/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 35/200, Average Training Loss: 0.9655
Test Accuracy: 40.57%, Test Loss: 2.6671
Checkpoint saved for epoch 35


Epoch 36/200: 100%|██████████| 728/728 [00:31<00:00, 22.83it/s]


Epoch 36/200, Average Training Loss: 0.9720


Epoch 37/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 37/200, Average Training Loss: 0.9461


Epoch 38/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 38/200, Average Training Loss: 0.9433


Epoch 39/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 39/200, Average Training Loss: 0.9398


Epoch 40/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 40/200, Average Training Loss: 0.9298
Test Accuracy: 40.98%, Test Loss: 2.8299
Checkpoint saved for epoch 40


Epoch 41/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 41/200, Average Training Loss: 0.9232


Epoch 42/200: 100%|██████████| 728/728 [00:31<00:00, 22.89it/s]


Epoch 42/200, Average Training Loss: 0.9149


Epoch 43/200: 100%|██████████| 728/728 [00:31<00:00, 22.90it/s]


Epoch 43/200, Average Training Loss: 0.9123


Epoch 44/200: 100%|██████████| 728/728 [00:31<00:00, 22.89it/s]


Epoch 44/200, Average Training Loss: 0.8971


Epoch 45/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 45/200, Average Training Loss: 0.9022
Test Accuracy: 40.53%, Test Loss: 2.9406
Checkpoint saved for epoch 45


Epoch 46/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 46/200, Average Training Loss: 0.8860


Epoch 47/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 47/200, Average Training Loss: 0.8862


Epoch 48/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 48/200, Average Training Loss: 0.8839


Epoch 49/200: 100%|██████████| 728/728 [00:31<00:00, 22.89it/s]


Epoch 49/200, Average Training Loss: 0.8739


Epoch 50/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 50/200, Average Training Loss: 0.8619
Test Accuracy: 39.17%, Test Loss: 2.9995
Checkpoint saved for epoch 50


Epoch 51/200: 100%|██████████| 728/728 [00:31<00:00, 22.85it/s]


Epoch 51/200, Average Training Loss: 0.8538


Epoch 52/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 52/200, Average Training Loss: 0.8598


Epoch 53/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 53/200, Average Training Loss: 0.8452


Epoch 54/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 54/200, Average Training Loss: 0.8533


Epoch 55/200: 100%|██████████| 728/728 [00:31<00:00, 22.91it/s]


Epoch 55/200, Average Training Loss: 0.8381
Test Accuracy: 39.51%, Test Loss: 3.0694
Checkpoint saved for epoch 55


Epoch 56/200: 100%|██████████| 728/728 [00:31<00:00, 22.83it/s]


Epoch 56/200, Average Training Loss: 0.8319


Epoch 57/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 57/200, Average Training Loss: 0.8272


Epoch 58/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 58/200, Average Training Loss: 0.8233


Epoch 59/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 59/200, Average Training Loss: 0.8123


Epoch 60/200: 100%|██████████| 728/728 [00:31<00:00, 22.85it/s]


Epoch 60/200, Average Training Loss: 0.8228
Test Accuracy: 38.98%, Test Loss: 3.2202
Checkpoint saved for epoch 60


Epoch 61/200: 100%|██████████| 728/728 [00:31<00:00, 22.85it/s]


Epoch 61/200, Average Training Loss: 0.8155


Epoch 62/200: 100%|██████████| 728/728 [00:31<00:00, 22.89it/s]


Epoch 62/200, Average Training Loss: 0.8079


Epoch 63/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 63/200, Average Training Loss: 0.7988


Epoch 64/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 64/200, Average Training Loss: 0.7982


Epoch 65/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 65/200, Average Training Loss: 0.7973
Test Accuracy: 41.21%, Test Loss: 3.1992
Checkpoint saved for epoch 65


Epoch 66/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 66/200, Average Training Loss: 0.7892


Epoch 67/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 67/200, Average Training Loss: 0.7810


Epoch 68/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 68/200, Average Training Loss: 0.7820


Epoch 69/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 69/200, Average Training Loss: 0.7824


Epoch 70/200: 100%|██████████| 728/728 [00:31<00:00, 22.85it/s]


Epoch 70/200, Average Training Loss: 0.7734
Test Accuracy: 39.62%, Test Loss: 3.3090
Checkpoint saved for epoch 70


Epoch 71/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 71/200, Average Training Loss: 0.7699


Epoch 72/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 72/200, Average Training Loss: 0.7658


Epoch 73/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 73/200, Average Training Loss: 0.7631


Epoch 74/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 74/200, Average Training Loss: 0.7599


Epoch 75/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 75/200, Average Training Loss: 0.7648
Test Accuracy: 39.40%, Test Loss: 3.3294
Checkpoint saved for epoch 75


Epoch 76/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 76/200, Average Training Loss: 0.7605


Epoch 77/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 77/200, Average Training Loss: 0.7568


Epoch 78/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 78/200, Average Training Loss: 0.7469


Epoch 79/200: 100%|██████████| 728/728 [00:31<00:00, 22.89it/s]


Epoch 79/200, Average Training Loss: 0.7390


Epoch 80/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 80/200, Average Training Loss: 0.7528
Test Accuracy: 39.89%, Test Loss: 3.4614
Checkpoint saved for epoch 80


Epoch 81/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 81/200, Average Training Loss: 0.7466


Epoch 82/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 82/200, Average Training Loss: 0.7307


Epoch 83/200: 100%|██████████| 728/728 [00:31<00:00, 22.85it/s]


Epoch 83/200, Average Training Loss: 0.7256


Epoch 84/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 84/200, Average Training Loss: 0.7405


Epoch 85/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 85/200, Average Training Loss: 0.7333
Test Accuracy: 39.28%, Test Loss: 3.4750
Checkpoint saved for epoch 85


Epoch 86/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 86/200, Average Training Loss: 0.7292


Epoch 87/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 87/200, Average Training Loss: 0.7255


Epoch 88/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 88/200, Average Training Loss: 0.7236


Epoch 89/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 89/200, Average Training Loss: 0.7237


Epoch 90/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 90/200, Average Training Loss: 0.7125
Test Accuracy: 38.68%, Test Loss: 3.4513
Checkpoint saved for epoch 90


Epoch 91/200: 100%|██████████| 728/728 [00:31<00:00, 22.89it/s]


Epoch 91/200, Average Training Loss: 0.7117


Epoch 92/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 92/200, Average Training Loss: 0.7094


Epoch 93/200: 100%|██████████| 728/728 [00:31<00:00, 22.89it/s]


Epoch 93/200, Average Training Loss: 0.7160


Epoch 94/200: 100%|██████████| 728/728 [00:31<00:00, 22.89it/s]


Epoch 94/200, Average Training Loss: 0.7087


Epoch 95/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 95/200, Average Training Loss: 0.7051
Test Accuracy: 37.96%, Test Loss: 3.5900
Checkpoint saved for epoch 95


Epoch 96/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 96/200, Average Training Loss: 0.6942


Epoch 97/200: 100%|██████████| 728/728 [00:31<00:00, 22.90it/s]


Epoch 97/200, Average Training Loss: 0.7023


Epoch 98/200: 100%|██████████| 728/728 [00:31<00:00, 22.87it/s]


Epoch 98/200, Average Training Loss: 0.6989


Epoch 99/200: 100%|██████████| 728/728 [00:31<00:00, 22.90it/s]


Epoch 99/200, Average Training Loss: 0.7004


Epoch 100/200: 100%|██████████| 728/728 [00:31<00:00, 22.90it/s]


Epoch 100/200, Average Training Loss: 0.6914
Test Accuracy: 38.38%, Test Loss: 3.5958
Checkpoint saved for epoch 100


Epoch 101/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 101/200, Average Training Loss: 0.6866


Epoch 102/200: 100%|██████████| 728/728 [00:31<00:00, 22.89it/s]


Epoch 102/200, Average Training Loss: 0.6784


Epoch 103/200: 100%|██████████| 728/728 [00:31<00:00, 22.89it/s]


Epoch 103/200, Average Training Loss: 0.6977


Epoch 104/200: 100%|██████████| 728/728 [00:31<00:00, 22.90it/s]


Epoch 104/200, Average Training Loss: 0.6914


Epoch 105/200: 100%|██████████| 728/728 [00:31<00:00, 22.89it/s]


Epoch 105/200, Average Training Loss: 0.6723
Test Accuracy: 39.77%, Test Loss: 3.6395
Checkpoint saved for epoch 105


Epoch 106/200: 100%|██████████| 728/728 [00:31<00:00, 22.86it/s]


Epoch 106/200, Average Training Loss: 0.6748


Epoch 107/200: 100%|██████████| 728/728 [00:31<00:00, 22.89it/s]


Epoch 107/200, Average Training Loss: 0.6724


Epoch 108/200: 100%|██████████| 728/728 [00:31<00:00, 22.90it/s]


Epoch 108/200, Average Training Loss: 0.6785


Epoch 109/200: 100%|██████████| 728/728 [00:31<00:00, 22.88it/s]


Epoch 109/200, Average Training Loss: 0.6778


Epoch 110/200:  49%|████▊     | 354/728 [00:15<00:16, 22.80it/s]


KeyboardInterrupt: 