<a href="https://colab.research.google.com/github/benjaminpolydeq/ARSLM/blob/main/ARSLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

Traceback (most recent call last):
  File "/data/user/0/ru.iiec.pydroid3/files/accomp_files/iiec_run/iiec_run.py", line 31, in <module>
    start(fakepyfile,mainpyfile)
    ~~~~~^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/user/0/ru.iiec.pydroid3/files/accomp_files/iiec_run/iiec_run.py", line 30, in start
    exec(open(mainpyfile).read(),  __main__.__dict__)
    ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 13, in <module>
ModuleNotFoundError: No module named 'torch'

[Program finished]

SyntaxError: invalid syntax. Perhaps you forgot a comma? (ipython-input-2882163312.py, line 1)

In [None]:
class ARSCell(nn.Module):
    """
    ARSCell: computes next hidden state from h_{t-2}, h_{t-1}, and input embedding x_{t-1}
    Implements an 'adapt' mechanism inspired by BenPolySeq:
      h_t = h_{t-1} + gate * transform(h_{t-1}, h_{t-2}, x)
    where gate is a function of the difference (h_{t-1} - h_{t-2}) and context.
    """
    def __init__(self, emb_dim: int, hidden_dim: int, dropout_prob: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        # transform candidate from concat([h_{t-1}, h_{t-2}, x])
        self.candidate_mlp = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim)
        )
        # gate network produces scalar gating factor in (0,1)
        self.gate_net = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        # small residual projection
        self.res_proj = nn.Linear(emb_dim, hidden_dim)
        # Dropout layer
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, h_prev2, h_prev1, x_embed):
        # h_prev2, h_prev1: (batch, hidden_dim)
        # x_embed: (batch, emb_dim)
        # difference signal
        diff = h_prev1 - h_prev2    # (batch, hidden_dim)
        # context vector (concat)
        ctx = torch.cat([h_prev1, h_prev2, x_embed], dim=-1)
        candidate = self.candidate_mlp(ctx)                       # (batch, hidden_dim)
        gate = self.gate_net(ctx).squeeze(-1)                     # (batch,)
        # incorporate residual from input embedding
        residual = self.res_proj(x_embed)
        # ARS-style update: adaptive addition of candidate scaled by gate
        h_t = h_prev1 + gate.unsqueeze(-1) * candidate + 0.1 * residual
        # Apply dropout
        h_t = self.dropout(h_t)
        # optional normalization
        h_t = F.layer_norm(h_t, (self.hidden_dim,))
        return h_t, gate

## Explication détaillée du modèle ARSLM

Le modèle ARSLM (Adaptive Recurrent State Language Model) est un prototype de modèle linguistique basé sur le concept d'états récurrents adaptatifs (ARS), inspiré par les séquences adaptatives BenPolySeq. Il est conçu pour être léger, explicable et extensible, ce qui le rend adapté à la recherche et à l'expérimentation.

Voici une décomposition de ses composants clés :

### 1. Tokenizer Avancé (`AdvancedTokenizer`)

Ce composant gère la conversion du texte en séquences d'identifiants (tokens) et vice-versa. Dans cette version améliorée, il utilise la bibliothèque `transformers` de Hugging Face, spécifiquement le `BertTokenizer`.

*   **Pourquoi un tokenizer avancé ?** Contrairement à un simple tokenizer qui pourrait se contenter de diviser le texte par espaces et de gérer un vocabulaire limité, un tokenizer basé sur les sous-mots (comme ceux utilisés par BERT) permet de :
    *   Mieux gérer les mots inconnus en les décomposant en sous-unités (par exemple, "tokenization" pourrait être décomposé en "token", "iza", "tion").
    *   Réduire la taille globale du vocabulaire tout en représentant une plus grande variété de mots.
    *   Gérer la ponctuation, les majuscules/minuscules et d'autres subtilités linguistiques.
*   **Fonctionnement :** L'`AdvancedTokenizer` utilise un vocabulaire pré-entraîné (ici, celui de `bert-base-uncased` par défaut). Il fournit des méthodes `encode` pour convertir le texte en identifiants et `decode` pour convertir les identifiants en texte. Il définit également des identifiants spéciaux pour le padding (`<pad>`), les mots inconnus (`<unk>`), le début de séquence (`<bos>`, utilisant `[CLS]` de BERT) et la fin de séquence (`<eos>`, utilisant `[SEP]` de BERT).

### 2. Cellule ARS (`ARSCell`)

C'est le cœur adaptatif du modèle. Une `ARSCell` calcule le prochain état caché (`h_t`) en se basant sur les deux états cachés précédents (`h_{t-2}`, `h_{t-1}`) et l'embedding de l'entrée courante (`x_embed`).

*   **Mécanisme d'Adaptation :** L'idée principale, inspirée par BenPolySeq, est que la mise à jour de l'état caché n'est pas une simple combinaison linéaire ou une porte fixe (comme dans les RNN ou LSTM basiques), mais une addition adaptative d'un "candidat" pondéré par un "gate".
    *   **Signal de Différence :** La cellule utilise la différence entre les deux états cachés précédents (`h_{t-1} - h_{t-2}`). Ce signal est censé capturer l'évolution ou la dynamique de l'état récurrent.
    *   **Contexte :** Un vecteur de contexte est créé en concaténant les deux états cachés précédents et l'embedding d'entrée (`[h_{t-1}, h_{t-2}, x_embed]`).
    *   **Réseau Candidat (`candidate_mlp`) :** Un MLP (réseau de neurones multi-couches) prend le contexte en entrée et génère un "candidat" pour la mise à jour de l'état caché.
    *   **Réseau de Porte (`gate_net`) :** Un autre réseau (comprenant une fonction sigmoïde à la fin) prend également le contexte en entrée et produit une valeur scalaire entre 0 et 1. Cette valeur agit comme un "gate" qui contrôle l'influence du candidat sur la mise à jour de l'état caché.
    *   **Résiduel :** Un petit résiduel de l'embedding d'entrée est ajouté pour aider à propager directement l'information de l'entrée.
    *   **Mise à Jour :** L'état caché suivant est calculé comme `h_t = h_{t-1} + gate * candidate + 0.1 * residual`. Cette formule montre comment la cellule *adapte* sa mise à jour en fonction du signal de différence et du contexte via le gate.
    *   **Dropout et Normalisation :** Une couche de dropout est appliquée pour la régularisation, et une normalisation de couche (`F.layer_norm`) est utilisée pour stabiliser l'entraînement.

### 3. Modèle ARSLM (`ARSLM`)

Ce module encapsule l'ensemble de l'architecture, combinant l'embedding, plusieurs couches d'`ARSCell`, un mécanisme d'attention et une couche de sortie (head).

*   **Embedding (`nn.Embedding`) :** Convertit les identifiants de tokens en vecteurs denses de dimension `emb_dim`.
*   **Couches d'ARSCell (`nn.ModuleList`) :** Le modèle utilise une liste de `num_layers` ARSCells empilées. L'entrée de la première cellule est l'embedding, et l'entrée des cellules subséquentes est l'état caché de la couche précédente.
*   **Mécanisme d'Attention :** Un mécanisme d'attention additive simple est appliqué après la dernière couche d'ARSCell.
    *   **Attention Causale :** Pendant la passe avant (forward), l'attention est causale, ce qui signifie que chaque token ne peut prêter attention qu'aux tokens précédents dans la séquence.
    *   **Calcul :** Pour chaque pas temporel `t`, l'état caché de la dernière couche (`last_layer_h_t`) sert de requête. L'attention est calculée sur l'historique de tous les états cachés de la dernière couche jusqu'au pas `t` (`all_last_layer_hidden_states`). Les scores d'attention sont calculés via un petit réseau (`self.attention`), puis normalisés avec un softmax. Un vecteur de contexte est obtenu en pondérant l'historique par les poids d'attention.
    *   **Combinaison :** Le vecteur de contexte est ajouté à l'état caché courant de la dernière couche (`attended_h_t = last_layer_h_t + context_vector`). Cette combinaison permet au modèle de tirer parti des informations pertinentes de la séquence passée.
*   **Couche de Sortie (`head`) :** Une couche linéaire qui prend l'état caché combiné (avec attention) de la dernière couche et le projette sur la taille du vocabulaire (`self.vocab_size`). Les sorties de cette couche sont les logits, qui représentent la probabilité non normalisée de chaque token dans le vocabulaire d'être le prochain token.

### 4. Fonction `forward`

Cette fonction décrit comment une séquence d'identifiants (`input_ids`) est traitée par le modèle pour produire les logits et les gates. Elle itère sur chaque pas temporel de la séquence d'entrée, calculant séquentiellement les états cachés pour chaque couche d'ARSCell et appliquant l'attention et la couche de sortie à chaque pas après la dernière couche.

### 5. Fonction `generate`

Cette fonction implémente le processus de génération de texte. À partir d'une séquence de contexte initiale (`idx`), le modèle génère séquentiellement de nouveaux tokens. Pour chaque nouveau token à générer :
*   Il prend l'embedding du dernier token généré.
*   Il utilise les états cachés mis à jour des étapes précédentes pour calculer le nouvel état caché pour chaque couche d'ARSCell.
*   Il applique le mécanisme d'attention sur l'historique des états cachés de la dernière couche (incluant le nouvel état).
*   Il utilise la couche de sortie pour obtenir les logits pour le prochain token.
*   Il applique un échantillonnage (ici, `torch.multinomial` avec une température et optionnellement `top_k`) pour sélectionner le prochain token en fonction des probabilités calculées à partir des logits.
*   Le nouveau token est ajouté à la séquence de sortie, et le processus se répète jusqu'à ce que le nombre maximal de nouveaux tokens soit atteint.

### 6. Utilitaires d'Entraînement/Démo (`collate_batch`, `train_demo`)

*   **`collate_batch` :** Prépare les données texte pour l'entraînement en les encodant à l'aide du tokenizer, en ajoutant les tokens de début/fin de séquence et en effectuant le padding pour que toutes les séquences d'un batch aient la même longueur.
*   **`train_demo` :** Une fonction de démonstration qui initialise le tokenizer, le modèle, l'optimiseur et la fonction de perte. Elle charge un petit corpus (ou un fichier si spécifié), prépare les données, entraîne le modèle sur un nombre spécifié d'époques, imprime la perte et le taux d'apprentissage, calcule la perplexité et effectue une petite démo de génération de texte.

En résumé, l'ARSLM combine l'idée d'états récurrents adaptatifs avec une architecture multi-couches et un mécanisme d'attention causale pour créer un modèle linguistique capable d'apprendre des dépendances séquentielles de manière flexible. L'aspect "adaptatif" via le mécanisme de gate dans l'ARSCell est la caractéristique distinctive, permettant potentiellement au modèle d'ajuster son comportement en fonction de la dynamique interne de sa mémoire récurrente.

In [None]:
# Install PyTorch
!pip install torch

In [None]:

import torch

# Create a tensor
x = torch.rand(5, 3)
print(x)

# Perform an operation
y = torch.rand(5, 3)
print(x + y)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List

# Assuming AdvancedTokenizer and ARSCell are defined or imported elsewhere
# For a complete working script, please refer to cell 64161967

# ----------------------------
# ARSLM model: embed -> ARSCell(s) -> attention -> head
# ----------------------------
class ARSLM(nn.Module):
    # Modified to accept a tokenizer instance, include attention, and support multiple layers
    def __init__(self, tokenizer, emb_dim=64, hidden_dim=128, num_layers=2):
        super().__init__()
        self.tokenizer = tokenizer # Store tokenizer instance
        self.vocab_size = len(tokenizer) # Get vocab size from tokenizer
        self.num_layers = num_layers
        self.emb = nn.Embedding(self.vocab_size, emb_dim)
        # Create a list of ARSCells
        self.cells = nn.ModuleList([ARSCell(emb_dim if i == 0 else hidden_dim, hidden_dim) for i in range(num_layers)])

        # Simple additive attention mechanism, applied after the last layer's hidden states
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        self.head = nn.Linear(hidden_dim, self.vocab_size)

    def forward(self, input_ids):
        """
        input_ids: (batch, seq_len)
        We compute autoregressively, passing hidden states through layers.
        Returns logits (batch, seq_len, vocab).
        """
        bsz, seq_len = input_ids.shape
        emb = self.emb(input_ids)  # (b, seq, emb)
        device = emb.device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        # We need two previous states for each layer's ARSCell
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []
        logits = []
        gates = [] # Store gates from the last layer

        for t in range(seq_len):
            x_t = emb[:, t, :] # embedding at position t
            h_t_input = x_t # Input to the first layer is the embedding

            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                # The input to subsequent layers is the hidden state of the previous layer
                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1] # Use the output of the previous layer

                h_t, gate = cell(h_prev2, h_prev1, h_t_input)

                # Store current layer's hidden state as input for the next layer in the next time step
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next time step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1: # Only store gates from the last layer
                    gates.append(gate.unsqueeze(1))

            # After processing all layers for time step t, the output is the hidden state of the last layer
            last_layer_h_t = current_layer_hidden_states[-1]
            all_last_layer_hidden_states.append(last_layer_h_t.unsqueeze(1))

            # Apply attention and head after the last layer's hidden state
            # Causal Attention: attend over hidden states up to current time step t from the last layer
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, t+1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim) - query is the current last layer hidden state

            # Calculate attention scores
            scores = self.attention(last_layer_history) # (b, t+1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, t+1, 1)

            # Apply attention weights
            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            # Combine last layer's current hidden state with context vector before the head
            attended_h_t = last_layer_h_t + context_vector

            logit = self.head(attended_h_t) # predict token at t (or next token)
            logits.append(logit.unsqueeze(1))


        logits = torch.cat(logits, dim=1)          # (b, seq, vocab)
        gates = torch.cat(gates, dim=1)            # (b, seq) # Gates from the last layer
        return logits, gates

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=20, temperature=1.0):
        """
        idx: (batch, current_seq_len) - input context (token ids)
        Returns extended token ids.
        Modified for multiple layers and attention during generation.
        """
        self.eval()
        bsz, seq_len = idx.shape
        device = next(self.parameters()).device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []

        # Process initial context to get starting hidden states for each layer
        input_ids = idx.clone()
        emb = self.emb(input_ids) # (b, seq_len, emb_dim)

        for t in range(seq_len):
            x_t = emb[:,t,:] # embedding at position t
            h_t_input = x_t

            current_layer_hidden_states = []
            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


        out_ids = input_ids.tolist()

        for _ in range(max_new_tokens):
            # Input for the first layer's ARSCell is the embedding of the last generated token
            last_token_ids = torch.tensor([ids[-1] for ids in out_ids], dtype=torch.long, device=device).unsqueeze(1)
            x_embed = self.emb(last_token_ids).squeeze(1) # (b, emb_dim)

            h_t_input = x_embed
            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next generation step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


            # After processing all layers for the new token, apply attention and head
            last_layer_h_t = current_layer_hidden_states[-1] # Hidden state of the last layer for the new token

            # Causal Attention over all last layer hidden states history
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, current_seq_len + 1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim)

            scores = self.attention(last_layer_history) # (b, current_seq_len + 1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, current_seq_len + 1, 1)

            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            attended_h_t = last_layer_h_t + context_vector

            logits = self.head(attended_h_t) / max(1e-6, temperature)
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)

            # append the newly generated token
            for i in range(bsz):
                out_ids[i].append(int(next_token[i].item()))

        return out_ids

# ----------------------------
# Training / demo utilities
# ----------------------------
# Modified to accept and use the tokenizer instance
def collate_batch(tokenizer, texts: List[str], device):
    # encode, add bos/eos
    encoded = []
    for t in texts:
        # Use the tokenizer's encode method and add special tokens if needed
        # BertTokenizer.encode already handles [CLS] (BOS) and [SEP] (EOS)
        # We might need to adjust based on whether the model expects them
        # For now, let's use the encode method without adding special tokens
        # and add them manually if the model architecture requires it explicitly.
        # However, for BERT-like models, the special tokens are typically added by encode_plus or encode.
        # Let's stick to encode for now and add BOS/EOS manually if necessary for the ARSLM structure.
        # Given the original SimpleTokenizer added BOS/EOS, let's do the same here.
        ids = [tokenizer.bos_token_id] + tokenizer.encode(t) + [tokenizer.eos_token_id]
        encoded.append(torch.tensor(ids, dtype=torch.long, device=device))
    # pad to max len
    max_len = max([x.size(0) for x in encoded])
    batch = torch.full((len(encoded), max_len), tokenizer.pad_token_id, dtype=torch.long, device=device) # Use pad_token_id
    for i, x in enumerate(encoded):
        batch[i, :x.size(0)] = x
    return batch

# Modified to use AdvancedTokenizer and handle multiple layers
def train_demo():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # toy corpus
    texts = [
        "hello world this is ars",
        "the system adapts to its history",
        "benpolyseq demonstrates adaptive sequences",
        "ars can inspire new network protocols",
        "self optimizing systems are possible"
    ]
    # Instantiate AdvancedTokenizer
    tokenizer = AdvancedTokenizer()
    # Vocab size is now determined by the tokenizer
    vocab_size = len(tokenizer)
    # Instantiate ARSLM with multiple layers
    model = ARSLM(tokenizer, emb_dim=64, hidden_dim=128, num_layers=2).to(device) # Added num_layers
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    # Use ignore_index from the tokenizer
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    # create training batches by shifting: predict next token at each position
    n_epochs = 200
    batch = collate_batch(tokenizer, texts, device) # Pass tokenizer to collate_batch
    # targets: same as input (language modeling)
    inputs = batch[:, :-1]
    targets = batch[:, 1:]
    for epoch in range(n_epochs):
        model.train()
        logits, gates = model(inputs)
        # flatten
        b, seq, v = logits.shape
        loss = loss_fn(logits.view(b*seq, v), targets.reshape(b*seq))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if (epoch+1) % 50 == 0 or epoch==0:
            print(f"Epoch {epoch+1}/{n_epochs} — loss: {loss.item():.4f}")

    # demo generation
    context = "hello world"
    # Use the tokenizer's encode method for the context
    idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(context) ], dtype=torch.long, device=device)
    out_ids = model.generate(idx, max_new_tokens=15, temperature=1.0)[0]
    print("=== Generated ===")
    # Use the tokenizer's decode method for the output
    print(tokenizer.decode(out_ids))
    return model, tokenizer

# ----------------------------
# Run demo if main
# ----------------------------
if __name__ == "__main__":
    random.seed(0)
    torch.manual_seed(0)
    model, tokenizer = train_demo()

NameError: name 'random' is not defined

Add `%load_ext cudf.pandas` before importing pandas to speed up operations using GPU

In [None]:
%load_ext cudf.pandas
import pandas as pd
import numpy as np

# Randomly generated dataset of parking violations-
# Define the number of rows
num_rows = 1000000

states = ["NY", "NJ", "CA", "TX"]
violations = ["Double Parking", "Expired Meter", "No Parking",
              "Fire Hydrant", "Bus Stop"]
vehicle_types = ["SUBN", "SDN"]

# Create a date range
start_date = "2022-01-01"
end_date = "2022-12-31"
dates = pd.date_range(start=start_date, end=end_date, freq='D')

# Generate random data
data = {
    "Registration State": np.random.choice(states, size=num_rows),
    "Violation Description": np.random.choice(violations, size=num_rows),
    "Vehicle Body Type": np.random.choice(vehicle_types, size=num_rows),
    "Issue Date": np.random.choice(dates, size=num_rows),
    "Ticket Number": np.random.randint(1000000000, 9999999999, size=num_rows)
}

# Create a DataFrame
df = pd.DataFrame(data)

# Which parking violation is most commonly committed by vehicles from various U.S states?

(df[["Registration State", "Violation Description"]]  # get only these two columns
 .value_counts()  # get the count of offences per state and per type of offence
 .groupby("Registration State")  # group by state
 .head(1)  # get the first row in each group (the type of offence with the largest count)
 .sort_index()  # sort by state name
 .reset_index()
)

In [None]:
"""
ARSLM prototype (toy engine)
Author: Benjamin Amaad Kama (concept)
Requirements: Python 3.8+, PyTorch, transformers

This is a minimal, fully-contained prototype of an ARS-based language model.
It is for research/experimentation: lightweight, explainable, and extendable.
"""

import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from transformers import BertTokenizer # Import BertTokenizer
from torch.optim.lr_scheduler import StepLR # Import StepLR
import os # Import os for file operations

# ----------------------------
# Advanced Tokenizer (using transformers)
# ----------------------------
class AdvancedTokenizer:
    def __init__(self, vocab_file=None):
        # Using a pre-trained tokenizer's vocabulary for demonstration
        # In a real scenario, you might train a new tokenizer on your specific corpus
        if vocab_file:
             self.tokenizer = BertTokenizer(vocab_file)
        else:
             self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.vocab = list(self.tokenizer.vocab.keys())
        self.stoi = self.tokenizer.vocab
        self.itos = {i: w for w, i in self.stoi.items()}

        # Define special tokens explicitly for clarity and compatibility
        # These might vary depending on the chosen tokenizer
        self.pad_token = self.tokenizer.pad_token
        self.unk_token = self.tokenizer.unk_token
        self.bos_token = self.tokenizer.cls_token # Using [CLS] as BOS for BERT-like tokenizers
        self.eos_token = self.tokenizer.sep_token # Using [SEP] as EOS for BERT-like tokenizers

        self.pad_token_id = self.tokenizer.pad_token_id
        self.unk_token_id = self.tokenizer.unk_token_id
        self.bos_token_id = self.tokenizer.cls_token_id
        self.eos_token_id = self.tokenizer.sep_token_id


    def encode(self, text: str) -> List[int]:
        # The tokenizer handles BOS/EOS internally with add_special_tokens=True
        # We can also manually add them if needed for specific model architectures
        return self.tokenizer.encode(text, add_special_tokens=False)

    def decode(self, ids: List[int]) -> str:
        # skip_special_tokens=True prevents decoding [CLS], [SEP], [PAD] etc.
        # We might adjust this based on how the model handles special tokens
        return self.tokenizer.decode(ids, skip_special_tokens=True)

    def __len__(self):
        return len(self.vocab)


# ----------------------------
# ARSCell: core adaptive cell
# ----------------------------
class ARSCell(nn.Module):
    """
    ARSCell: computes next hidden state from h_{t-2}, h_{t-1}, and input embedding x_{t-1}
    Implements an 'adapt' mechanism inspired by BenPolySeq:
      h_t = h_{t-1} + gate * transform(h_{t-1}, h_{t-2}, x)
    where gate is a function of the difference (h_{t-1} - h_{t-2}) and context.
    """
    def __init__(self, emb_dim: int, hidden_dim: int, dropout_prob: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        # transform candidate from concat([h_{t-1}, h_{t-2}, x])
        self.candidate_mlp = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim)
        )
        # gate network produces scalar gating factor in (0,1)
        self.gate_net = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        # small residual projection
        self.res_proj = nn.Linear(emb_dim, hidden_dim)
        # Dropout layer
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, h_prev2, h_prev1, x_embed):
        # h_prev2, h_prev1: (batch, hidden_dim)
        # x_embed: (batch, emb_dim)
        # difference signal
        diff = h_prev1 - h_prev2    # (batch, hidden_dim)
        # context vector (concat)
        ctx = torch.cat([h_prev1, h_prev2, x_embed], dim=-1)
        candidate = self.candidate_mlp(ctx)                       # (batch, hidden_dim)
        gate = self.gate_net(ctx).squeeze(-1)                     # (batch,)
        # incorporate residual from input embedding
        residual = self.res_proj(x_embed)
        # ARS-style update: adaptive addition of candidate scaled by gate
        h_t = h_prev1 + gate.unsqueeze(-1) * candidate + 0.1 * residual
        # Apply dropout
        h_t = self.dropout(h_t)
        # optional normalization
        h_t = F.layer_norm(h_t, (self.hidden_dim,))
        return h_t, gate

# ----------------------------
# ARSLM model: embed -> ARSCell(s) -> attention -> head
# ----------------------------
class ARSLM(nn.Module):
    # Modified to accept a tokenizer instance, include attention, and support multiple layers
    def __init__(self, tokenizer, emb_dim=64, hidden_dim=128, num_layers=2):
        super().__init__()
        self.tokenizer = tokenizer # Store tokenizer instance
        self.vocab_size = len(tokenizer) # Get vocab size from tokenizer
        self.num_layers = num_layers
        self.emb = nn.Embedding(self.vocab_size, emb_dim)
        # Create a list of ARSCells
        # Pass dropout_prob to ARSCell constructor
        self.cells = nn.ModuleList([ARSCell(emb_dim if i == 0 else hidden_dim, hidden_dim, dropout_prob=0.1) for i in range(num_layers)])

        # Simple additive attention mechanism, applied after the last layer's hidden states
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        self.head = nn.Linear(hidden_dim, self.vocab_size)

    def forward(self, input_ids):
        """
        input_ids: (batch, seq_len)
        We compute autoregressively, passing hidden states through layers.
        Returns logits (batch, seq_len, vocab).
        """
        bsz, seq_len = input_ids.shape
        emb = self.emb(input_ids)  # (b, seq, emb)
        device = emb.device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        # We need two previous states for each layer's ARSCell
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []
        logits = []
        gates = [] # Store gates from the last layer

        for t in range(seq_len):
            x_t = emb[:, t, :] # embedding at position t
            h_t_input = x_t # Input to the first layer is the embedding

            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                # The input to subsequent layers is the hidden state of the previous layer
                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1] # Use the output of the previous layer

                h_t, gate = cell(h_prev2, h_prev1, h_t_input)

                # Store current layer's hidden state as input for the next layer in the next time step
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next time step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1: # Only store gates from the last layer
                    gates.append(gate.unsqueeze(1))

            # After processing all layers for time step t, the output is the hidden state of the last layer
            last_layer_h_t = current_layer_hidden_states[-1]
            all_last_layer_hidden_states.append(last_layer_h_t.unsqueeze(1))

            # Apply attention and head after the last layer's hidden state
            # Causal Attention: attend over hidden states up to current time step t from the last layer
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, t+1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim) - query is the current last layer hidden state

            # Calculate attention scores
            scores = self.attention(last_layer_history) # (b, t+1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, t+1, 1)

            # Apply attention weights
            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            # Combine last layer's current hidden state with context vector before the head
            attended_h_t = last_layer_h_t + context_vector

            logit = self.head(attended_h_t) # predict token at t (or next token)
            logits.append(logit.unsqueeze(1))


        logits = torch.cat(logits, dim=1)          # (b, seq, vocab)
        gates = torch.cat(gates, dim=1)            # (b, seq) # Gates from the last layer
        return logits, gates

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=20, temperature=1.0, top_k=None):
        """
        idx: (batch, current_seq_len) - input context (token ids)
        Returns extended token ids.
        Modified for multiple layers, attention during generation, and optional top-k sampling.
        """
        self.eval()
        bsz, seq_len = idx.shape
        device = next(self.parameters()).device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []

        # Process initial context to get starting hidden states for each layer
        input_ids = idx.clone()
        emb = self.emb(input_ids) # (b, seq_len, emb_dim)

        for t in range(seq_len):
            x_t = emb[:,t,:] # embedding at position t
            h_t_input = x_t

            current_layer_hidden_states = []
            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


        out_ids = input_ids.tolist()

        for _ in range(max_new_tokens):
            # Input for the first layer's ARSCell is the embedding of the last generated token
            last_token_ids = torch.tensor([ids[-1] for ids in out_ids], dtype=torch.long, device=device).unsqueeze(1)
            x_embed = self.emb(last_token_ids).squeeze(1) # (b, emb_dim)

            h_t_input = x_embed
            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next generation step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


            # After processing all layers for the new token, apply attention and head
            last_layer_h_t = current_layer_hidden_states[-1] # Hidden state of the last layer for the new token

            # Causal Attention over all last layer hidden states history
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, current_seq_len + 1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim)

            scores = self.attention(last_layer_history) # (b, current_seq_len + 1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, current_seq_len + 1, 1)

            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            attended_h_t = last_layer_h_t + context_vector

            logits = self.head(attended_h_t) / max(1e-6, temperature)

            # Apply top-k sampling
            if top_k is not None:
                # get top k logits
                top_k = min(max(top_k, 1), logits.size(-1))  # Clamp k to be at least 1 and at most vocab size
                # Remove all tokens with a probability less than the top-k'th token's probability
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')

            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)


            # append the newly generated token
            for i in range(bsz):
                out_ids[i].append(int(next_token[i].item()))

            # For the next step, we need the hidden state corresponding to the newly generated token.
            # We'll simulate computing the next hidden state using the ARSCell with the *newly generated token's embedding*
            # and the last two hidden states (h_prev2, h_prev1).
            new_token_embed = self.emb(next_token.unsqueeze(1)).squeeze(1) # (b, emb_dim)

            # Re-running the layers for the new token to get updated hidden states
            h_t_input_next = new_token_embed
            current_layer_hidden_states_next = []
            for layer in range(self.num_layers):
                cell = self.cells[layer]
                # Use the updated h_prev2_list and h_prev1_list from the previous time step's processing
                h_prev2_l = h_prev2_list[layer]
                h_prev1_l = h_prev1_list[layer]

                if layer > 0:
                    h_t_input_next = current_layer_hidden_states_next[-1]

                h_t_next_l, _ = cell(h_prev2_l, h_prev1_l, h_t_input_next)
                current_layer_hidden_states_next.append(h_t_next_l)

                # Update previous hidden states for the next generation step for this layer
                h_prev2_list[layer] = h_prev1_list[layer]
                h_prev1_list[layer] = h_t_next_l

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t_next_l.unsqueeze(1))


        return out_ids

# ----------------------------
# Training / demo utilities
# ----------------------------
# Modified to accept and use the tokenizer instance
def collate_batch(tokenizer, texts: List[str], device):
    # encode, add bos/eos
    encoded = []
    for t in texts:
        # Use the tokenizer's encode method and add special tokens if needed
        # BertTokenizer.encode already handles [CLS] (BOS) and [SEP] (EOS)
        # We might need to adjust based on whether the model expects them
        # For now, let's use the encode method without adding special tokens
        # and add them manually if the model architecture requires it explicitly.
        # However, for BERT-like models, the special tokens are typically added by encode_plus or encode.
        # Let's stick to encode for now and add BOS/EOS manually if necessary for the ARSLM structure.
        # Given the original SimpleTokenizer added BOS/EOS, let's do the same here.
        ids = [tokenizer.bos_token_id] + tokenizer.encode(t) + [tokenizer.eos_token_id]
        encoded.append(torch.tensor(ids, dtype=torch.long, device=device))
    # pad to max len
    max_len = max([x.size(0) for x in encoded])
    batch = torch.full((len(encoded), max_len), tokenizer.pad_token_id, dtype=torch.long, device=device) # Use pad_token_id
    for i, x in enumerate(encoded):
        batch[i, :x.size(0)] = x
    return batch

# Modified to use AdvancedTokenizer and handle multiple layers
def train_demo():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # toy corpus
    # texts = [
    #     "hello world this is ars",
    #     "the system adapts to its history",
    #     "benpolyseq demonstrates adaptive sequences",
    #     "ars can inspire new network protocols",
    #     "self optimizing systems are possible"
    # ]
    # --- Load text from a file (example) ---
    corpus_file = "/tmp/large_corpus.txt" # Replace with your corpus file path
    texts = []
    if os.path.exists(corpus_file):
        print(f"Loading corpus from {corpus_file}...")
        with open(corpus_file, 'r', encoding='utf-8') as f:
            # Read a limited number of lines for demo, for full training, iterate or use data loaders
            for i, line in enumerate(f):
                if i < 1000: # Load first 1000 lines as an example
                    texts.append(line.strip())
                else:
                    break
        print(f"Loaded {len(texts)} lines.")
    else:
        print(f"Corpus file not found at {corpus_file}. Using a small toy corpus instead.")
        texts = [
            "hello world this is ars",
            "the system adapts to its history",
            "benpolyseq demonstrates adaptive sequences",
            "ars can inspire new network protocols",
            "self optimizing systems are possible"
        ]


    # Instantiate AdvancedTokenizer
    tokenizer = AdvancedTokenizer()
    # Vocab size is now determined by the tokenizer
    vocab_size = len(tokenizer)
    # Instantiate ARSLM with multiple layers
    model = ARSLM(tokenizer, emb_dim=64, hidden_dim=128, num_layers=2).to(device) # Added num_layers
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    # Add a StepLR scheduler
    scheduler = StepLR(optimizer, step_size=50, gamma=0.5) # Reduce LR by half every 50 epochs
    # Use ignore_index from the tokenizer
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    # create training batches by shifting: predict next token at each position
    n_epochs = 200
    batch = collate_batch(tokenizer, texts, device) # Pass tokenizer to collate_batch
    # targets: same as input (language modeling)
    inputs = batch[:, :-1]
    targets = batch[:, 1:]

    # Simple check if batch is empty due to no text loaded
    if inputs.numel() == 0:
        print("No training data loaded. Skipping training.")
        return model, tokenizer


    for epoch in range(n_epochs):
        model.train()
        logits, gates = model(inputs)
        # flatten
        b, seq, v = logits.shape
        loss = loss_fn(logits.view(b*seq, v), targets.reshape(b*seq))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        # Step the scheduler
        scheduler.step()
        if (epoch+1) % 50 == 0 or epoch==0:
            print(f"Epoch {epoch+1}/{n_epochs} — loss: {loss.item():.4f}")
            print(f"Current learning rate: {scheduler.get_last_lr()[0]:.6f}") # Print current LR

    print("\n=== Evaluation ===")
    model.eval()
    with torch.no_grad():
        logits, _ = model(inputs) # Use the same inputs for simplicity in this demo
        # Calculate perplexity
        # Perplexity = exp(average negative log-likelihood)
        loss = loss_fn(logits.view(-1, v), targets.reshape(-1))
        perplexity = torch.exp(loss)
        print(f"Perplexity on training data: {perplexity.item():.4f}")


    # demo generation
    context = "hello world"
    # Use the tokenizer's encode method for the context
    idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(context) ], dtype=torch.long, device=device)
    # Use top_k sampling during generation
    out_ids = model.generate(idx, max_new_tokens=15, temperature=1.0, top_k=50)[0] # Added top_k=50
    print("\n=== Generated ===")
    # Use the tokenizer's decode method for the output
    print(tokenizer.decode(out_ids))
    return model, tokenizer

# ----------------------------
# Run demo if main
# ----------------------------
if __name__ == "__main__":
    random.seed(0)
    torch.manual_seed(0)
    model, tokenizer = train_demo()

Corpus file not found at /tmp/large_corpus.txt. Using a small toy corpus instead.


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Epoch 1/200 — loss: 10.9862
Current learning rate: 0.001000
Epoch 50/200 — loss: 0.8037
Current learning rate: 0.000500
Epoch 100/200 — loss: 0.3318
Current learning rate: 0.000250
Epoch 150/200 — loss: 0.2478
Current learning rate: 0.000125
Epoch 200/200 — loss: 0.2542
Current learning rate: 0.000063

=== Evaluation ===
Perplexity on training data: 1.2631

=== Generated ===
hello world is


In [None]:
# Modified to accept and use the tokenizer instance
def collate_batch(tokenizer, texts: List[str], device):
    # encode, add bos/eos
    encoded = []
    for t in texts:
        # Use the tokenizer's encode method and add special tokens if needed
        # BertTokenizer.encode already handles [CLS] (BOS) and [SEP] (EOS)
        # We might need to adjust based on whether the model expects them
        # For now, let's use the encode method without adding special tokens
        # and add them manually if the model architecture requires it explicitly.
        # However, for BERT-like models, the special tokens are typically added by encode_plus or encode.
        # Let's stick to encode for now and add BOS/EOS manually if necessary for the ARSLM structure.
        # Given the original SimpleTokenizer added BOS/EOS, let's do the same here.
        ids = [tokenizer.bos_token_id] + tokenizer.encode(t) + [tokenizer.eos_token_id]
        encoded.append(torch.tensor(ids, dtype=torch.long, device=device))
    # pad to max len
    max_len = max([x.size(0) for x in encoded])
    batch = torch.full((len(encoded), max_len), tokenizer.pad_token_id, dtype=torch.long, device=device) # Use pad_token_id
    for i, x in enumerate(encoded):
        batch[i, :x.size(0)] = x
    return batch

**Reasoning**:
Now that the `transformers` library is installed, I will import a suitable tokenizer, such as `BertTokenizer`, and create a new class that wraps this tokenizer to maintain compatibility with the existing code's `encode` and `decode` methods. I will also define the special tokens needed for the model.

In [None]:
from transformers import BertTokenizer
from typing import List

class AdvancedTokenizer:
    def __init__(self, vocab_file=None):
        # Using a pre-trained tokenizer's vocabulary for demonstration
        # In a real scenario, you might train a new tokenizer on your specific corpus
        if vocab_file:
             self.tokenizer = BertTokenizer(vocab_file)
        else:
             self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.vocab = list(self.tokenizer.vocab.keys())
        self.stoi = self.tokenizer.vocab
        self.itos = {i: w for w, i in self.stoi.items()}

        # Define special tokens explicitly for clarity and compatibility
        # These might vary depending on the chosen tokenizer
        self.pad_token = self.tokenizer.pad_token
        self.unk_token = self.tokenizer.unk_token
        self.bos_token = self.tokenizer.cls_token # Using [CLS] as BOS for BERT-like tokenizers
        self.eos_token = self.tokenizer.sep_token # Using [SEP] as EOS for BERT-like tokenizers

        self.pad_token_id = self.tokenizer.pad_token_id
        self.unk_token_id = self.tokenizer.unk_token_id
        self.bos_token_id = self.tokenizer.cls_token_id
        self.eos_token_id = self.tokenizer.sep_token_id


    def encode(self, text: str) -> List[int]:
        # The tokenizer handles BOS/EOS internally with add_special_tokens=True
        # We can also manually add them if needed for specific model architectures
        return self.tokenizer.encode(text, add_special_tokens=False)

    def decode(self, ids: List[int]) -> str:
        # skip_special_tokens=True prevents decoding [CLS], [SEP], [PAD] etc.
        # We might adjust this based on how the model handles special tokens
        return self.tokenizer.decode(ids, skip_special_tokens=True)

    def __len__(self):
        return len(self.vocab)

# Example usage:
# tokenizer = AdvancedTokenizer()
# text = "Hello, world! This is a test sentence."
# encoded_text = tokenizer.encode(text)
# print("Encoded:", encoded_text)
# decoded_text = tokenizer.decode(encoded_text)
# print("Decoded:", decoded_text)
# print("Vocab size:", len(tokenizer))

**Reasoning**:
The first step is to install a more robust tokenizer library. Hugging Face's `transformers` library is a good choice as it provides access to many pre-trained tokenizers including those using sub-word techniques. The `tokenizers` library is also a good choice, but `transformers` is more commonly used and includes the tokenizers.

In [None]:
!pip install transformers



# Nouvelle section

In [None]:
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from transformers import BertTokenizer # Import BertTokenizer here
from torch.optim.lr_scheduler import StepLR # Import StepLR
import os # Import os for file operations

# ----------------------------
# Simple tokenizer / vocab (to be replaced)
# ----------------------------
# class SimpleTokenizer:
#     def __init__(self, texts: List[str], min_freq=1):
#         tokens = []
#         for t in texts:
#             tokens += t.strip().split()
#         freq = {}
#         for w in tokens:
#             freq[w] = freq.get(w, 0) + 1
#         self.vocab = ['<pad>', '<unk>', '<bos>', '<eos>']
#         for w, c in sorted(freq.items(), key=lambda x: (-x[1], x[0])):
#             if c >= min_freq and w not in self.vocab:
#                 self.vocab.append(w)
#         self.stoi = {w:i for i,w in enumerate(self.vocab)}
#         self.itos = {i:w for w,i in self.stoi.items()}

#     def encode(self, text: str) -> List[int]:
#         toks = text.strip().split()
#         ids = [self.stoi.get(t, self.stoi['<unk>']) for t in toks]
#         return ids

#     def decode(self, ids: List[int]) -> str:
#         tokens = [self.itos.get(i, '<unk>') for i in ids]
#         return ' '.join(tokens)

# ----------------------------
# Advanced Tokenizer (using transformers)
# ----------------------------
class AdvancedTokenizer:
    def __init__(self, vocab_file=None):
        # Using a pre-trained tokenizer's vocabulary for demonstration
        # In a real scenario, you might train a new tokenizer on your specific corpus
        if vocab_file:
             self.tokenizer = BertTokenizer(vocab_file)
        else:
             self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.vocab = list(self.tokenizer.vocab.keys())
        self.stoi = self.tokenizer.vocab
        self.itos = {i: w for w, i in self.stoi.items()}

        # Define special tokens explicitly for clarity and compatibility
        # These might vary depending on the chosen tokenizer
        self.pad_token = self.tokenizer.pad_token
        self.unk_token = self.tokenizer.unk_token
        self.bos_token = self.tokenizer.cls_token # Using [CLS] as BOS for BERT-like tokenizers
        self.eos_token = self.tokenizer.sep_token # Using [SEP] as EOS for BERT-like tokenizers

        self.pad_token_id = self.tokenizer.pad_token_id
        self.unk_token_id = self.tokenizer.unk_token_id
        self.bos_token_id = self.tokenizer.cls_token_id
        self.eos_token_id = self.tokenizer.sep_token_id


    def encode(self, text: str) -> List[int]:
        # The tokenizer handles BOS/EOS internally with add_special_tokens=True
        # We can also manually add them if needed for specific model architectures
        return self.tokenizer.encode(text, add_special_tokens=False)

    def decode(self, ids: List[int]) -> str:
        # skip_special_tokens=True prevents decoding [CLS], [SEP], [PAD] etc.
        # We might adjust this based on how the model handles special tokens
        return self.tokenizer.decode(ids, skip_special_tokens=True)

    def __len__(self):
        return len(self.vocab)

# ----------------------------
# ARSCell: core adaptive cell
# ----------------------------
class ARSCell(nn.Module):
    """
    ARSCell: computes next hidden state from h_{t-2}, h_{t-1}, and input embedding x_{t-1}
    Implements an 'adapt' mechanism inspired by BenPolySeq:
      h_t = h_{t-1} + gate * transform(h_{t-1}, h_{t-2}, x)
    where gate is a function of the difference (h_{t-1} - h_{t-2}) and context.
    """
    def __init__(self, emb_dim: int, hidden_dim: int, dropout_prob: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        # transform candidate from concat([h_{t-1}, h_{t-2}, x])
        self.candidate_mlp = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim)
        )
        # gate network produces scalar gating factor in (0,1)
        self.gate_net = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        # small residual projection
        self.res_proj = nn.Linear(emb_dim, hidden_dim)
        # Dropout layer
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, h_prev2, h_prev1, x_embed):
        # h_prev2, h_prev1: (batch, hidden_dim)
        # x_embed: (batch, emb_dim)
        # difference signal
        diff = h_prev1 - h_prev2    # (batch, hidden_dim)
        # context vector (concat)
        ctx = torch.cat([h_prev1, h_prev2, x_embed], dim=-1)
        candidate = self.candidate_mlp(ctx)                       # (batch, hidden_dim)
        gate = self.gate_net(ctx).squeeze(-1)                     # (batch,)
        # incorporate residual from input embedding
        residual = self.res_proj(x_embed)
        # ARS-style update: adaptive addition of candidate scaled by gate
        h_t = h_prev1 + gate.unsqueeze(-1) * candidate + 0.1 * residual
        # Apply dropout
        h_t = self.dropout(h_t)
        # optional normalization
        h_t = F.layer_norm(h_t, (self.hidden_dim,))
        return h_t, gate

# ----------------------------
# ARSLM model: embed -> ARSCell(s) -> attention -> head
# ----------------------------
class ARSLM(nn.Module):
    # Modified to accept a tokenizer instance, include attention, and support multiple layers
    def __init__(self, tokenizer, emb_dim=64, hidden_dim=128, num_layers=2):
        super().__init__()
        self.tokenizer = tokenizer # Store tokenizer instance
        self.vocab_size = len(tokenizer) # Get vocab size from tokenizer
        self.num_layers = num_layers
        self.emb = nn.Embedding(self.vocab_size, emb_dim)
        # Create a list of ARSCells
        # Pass dropout_prob to ARSCell constructor
        self.cells = nn.ModuleList([ARSCell(emb_dim if i == 0 else hidden_dim, hidden_dim, dropout_prob=0.1) for i in range(num_layers)])

        # Simple additive attention mechanism, applied after the last layer's hidden states
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        self.head = nn.Linear(hidden_dim, self.vocab_size)

    def forward(self, input_ids):
        """
        input_ids: (batch, seq_len)
        We compute autoregressively, passing hidden states through layers.
        Returns logits (batch, seq_len, vocab).
        """
        bsz, seq_len = input_ids.shape
        emb = self.emb(input_ids)  # (b, seq, emb)
        device = emb.device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        # We need two previous states for each layer's ARSCell
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []
        logits = []
        gates = [] # Store gates from the last layer

        for t in range(seq_len):
            x_t = emb[:, t, :] # embedding at position t
            h_t_input = x_t # Input to the first layer is the embedding

            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                # The input to subsequent layers is the hidden state of the previous layer
                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1] # Use the output of the previous layer

                h_t, gate = cell(h_prev2, h_prev1, h_t_input)

                # Store current layer's hidden state as input for the next layer in the next time step
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next time step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1: # Only store gates from the last layer
                    gates.append(gate.unsqueeze(1))

            # After processing all layers for time step t, the output is the hidden state of the last layer
            last_layer_h_t = current_layer_hidden_states[-1]
            all_last_layer_hidden_states.append(last_layer_h_t.unsqueeze(1))

            # Apply attention and head after the last layer's hidden state
            # Causal Attention: attend over hidden states up to current time step t from the last layer
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, t+1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim) - query is the current last layer hidden state

            # Calculate attention scores
            scores = self.attention(last_layer_history) # (b, t+1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, t+1, 1)

            # Apply attention weights
            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            # Combine last layer's current hidden state with context vector before the head
            attended_h_t = last_layer_h_t + context_vector

            logit = self.head(attended_h_t) # predict token at t (or next token)
            logits.append(logit.unsqueeze(1))


        logits = torch.cat(logits, dim=1)          # (b, seq, vocab)
        gates = torch.cat(gates, dim=1)            # (b, seq) # Gates from the last layer
        return logits, gates

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=20, temperature=1.0, top_k=None):
        """
        idx: (batch, current_seq_len) - input context (token ids)
        Returns extended token ids.
        Modified for multiple layers, attention during generation, and optional top-k sampling.
        """
        self.eval()
        bsz, seq_len = idx.shape
        device = next(self.parameters()).device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []

        # Process initial context to get starting hidden states for each layer
        input_ids = idx.clone()
        emb = self.emb(input_ids) # (b, seq_len, emb_dim)

        for t in range(seq_len):
            x_t = emb[:,t,:] # embedding at position t
            h_t_input = x_t

            current_layer_hidden_states = []
            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


        out_ids = input_ids.tolist()

        for _ in range(max_new_tokens):
            # Input for the first layer's ARSCell is the embedding of the last generated token
            last_token_ids = torch.tensor([ids[-1] for ids in out_ids], dtype=torch.long, device=device).unsqueeze(1)
            x_embed = self.emb(last_token_ids).squeeze(1) # (b, emb_dim)

            h_t_input = x_embed
            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next generation step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


            # After processing all layers for the new token, apply attention and head
            last_layer_h_t = current_layer_hidden_states[-1] # Hidden state of the last layer for the new token

            # Causal Attention over all last layer hidden states history
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, current_seq_len + 1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim)

            scores = self.attention(last_layer_history) # (b, current_seq_len + 1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, current_seq_len + 1, 1)

            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            attended_h_t = last_layer_h_t + context_vector

            logits = self.head(attended_h_t) / max(1e-6, temperature)

            # Apply top-k sampling
            if top_k is not None:
                # get top k logits
                top_k = min(max(top_k, 1), logits.size(-1))  # Clamp k to be at least 1 and at most vocab size
                # Remove all tokens with a probability less than the top-k'th token's probability
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')

            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)


            # append the newly generated token
            for i in range(bsz):
                out_ids[i].append(int(next_token[i].item()))


        return out_ids

# ----------------------------
# Training / demo utilities
# ----------------------------
# Modified to accept and use the tokenizer instance
def collate_batch(tokenizer, texts: List[str], device):
    # encode, add bos/eos
    encoded = []
    for t in texts:
        # Use the tokenizer's encode method and add special tokens if needed
        # BertTokenizer.encode already handles [CLS] (BOS) and [SEP] (EOS)
        # We might need to adjust based on whether the model expects them
        # For now, let's use the encode method without adding special tokens
        # and add them manually if the model architecture requires it explicitly.
        # However, for BERT-like models, the special tokens are typically added by encode_plus or encode.
        # Let's stick to encode for now and add BOS/EOS manually if necessary for the ARSLM structure.
        # Given the original SimpleTokenizer added BOS/EOS, let's do the same here.
        ids = [tokenizer.bos_token_id] + tokenizer.encode(t) + [tokenizer.eos_token_id]
        encoded.append(torch.tensor(ids, dtype=torch.long, device=device))
    # pad to max len
    max_len = max([x.size(0) for x in encoded])
    batch = torch.full((len(encoded), max_len), tokenizer.pad_token_id, dtype=torch.long, device=device) # Use pad_token_id
    for i, x in enumerate(encoded):
        batch[i, :x.size(0)] = x
    return batch

# Modified to use AdvancedTokenizer and handle multiple layers
def train_demo():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # toy corpus
    # texts = [
    #     "hello world this is ars",
    #     "the system adapts to its history",
    #     "benpolyseq demonstrates adaptive sequences",
    #     "ars can inspire new network protocols",
    #     "self optimizing systems are possible"
    # ]
    # --- Load text from a file (example) ---
    corpus_file = "/tmp/large_corpus.txt" # Replace with your corpus file path
    texts = []
    if os.path.exists(corpus_file):
        print(f"Loading corpus from {corpus_file}...")
        with open(corpus_file, 'r', encoding='utf-8') as f:
            # Read a limited number of lines for demo, for full training, iterate or use data loaders
            for i, line in enumerate(f):
                if i < 1000: # Load first 1000 lines as an example
                    texts.append(line.strip())
                else:
                    break
        print(f"Loaded {len(texts)} lines.")
    else:
        print(f"Corpus file not found at {corpus_file}. Using a small toy corpus instead.")
        texts = [
            "hello world this is ars",
            "the system adapts to its history",
            "benpolyseq demonstrates adaptive sequences",
            "ars can inspire new network protocols",
            "self optimizing systems are possible"
        ]


    # Instantiate AdvancedTokenizer
    tokenizer = AdvancedTokenizer()
    # Vocab size is now determined by the tokenizer
    vocab_size = len(tokenizer)
    # Instantiate ARSLM with multiple layers
    model = ARSLM(tokenizer, emb_dim=64, hidden_dim=128, num_layers=2).to(device) # Added num_layers
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    # Add a StepLR scheduler
    scheduler = StepLR(optimizer, step_size=50, gamma=0.5) # Reduce LR by half every 50 epochs
    # Use ignore_index from the tokenizer
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    # create training batches by shifting: predict next token at each position
    n_epochs = 200
    batch = collate_batch(tokenizer, texts, device) # Pass tokenizer to collate_batch
    # targets: same as input (language modeling)
    inputs = batch[:, :-1]
    targets = batch[:, 1:]

    # Simple check if batch is empty due to no text loaded
    if inputs.numel() == 0:
        print("No training data loaded. Skipping training.")
        return model, tokenizer


    for epoch in range(n_epochs):
        model.train()
        logits, gates = model(inputs)
        # flatten
        b, seq, v = logits.shape
        loss = loss_fn(logits.view(b*seq, v), targets.reshape(b*seq))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        # Step the scheduler
        scheduler.step()
        if (epoch+1) % 50 == 0 or epoch==0:
            print(f"Epoch {epoch+1}/{n_epochs} — loss: {loss.item():.4f}")
            print(f"Current learning rate: {scheduler.get_last_lr()[0]:.6f}") # Print current LR

    print("\n=== Evaluation ===")
    model.eval()
    with torch.no_grad():
        logits, _ = model(inputs) # Use the same inputs for simplicity in this demo
        # Calculate perplexity
        # Perplexity = exp(average negative log-likelihood)
        loss = loss_fn(logits.view(-1, v), targets.reshape(-1))
        perplexity = torch.exp(loss)
        print(f"Perplexity on training data: {perplexity.item():.4f}")


    # demo generation
    context = "hello world"
    # Use the tokenizer's encode method for the context
    idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(context) ], dtype=torch.long, device=device)
    # Use top_k sampling during generation
    out_ids = model.generate(idx, max_new_tokens=15, temperature=1.0, top_k=50)[0] # Added top_k=50
    print("\n=== Generated ===")
    # Use the tokenizer's decode method for the output
    print(tokenizer.decode(out_ids))
    return model, tokenizer

# ----------------------------
# Run demo if main
# ----------------------------
if __name__ == "__main__":
    random.seed(0)
    torch.manual_seed(0)
    model, tokenizer = train_demo()

Corpus file not found at /tmp/large_corpus.txt. Using a small toy corpus instead.
Epoch 1/200 — loss: 10.9862
Current learning rate: 0.001000
Epoch 50/200 — loss: 0.8037
Current learning rate: 0.000500
Epoch 100/200 — loss: 0.3318
Current learning rate: 0.000250
Epoch 150/200 — loss: 0.2478
Current learning rate: 0.000125
Epoch 200/200 — loss: 0.2542
Current learning rate: 0.000063

=== Evaluation ===
Perplexity on training data: 1.2631

=== Generated ===
hello world is ars


In [None]:
class ARSCell(nn.Module):
    """
    ARSCell: computes next hidden state from h_{t-2}, h_{t-1}, and input embedding x_{t-1}
    Implements an 'adapt' mechanism inspired by BenPolySeq:
      h_t = h_{t-1} + gate * transform(h_{t-1}, h_{t-2}, x)
    where gate is a function of the difference (h_{t-1} - h_{t-2}) and context.
    """
    def __init__(self, emb_dim: int, hidden_dim: int, dropout_prob: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        # transform candidate from concat([h_{t-1}, h_{t-2}, x])
        self.candidate_mlp = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim)
        )
        # gate network produces scalar gating factor in (0,1)
        self.gate_net = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        # small residual projection
        self.res_proj = nn.Linear(emb_dim, hidden_dim)
        # Dropout layer
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, h_prev2, h_prev1, x_embed):
        # h_prev2, h_prev1: (batch, hidden_dim)
        # x_embed: (batch, emb_dim)
        # difference signal
        diff = h_prev1 - h_prev2    # (batch, hidden_dim)
        # context vector (concat)
        ctx = torch.cat([h_prev1, h_prev2, x_embed], dim=-1)
        candidate = self.candidate_mlp(ctx)                       # (batch, hidden_dim)
        gate = self.gate_net(ctx).squeeze(-1)                     # (batch,)
        # incorporate residual from input embedding
        residual = self.res_proj(x_embed)
        # ARS-style update: adaptive addition of candidate scaled by gate
        h_t = h_prev1 + gate.unsqueeze(-1) * candidate + 0.1 * residual
        # Apply dropout
        h_t = self.dropout(h_t)
        # optional normalization
        h_t = F.layer_norm(h_t, (self.hidden_dim,))
        return h_t, gate

# Task
Propose des améliorations pour le modèle ARSLM afin de le rendre plus efficace et opérationnel.

## Amélioration du tokenizer

### Subtask:
Remplacer le `SimpleTokenizer` par un tokenizer plus robuste capable de gérer la ponctuation, les majuscules/minuscules, les sous-mots (comme les Byte Pair Encoding - BPE ou WordPiece) pour mieux gérer les mots inconnus et réduire la taille du vocabulaire.


**Reasoning**:
The first step is to install a more robust tokenizer library. Hugging Face's `transformers` library is a good choice as it provides access to many pre-trained tokenizers including those using sub-word techniques. The `tokenizers` library is also a good choice, but `transformers` is more commonly used and includes the tokenizers.



In [None]:
!pip install transformers



## Explication détaillée du modèle ARSLM

Le modèle ARSLM (Adaptive Recurrent State Language Model) est un prototype de modèle linguistique basé sur le concept d'états récurrents adaptatifs (ARS), inspiré par les séquences adaptatives BenPolySeq. Il est conçu pour être léger, explicable et extensible, ce qui le rend adapté à la recherche et à l'expérimentation.

Voici une décomposition de ses composants clés :

### 1. Tokenizer Avancé (`AdvancedTokenizer`)

Ce composant gère la conversion du texte en séquences d'identifiants (tokens) et vice-versa. Dans cette version améliorée, il utilise la bibliothèque `transformers` de Hugging Face, spécifiquement le `BertTokenizer`.

*   **Pourquoi un tokenizer avancé ?** Contrairement à un simple tokenizer qui pourrait se contenter de diviser le texte par espaces et de gérer un vocabulaire limité, un tokenizer basé sur les sous-mots (comme ceux utilisés par BERT) permet de :
    *   Mieux gérer les mots inconnus en les décomposant en sous-unités (par exemple, "tokenization" pourrait être décomposé en "token", "iza", "tion").
    *   Réduire la taille globale du vocabulaire tout en représentant une plus grande variété de mots.
    *   Gérer la ponctuation, les majuscules/minuscules et d'autres subtilités linguistiques.
*   **Fonctionnement :** L'`AdvancedTokenizer` utilise un vocabulaire pré-entraîné (ici, celui de `bert-base-uncased` par défaut). Il fournit des méthodes `encode` pour convertir le texte en identifiants et `decode` pour convertir les identifiants en texte. Il définit également des identifiants spéciaux pour le padding (`<pad>`), les mots inconnus (`<unk>`), le début de séquence (`<bos>`, utilisant `[CLS]` de BERT) et la fin de séquence (`<eos>`, utilisant `[SEP]` de BERT).

### 2. Cellule ARS (`ARSCell`)

C'est le cœur adaptatif du modèle. Une `ARSCell` calcule le prochain état caché (`h_t`) en se basant sur les deux états cachés précédents (`h_{t-2}`, `h_{t-1}`) et l'embedding de l'entrée courante (`x_embed`).

*   **Mécanisme d'Adaptation :** L'idée principale, inspirée par BenPolySeq, est que la mise à jour de l'état caché n'est pas une simple combinaison linéaire ou une porte fixe (comme dans les RNN ou LSTM basiques), mais une addition adaptative d'un "candidat" pondéré par un "gate".
    *   **Signal de Différence :** La cellule utilise la différence entre les deux états cachés précédents (`h_{t-1} - h_{t-2}`). Ce signal est censé capturer l'évolution ou la dynamique de l'état récurrent.
    *   **Contexte :** Un vecteur de contexte est créé en concaténant les deux états cachés précédents et l'embedding d'entrée (`[h_{t-1}, h_{t-2}, x_embed]`).
    *   **Réseau Candidat (`candidate_mlp`) :** Un MLP (réseau de neurones multi-couches) prend le contexte en entrée et génère un "candidat" pour la mise à jour de l'état caché.
    *   **Réseau de Porte (`gate_net`) :** Un autre réseau (comprenant une fonction sigmoïde à la fin) prend également le contexte en entrée et produit une valeur scalaire entre 0 et 1. Cette valeur agit comme un "gate" qui contrôle l'influence du candidat sur la mise à jour de l'état caché.
    *   **Résiduel :** Un petit résiduel de l'embedding d'entrée est ajouté pour aider à propager directement l'information de l'entrée.
    *   **Mise à Jour :** L'état caché suivant est calculé comme `h_t = h_{t-1} + gate * candidate + 0.1 * residual`. Cette formule montre comment la cellule *adapte* sa mise à jour en fonction du signal de différence et du contexte via le gate.
    *   **Dropout et Normalisation :** Une couche de dropout est appliquée pour la régularisation, et une normalisation de couche (`F.layer_norm`) est utilisée pour stabiliser l'entraînement.

### 3. Modèle ARSLM (`ARSLM`)

Ce module encapsule l'ensemble de l'architecture, combinant l'embedding, plusieurs couches d'`ARSCell`, un mécanisme d'attention et une couche de sortie (head).

*   **Embedding (`nn.Embedding`) :** Convertit les identifiants de tokens en vecteurs denses de dimension `emb_dim`.
*   **Couches d'ARSCell (`nn.ModuleList`) :** Le modèle utilise une liste de `num_layers` ARSCells empilées. L'entrée de la première cellule est l'embedding, et l'entrée des cellules subséquentes est l'état caché de la couche précédente.
*   **Mécanisme d'Attention :** Un mécanisme d'attention additive simple est appliqué après la dernière couche d'ARSCell.
    *   **Attention Causale :** Pendant la passe avant (forward), l'attention est causale, ce qui signifie que chaque token ne peut prêter attention qu'aux tokens précédents dans la séquence.
    *   **Calcul :** Pour chaque pas temporel `t`, l'état caché de la dernière couche (`last_layer_h_t`) sert de requête. L'attention est calculée sur l'historique de tous les états cachés de la dernière couche jusqu'au pas `t` (`all_last_layer_hidden_states`). Les scores d'attention sont calculés via un petit réseau (`self.attention`), puis normalisés avec un softmax. Un vecteur de contexte est obtenu en pondérant l'historique par les poids d'attention.
    *   **Combinaison :** Le vecteur de contexte est ajouté à l'état caché courant de la dernière couche (`attended_h_t = last_layer_h_t + context_vector`). Cette combinaison permet au modèle de tirer parti des informations pertinentes de la séquence passée.
*   **Couche de Sortie (`head`) :** Une couche linéaire qui prend l'état caché combiné (avec attention) de la dernière couche et le projette sur la taille du vocabulaire (`self.vocab_size`). Les sorties de cette couche sont les logits, qui représentent la probabilité non normalisée de chaque token dans le vocabulaire d'être le prochain token.

### 4. Fonction `forward`

Cette fonction décrit comment une séquence d'identifiants (`input_ids`) est traitée par le modèle pour produire les logits et les gates. Elle itère sur chaque pas temporel de la séquence d'entrée, calculant séquentiellement les états cachés pour chaque couche d'ARSCell et appliquant l'attention et la couche de sortie à chaque pas après la dernière couche.

### 5. Fonction `generate`

Cette fonction implémente le processus de génération de texte. À partir d'une séquence de contexte initiale (`idx`), le modèle génère séquentiellement de nouveaux tokens. Pour chaque nouveau token à générer :
*   Il prend l'embedding du dernier token généré.
*   Il utilise les états cachés mis à jour des étapes précédentes pour calculer le nouvel état caché pour chaque couche d'ARSCell.
*   Il applique le mécanisme d'attention sur l'historique des états cachés de la dernière couche (incluant le nouvel état).
*   Il utilise la couche de sortie pour obtenir les logits pour le prochain token.
*   Il applique un échantillonnage (ici, `torch.multinomial` avec une température et optionnellement `top_k`) pour sélectionner le prochain token en fonction des probabilités calculées à partir des logits.
*   Le nouveau token est ajouté à la séquence de sortie, et le processus se répète jusqu'à ce que le nombre maximal de nouveaux tokens soit atteint.

### 6. Utilitaires d'Entraînement/Démo (`collate_batch`, `train_demo`)

*   **`collate_batch` :** Prépare les données texte pour l'entraînement en les encodant à l'aide du tokenizer, en ajoutant les tokens de début/fin de séquence et en effectuant le padding pour que toutes les séquences d'un batch aient la même longueur.
*   **`train_demo` :** Une fonction de démonstration qui initialise le tokenizer, le modèle, l'optimiseur et la fonction de perte. Elle charge un petit corpus (ou un fichier si spécifié), prépare les données, entraîne le modèle sur un nombre spécifié d'époques, imprime la perte et le taux d'apprentissage, calcule la perplexité et effectue une petite démo de génération de texte.

En résumé, l'ARSLM combine l'idée d'états récurrents adaptatifs avec une architecture multi-couches et un mécanisme d'attention causale pour créer un modèle linguistique capable d'apprendre des dépendances séquentielles de manière flexible. L'aspect "adaptatif" via le mécanisme de gate dans l'ARSCell est la caractéristique distinctive, permettant potentiellement au modèle d'ajuster son comportement en fonction de la dynamique interne de sa mémoire récurrente.

**Reasoning**:
Now that the `transformers` library is installed, I will import a suitable tokenizer, such as `BertTokenizer`, and create a new class that wraps this tokenizer to maintain compatibility with the existing code's `encode` and `decode` methods. I will also define the special tokens needed for the model.



# Task
Refactor the data loading and preparation process by creating a `TextDataset` class that integrates with `torch.utils.data.DataLoader`, modifying `collate_batch` to support batched data loading, and implementing training, validation, and test data splits within the `train_demo` function.

## Refactorisation du Chargement et Préparation des Données

### Subtask:
Créer une classe `TextDataset` pour une gestion plus structurée des données textuelles et de la tokenisation. Modifier la fonction `collate_batch` pour qu'elle soit compatible avec `torch.utils.data.DataLoader`, et intégrer une séparation des données en ensembles d'entraînement, de validation et de test dans `train_demo`. L'objectif est d'assurer un chargement efficace des données et une évaluation fiable.


**Reasoning**:
I will define the `TextDataset` class which will encapsulate the text data and the tokenizer, and implement its `__len__` and `__getitem__` methods as specified in the instructions to prepare for data loading with `DataLoader`.



In [None]:
import torch
from torch.utils.data import Dataset # Import Dataset

# ----------------------------
# TextDataset class for structured data handling
# ----------------------------
class TextDataset(Dataset):
    def __init__(self, texts: list, tokenizer):
        self.texts = texts
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]

**Reasoning**:
I will modify the `collate_batch` function to accept a list of raw text strings and the tokenizer, then encode, pad, and prepare input and target tensors for language modeling, as required by the subtask.



In [None]:
import torch
from typing import List

def collate_batch(batch_texts: List[str], tokenizer):
    # Encode each text, add BOS/EOS, and convert to tensor
    encoded = []
    for t in batch_texts:
        ids = [tokenizer.bos_token_id] + tokenizer.encode(t) + [tokenizer.eos_token_id]
        encoded.append(torch.tensor(ids, dtype=torch.long))

    # Pad to max len within the batch
    max_len = max([x.size(0) for x in encoded])
    # All tensors should be on the same device before stacking for DataLoader compatibility
    # We'll assume the device is passed to train_demo and handled there for now.
    # For simplicity in collate_batch, let's keep it device-agnostic, and move to device later.
    padded_batch = torch.full((len(encoded), max_len), tokenizer.pad_token_id, dtype=torch.long)
    for i, x in enumerate(encoded):
        padded_batch[i, :x.size(0)] = x

    # Prepare inputs and targets for language modeling
    inputs = padded_batch[:, :-1] # All tokens except the last one
    targets = padded_batch[:, 1:]  # All tokens except the first one

    return inputs, targets

**Reasoning**:
I will update the `train_demo` function to split the data into training, validation, and test sets, create `TextDataset` and `DataLoader` instances for each, and modify the training loop to use the DataLoader for batching. This will align with the subtask's requirements for refactoring data loading.



In [None]:
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from transformers import BertTokenizer # Import BertTokenizer here
from torch.optim.lr_scheduler import StepLR # Import StepLR
import os # Import os for file operations
from torch.utils.data import Dataset, DataLoader # Import Dataset and DataLoader
from sklearn.model_selection import train_test_split # Import train_test_split

# ----------------------------
# Advanced Tokenizer (using transformers)
# ----------------------------
class AdvancedTokenizer:
    def __init__(self, vocab_file=None):
        # Using a pre-trained tokenizer's vocabulary for demonstration
        # In a real scenario, you might train a new tokenizer on your specific corpus
        if vocab_file:
             self.tokenizer = BertTokenizer(vocab_file)
        else:
             self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.vocab = list(self.tokenizer.vocab.keys())
        self.stoi = self.tokenizer.vocab
        self.itos = {i: w for w, i in self.stoi.items()}

        # Define special tokens explicitly for clarity and compatibility
        # These might vary depending on the chosen tokenizer
        self.pad_token = self.tokenizer.pad_token
        self.unk_token = self.tokenizer.unk_token
        self.bos_token = self.tokenizer.cls_token # Using [CLS] as BOS for BERT-like tokenizers
        self.eos_token = self.tokenizer.sep_token # Using [SEP] as EOS for BERT-like tokenizers

        self.pad_token_id = self.tokenizer.pad_token_id
        self.unk_token_id = self.tokenizer.unk_token_id
        self.bos_token_id = self.tokenizer.cls_token_id
        self.eos_token_id = self.tokenizer.sep_token_id


    def encode(self, text: str) -> List[int]:
        # The tokenizer handles BOS/EOS internally with add_special_tokens=True
        # We can also manually add them if needed for specific model architectures
        return self.tokenizer.encode(text, add_special_tokens=False)

    def decode(self, ids: List[int]) -> str:
        # skip_special_tokens=True prevents decoding [CLS], [SEP], [PAD] etc.
        # We might adjust this based on how the model handles special tokens
        return self.tokenizer.decode(ids, skip_special_tokens=True)

    def __len__(self):
        return len(self.vocab)

# ----------------------------
# ARSCell: core adaptive cell
# ----------------------------
class ARSCell(nn.Module):
    """
    ARSCell: computes next hidden state from h_{t-2}, h_{t-1}, and input embedding x_{t-1}
    Implements an 'adapt' mechanism inspired by BenPolySeq:
      h_t = h_{t-1} + gate * transform(h_{t-1}, h_{t-2}, x)
    where gate is a function of the difference (h_{t-1} - h_{t-2}) and context.
    """
    def __init__(self, emb_dim: int, hidden_dim: int, dropout_prob: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        # transform candidate from concat([h_{t-1}, h_{t-2}, x])
        self.candidate_mlp = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim)
        )
        # gate network produces scalar gating factor in (0,1)
        self.gate_net = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        # small residual projection
        self.res_proj = nn.Linear(emb_dim, hidden_dim)
        # Dropout layer
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, h_prev2, h_prev1, x_embed):
        # h_prev2, h_prev1: (batch, hidden_dim)
        # x_embed: (batch, emb_dim)
        # difference signal
        diff = h_prev1 - h_prev2    # (batch, hidden_dim)
        # context vector (concat)
        ctx = torch.cat([h_prev1, h_prev2, x_embed], dim=-1)
        candidate = self.candidate_mlp(ctx)                       # (batch, hidden_dim)
        gate = self.gate_net(ctx).squeeze(-1)                     # (batch,)
        # incorporate residual from input embedding
        residual = self.res_proj(x_embed)
        # ARS-style update: adaptive addition of candidate scaled by gate
        h_t = h_prev1 + gate.unsqueeze(-1) * candidate + 0.1 * residual
        # Apply dropout
        h_t = self.dropout(h_t)
        # optional normalization
        h_t = F.layer_norm(h_t, (self.hidden_dim,))
        return h_t, gate

# ----------------------------
# ARSLM model: embed -> ARSCell(s) -> attention -> head
# ----------------------------
class ARSLM(nn.Module):
    # Modified to accept a tokenizer instance, include attention, and support multiple layers
    def __init__(self, tokenizer, emb_dim=64, hidden_dim=128, num_layers=2):
        super().__init__()
        self.tokenizer = tokenizer # Store tokenizer instance
        self.vocab_size = len(tokenizer) # Get vocab size from tokenizer
        self.num_layers = num_layers
        self.emb = nn.Embedding(self.vocab_size, emb_dim)
        # Create a list of ARSCells
        # Pass dropout_prob to ARSCell constructor
        self.cells = nn.ModuleList([ARSCell(emb_dim if i == 0 else hidden_dim, hidden_dim, dropout_prob=0.1) for i in range(num_layers)])

        # Simple additive attention mechanism, applied after the last layer's hidden states
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        self.head = nn.Linear(hidden_dim, self.vocab_size)

    def forward(self, input_ids):
        """
        input_ids: (batch, seq_len)
        We compute autoregressively, passing hidden states through layers.
        Returns logits (batch, seq_len, vocab).
        """
        bsz, seq_len = input_ids.shape
        emb = self.emb(input_ids)  # (b, seq, emb)
        device = emb.device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        # We need two previous states for each layer's ARSCell
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []
        logits = []
        gates = [] # Store gates from the last layer

        for t in range(seq_len):
            x_t = emb[:, t, :] # embedding at position t
            h_t_input = x_t # Input to the first layer is the embedding

            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                # The input to subsequent layers is the hidden state of the previous layer
                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1] # Use the output of the previous layer

                h_t, gate = cell(h_prev2, h_prev1, h_t_input)

                # Store current layer's hidden state as input for the next layer in the next time step
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next time step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1: # Only store gates from the last layer
                    gates.append(gate.unsqueeze(1))

            # After processing all layers for time step t, the output is the hidden state of the last layer
            last_layer_h_t = current_layer_hidden_states[-1]
            all_last_layer_hidden_states.append(last_layer_h_t.unsqueeze(1))

            # Apply attention and head after the last layer's hidden state
            # Causal Attention: attend over hidden states up to current time step t from the last layer
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, t+1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim) - query is the current last layer hidden state

            # Calculate attention scores
            scores = self.attention(last_layer_history) # (b, t+1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, t+1, 1)

            # Apply attention weights
            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            # Combine last layer's current hidden state with context vector before the head
            attended_h_t = last_layer_h_t + context_vector

            logit = self.head(attended_h_t) # predict token at t (or next token)
            logits.append(logit.unsqueeze(1))


        logits = torch.cat(logits, dim=1)          # (b, seq, vocab)
        gates = torch.cat(gates, dim=1)            # (b, seq) # Gates from the last layer
        return logits, gates

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=20, temperature=1.0, top_k=None):
        """
        idx: (batch, current_seq_len) - input context (token ids)
        Returns extended token ids.
        Modified for multiple layers, attention during generation, and optional top-k sampling.
        """
        self.eval()
        bsz, seq_len = idx.shape
        device = next(self.parameters()).device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []

        # Process initial context to get starting hidden states for each layer
        input_ids = idx.clone()
        emb = self.emb(input_ids) # (b, seq_len, emb_dim)

        for t in range(seq_len):
            x_t = emb[:,t,:] # embedding at position t
            h_t_input = x_t

            current_layer_hidden_states = []
            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


        out_ids = input_ids.tolist()

        for _ in range(max_new_tokens):
            # Input for the first layer's ARSCell is the embedding of the last generated token
            last_token_ids = torch.tensor([ids[-1] for ids in out_ids], dtype=torch.long, device=device).unsqueeze(1)
            x_embed = self.emb(last_token_ids).squeeze(1) # (b, emb_dim)

            h_t_input = x_embed
            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next generation step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


            # After processing all layers for the new token, apply attention and head
            last_layer_h_t = current_layer_hidden_states[-1] # Hidden state of the last layer for the new token

            # Causal Attention over all last layer hidden states history
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, current_seq_len + 1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim)

            scores = self.attention(last_layer_history) # (b, current_seq_len + 1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, current_seq_len + 1, 1)

            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            attended_h_t = last_layer_h_t + context_vector

            logits = self.head(attended_h_t) / max(1e-6, temperature)

            # Apply top-k sampling
            if top_k is not None:
                # get top k logits
                top_k = min(max(top_k, 1), logits.size(-1))  # Clamp k to be at least 1 and at most vocab size
                # Remove all tokens with a probability less than the top-k'th token's probability
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')

            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)


            # append the newly generated token
            for i in range(bsz):
                out_ids[i].append(int(next_token[i].item()))


        return out_ids

# ----------------------------
# TextDataset class for structured data handling
# ----------------------------
class TextDataset(Dataset):
    def __init__(self, texts: list, tokenizer):
        self.texts = texts
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]

# Modified to accept and use the tokenizer instance
def collate_batch(batch_texts: List[str], tokenizer):
    # Encode each text, add BOS/EOS, and convert to tensor
    encoded = []
    for t in batch_texts:
        ids = [tokenizer.bos_token_id] + tokenizer.encode(t) + [tokenizer.eos_token_id]
        encoded.append(torch.tensor(ids, dtype=torch.long))

    # Pad to max len within the batch
    max_len = max([x.size(0) for x in encoded])
    padded_batch = torch.full((len(encoded), max_len), tokenizer.pad_token_id, dtype=torch.long)
    for i, x in enumerate(encoded):
        padded_batch[i, :x.size(0)] = x

    # Prepare inputs and targets for language modeling
    inputs = padded_batch[:, :-1]
    targets = padded_batch[:, 1:]

    return inputs, targets


# Modified to use AdvancedTokenizer and handle multiple layers, and DataLoaders
def train_demo():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    corpus_file = "/tmp/large_corpus.txt"
    texts = []
    if os.path.exists(corpus_file):
        print(f"Loading corpus from {corpus_file}...")
        with open(corpus_file, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i < 1000:
                    texts.append(line.strip())
                else:
                    break
        print(f"Loaded {len(texts)} lines.")
    else:
        print(f"Corpus file not found at {corpus_file}. Using a small toy corpus instead.")
        texts = [
            "hello world this is ars",
            "the system adapts to its history",
            "benpolyseq demonstrates adaptive sequences",
            "ars can inspire new network protocols",
            "self optimizing systems are possible"
        ]

    tokenizer = AdvancedTokenizer()
    vocab_size = len(tokenizer)

    # Split data into train, validation, and test sets
    train_texts, test_texts = train_test_split(texts, test_size=0.2, random_state=42)
    val_texts, test_texts = train_test_split(test_texts, test_size=0.5, random_state=42) # 0.1 of total

    print(f"Train size: {len(train_texts)}, Validation size: {len(val_texts)}, Test size: {len(test_texts)}")

    # Create datasets
    train_dataset = TextDataset(train_texts, tokenizer)
    val_dataset = TextDataset(val_texts, tokenizer)
    test_dataset = TextDataset(test_texts, tokenizer)

    batch_size = 8 # Define a batch size

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                collate_fn=lambda b: collate_batch(b, tokenizer))
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                              collate_fn=lambda b: collate_batch(b, tokenizer))
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                               collate_fn=lambda b: collate_batch(b, tokenizer))

    model = ARSLM(tokenizer, emb_dim=64, hidden_dim=128, num_layers=2).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    scheduler = StepLR(optimizer, step_size=50, gamma=0.5)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    n_epochs = 200
    if len(train_texts) == 0:
        print("No training data available after split. Skipping training.")
        return model, tokenizer

    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            logits, gates = model(inputs)
            b, seq, v = logits.shape
            loss = loss_fn(logits.view(b*seq, v), targets.reshape(b*seq))
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        scheduler.step()

        if (epoch+1) % 50 == 0 or epoch==0:
            print(f"Epoch {epoch+1}/{n_epochs} — Avg Train Loss: {avg_train_loss:.4f}")
            print(f"Current learning rate: {scheduler.get_last_lr()[0]:.6f}")

            # Validation step
            model.eval()
            total_val_loss = 0
            with torch.no_grad():
                for inputs_val, targets_val in val_loader:
                    inputs_val, targets_val = inputs_val.to(device), targets_val.to(device)
                    logits_val, _ = model(inputs_val)
                    b_val, seq_val, v_val = logits_val.shape
                    loss_val = loss_fn(logits_val.view(b_val*seq_val, v_val), targets_val.reshape(b_val*seq_val))
                    total_val_loss += loss_val.item()
            avg_val_loss = total_val_loss / len(val_loader)
            print(f"Validation Loss: {avg_val_loss:.4f}")

    print("\n=== Evaluation ===")
    model.eval()
    total_test_loss = 0
    total_perplexity = 0
    with torch.no_grad():
        for inputs_test, targets_test in test_loader:
            inputs_test, targets_test = inputs_test.to(device), targets_test.to(device)
            logits_test, _ = model(inputs_test)
            b_test, seq_test, v_test = logits_test.shape
            loss_test = loss_fn(logits_test.view(b_test*seq_test, v_test), targets_test.reshape(b_test*seq_test))
            total_test_loss += loss_test.item()
            total_perplexity += torch.exp(loss_test).item()

    avg_test_loss = total_test_loss / len(test_loader)
    avg_test_perplexity = total_perplexity / len(test_loader)
    print(f"Test Loss: {avg_test_loss:.4f}")
    print(f"Perplexity on test data: {avg_test_perplexity:.4f}")

    # demo generation
    context = "hello world"
    idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(context) ], dtype=torch.long, device=device)
    out_ids = model.generate(idx, max_new_tokens=15, temperature=1.0, top_k=50)[0]
    print("\n=== Generated ===")
    print(tokenizer.decode(out_ids))
    return model, tokenizer

# ----------------------------
# Run demo if main
# ----------------------------
if __name__ == "__main__":
    random.seed(0)
    torch.manual_seed(0)
    model, tokenizer = train_demo()

Corpus file not found at /tmp/large_corpus.txt. Using a small toy corpus instead.


ValueError: With n_samples=1, test_size=0.5 and train_size=None, the resulting train set will be empty. Adjust any of the aforementioned parameters.

**Reasoning**:
The previous `ValueError` occurred because the `train_test_split` function was called with too few samples to create distinct validation and test sets when using the small toy corpus. I will modify the data splitting logic in `train_demo` to handle cases with very few samples gracefully, ensuring that `train_test_split` is called only when there are enough samples and avoiding empty splits. Additionally, I will add checks for `len(loader)` before calculating average losses to prevent `ZeroDivisionError` if a DataLoader is empty.



# Task
The next step is to introduce a configuration system (using a dictionary or a simple class) to manage model hyperparameters, training parameters (learning rate, number of epochs, etc.), and file paths. This will centralize settings and make the model more easily configurable and adaptable.

## Refactorisation du Chargement et Préparation des Données

### Subtask:
Créer une classe `TextDataset` pour une gestion plus structurée des données textuelles et de la tokenisation. Modifier la fonction `collate_batch` pour qu'elle soit compatible avec `torch.utils.data.DataLoader`, et intégrer une séparation des données en ensembles d'entraînement, de validation et de test dans `train_demo`. L'objectif est d'assurer un chargement efficace des données et une évaluation fiable.


**Reasoning**:
The previous `ValueError` occurred because the `train_test_split` function was called with too few samples to create distinct validation and test sets when using the small toy corpus. I will modify the `train_demo` function to robustly handle data splitting, ensuring that `train_test_split` is called only when there are enough samples and avoiding empty splits. Additionally, I will add checks for `len(loader)` before calculating average losses to prevent `ZeroDivisionError` if a DataLoader is empty.



# Task
It looks like the previous execution failed due to a `ValueError` in `train_test_split` when the corpus was too small, and there's also a potential for `ZeroDivisionError` if a DataLoader ends up being empty.

The next step is to refine the `train_demo` function to make the data splitting more robust, especially for small datasets, and to ensure that loss calculations only occur when DataLoaders contain data. This involves adjusting the `train_test_split` parameters and adding checks before calculating average losses or perplexity.

## Refactorisation du Chargement et Préparation des Données

### Subtask:
Corriger la fonction `train_demo` pour gérer les `ValueError` lors de la division des données (entraînement, validation, test) avec de petits corpus, en s'assurant que les `DataLoader` sont correctement créés et utilisés uniquement s'ils contiennent des données. Cela inclura l'ajustement des appels à `train_test_split` et les calculs de perte moyenne pour éviter les erreurs de division par zéro.


**Reasoning**:
The previous `ValueError` occurred because `train_test_split` was called with too few samples to create distinct validation and test sets. I will modify the `train_demo` function to include robust data splitting logic, handling cases with small corpora by assigning all data to the training set and ensuring `val_texts` and `test_texts` are empty. I will also add checks for empty `DataLoader` instances before calculating average losses to prevent `ZeroDivisionError`.



In [None]:
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from transformers import BertTokenizer # Import BertTokenizer here
from torch.optim.lr_scheduler import StepLR # Import StepLR
import os # Import os for file operations
from torch.utils.data import Dataset, DataLoader # Import Dataset and DataLoader
from sklearn.model_selection import train_test_split # Import train_test_split

# ----------------------------
# Advanced Tokenizer (using transformers)
# ----------------------------
class AdvancedTokenizer:
    def __init__(self, vocab_file=None):
        # Using a pre-trained tokenizer's vocabulary for demonstration
        # In a real scenario, you might train a new tokenizer on your specific corpus
        if vocab_file:
             self.tokenizer = BertTokenizer(vocab_file)
        else:
             self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.vocab = list(self.tokenizer.vocab.keys())
        self.stoi = self.tokenizer.vocab
        self.itos = {i: w for w, i in self.stoi.items()}

        # Define special tokens explicitly for clarity and compatibility
        # These might vary depending on the chosen tokenizer
        self.pad_token = self.tokenizer.pad_token
        self.unk_token = self.tokenizer.unk_token
        self.bos_token = self.tokenizer.cls_token # Using [CLS] as BOS for BERT-like tokenizers
        self.eos_token = self.tokenizer.sep_token # Using [SEP] as EOS for BERT-like tokenizers

        self.pad_token_id = self.tokenizer.pad_token_id
        self.unk_token_id = self.tokenizer.unk_token_id
        self.bos_token_id = self.tokenizer.cls_token_id
        self.eos_token_id = self.tokenizer.sep_token_id


    def encode(self, text: str) -> List[int]:
        # The tokenizer handles BOS/EOS internally with add_special_tokens=True
        # We can also manually add them if needed for specific model architectures
        return self.tokenizer.encode(text, add_special_tokens=False)

    def decode(self, ids: List[int]) -> str:
        # skip_special_tokens=True prevents decoding [CLS], [SEP], [PAD] etc.
        # We might adjust this based on how the model handles special tokens
        return self.tokenizer.decode(ids, skip_special_tokens=True)

    def __len__(self):
        return len(self.vocab)

# ----------------------------
# ARSCell: core adaptive cell
# ----------------------------
class ARSCell(nn.Module):
    """
    ARSCell: computes next hidden state from h_{t-2}, h_{t-1}, and input embedding x_{t-1}
    Implements an 'adapt' mechanism inspired by BenPolySeq:
      h_t = h_{t-1} + gate * transform(h_{t-1}, h_{t-2}, x)
    where gate is a function of the difference (h_{t-1} - h_{t-2}) and context.
    """
    def __init__(self, emb_dim: int, hidden_dim: int, dropout_prob: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        # transform candidate from concat([h_{t-1}, h_{t-2}, x])
        self.candidate_mlp = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim)
        )
        # gate network produces scalar gating factor in (0,1)
        self.gate_net = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        # small residual projection
        self.res_proj = nn.Linear(emb_dim, hidden_dim)
        # Dropout layer
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, h_prev2, h_prev1, x_embed):
        # h_prev2, h_prev1: (batch, hidden_dim)
        # x_embed: (batch, emb_dim)
        # difference signal
        diff = h_prev1 - h_prev2    # (batch, hidden_dim)
        # context vector (concat)
        ctx = torch.cat([h_prev1, h_prev2, x_embed], dim=-1)
        candidate = self.candidate_mlp(ctx)                       # (batch, hidden_dim)
        gate = self.gate_net(ctx).squeeze(-1)                     # (batch,)
        # incorporate residual from input embedding
        residual = self.res_proj(x_embed)
        # ARS-style update: adaptive addition of candidate scaled by gate
        h_t = h_prev1 + gate.unsqueeze(-1) * candidate + 0.1 * residual
        # Apply dropout
        h_t = self.dropout(h_t)
        # optional normalization
        h_t = F.layer_norm(h_t, (self.hidden_dim,))
        return h_t, gate

# ----------------------------
# ARSLM model: embed -> ARSCell(s) -> attention -> head
# ----------------------------
class ARSLM(nn.Module):
    # Modified to accept a tokenizer instance, include attention, and support multiple layers
    def __init__(self, tokenizer, emb_dim=64, hidden_dim=128, num_layers=2):
        super().__init__()
        self.tokenizer = tokenizer # Store tokenizer instance
        self.vocab_size = len(tokenizer) # Get vocab size from tokenizer
        self.num_layers = num_layers
        self.emb = nn.Embedding(self.vocab_size, emb_dim)
        # Create a list of ARSCells
        # Pass dropout_prob to ARSCell constructor
        self.cells = nn.ModuleList([ARSCell(emb_dim if i == 0 else hidden_dim, hidden_dim, dropout_prob=0.1) for i in range(num_layers)])

        # Simple additive attention mechanism, applied after the last layer's hidden states
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        self.head = nn.Linear(hidden_dim, self.vocab_size)

    def forward(self, input_ids):
        """
        input_ids: (batch, seq_len)
        We compute autoregressively, passing hidden states through layers.
        Returns logits (batch, seq_len, vocab).
        """
        bsz, seq_len = input_ids.shape
        emb = self.emb(input_ids)  # (b, seq, emb)
        device = emb.device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        # We need two previous states for each layer's ARSCell
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []
        logits = []
        gates = [] # Store gates from the last layer

        for t in range(seq_len):
            x_t = emb[:, t, :] # embedding at position t
            h_t_input = x_t # Input to the first layer is the embedding

            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                # The input to subsequent layers is the hidden state of the previous layer
                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1] # Use the output of the previous layer

                h_t, gate = cell(h_prev2, h_prev1, h_t_input)

                # Store current layer's hidden state as input for the next layer in the next time step
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next time step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1: # Only store gates from the last layer
                    gates.append(gate.unsqueeze(1))

            # After processing all layers for time step t, the output is the hidden state of the last layer
            last_layer_h_t = current_layer_hidden_states[-1]
            all_last_layer_hidden_states.append(last_layer_h_t.unsqueeze(1))

            # Apply attention and head after the last layer's hidden state
            # Causal Attention: attend over hidden states up to current time step t from the last layer
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, t+1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim) - query is the current last layer hidden state

            # Calculate attention scores
            scores = self.attention(last_layer_history) # (b, t+1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, t+1, 1)

            # Apply attention weights
            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            # Combine last layer's current hidden state with context vector before the head
            attended_h_t = last_layer_h_t + context_vector

            logit = self.head(attended_h_t) # predict token at t (or next token)
            logits.append(logit.unsqueeze(1))


        logits = torch.cat(logits, dim=1)          # (b, seq, vocab)
        gates = torch.cat(gates, dim=1)            # (b, seq) # Gates from the last layer
        return logits, gates

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=20, temperature=1.0, top_k=None):
        """
        idx: (batch, current_seq_len) - input context (token ids)
        Returns extended token ids.
        Modified for multiple layers, attention during generation, and optional top-k sampling.
        """
        self.eval()
        bsz, seq_len = idx.shape
        device = next(self.parameters()).device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []

        # Process initial context to get starting hidden states for each layer
        input_ids = idx.clone()
        emb = self.emb(input_ids) # (b, seq_len, emb_dim)

        for t in range(seq_len):
            x_t = emb[:,t,:] # embedding at position t
            h_t_input = x_t

            current_layer_hidden_states = []
            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


        out_ids = input_ids.tolist()

        for _ in range(max_new_tokens):
            # Input for the first layer's ARSCell is the embedding of the last generated token
            last_token_ids = torch.tensor([ids[-1] for ids in out_ids], dtype=torch.long, device=device).unsqueeze(1)
            x_embed = self.emb(last_token_ids).squeeze(1) # (b, emb_dim)

            h_t_input = x_embed
            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next generation step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


            # After processing all layers for the new token, apply attention and head
            last_layer_h_t = current_layer_hidden_states[-1] # Hidden state of the last layer for the new token

            # Causal Attention over all last layer hidden states history
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, current_seq_len + 1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim)

            scores = self.attention(last_layer_history) # (b, current_seq_len + 1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, current_seq_len + 1, 1)

            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            attended_h_t = last_layer_h_t + context_vector

            logits = self.head(attended_h_t) / max(1e-6, temperature)

            # Apply top-k sampling
            if top_k is not None:
                # get top k logits
                top_k = min(max(top_k, 1), logits.size(-1))  # Clamp k to be at least 1 and at most vocab size
                # Remove all tokens with a probability less than the top-k'th token's probability
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')

            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)


            # append the newly generated token
            for i in range(bsz):
                out_ids[i].append(int(next_token[i].item()))


        return out_ids

# ----------------------------
# TextDataset class for structured data handling
# ----------------------------
class TextDataset(Dataset):
    def __init__(self, texts: list, tokenizer):
        self.texts = texts
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]

# Modified to accept and use the tokenizer instance
def collate_batch(batch_texts: List[str], tokenizer):
    # Encode each text, add BOS/EOS, and convert to tensor
    encoded = []
    for t in batch_texts:
        ids = [tokenizer.bos_token_id] + tokenizer.encode(t) + [tokenizer.eos_token_id]
        encoded.append(torch.tensor(ids, dtype=torch.long))

    # Pad to max len within the batch
    max_len = max([x.size(0) for x in encoded])
    padded_batch = torch.full((len(encoded), max_len), tokenizer.pad_token_id, dtype=torch.long)
    for i, x in enumerate(encoded):
        padded_batch[i, :x.size(0)] = x

    # Prepare inputs and targets for language modeling
    inputs = padded_batch[:, :-1]
    targets = padded_batch[:, 1:]

    return inputs, targets


# Modified to use AdvancedTokenizer and handle multiple layers, and DataLoaders
def train_demo():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    corpus_file = "/tmp/large_corpus.txt"
    texts = []
    if os.path.exists(corpus_file):
        print(f"Loading corpus from {corpus_file}...")
        with open(corpus_file, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i < 1000:
                    texts.append(line.strip())
                else:
                    break
        print(f"Loaded {len(texts)} lines.")
    else:
        print(f"Corpus file not found at {corpus_file}. Using a small toy corpus instead.")
        texts = [
            "hello world this is ars",
            "the system adapts to its history",
            "benpolyseq demonstrates adaptive sequences",
            "ars can inspire new network protocols",
            "self optimizing systems are possible"
        ]

    tokenizer = AdvancedTokenizer()
    vocab_size = len(tokenizer)

    # Modified Data Splitting Logic
    num_samples = len(texts)
    if num_samples < 3:
        print(f"Warning: Corpus size ({num_samples}) is too small for a proper train/validation/test split. Using all data for training.")
        train_texts = texts
        val_texts = []
        test_texts = []
    else:
        # First split: train_val and test
        train_val_texts, test_texts = train_test_split(texts, test_size=0.2, random_state=42)
        # Defensive check for test_texts being empty
        if not test_texts and train_val_texts:
            test_texts.append(train_val_texts.pop())

        # Second split: train and validation from train_val
        # Calculate test_size for validation set: 0.1 / (1-0.2) = 0.125
        if len(train_val_texts) >= 2: # Need at least two samples to split into train and val
            val_split_size = 0.1 / (1.0 - 0.2)
            train_texts, val_texts = train_test_split(train_val_texts, test_size=val_split_size, random_state=42)
            # Defensive check for val_texts being empty
            if not val_texts and train_texts:
                val_texts.append(train_texts.pop())
        else:
            train_texts = train_val_texts
            val_texts = []

    print(f"Train size: {len(train_texts)}, Validation size: {len(val_texts)}, Test size: {len(test_texts)}")

    # Create datasets
    train_dataset = TextDataset(train_texts, tokenizer)
    val_dataset = TextDataset(val_texts, tokenizer)
    test_dataset = TextDataset(test_texts, tokenizer)

    batch_size = 8 # Define a batch size

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                collate_fn=lambda b: collate_batch(b, tokenizer))
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                              collate_fn=lambda b: collate_batch(b, tokenizer))
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                               collate_fn=lambda b: collate_batch(b, tokenizer))

    model = ARSLM(tokenizer, emb_dim=64, hidden_dim=128, num_layers=2).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    scheduler = StepLR(optimizer, step_size=50, gamma=0.5)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    n_epochs = 200
    if len(train_texts) == 0: # Check if train_texts is empty, not inputs.numel()
        print("No training data available. Skipping training.")
        return model, tokenizer

    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        num_batches = len(train_loader)
        if num_batches > 0:
            for batch_idx, (inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                logits, gates = model(inputs)
                b, seq, v = logits.shape
                loss = loss_fn(logits.view(b*seq, v), targets.reshape(b*seq))
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                total_loss += loss.item()
            avg_train_loss = total_loss / num_batches
        else:
            avg_train_loss = float('inf') # Set to inf if no training batches

        scheduler.step()

        if (epoch+1) % 50 == 0 or epoch==0:
            print(f"Epoch {epoch+1}/{n_epochs} — Avg Train Loss: {avg_train_loss:.4f}")
            print(f"Current learning rate: {scheduler.get_last_lr()[0]:.6f}")

            # Validation step with check for empty loader
            if len(val_loader) > 0:
                model.eval()
                total_val_loss = 0
                with torch.no_grad():
                    for inputs_val, targets_val in val_loader:
                        inputs_val, targets_val = inputs_val.to(device), targets_val.to(device)
                        logits_val, _ = model(inputs_val)
                        b_val, seq_val, v_val = logits_val.shape
                        loss_val = loss_fn(logits_val.view(b_val*seq_val, v_val), targets_val.reshape(b_val*seq_val))
                        total_val_loss += loss_val.item()
                avg_val_loss = total_val_loss / len(val_loader)
                print(f"Validation Loss: {avg_val_loss:.4f}")
            else:
                print("No validation data available. Skipping validation.")
                avg_val_loss = float('inf')

    print("\n=== Evaluation ===")
    model.eval()
    total_test_loss = 0
    total_perplexity = 0
    if len(test_loader) > 0:
        with torch.no_grad():
            for inputs_test, targets_test in test_loader:
                inputs_test, targets_test = inputs_test.to(device), targets_test.to(device)
                logits_test, _ = model(inputs_test)
                b_test, seq_test, v_test = logits_test.shape
                loss_test = loss_fn(logits_test.view(b_test*seq_test, v_test), targets_test.reshape(b_test*seq_test))
                total_test_loss += loss_test.item()
                total_perplexity += torch.exp(loss_test).item()

        avg_test_loss = total_test_loss / len(test_loader)
        avg_test_perplexity = total_perplexity / len(test_loader)
        print(f"Test Loss: {avg_test_loss:.4f}")
        print(f"Perplexity on test data: {avg_test_perplexity:.4f}")
    else:
        print("No test data available. Skipping test evaluation.")
        avg_test_loss = float('inf')
        avg_test_perplexity = float('inf')

    # demo generation
    context = "hello world"
    idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(context) ], dtype=torch.long, device=device)
    out_ids = model.generate(idx, max_new_tokens=15, temperature=1.0, top_k=50)[0]
    print("\n=== Generated ===")
    print(tokenizer.decode(out_ids))
    return model, tokenizer

# ----------------------------
# Run demo if main
# ----------------------------
if __name__ == "__main__":
    random.seed(0)
    torch.manual_seed(0)
    model, tokenizer = train_demo()

Corpus file not found at /tmp/large_corpus.txt. Using a small toy corpus instead.
Train size: 3, Validation size: 1, Test size: 1
Epoch 1/200 — Avg Train Loss: 11.1917
Current learning rate: 0.001000
Validation Loss: 10.5004
Epoch 50/200 — Avg Train Loss: 0.4314
Current learning rate: 0.000500
Validation Loss: 13.7606
Epoch 100/200 — Avg Train Loss: 0.1822
Current learning rate: 0.000250
Validation Loss: 14.0828
Epoch 150/200 — Avg Train Loss: 0.1426
Current learning rate: 0.000125
Validation Loss: 14.6507
Epoch 200/200 — Avg Train Loss: 0.2472
Current learning rate: 0.000063
Validation Loss: 14.6519

=== Evaluation ===
Test Loss: 15.3464
Perplexity on test data: 4622208.0000

=== Generated ===
hello world is ars


**Reasoning**:
I will define a `Config` class to centralize all hyperparameters and settings, making the model more configurable. This class will hold model parameters, training parameters, and file paths.



In [None]:
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from transformers import BertTokenizer # Import BertTokenizer here
from torch.optim.lr_scheduler import StepLR # Import StepLR
import os # Import os for file operations
from torch.utils.data import Dataset, DataLoader # Import Dataset and DataLoader
from sklearn.model_selection import train_test_split # Import train_test_split

# ----------------------------
# Configuration Class
# ----------------------------
class Config:
    """Centralized configuration for the ARSLM model and training."""
    def __init__(self):
        # Model Parameters
        self.emb_dim = 64
        self.hidden_dim = 128
        self.num_layers = 2
        self.dropout_prob = 0.1

        # Training Parameters
        self.learning_rate = 1e-3
        self.n_epochs = 200
        self.batch_size = 8
        self.lr_scheduler_step_size = 50
        self.lr_scheduler_gamma = 0.5
        self.clip_grad_norm = 1.0

        # Generation Parameters
        self.max_new_tokens = 15
        self.temperature = 1.0
        self.top_k = 50
        self.generation_context = "hello world"

        # Data Parameters
        self.corpus_file = "/tmp/large_corpus.txt"
        self.max_corpus_lines = 1000 # Max lines to load from corpus_file for demo
        self.train_test_split_ratio = 0.2
        self.train_val_split_ratio = 0.5 # 0.1 of total
        self.random_state = 42

        # Saving/Loading Paths
        self.model_save_dir = "./model_checkpoint"
        self.tokenizer_save_dir = "./tokenizer_checkpoint"
        self.config_save_path = os.path.join(self.model_save_dir, "config.pt") # Path to save config

# ----------------------------
# Advanced Tokenizer (using transformers)
# ----------------------------
class AdvancedTokenizer:
    def __init__(self, vocab_file=None, pretrained_model_name_or_path='bert-base-uncased'):
        if vocab_file:
             self.tokenizer = BertTokenizer(vocab_file)
        else:
             # Check if pretrained_model_name_or_path is a directory for a saved tokenizer
             if os.path.isdir(pretrained_model_name_or_path):
                 self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
             else:
                 self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)

        self.vocab = list(self.tokenizer.vocab.keys())
        self.stoi = self.tokenizer.vocab
        self.itos = {i: w for w, i in self.stoi.items()}

        self.pad_token = self.tokenizer.pad_token
        self.unk_token = self.tokenizer.unk_token
        self.bos_token = self.tokenizer.cls_token # Using [CLS] as BOS for BERT-like tokenizers
        self.eos_token = self.tokenizer.sep_token # Using [SEP] as EOS for BERT-like tokenizers

        self.pad_token_id = self.tokenizer.pad_token_id
        self.unk_token_id = self.tokenizer.unk_token_id
        self.bos_token_id = self.tokenizer.cls_token_id
        self.eos_token_id = self.tokenizer.sep_token_id


    def encode(self, text: str) -> List[int]:
        return self.tokenizer.encode(text, add_special_tokens=False)

    def decode(self, ids: List[int]) -> str:
        return self.tokenizer.decode(ids, skip_special_tokens=True)

    def __len__(self):
        return len(self.vocab)

# ----------------------------
# ARSCell: core adaptive cell
# ----------------------------
class ARSCell(nn.Module):
    """
    ARSCell: computes next hidden state from h_{t-2}, h_{t-1}, and input embedding x_{t-1}
    Implements an 'adapt' mechanism inspired by BenPolySeq:
      h_t = h_{t-1} + gate * transform(h_{t-1}, h_{t-2}, x)
    where gate is a function of the difference (h_{t-1} - h_{t-2}) and context.
    """
    def __init__(self, emb_dim: int, hidden_dim: int, dropout_prob: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        # transform candidate from concat([h_{t-1}, h_{t-2}, x])
        self.candidate_mlp = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim)
        )
        # gate network produces scalar gating factor in (0,1)
        self.gate_net = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        # small residual projection
        self.res_proj = nn.Linear(emb_dim, hidden_dim)
        # Dropout layer
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, h_prev2, h_prev1, x_embed):
        # h_prev2, h_prev1: (batch, hidden_dim)
        # x_embed: (batch, emb_dim)
        # difference signal
        diff = h_prev1 - h_prev2    # (batch, hidden_dim)
        # context vector (concat)
        ctx = torch.cat([h_prev1, h_prev2, x_embed], dim=-1)
        candidate = self.candidate_mlp(ctx)                       # (batch, hidden_dim)
        gate = self.gate_net(ctx).squeeze(-1)                     # (batch,)
        # incorporate residual from input embedding
        residual = self.res_proj(x_embed)
        # ARS-style update: adaptive addition of candidate scaled by gate
        h_t = h_prev1 + gate.unsqueeze(-1) * candidate + 0.1 * residual
        # Apply dropout
        h_t = self.dropout(h_t)
        # optional normalization
        h_t = F.layer_norm(h_t, (self.hidden_dim,))
        return h_t, gate

# ----------------------------
# ARSLM model: embed -> ARSCell(s) -> attention -> head
# ----------------------------
class ARSLM(nn.Module):
    # Modified to accept a tokenizer instance, include attention, and support multiple layers
    def __init__(self, tokenizer, emb_dim=64, hidden_dim=128, num_layers=2, dropout_prob=0.1):
        super().__init__()
        self.tokenizer = tokenizer # Store tokenizer instance
        self.vocab_size = len(tokenizer) # Get vocab size from tokenizer
        self.num_layers = num_layers
        self.emb = nn.Embedding(self.vocab_size, emb_dim)
        # Create a list of ARSCells
        # Pass dropout_prob to ARSCell constructor
        self.cells = nn.ModuleList([ARSCell(emb_dim if i == 0 else hidden_dim, hidden_dim, dropout_prob=dropout_prob) for i in range(num_layers)])

        # Simple additive attention mechanism, applied after the last layer's hidden states
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        self.head = nn.Linear(hidden_dim, self.vocab_size)

    def forward(self, input_ids):
        """
        input_ids: (batch, seq_len)
        We compute autoregressively, passing hidden states through layers.
        Returns logits (batch, seq_len, vocab).
        """
        bsz, seq_len = input_ids.shape
        emb = self.emb(input_ids)  # (b, seq, emb)
        device = emb.device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        # We need two previous states for each layer's ARSCell
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []
        logits = []
        gates = [] # Store gates from the last layer

        for t in range(seq_len):
            x_t = emb[:, t, :] # embedding at position t
            h_t_input = x_t # Input to the first layer is the embedding

            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1] # Use the output of the previous layer

                h_t, gate = cell(h_prev2, h_prev1, h_t_input)

                # Store current layer's hidden state as input for the next layer in the next time step
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next time step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1: # Only store gates from the last layer
                    gates.append(gate.unsqueeze(1))

            # After processing all layers for time step t, the output is the hidden state of the last layer
            last_layer_h_t = current_layer_hidden_states[-1]
            all_last_layer_hidden_states.append(last_layer_h_t.unsqueeze(1))

            # Apply attention and head after the last layer's hidden state
            # Causal Attention: attend over hidden states up to current time step t from the last layer
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, t+1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim) - query is the current last layer hidden state

            # Calculate attention scores
            scores = self.attention(last_layer_history) # (b, t+1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, t+1, 1)

            # Apply attention weights
            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            # Combine last layer's current hidden state with context vector before the head
            attended_h_t = last_layer_h_t + context_vector

            logit = self.head(attended_h_t) # predict token at t (or next token)
            logits.append(logit.unsqueeze(1))


        logits = torch.cat(logits, dim=1)          # (b, seq, vocab)
        gates = torch.cat(gates, dim=1)            # (b, seq) # Gates from the last layer
        return logits, gates

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=20, temperature=1.0, top_k=None):
        """
        idx: (batch, current_seq_len) - input context (token ids)
        Returns extended token ids.
        Modified for multiple layers, attention during generation, and optional top-k sampling.
        """
        self.eval()
        bsz, seq_len = idx.shape
        device = next(self.parameters()).device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []

        # Process initial context to get starting hidden states for each layer
        input_ids = idx.clone()
        emb = self.emb(input_ids) # (b, seq_len, emb_dim)

        for t in range(seq_len):
            x_t = emb[:,t,:] # embedding at position t
            h_t_input = x_t

            current_layer_hidden_states = []
            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


        out_ids = input_ids.tolist()

        for _ in range(max_new_tokens):
            # Input for the first layer's ARSCell is the embedding of the last generated token
            last_token_ids = torch.tensor([ids[-1] for ids in out_ids], dtype=torch.long, device=device).unsqueeze(1)
            x_embed = self.emb(last_token_ids).squeeze(1) # (b, emb_dim)

            h_t_input = x_embed
            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next generation step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


            # After processing all layers for the new token, apply attention and head
            last_layer_h_t = current_layer_hidden_states[-1] # Hidden state of the last layer for the new token

            # Causal Attention over all last layer hidden states history
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, current_seq_len + 1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim)

            scores = self.attention(last_layer_history) # (b, current_seq_len + 1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, current_seq_len + 1, 1)

            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            attended_h_t = last_layer_h_t + context_vector

            logits = self.head(attended_h_t) / max(1e-6, temperature)

            # Apply top-k sampling
            if top_k is not None:
                # get top k logits
                top_k = min(max(top_k, 1), logits.size(-1))  # Clamp k to be at least 1 and at most vocab size
                # Remove all tokens with a probability less than the top-k'th token's probability
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')

            # Check if all probabilities are zero after filtering/temperature, if so, fall back to argmax
            if torch.all(logits == -float('Inf')):
                # This means all tokens were filtered out or resulted in 0 probability. Choose a random valid token or a default.
                # For now, let's just pick a generic unk_token or a random one if unk is not available/appropriate
                if self.tokenizer.unk_token_id is not None:
                    next_token = torch.tensor([self.tokenizer.unk_token_id] * bsz, device=device)
                else:
                    # Fallback to random if unk token is not defined
                    next_token = torch.randint(0, self.vocab_size, (bsz,), device=device)
            else:
                probs = F.softmax(logits, dim=-1)
                # Ensure that if probabilities sum to zero (e.g., all logits were -inf before soft max),
                # we handle it gracefully, though the above check should prevent it.
                if torch.sum(probs) == 0: # This case should ideally not happen after the -inf check
                    if self.tokenizer.unk_token_id is not None:
                        next_token = torch.tensor([self.tokenizer.unk_token_id] * bsz, device=device)
                    else:
                        next_token = torch.randint(0, self.vocab_size, (bsz,), device=device)
                else:
                    next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)


            # append the newly generated token
            for i in range(bsz):
                out_ids[i].append(int(next_token[i].item()))


        return out_ids

# ----------------------------
# TextDataset class for structured data handling
# ----------------------------
class TextDataset(Dataset):
    def __init__(self, texts: list, tokenizer):
        self.texts = texts
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]

# Modified to accept and use the tokenizer instance
def collate_batch(batch_texts: List[str], tokenizer):
    # Encode each text, add BOS/EOS, and convert to tensor
    encoded = []
    for t in batch_texts:
        ids = [tokenizer.bos_token_id] + tokenizer.encode(t) + [tokenizer.eos_token_id]
        encoded.append(torch.tensor(ids, dtype=torch.long))

    # Pad to max len within the batch
    max_len = max([x.size(0) for x in encoded])
    padded_batch = torch.full((len(encoded), max_len), tokenizer.pad_token_id, dtype=torch.long)
    for i, x in enumerate(encoded):
        padded_batch[i, :x.size(0)] = x

    # Prepare inputs and targets for language modeling
    inputs = padded_batch[:, :-1]
    targets = padded_batch[:, 1:]

    return inputs, targets


# Modified to use AdvancedTokenizer and handle multiple layers, and DataLoaders
def train_demo():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize configuration
    cfg = Config()

    corpus_file = cfg.corpus_file
    texts = []
    if os.path.exists(corpus_file):
        print(f"Loading corpus from {corpus_file}...")
        with open(corpus_file, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i < cfg.max_corpus_lines:
                    texts.append(line.strip())
                else:
                    break
        print(f"Loaded {len(texts)} lines.")
    else:
        print(f"Corpus file not found at {corpus_file}. Using a small toy corpus instead.")
        texts = [
            "hello world this is ars",
            "the system adapts to its history",
            "benpolyseq demonstrates adaptive sequences",
            "ars can inspire new network protocols",
            "self optimizing systems are possible"
        ]

    tokenizer = AdvancedTokenizer() # Initialize AdvancedTokenizer
    vocab_size = len(tokenizer)

    # Modified Data Splitting Logic
    num_samples = len(texts)
    if num_samples < 3:
        print(f"Warning: Corpus size ({num_samples}) is too small for a proper train/validation/test split. Using all data for training.")
        train_texts = texts
        val_texts = []
        test_texts = []
    else:
        # First split: train_val and test
        train_val_texts, test_texts = train_test_split(texts, test_size=cfg.train_test_split_ratio, random_state=cfg.random_state)
        # Defensive check for test_texts being empty
        if not test_texts and train_val_texts:
            test_texts.append(train_val_texts.pop()) # Move one item to test if test is empty and train_val is not

        # Second split: train and validation from train_val
        if len(train_val_texts) >= 2: # Need at least two samples to split into train and val
            val_split_size = cfg.train_val_split_ratio
            train_texts, val_texts = train_test_split(train_val_texts, test_size=val_split_size, random_state=cfg.random_state)
            # Defensive check for val_texts being empty
            if not val_texts and train_texts:
                val_texts.append(train_texts.pop()) # Move one item to val if val is empty and train is not
        else:
            train_texts = train_val_texts
            val_texts = []

    print(f"Train size: {len(train_texts)}, Validation size: {len(val_texts)}, Test size: {len(test_texts)}")

    # Create datasets
    train_dataset = TextDataset(train_texts, tokenizer)
    val_dataset = TextDataset(val_texts, tokenizer)
    test_dataset = TextDataset(test_texts, tokenizer)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True,
                                collate_fn=lambda b: collate_batch(b, tokenizer))
    val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False,
                              collate_fn=lambda b: collate_batch(b, tokenizer))
    test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False,
                               collate_fn=lambda b: collate_batch(b, tokenizer))

    model = ARSLM(tokenizer, emb_dim=cfg.emb_dim, hidden_dim=cfg.hidden_dim, num_layers=cfg.num_layers, dropout_prob=cfg.dropout_prob).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate)
    scheduler = StepLR(optimizer, step_size=cfg.lr_scheduler_step_size, gamma=cfg.lr_scheduler_gamma)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    n_epochs = cfg.n_epochs
    if len(train_texts) == 0: # Check if train_texts is empty, not inputs.numel()
        print("No training data available. Skipping training.")
        return model, tokenizer

    # Initialize best validation loss for saving the best model
    best_val_loss = float('inf')
    # Create directories for saving if they don't exist
    os.makedirs(cfg.model_save_dir, exist_ok=True)
    os.makedirs(cfg.tokenizer_save_dir, exist_ok=True)

    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        num_batches = len(train_loader)
        if num_batches > 0:
            for batch_idx, (inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                logits, gates = model(inputs)
                b, seq, v = logits.shape
                loss = loss_fn(logits.view(b*seq, v), targets.reshape(b*seq))
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad_norm)
                optimizer.step()
                total_loss += loss.item()
            avg_train_loss = total_loss / num_batches
        else:
            avg_train_loss = float('inf') # Set to inf if no training batches

        scheduler.step()

        if (epoch+1) % 50 == 0 or epoch==0:
            print(f"Epoch {epoch+1}/{n_epochs} — Avg Train Loss: {avg_train_loss:.4f}")
            print(f"Current learning rate: {scheduler.get_last_lr()[0]:.6f}")

            # Validation step with check for empty loader
            if len(val_loader) > 0:
                model.eval()
                total_val_loss = 0
                with torch.no_grad():
                    for inputs_val, targets_val in val_loader:
                        inputs_val, targets_val = inputs_val.to(device), targets_val.to(device)
                        logits_val, _ = model(inputs_val)
                        b_val, seq_val, v_val = logits_val.shape
                        loss_val = loss_fn(logits_val.view(b_val*seq_val, v_val), targets_val.reshape(b_val*seq_val))
                        total_val_loss += loss_val.item()
                avg_val_loss = total_val_loss / len(val_loader)
                print(f"Validation Loss: {avg_val_loss:.4f}")

                # Save the model if this is the best validation loss so far
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    print(f"Saving best model with validation loss: {best_val_loss:.4f}")
                    torch.save(model.state_dict(), os.path.join(cfg.model_save_dir, "best_model.pt"))
                    tokenizer.tokenizer.save_pretrained(cfg.tokenizer_save_dir)
                    # Save the configuration as well
                    torch.save(cfg, cfg.config_save_path)

            else:
                print("No validation data available. Skipping validation.")
                avg_val_loss = float('inf')

    print("\n=== Evaluation ===")
    model.eval()
    total_test_loss = 0
    total_perplexity = 0
    if len(test_loader) > 0:
        with torch.no_grad():
            for inputs_test, targets_test in test_loader:
                inputs_test, targets_test = inputs_test.to(device), targets_test.to(device)
                logits_test, _ = model(inputs_test)
                b_test, seq_test, v_test = logits_test.shape
                loss_test = loss_fn(logits_test.view(b_test*seq_test, v_test), targets_test.reshape(b_test*seq_test))
                total_test_loss += loss_test.item()
                total_perplexity += torch.exp(loss_test).item()

        avg_test_loss = total_test_loss / len(test_loader)
        avg_test_perplexity = total_perplexity / len(test_loader)
        print(f"Test Loss: {avg_test_loss:.4f}")
        print(f"Perplexity on test data: {avg_test_perplexity:.4f}")
    else:
        print("No test data available. Skipping test evaluation.")
        avg_test_loss = float('inf')
        avg_test_perplexity = float('inf')

    # demo generation
    context = cfg.generation_context
    idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(context) ], dtype=torch.long, device=device)
    out_ids = model.generate(idx, max_new_tokens=cfg.max_new_tokens, temperature=cfg.temperature, top_k=cfg.top_k)[0]
    print("\n=== Generated ===")
    print(tokenizer.decode(out_ids))
    return model, tokenizer

def load_model_and_tokenizer(model_path: str, tokenizer_path: str, config_path: str, device):
    """
    Loads a pre-trained ARSLM model and its tokenizer along with its configuration.
    """
    # Add Config class to safe globals for unpickling
    torch.serialization.add_safe_globals([Config])
    # Load the configuration first
    cfg = torch.load(config_path)

    loaded_tokenizer = AdvancedTokenizer(pretrained_model_name_or_path=tokenizer_path)

    # Instantiate the model using the loaded configuration
    loaded_model = ARSLM(loaded_tokenizer,
                         emb_dim=cfg.emb_dim,
                         hidden_dim=cfg.hidden_dim,
                         num_layers=cfg.num_layers,
                         dropout_prob=cfg.dropout_prob).to(device)

    loaded_model.load_state_dict(torch.load(model_path, map_location=device))
    loaded_model.eval() # Set model to evaluation mode

    print(f"Model loaded from {model_path}")
    print(f"Tokenizer loaded from {tokenizer_path}")
    print(f"Configuration loaded from {config_path}")

    return loaded_model, loaded_tokenizer, cfg

def deploy_model(model, tokenizer, config, device, input_text: str):
    """
    Function to deploy the model for inference (text generation).
    """
    model.eval()
    with torch.no_grad():
        idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(input_text) ], dtype=torch.long, device=device)
        out_ids = model.generate(idx, max_new_tokens=config.max_new_tokens, temperature=config.temperature, top_k=config.top_k)[0]
        generated_text = tokenizer.decode(out_ids)
    return generated_text

# ----------------------------
# Run demo if main
# ----------------------------
if __name__ == "__main__":
    random.seed(0)
    torch.manual_seed(0)
    model, tokenizer = train_demo()

    # Demonstrate loading the saved model and tokenizer
    print("\n=== Demonstrating Model Loading and Deployment ===")
    cfg_instance = Config()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    try:
        loaded_model, loaded_tokenizer, loaded_config = load_model_and_tokenizer(
            os.path.join(cfg_instance.model_save_dir, "best_model.pt"),
            cfg_instance.tokenizer_save_dir,
            cfg_instance.config_save_path,
            device
        )
        # Test generation with loaded model using the deploy_model function
        context = "loaded model generates"
        generated_output = deploy_model(loaded_model, loaded_tokenizer, loaded_config, device, context)
        print("Generated with loaded model:")
        print(generated_output)
    except FileNotFoundError:
        print("No saved model found. Please ensure training completed successfully and saved a model.")


Corpus file not found at /tmp/large_corpus.txt. Using a small toy corpus instead.
Train size: 2, Validation size: 2, Test size: 1
Epoch 1/200 — Avg Train Loss: 11.5174
Current learning rate: 0.001000
Validation Loss: 10.5243
Saving best model with validation loss: 10.5243
Epoch 50/200 — Avg Train Loss: 0.3560
Current learning rate: 0.000500
Validation Loss: 13.7622
Epoch 100/200 — Avg Train Loss: 0.1040
Current learning rate: 0.000250
Validation Loss: 14.4926
Epoch 150/200 — Avg Train Loss: 0.1516
Current learning rate: 0.000125
Validation Loss: 14.9816
Epoch 200/200 — Avg Train Loss: 0.0833
Current learning rate: 0.000063
Validation Loss: 15.3237

=== Evaluation ===
Test Loss: 15.1545
Perplexity on test data: 3815011.5000

=== Generated ===
hello world is ars

=== Demonstrating Model Loading and Deployment ===
Model loaded from ./model_checkpoint/best_model.pt
Tokenizer loaded from ./tokenizer_checkpoint
Configuration loaded from ./model_checkpoint/config.pt
Generated with loaded mode

### Déploiement du modèle pour l'inférence

Pour déployer le modèle et générer du texte, suivez ces étapes :
1.  Chargez le modèle, le tokenizer et la configuration à l'aide de la fonction `load_model_and_tokenizer`.
2.  Utilisez la fonction `deploy_model` avec votre texte d'entrée (`input_text`).

In [None]:
# Initialize configuration (to get paths)
cfg = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    # Load the trained model, tokenizer, and configuration
    loaded_model, loaded_tokenizer, loaded_config = load_model_and_tokenizer(
        os.path.join(cfg.model_save_dir, "best_model.pt"),
        cfg.tokenizer_save_dir,
        cfg.config_save_path,
        device
    )

    # Define your input text for generation
    input_text = "Colab is a great tool for"

    # Deploy the model to generate text
    generated_output = deploy_model(loaded_model, loaded_tokenizer, loaded_config, device, input_text)

    print(f"Input Text: {input_text}")
    print(f"Generated Output: {generated_output}")

except FileNotFoundError:
    print("Erreur: Le modèle, le tokenizer ou la configuration n'ont pas été trouvés. Veuillez vous assurer que l'entraînement a été effectué et que les fichiers ont été sauvegardés.")
except Exception as e:
    print(f"Une erreur inattendue est survenue lors du déploiement : {e}")


Model loaded from ./model_checkpoint/best_model.pt
Tokenizer loaded from ./tokenizer_checkpoint
Configuration loaded from ./model_checkpoint/config.pt
Input Text: Colab is a great tool for
Generated Output: colab is a great tool for [unused276] modernism publishes [unused276] detrimental cote undeveloped [unused714]lok vanity deemed [unused714] fisheries marriage choke


## Implémentation de la Sauvegarde et du Chargement du Modèle

### Subtask:
Ajouter des fonctionnalités dans train_demo pour sauvegarder les poids du meilleur modèle entraîné (basé sur une métrique comme la perte de validation ou la perplexité) et le tokenizer associé. Développer également une fonction distincte pour charger un modèle et son tokenizer pré-entraînés, essentielle pour l'inférence en production.


**Reasoning**:
I will define `model_save_dir` and `tokenizer_save_dir` in the `Config` class to specify where the model and tokenizer should be saved.



In [None]:
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from transformers import BertTokenizer # Import BertTokenizer here
from torch.optim.lr_scheduler import StepLR # Import StepLR
import os # Import os for file operations
from torch.utils.data import Dataset, DataLoader # Import Dataset and DataLoader
from sklearn.model_selection import train_test_split # Import train_test_split

# ----------------------------
# Configuration Class
# ----------------------------
class Config:
    """Centralized configuration for the ARSLM model and training."""
    def __init__(self):
        # Model Parameters
        self.emb_dim = 64
        self.hidden_dim = 128
        self.num_layers = 2
        self.dropout_prob = 0.1

        # Training Parameters
        self.learning_rate = 1e-3
        self.n_epochs = 200
        self.batch_size = 8
        self.lr_scheduler_step_size = 50
        self.lr_scheduler_gamma = 0.5
        self.clip_grad_norm = 1.0

        # Generation Parameters
        self.max_new_tokens = 15
        self.temperature = 1.0
        self.top_k = 50
        self.generation_context = "hello world"

        # Data Parameters
        self.corpus_file = "/tmp/large_corpus.txt"
        self.max_corpus_lines = 1000 # Max lines to load from corpus_file for demo
        self.train_test_split_ratio = 0.2
        self.train_val_split_ratio = 0.5 # 0.1 of total
        self.random_state = 42

        # Saving/Loading Paths
        self.model_save_dir = "./model_checkpoint"
        self.tokenizer_save_dir = "./tokenizer_checkpoint"

# ----------------------------
# Advanced Tokenizer (using transformers)
# ----------------------------
class AdvancedTokenizer:
    def __init__(self, vocab_file=None, pretrained_model_name_or_path='bert-base-uncased'):
        if vocab_file:
             self.tokenizer = BertTokenizer(vocab_file)
        else:
             # Check if pretrained_model_name_or_path is a directory for a saved tokenizer
             if os.path.isdir(pretrained_model_name_or_path):
                 self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
             else:
                 self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)

        self.vocab = list(self.tokenizer.vocab.keys())
        self.stoi = self.tokenizer.vocab
        self.itos = {i: w for w, i in self.stoi.items()}

        self.pad_token = self.tokenizer.pad_token
        self.unk_token = self.tokenizer.unk_token
        self.bos_token = self.tokenizer.cls_token # Using [CLS] as BOS for BERT-like tokenizers
        self.eos_token = self.tokenizer.sep_token # Using [SEP] as EOS for BERT-like tokenizers

        self.pad_token_id = self.tokenizer.pad_token_id
        self.unk_token_id = self.tokenizer.unk_token_id
        self.bos_token_id = self.tokenizer.cls_token_id
        self.eos_token_id = self.tokenizer.sep_token_id


    def encode(self, text: str) -> List[int]:
        return self.tokenizer.encode(text, add_special_tokens=False)

    def decode(self, ids: List[int]) -> str:
        return self.tokenizer.decode(ids, skip_special_tokens=True)

    def __len__(self):
        return len(self.vocab)

# ----------------------------
# ARSCell: core adaptive cell
# ----------------------------
class ARSCell(nn.Module):
    """
    ARSCell: computes next hidden state from h_{t-2}, h_{t-1}, and input embedding x_{t-1}
    Implements an 'adapt' mechanism inspired by BenPolySeq:
      h_t = h_{t-1} + gate * transform(h_{t-1}, h_{t-2}, x)
    where gate is a function of the difference (h_{t-1} - h_{t-2}) and context.
    """
    def __init__(self, emb_dim: int, hidden_dim: int, dropout_prob: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        # transform candidate from concat([h_{t-1}, h_{t-2}, x])
        self.candidate_mlp = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim)
        )
        # gate network produces scalar gating factor in (0,1)
        self.gate_net = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        # small residual projection
        self.res_proj = nn.Linear(emb_dim, hidden_dim)
        # Dropout layer
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, h_prev2, h_prev1, x_embed):
        # h_prev2, h_prev1: (batch, hidden_dim)
        # x_embed: (batch, emb_dim)
        # difference signal
        diff = h_prev1 - h_prev2    # (batch, hidden_dim)
        # context vector (concat)
        ctx = torch.cat([h_prev1, h_prev2, x_embed], dim=-1)
        candidate = self.candidate_mlp(ctx)                       # (batch, hidden_dim)
        gate = self.gate_net(ctx).squeeze(-1)                     # (batch,)
        # incorporate residual from input embedding
        residual = self.res_proj(x_embed)
        # ARS-style update: adaptive addition of candidate scaled by gate
        h_t = h_prev1 + gate.unsqueeze(-1) * candidate + 0.1 * residual
        # Apply dropout
        h_t = self.dropout(h_t)
        # optional normalization
        h_t = F.layer_norm(h_t, (self.hidden_dim,))
        return h_t, gate

# ----------------------------
# ARSLM model: embed -> ARSCell(s) -> attention -> head
# ----------------------------
class ARSLM(nn.Module):
    # Modified to accept a tokenizer instance, include attention, and support multiple layers
    def __init__(self, tokenizer, emb_dim=64, hidden_dim=128, num_layers=2, dropout_prob=0.1):
        super().__init__()
        self.tokenizer = tokenizer # Store tokenizer instance
        self.vocab_size = len(tokenizer) # Get vocab size from tokenizer
        self.num_layers = num_layers
        self.emb = nn.Embedding(self.vocab_size, emb_dim)
        # Create a list of ARSCells
        # Pass dropout_prob to ARSCell constructor
        self.cells = nn.ModuleList([ARSCell(emb_dim if i == 0 else hidden_dim, hidden_dim, dropout_prob=dropout_prob) for i in range(num_layers)])

        # Simple additive attention mechanism, applied after the last layer's hidden states
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        self.head = nn.Linear(hidden_dim, self.vocab_size)

    def forward(self, input_ids):
        """
        input_ids: (batch, seq_len)
        We compute autoregressively, passing hidden states through layers.
        Returns logits (batch, seq_len, vocab).
        """
        bsz, seq_len = input_ids.shape
        emb = self.emb(input_ids)  # (b, seq, emb)
        device = emb.device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        # We need two previous states for each layer's ARSCell
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []
        logits = []
        gates = [] # Store gates from the last layer

        for t in range(seq_len):
            x_t = emb[:, t, :] # embedding at position t
            h_t_input = x_t # Input to the first layer is the embedding

            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                # The input to subsequent layers is the hidden state of the previous layer
                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1] # Use the output of the previous layer

                h_t, gate = cell(h_prev2, h_prev1, h_t_input)

                # Store current layer's hidden state as input for the next layer in the next time step
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next time step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1: # Only store gates from the last layer
                    gates.append(gate.unsqueeze(1))

            # After processing all layers for time step t, the output is the hidden state of the last layer
            last_layer_h_t = current_layer_hidden_states[-1]
            all_last_layer_hidden_states.append(last_layer_h_t.unsqueeze(1))

            # Apply attention and head after the last layer's hidden state
            # Causal Attention: attend over hidden states up to current time step t from the last layer
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, t+1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim) - query is the current last layer hidden state

            # Calculate attention scores
            scores = self.attention(last_layer_history) # (b, t+1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, t+1, 1)

            # Apply attention weights
            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            # Combine last layer's current hidden state with context vector before the head
            attended_h_t = last_layer_h_t + context_vector

            logit = self.head(attended_h_t) # predict token at t (or next token)
            logits.append(logit.unsqueeze(1))


        logits = torch.cat(logits, dim=1)          # (b, seq, vocab)
        gates = torch.cat(gates, dim=1)            # (b, seq) # Gates from the last layer
        return logits, gates

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=20, temperature=1.0, top_k=None):
        """
        idx: (batch, current_seq_len) - input context (token ids)
        Returns extended token ids.
        Modified for multiple layers, attention during generation, and optional top-k sampling.
        """
        self.eval()
        bsz, seq_len = idx.shape
        device = next(self.parameters()).device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []

        # Process initial context to get starting hidden states for each layer
        input_ids = idx.clone()
        emb = self.emb(input_ids) # (b, seq_len, emb_dim)

        for t in range(seq_len):
            x_t = emb[:,t,:] # embedding at position t
            h_t_input = x_t

            current_layer_hidden_states = []
            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


        out_ids = input_ids.tolist()

        for _ in range(max_new_tokens):
            # Input for the first layer's ARSCell is the embedding of the last generated token
            last_token_ids = torch.tensor([ids[-1] for ids in out_ids], dtype=torch.long, device=device).unsqueeze(1)
            x_embed = self.emb(last_token_ids).squeeze(1) # (b, emb_dim)

            h_t_input = x_embed
            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next generation step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


            # After processing all layers for the new token, apply attention and head
            last_layer_h_t = current_layer_hidden_states[-1] # Hidden state of the last layer for the new token

            # Causal Attention over all last layer hidden states history
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, current_seq_len + 1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim)

            scores = self.attention(last_layer_history) # (b, current_seq_len + 1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, current_seq_len + 1, 1)

            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            attended_h_t = last_layer_h_t + context_vector

            logits = self.head(attended_h_t) / max(1e-6, temperature)

            # Apply top-k sampling
            if top_k is not None:
                # get top k logits
                top_k = min(max(top_k, 1), logits.size(-1))  # Clamp k to be at least 1 and at most vocab size
                # Remove all tokens with a probability less than the top-k'th token's probability
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')

            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)


            # append the newly generated token
            for i in range(bsz):
                out_ids[i].append(int(next_token[i].item()))


        return out_ids

# ----------------------------
# TextDataset class for structured data handling
# ----------------------------
class TextDataset(Dataset):
    def __init__(self, texts: list, tokenizer):
        self.texts = texts
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]

# Modified to accept and use the tokenizer instance
def collate_batch(batch_texts: List[str], tokenizer):
    # Encode each text, add BOS/EOS, and convert to tensor
    encoded = []
    for t in batch_texts:
        ids = [tokenizer.bos_token_id] + tokenizer.encode(t) + [tokenizer.eos_token_id]
        encoded.append(torch.tensor(ids, dtype=torch.long))

    # Pad to max len within the batch
    max_len = max([x.size(0) for x in encoded])
    padded_batch = torch.full((len(encoded), max_len), tokenizer.pad_token_id, dtype=torch.long)
    for i, x in enumerate(encoded):
        padded_batch[i, :x.size(0)] = x

    # Prepare inputs and targets for language modeling
    inputs = padded_batch[:, :-1]
    targets = padded_batch[:, 1:]

    return inputs, targets


# Modified to use AdvancedTokenizer and handle multiple layers, and DataLoaders
def train_demo():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize configuration
    cfg = Config()

    corpus_file = cfg.corpus_file
    texts = []
    if os.path.exists(corpus_file):
        print(f"Loading corpus from {corpus_file}...")
        with open(corpus_file, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i < cfg.max_corpus_lines:
                    texts.append(line.strip())
                else:
                    break
        print(f"Loaded {len(texts)} lines.")
    else:
        print(f"Corpus file not found at {corpus_file}. Using a small toy corpus instead.")
        texts = [
            "hello world this is ars",
            "the system adapts to its history",
            "benpolyseq demonstrates adaptive sequences",
            "ars can inspire new network protocols",
            "self optimizing systems are possible"
        ]

    tokenizer = AdvancedTokenizer()
    vocab_size = len(tokenizer)

    # Modified Data Splitting Logic
    num_samples = len(texts)
    if num_samples < 3:
        print(f"Warning: Corpus size ({num_samples}) is too small for a proper train/validation/test split. Using all data for training.")
        train_texts = texts
        val_texts = []
        test_texts = []
    else:
        # First split: train_val and test
        train_val_texts, test_texts = train_test_split(texts, test_size=cfg.train_test_split_ratio, random_state=cfg.random_state)
        # Defensive check for test_texts being empty
        if not test_texts and train_val_texts:
            test_texts.append(train_val_texts.pop())

        # Second split: train and validation from train_val
        # Calculate test_size for validation set: 0.1 / (1-0.2) = 0.125
        if len(train_val_texts) >= 2: # Need at least two samples to split into train and val
            val_split_size = cfg.train_val_split_ratio / (1.0 - cfg.train_test_split_ratio)
            train_texts, val_texts = train_test_split(train_val_texts, test_size=val_split_size, random_state=cfg.random_state)
            # Defensive check for val_texts being empty
            if not val_texts and train_texts:
                val_texts.append(train_texts.pop())
        else:
            train_texts = train_val_texts
            val_texts = []

    print(f"Train size: {len(train_texts)}, Validation size: {len(val_texts)}, Test size: {len(test_texts)}")

    # Create datasets
    train_dataset = TextDataset(train_texts, tokenizer)
    val_dataset = TextDataset(val_texts, tokenizer)
    test_dataset = TextDataset(test_texts, tokenizer)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True,
                                collate_fn=lambda b: collate_batch(b, tokenizer))
    val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False,
                              collate_fn=lambda b: collate_batch(b, tokenizer))
    test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False,
                               collate_fn=lambda b: collate_batch(b, tokenizer))

    model = ARSLM(tokenizer, emb_dim=cfg.emb_dim, hidden_dim=cfg.hidden_dim, num_layers=cfg.num_layers, dropout_prob=cfg.dropout_prob).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate)
    scheduler = StepLR(optimizer, step_size=cfg.lr_scheduler_step_size, gamma=cfg.lr_scheduler_gamma)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    n_epochs = cfg.n_epochs
    if len(train_texts) == 0: # Check if train_texts is empty, not inputs.numel()
        print("No training data available. Skipping training.")
        return model, tokenizer

    # Initialize best validation loss for saving the best model
    best_val_loss = float('inf')
    # Create directories for saving if they don't exist
    os.makedirs(cfg.model_save_dir, exist_ok=True)
    os.makedirs(cfg.tokenizer_save_dir, exist_ok=True)

    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        num_batches = len(train_loader)
        if num_batches > 0:
            for batch_idx, (inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                logits, gates = model(inputs)
                b, seq, v = logits.shape
                loss = loss_fn(logits.view(b*seq, v), targets.reshape(b*seq))
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad_norm)
                optimizer.step()
                total_loss += loss.item()
            avg_train_loss = total_loss / num_batches
        else:
            avg_train_loss = float('inf') # Set to inf if no training batches

        scheduler.step()

        if (epoch+1) % 50 == 0 or epoch==0:
            print(f"Epoch {epoch+1}/{n_epochs} — Avg Train Loss: {avg_train_loss:.4f}")
            print(f"Current learning rate: {scheduler.get_last_lr()[0]:.6f}")

            # Validation step with check for empty loader
            if len(val_loader) > 0:
                model.eval()
                total_val_loss = 0
                with torch.no_grad():
                    for inputs_val, targets_val in val_loader:
                        inputs_val, targets_val = inputs_val.to(device), targets_val.to(device)
                        logits_val, _ = model(inputs_val)
                        b_val, seq_val, v_val = logits_val.shape
                        loss_val = loss_fn(logits_val.view(b_val*seq_val, v_val), targets_val.reshape(b_val*seq_val))
                        total_val_loss += loss_val.item()
                avg_val_loss = total_val_loss / len(val_loader)
                print(f"Validation Loss: {avg_val_loss:.4f}")

                # Save the model if this is the best validation loss so far
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    print(f"Saving best model with validation loss: {best_val_loss:.4f}")
                    torch.save(model.state_dict(), os.path.join(cfg.model_save_dir, "best_model.pt"))
                    tokenizer.tokenizer.save_pretrained(cfg.tokenizer_save_dir)

            else:
                print("No validation data available. Skipping validation.")
                avg_val_loss = float('inf')

    print("\n=== Evaluation ===")
    model.eval()
    total_test_loss = 0
    total_perplexity = 0
    if len(test_loader) > 0:
        with torch.no_grad():
            for inputs_test, targets_test in test_loader:
                inputs_test, targets_test = inputs_test.to(device), targets_test.to(device)
                logits_test, _ = model(inputs_test)
                b_test, seq_test, v_test = logits_test.shape
                loss_test = loss_fn(logits_test.view(b_test*seq_test, v_test), targets_test.reshape(b_test*seq_test))
                total_test_loss += loss_test.item()
                total_perplexity += torch.exp(loss_test).item()

        avg_test_loss = total_test_loss / len(test_loader)
        avg_test_perplexity = total_perplexity / len(test_loader)
        print(f"Test Loss: {avg_test_loss:.4f}")
        print(f"Perplexity on test data: {avg_test_perplexity:.4f}")
    else:
        print("No test data available. Skipping test evaluation.")
        avg_test_loss = float('inf')
        avg_test_perplexity = float('inf')

    # demo generation
    context = cfg.generation_context
    idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(context) ], dtype=torch.long, device=device)
    out_ids = model.generate(idx, max_new_tokens=cfg.max_new_tokens, temperature=cfg.temperature, top_k=cfg.top_k)[0]
    print("\n=== Generated ===")
    print(tokenizer.decode(out_ids))
    return model, tokenizer

def load_model_and_tokenizer(model_path: str, tokenizer_path: str, device, config: Config):
    """
    Loads a pre-trained ARSLM model and its tokenizer.
    """
    # Load tokenizer first to get vocab size for model initialization
    loaded_tokenizer = AdvancedTokenizer(pretrained_model_name_or_path=tokenizer_path)

    # Instantiate the model with the same configuration as the saved model
    loaded_model = ARSLM(loaded_tokenizer,
                         emb_dim=config.emb_dim,
                         hidden_dim=config.hidden_dim,
                         num_layers=config.num_layers,
                         dropout_prob=config.dropout_prob).to(device)

    # Load the model's state_dict
    loaded_model.load_state_dict(torch.load(model_path, map_location=device))
    loaded_model.eval() # Set model to evaluation mode

    print(f"Model loaded from {model_path}")
    print(f"Tokenizer loaded from {tokenizer_path}")

    return loaded_model, loaded_tokenizer

# ----------------------------
# Run demo if main
# ----------------------------
if __name__ == "__main__":
    random.seed(0)
    torch.manual_seed(0)
    model, tokenizer = train_demo()

    # Demonstrate loading the saved model and tokenizer
    print("\n=== Demonstrating Model Loading ===")
    cfg = Config()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    try:
        loaded_model, loaded_tokenizer = load_model_and_tokenizer(
            os.path.join(cfg.model_save_dir, "best_model.pt"),
            cfg.tokenizer_save_dir,
            device,
            cfg
        )
        # Test generation with loaded model
        context = "loaded model generates"
        idx = torch.tensor([ [loaded_tokenizer.bos_token_id] + loaded_tokenizer.encode(context) ], dtype=torch.long, device=device)
        out_ids_loaded = loaded_model.generate(idx, max_new_tokens=15, temperature=1.0, top_k=50)[0]
        print("Generated with loaded model:")
        print(loaded_tokenizer.decode(out_ids_loaded))
    except FileNotFoundError:
        print("No saved model found. Please ensure training completed successfully and saved a model.")


Corpus file not found at /tmp/large_corpus.txt. Using a small toy corpus instead.
Train size: 1, Validation size: 3, Test size: 1
Epoch 1/200 — Avg Train Loss: 11.0367
Current learning rate: 0.001000
Validation Loss: 10.9150
Saving best model with validation loss: 10.9150
Epoch 50/200 — Avg Train Loss: 0.0229
Current learning rate: 0.000500
Validation Loss: 15.2596
Epoch 100/200 — Avg Train Loss: 0.0019
Current learning rate: 0.000250
Validation Loss: 16.5266
Epoch 150/200 — Avg Train Loss: 0.0027
Current learning rate: 0.000125
Validation Loss: 16.8605
Epoch 200/200 — Avg Train Loss: 0.0017
Current learning rate: 0.000063
Validation Loss: 16.9671

=== Evaluation ===
Test Loss: 17.3669
Perplexity on test data: 34861304.0000

=== Generated ===
hello world is ars

=== Demonstrating Model Loading ===
Model loaded from ./model_checkpoint/best_model.pt
Tokenizer loaded from ./tokenizer_checkpoint
Generated with loaded model:
loaded model generates knockoutal [unused485] modifyingpers mammot

# Task
**Task:** Download the "wikitext-2-raw-v1" dataset from Hugging Face, extract the text, and save the first 10,000 lines into the `/tmp/large_corpus.txt` file. Then, modify the `Config` class to set `max_corpus_lines` to 10,000 to utilize this larger corpus for training.

## Rechercher et Télécharger un Corpus de Texte

### Subtask:
Download the 'wikitext-2-raw-v1' dataset from Hugging Face, extract the text, and save the first 10,000 lines into the `/tmp/large_corpus.txt` file.


**Reasoning**:
The subtask requires downloading a dataset from Hugging Face, which needs the `datasets` library. Therefore, the first step is to install this library.



In [None]:
pip install datasets



**Reasoning**:
Now that the `datasets` library is confirmed to be installed, I will proceed to download the 'wikitext-2-raw-v1' dataset, extract its training text, filter out empty lines, and save the first 10,000 non-empty lines to `/tmp/large_corpus.txt` as specified in the subtask.



In [None]:
from datasets import load_dataset
import os

# Define the output file path
corpus_file_path = "/tmp/large_corpus.txt"

# Load the 'wikitext-2-raw-v1' dataset
print("Loading wikitext-2-raw-v1 dataset from Hugging Face...")
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')

# Access the 'train' split and extract 'text' column
train_texts = dataset['train']['text']

# Filter out empty strings and strip whitespace
filtered_texts = [line.strip() for line in train_texts if line.strip()]

# Ensure the directory exists
os.makedirs(os.path.dirname(corpus_file_path), exist_ok=True)

# Write the first 10,000 lines to the specified file
print(f"Writing first 10,000 lines to {corpus_file_path}...")
with open(corpus_file_path, 'w', encoding='utf-8') as f:
    for i, line in enumerate(filtered_texts):
        if i < 10000:
            f.write(line + '\n')
        else:
            break
print("Dataset extraction and saving complete.")

Loading wikitext-2-raw-v1 dataset from Hugging Face...


README.md: 0.00B [00:00, ?B/s]

wikitext-2-raw-v1/test-00000-of-00001.pa(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

wikitext-2-raw-v1/validation-00000-of-00(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Writing first 10,000 lines to /tmp/large_corpus.txt...
Dataset extraction and saving complete.


**Reasoning**:
The subtask required saving the first 10,000 lines of the dataset to `/tmp/large_corpus.txt`. Now that this file is created, I need to update the `Config` class to reflect this by setting `max_corpus_lines` to 10,000. This will ensure that the `train_demo` function, when called later, correctly loads the larger corpus for training.



In [None]:
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from transformers import BertTokenizer # Import BertTokenizer here
from torch.optim.lr_scheduler import StepLR # Import StepLR
import os # Import os for file operations
from torch.utils.data import Dataset, DataLoader # Import Dataset and DataLoader
from sklearn.model_selection import train_test_split # Import train_test_split

# ----------------------------
# Configuration Class
# ----------------------------
class Config:
    """Centralized configuration for the ARSLM model and training."""
    def __init__(self):
        # Model Parameters
        self.emb_dim = 64
        self.hidden_dim = 128
        self.num_layers = 2
        self.dropout_prob = 0.1

        # Training Parameters
        self.learning_rate = 1e-3
        self.n_epochs = 200
        self.batch_size = 8
        self.lr_scheduler_step_size = 50
        self.lr_scheduler_gamma = 0.5
        self.clip_grad_norm = 1.0

        # Generation Parameters
        self.max_new_tokens = 15
        self.temperature = 1.0
        self.top_k = 50
        self.generation_context = "hello world"

        # Data Parameters
        self.corpus_file = "/tmp/large_corpus.txt"
        self.max_corpus_lines = 10000 # Max lines to load from corpus_file for demo - UPDATED TO 10,000
        self.train_test_split_ratio = 0.2
        self.train_val_split_ratio = 0.5 # 0.1 of total
        self.random_state = 42

        # Saving/Loading Paths
        self.model_save_dir = "./model_checkpoint"
        self.tokenizer_save_dir = "./tokenizer_checkpoint"
        self.config_save_path = os.path.join(self.model_save_dir, "config.pt") # Path to save config

# ----------------------------
# Advanced Tokenizer (using transformers)
# ----------------------------
class AdvancedTokenizer:
    def __init__(self, vocab_file=None, pretrained_model_name_or_path='bert-base-uncased'):
        if vocab_file:
             self.tokenizer = BertTokenizer(vocab_file)
        else:
             # Check if pretrained_model_name_or_path is a directory for a saved tokenizer
             if os.path.isdir(pretrained_model_name_or_path):
                 self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
             else:
                 self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)

        self.vocab = list(self.tokenizer.vocab.keys())
        self.stoi = self.tokenizer.vocab
        self.itos = {i: w for w, i in self.stoi.items()}

        self.pad_token = self.tokenizer.pad_token
        self.unk_token = self.tokenizer.unk_token
        self.bos_token = self.tokenizer.cls_token # Using [CLS] as BOS for BERT-like tokenizers
        self.eos_token = self.tokenizer.sep_token # Using [SEP] as EOS for BERT-like tokenizers

        self.pad_token_id = self.tokenizer.pad_token_id
        self.unk_token_id = self.tokenizer.unk_token_id
        self.bos_token_id = self.tokenizer.cls_token_id
        self.eos_token_id = self.tokenizer.sep_token_id


    def encode(self, text: str) -> List[int]:
        return self.tokenizer.encode(text, add_special_tokens=False)

    def decode(self, ids: List[int]) -> str:
        return self.tokenizer.decode(ids, skip_special_tokens=True)

    def __len__(self):
        return len(self.vocab)

# ----------------------------
# ARSCell: core adaptive cell
# ----------------------------
class ARSCell(nn.Module):
    """
    ARSCell: computes next hidden state from h_{t-2}, h_{t-1}, and input embedding x_{t-1}
    Implements an 'adapt' mechanism inspired by BenPolySeq:
      h_t = h_{t-1} + gate * transform(h_{t-1}, h_{t-2}, x)
    where gate is a function of the difference (h_{t-1} - h_{t-2}) and context.
    """
    def __init__(self, emb_dim: int, hidden_dim: int, dropout_prob: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        # transform candidate from concat([h_{t-1}, h_{t-2}, x])
        self.candidate_mlp = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim)
        )
        # gate network produces scalar gating factor in (0,1)
        self.gate_net = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        # small residual projection
        self.res_proj = nn.Linear(emb_dim, hidden_dim)
        # Dropout layer
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, h_prev2, h_prev1, x_embed):
        # h_prev2, h_prev1: (batch, hidden_dim)
        # x_embed: (batch, emb_dim)
        # difference signal
        diff = h_prev1 - h_prev2    # (batch, hidden_dim)
        # context vector (concat)
        ctx = torch.cat([h_prev1, h_prev2, x_embed], dim=-1)
        candidate = self.candidate_mlp(ctx)                       # (batch, hidden_dim)
        gate = self.gate_net(ctx).squeeze(-1)                     # (batch,)
        # incorporate residual from input embedding
        residual = self.res_proj(x_embed)
        # ARS-style update: adaptive addition of candidate scaled by gate
        h_t = h_prev1 + gate.unsqueeze(-1) * candidate + 0.1 * residual
        # Apply dropout
        h_t = self.dropout(h_t)
        # optional normalization
        h_t = F.layer_norm(h_t, (self.hidden_dim,))
        return h_t, gate

# ----------------------------
# ARSLM model: embed -> ARSCell(s) -> attention -> head
# ----------------------------
class ARSLM(nn.Module):
    # Modified to accept a tokenizer instance, include attention, and support multiple layers
    def __init__(self, tokenizer, emb_dim=64, hidden_dim=128, num_layers=2, dropout_prob=0.1):
        super().__init__()
        self.tokenizer = tokenizer # Store tokenizer instance
        self.vocab_size = len(tokenizer) # Get vocab size from tokenizer
        self.num_layers = num_layers
        self.emb = nn.Embedding(self.vocab_size, emb_dim)
        # Create a list of ARSCells
        # Pass dropout_prob to ARSCell constructor
        self.cells = nn.ModuleList([ARSCell(emb_dim if i == 0 else hidden_dim, hidden_dim, dropout_prob=dropout_prob) for i in range(num_layers)])

        # Simple additive attention mechanism, applied after the last layer's hidden states
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        self.head = nn.Linear(hidden_dim, self.vocab_size)

    def forward(self, input_ids):
        """
        input_ids: (batch, seq_len)
        We compute autoregressively, passing hidden states through layers.
        Returns logits (batch, seq_len, vocab).
        """
        bsz, seq_len = input_ids.shape
        emb = self.emb(input_ids)  # (b, seq, emb)
        device = emb.device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        # We need two previous states for each layer's ARSCell
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []
        logits = []
        gates = [] # Store gates from the last layer

        for t in range(seq_len):
            x_t = emb[:, t, :] # embedding at position t
            h_t_input = x_t # Input to the first layer is the embedding

            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1] # Use the output of the previous layer

                h_t, gate = cell(h_prev2, h_prev1, h_t_input)

                # Store current layer's hidden state as input for the next layer in the next time step
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next time step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1: # Only store gates from the last layer
                    gates.append(gate.unsqueeze(1))

            # After processing all layers for time step t, the output is the hidden state of the last layer
            last_layer_h_t = current_layer_hidden_states[-1]
            all_last_layer_hidden_states.append(last_layer_h_t.unsqueeze(1))

            # Apply attention and head after the last layer's hidden state
            # Causal Attention: attend over hidden states up to current time step t from the last layer
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, t+1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim) - query is the current last layer hidden state

            # Calculate attention scores
            scores = self.attention(last_layer_history) # (b, t+1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, t+1, 1)

            # Apply attention weights
            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            # Combine last layer's current hidden state with context vector before the head
            attended_h_t = last_layer_h_t + context_vector

            logit = self.head(attended_h_t) # predict token at t (or next token)
            logits.append(logit.unsqueeze(1))


        logits = torch.cat(logits, dim=1)          # (b, seq, vocab)
        gates = torch.cat(gates, dim=1)            # (b, seq) # Gates from the last layer
        return logits, gates

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=20, temperature=1.0, top_k=None):
        """
        idx: (batch, current_seq_len) - input context (token ids)
        Returns extended token ids.
        Modified for multiple layers, attention during generation, and optional top-k sampling.
        """
        self.eval()
        bsz, seq_len = idx.shape
        device = next(self.parameters()).device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []

        # Process initial context to get starting hidden states for each layer
        input_ids = idx.clone()
        emb = self.emb(input_ids) # (b, seq_len, emb_dim)

        for t in range(seq_len):
            x_t = emb[:,t,:] # embedding at position t
            h_t_input = x_t

            current_layer_hidden_states = []
            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


        out_ids = input_ids.tolist()

        for _ in range(max_new_tokens):
            # Input for the first layer's ARSCell is the embedding of the last generated token
            last_token_ids = torch.tensor([ids[-1] for ids in out_ids], dtype=torch.long, device=device).unsqueeze(1)
            x_embed = self.emb(last_token_ids).squeeze(1) # (b, emb_dim)

            h_t_input = x_embed
            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next generation step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


            # After processing all layers for the new token, apply attention and head
            last_layer_h_t = current_layer_hidden_states[-1] # Hidden state of the last layer for the new token

            # Causal Attention over all last layer hidden states history
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, current_seq_len + 1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim)

            scores = self.attention(last_layer_history) # (b, current_seq_len + 1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, current_seq_len + 1, 1)

            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            attended_h_t = last_layer_h_t + context_vector

            logits = self.head(attended_h_t) / max(1e-6, temperature)

            # Apply top-k sampling
            if top_k is not None:
                # get top k logits
                top_k = min(max(top_k, 1), logits.size(-1))  # Clamp k to be at least 1 and at most vocab size
                # Remove all tokens with a probability less than the top-k'th token's probability
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')

            # Check if all probabilities are zero after filtering/temperature, if so, fall back to argmax
            if torch.all(logits == -float('Inf')):
                # This means all tokens were filtered out or resulted in 0 probability. Choose a random valid token or a default.
                # For now, let's just pick a generic unk_token or a random one if unk is not available/appropriate
                if self.tokenizer.unk_token_id is not None:
                    next_token = torch.tensor([self.tokenizer.unk_token_id] * bsz, device=device)
                else:
                    # Fallback to random if unk token is not defined
                    next_token = torch.randint(0, self.vocab_size, (bsz,), device=device)
            else:
                probs = F.softmax(logits, dim=-1)
                # Ensure that if probabilities sum to zero (e.g., all logits were -inf before soft max),
                # we handle it gracefully, though the above check should prevent it.
                if torch.sum(probs) == 0: # This case should ideally not happen after the -inf check
                    if self.tokenizer.unk_token_id is not None:
                        next_token = torch.tensor([self.tokenizer.unk_token_id] * bsz, device=device)
                    else:
                        next_token = torch.randint(0, self.vocab_size, (bsz,), device=device)
                else:
                    next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)


            # append the newly generated token
            for i in range(bsz):
                out_ids[i].append(int(next_token[i].item()))


        return out_ids

# ----------------------------
# TextDataset class for structured data handling
# ----------------------------
class TextDataset(Dataset):
    def __init__(self, texts: list, tokenizer):
        self.texts = texts
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]

# Modified to accept and use the tokenizer instance
def collate_batch(batch_texts: List[str], tokenizer):
    # Encode each text, add BOS/EOS, and convert to tensor
    encoded = []
    for t in batch_texts:
        ids = [tokenizer.bos_token_id] + tokenizer.encode(t) + [tokenizer.eos_token_id]
        encoded.append(torch.tensor(ids, dtype=torch.long))

    # Pad to max len within the batch
    max_len = max([x.size(0) for x in encoded])
    padded_batch = torch.full((len(encoded), max_len), tokenizer.pad_token_id, dtype=torch.long)
    for i, x in enumerate(encoded):
        padded_batch[i, :x.size(0)] = x

    # Prepare inputs and targets for language modeling
    inputs = padded_batch[:, :-1]
    targets = padded_batch[:, 1:]

    return inputs, targets


# Modified to use AdvancedTokenizer and handle multiple layers, and DataLoaders
def train_demo():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize configuration
    cfg = Config()

    corpus_file = cfg.corpus_file
    texts = []
    if os.path.exists(corpus_file):
        print(f"Loading corpus from {corpus_file}...")
        with open(corpus_file, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i < cfg.max_corpus_lines:
                    texts.append(line.strip())
                else:
                    break
        print(f"Loaded {len(texts)} lines.")
    else:
        print(f"Corpus file not found at {corpus_file}. Using a small toy corpus instead.")
        texts = [
            "hello world this is ars",
            "the system adapts to its history",
            "benpolyseq demonstrates adaptive sequences",
            "ars can inspire new network protocols",
            "self optimizing systems are possible"
        ]

    tokenizer = AdvancedTokenizer() # Initialize AdvancedTokenizer
    vocab_size = len(tokenizer)

    # Modified Data Splitting Logic
    num_samples = len(texts)
    if num_samples < 3:
        print(f"Warning: Corpus size ({num_samples}) is too small for a proper train/validation/test split. Using all data for training.")
        train_texts = texts
        val_texts = []
        test_texts = []
    else:
        # First split: train_val and test
        train_val_texts, test_texts = train_test_split(texts, test_size=cfg.train_test_split_ratio, random_state=cfg.random_state)
        # Defensive check for test_texts being empty
        if not test_texts and train_val_texts:
            test_texts.append(train_val_texts.pop()) # Move one item to test if test is empty and train_val is not

        # Second split: train and validation from train_val
        if len(train_val_texts) >= 2: # Need at least two samples to split into train and val
            val_split_size = cfg.train_val_split_ratio
            train_texts, val_texts = train_test_split(train_val_texts, test_size=val_split_size, random_state=cfg.random_state)
            # Defensive check for val_texts being empty
            if not val_texts and train_texts:
                val_texts.append(train_texts.pop()) # Move one item to val if val is empty and train is not
        else:
            train_texts = train_val_texts
            val_texts = []

    print(f"Train size: {len(train_texts)}, Validation size: {len(val_texts)}, Test size: {len(test_texts)}")

    # Create datasets
    train_dataset = TextDataset(train_texts, tokenizer)
    val_dataset = TextDataset(val_texts, tokenizer)
    test_dataset = TextDataset(test_texts, tokenizer)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True,
                                collate_fn=lambda b: collate_batch(b, tokenizer))
    val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False,
                              collate_fn=lambda b: collate_batch(b, tokenizer))
    test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False,
                               collate_fn=lambda b: collate_batch(b, tokenizer))

    model = ARSLM(tokenizer, emb_dim=cfg.emb_dim, hidden_dim=cfg.hidden_dim, num_layers=cfg.num_layers, dropout_prob=cfg.dropout_prob).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate)
    scheduler = StepLR(optimizer, step_size=cfg.lr_scheduler_step_size, gamma=cfg.lr_scheduler_gamma)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    n_epochs = cfg.n_epochs
    if len(train_texts) == 0: # Check if train_texts is empty, not inputs.numel()
        print("No training data available. Skipping training.")
        return model, tokenizer

    # Initialize best validation loss for saving the best model
    best_val_loss = float('inf')
    # Create directories for saving if they don't exist
    os.makedirs(cfg.model_save_dir, exist_ok=True)
    os.makedirs(cfg.tokenizer_save_dir, exist_ok=True)

    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        num_batches = len(train_loader)
        if num_batches > 0:
            for batch_idx, (inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                logits, gates = model(inputs)
                b, seq, v = logits.shape
                loss = loss_fn(logits.view(b*seq, v), targets.reshape(b*seq))
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad_norm)
                optimizer.step()
                total_loss += loss.item()
            avg_train_loss = total_loss / num_batches
        else:
            avg_train_loss = float('inf') # Set to inf if no training batches

        scheduler.step()

        if (epoch+1) % 50 == 0 or epoch==0:
            print(f"Epoch {epoch+1}/{n_epochs} — Avg Train Loss: {avg_train_loss:.4f}")
            print(f"Current learning rate: {scheduler.get_last_lr()[0]:.6f}")

            # Validation step with check for empty loader
            if len(val_loader) > 0:
                model.eval()
                total_val_loss = 0
                with torch.no_grad():
                    for inputs_val, targets_val in val_loader:
                        inputs_val, targets_val = inputs_val.to(device), targets_val.to(device)
                        logits_val, _ = model(inputs_val)
                        b_val, seq_val, v_val = logits_val.shape
                        loss_val = loss_fn(logits_val.view(b_val*seq_val, v_val), targets_val.reshape(b_val*seq_val))
                        total_val_loss += loss_val.item()
                avg_val_loss = total_val_loss / len(val_loader)
                print(f"Validation Loss: {avg_val_loss:.4f}")

                # Save the model if this is the best validation loss so far
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    print(f"Saving best model with validation loss: {best_val_loss:.4f}")
                    torch.save(model.state_dict(), os.path.join(cfg.model_save_dir, "best_model.pt"))
                    tokenizer.tokenizer.save_pretrained(cfg.tokenizer_save_dir)
                    # Save the configuration as well
                    torch.save(cfg, cfg.config_save_path)

            else:
                print("No validation data available. Skipping validation.")
                avg_val_loss = float('inf')

    print("\n=== Evaluation ===")
    model.eval()
    total_test_loss = 0
    total_perplexity = 0
    if len(test_loader) > 0:
        with torch.no_grad():
            for inputs_test, targets_test in test_loader:
                inputs_test, targets_test = inputs_test.to(device), targets_test.to(device)
                logits_test, _ = model(inputs_test)
                b_test, seq_test, v_test = logits_test.shape
                loss_test = loss_fn(logits_test.view(b_test*seq_test, v_test), targets_test.reshape(b_test*seq_test))
                total_test_loss += loss_test.item()
                total_perplexity += torch.exp(loss_test).item()

        avg_test_loss = total_test_loss / len(test_loader)
        avg_test_perplexity = total_perplexity / len(test_loader)
        print(f"Test Loss: {avg_test_loss:.4f}")
        print(f"Perplexity on test data: {avg_test_perplexity:.4f}")
    else:
        print("No test data available. Skipping test evaluation.")
        avg_test_loss = float('inf')
        avg_test_perplexity = float('inf')

    # demo generation
    context = cfg.generation_context
    idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(context) ], dtype=torch.long, device=device)
    out_ids = model.generate(idx, max_new_tokens=cfg.max_new_tokens, temperature=cfg.temperature, top_k=cfg.top_k)[0]
    print("\n=== Generated ===")
    print(tokenizer.decode(out_ids))
    return model, tokenizer

def load_model_and_tokenizer(model_path: str, tokenizer_path: str, config_path: str, device):
    """
    Loads a pre-trained ARSLM model and its tokenizer along with its configuration.
    """
    # Add Config class to safe globals for unpickling
    torch.serialization.add_safe_globals([Config])
    # Load the configuration first
    cfg = torch.load(config_path)

    loaded_tokenizer = AdvancedTokenizer(pretrained_model_name_or_path=tokenizer_path)

    # Instantiate the model using the loaded configuration
    loaded_model = ARSLM(loaded_tokenizer,
                         emb_dim=cfg.emb_dim,
                         hidden_dim=cfg.hidden_dim,
                         num_layers=cfg.num_layers,
                         dropout_prob=cfg.dropout_prob).to(device)

    loaded_model.load_state_dict(torch.load(model_path, map_location=device))
    loaded_model.eval() # Set model to evaluation mode

    print(f"Model loaded from {model_path}")
    print(f"Tokenizer loaded from {tokenizer_path}")
    print(f"Configuration loaded from {config_path}")

    return loaded_model, loaded_tokenizer, cfg

def deploy_model(model, tokenizer, config, device, input_text: str):
    """
    Function to deploy the model for inference (text generation).
    """
    model.eval()
    with torch.no_grad():
        idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(input_text) ], dtype=torch.long, device=device)
        out_ids = model.generate(idx, max_new_tokens=config.max_new_tokens, temperature=config.temperature, top_k=config.top_k)[0]
        generated_text = tokenizer.decode(out_ids)
    return generated_text

# ----------------------------
# Run demo if main
# ----------------------------
if __name__ == "__main__":
    random.seed(0)
    torch.manual_seed(0)
    model, tokenizer = train_demo()

    # Demonstrate loading the saved model and tokenizer
    print("\n=== Demonstrating Model Loading and Deployment ===")
    cfg_instance = Config()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    try:
        loaded_model, loaded_tokenizer, loaded_config = load_model_and_tokenizer(
            os.path.join(cfg_instance.model_save_dir, "best_model.pt"),
            cfg_instance.tokenizer_save_dir,
            cfg_instance.config_save_path,
            device
        )
        # Test generation with loaded model using the deploy_model function
        context = "loaded model generates"
        generated_output = deploy_model(loaded_model, loaded_tokenizer, loaded_config, device, context)
        print("Generated with loaded model:")
        print(generated_output)
    except FileNotFoundError:
        print("No saved model found. Please ensure training completed successfully and saved a model.")

Loading corpus from /tmp/large_corpus.txt...
Loaded 10000 lines.
Train size: 4000, Validation size: 4000, Test size: 2000


Token indices sequence length is longer than the specified maximum sequence length for this model (819 > 512). Running this sequence through the model will result in indexing errors


Epoch 1/200 — Avg Train Loss: 7.2199
Current learning rate: 0.001000
Validation Loss: 6.8119
Saving best model with validation loss: 6.8119


In [None]:
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from transformers import BertTokenizer # Import BertTokenizer here
from torch.optim.lr_scheduler import StepLR # Import StepLR
import os # Import os for file operations
from torch.utils.data import Dataset, DataLoader # Import Dataset and DataLoader
from sklearn.model_selection import train_test_split # Import train_test_split

# ----------------------------
# Configuration Class
# ----------------------------
class Config:
    """Centralized configuration for the ARSLM model and training."""
    def __init__(self):
        # Model Parameters
        self.emb_dim = 64
        self.hidden_dim = 128
        self.num_layers = 2
        self.dropout_prob = 0.1
        self.max_seq_len = 512 # Added max_seq_len

        # Training Parameters
        self.learning_rate = 1e-3
        self.n_epochs = 200
        self.batch_size = 8
        self.lr_scheduler_step_size = 50
        self.lr_scheduler_gamma = 0.5
        self.clip_grad_norm = 1.0

        # Generation Parameters
        self.max_new_tokens = 15
        self.temperature = 1.0
        self.top_k = 50
        self.generation_context = "hello world"

        # Data Parameters
        self.corpus_file = "/tmp/large_corpus.txt"
        self.max_corpus_lines = 10000 # Max lines to load from corpus_file for demo - UPDATED TO 10,000
        self.train_test_split_ratio = 0.2
        self.train_val_split_ratio = 0.5 # 0.1 of total
        self.random_state = 42

        # Saving/Loading Paths
        self.model_save_dir = "./model_checkpoint"
        self.tokenizer_save_dir = "./tokenizer_checkpoint"
        self.config_save_path = os.path.join(self.model_save_dir, "config.pt") # Path to save config

# ----------------------------
# Advanced Tokenizer (using transformers)
# ----------------------------
class AdvancedTokenizer:
    def __init__(self, vocab_file=None, pretrained_model_name_or_path='bert-base-uncased'):
        if vocab_file:
             self.tokenizer = BertTokenizer(vocab_file)
        else:
             # Check if pretrained_model_name_or_path is a directory for a saved tokenizer
             if os.path.isdir(pretrained_model_name_or_path):
                 self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
             else:
                 self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)

        self.vocab = list(self.tokenizer.vocab.keys())
        self.stoi = self.tokenizer.vocab
        self.itos = {i: w for w, i in self.stoi.items()}

        self.pad_token = self.tokenizer.pad_token
        self.unk_token = self.tokenizer.unk_token
        self.bos_token = self.tokenizer.cls_token # Using [CLS] as BOS for BERT-like tokenizers
        self.eos_token = self.tokenizer.sep_token # Using [SEP] as EOS for BERT-like tokenizers

        self.pad_token_id = self.tokenizer.pad_token_id
        self.unk_token_id = self.tokenizer.unk_token_id
        self.bos_token_id = self.tokenizer.cls_token_id
        self.eos_token_id = self.tokenizer.sep_token_id


    def encode(self, text: str) -> List[int]:
        return self.tokenizer.encode(text, add_special_tokens=False)

    def decode(self, ids: List[int]) -> str:
        return self.tokenizer.decode(ids, skip_special_tokens=True)

    def __len__(self):
        return len(self.vocab)

# ----------------------------
# ARSCell: core adaptive cell
# ----------------------------
class ARSCell(nn.Module):
    """
    ARSCell: computes next hidden state from h_{t-2}, h_{t-1}, and input embedding x_{t-1}
    Implements an 'adapt' mechanism inspired by BenPolySeq:
      h_t = h_{t-1} + gate * transform(h_{t-1}, h_{t-2}, x)
    where gate is a function of the difference (h_{t-1} - h_{t-2}) and context.
    """
    def __init__(self, emb_dim: int, hidden_dim: int, dropout_prob: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        # transform candidate from concat([h_{t-1}, h_{t-2}, x])
        self.candidate_mlp = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim)
        )
        # gate network produces scalar gating factor in (0,1)
        self.gate_net = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        # small residual projection
        self.res_proj = nn.Linear(emb_dim, hidden_dim)
        # Dropout layer
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, h_prev2, h_prev1, x_embed):
        # h_prev2, h_prev1: (batch, hidden_dim)
        # x_embed: (batch, emb_dim)
        # difference signal
        diff = h_prev1 - h_prev2    # (batch, hidden_dim)
        # context vector (concat)
        ctx = torch.cat([h_prev1, h_prev2, x_embed], dim=-1)
        candidate = self.candidate_mlp(ctx)                       # (batch, hidden_dim)
        gate = self.gate_net(ctx).squeeze(-1)                     # (batch,)
        # incorporate residual from input embedding
        residual = self.res_proj(x_embed)
        # ARS-style update: adaptive addition of candidate scaled by gate
        h_t = h_prev1 + gate.unsqueeze(-1) * candidate + 0.1 * residual
        # Apply dropout
        h_t = self.dropout(h_t)
        # optional normalization
        h_t = F.layer_norm(h_t, (self.hidden_dim,))
        return h_t, gate

# ----------------------------
# ARSLM model: embed -> ARSCell(s) -> attention -> head
# ----------------------------
class ARSLM(nn.Module):
    # Modified to accept a tokenizer instance, include attention, and support multiple layers
    def __init__(self, tokenizer, emb_dim=64, hidden_dim=128, num_layers=2, dropout_prob=0.1):
        super().__init__()
        self.tokenizer = tokenizer # Store tokenizer instance
        self.vocab_size = len(tokenizer) # Get vocab size from tokenizer
        self.num_layers = num_layers
        self.emb = nn.Embedding(self.vocab_size, emb_dim)
        # Create a list of ARSCells
        # Pass dropout_prob to ARSCell constructor
        self.cells = nn.ModuleList([ARSCell(emb_dim if i == 0 else hidden_dim, hidden_dim, dropout_prob=dropout_prob) for i in range(num_layers)])

        # Simple additive attention mechanism, applied after the last layer's hidden states
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        self.head = nn.Linear(hidden_dim, self.vocab_size)

    def forward(self, input_ids):
        """
        input_ids: (batch, seq_len)
        We compute autoregressively, passing hidden states through layers.
        Returns logits (batch, seq_len, vocab).
        """
        bsz, seq_len = input_ids.shape
        emb = self.emb(input_ids)  # (b, seq, emb)
        device = emb.device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        # We need two previous states for each layer's ARSCell
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []
        logits = []
        gates = [] # Store gates from the last layer

        for t in range(seq_len):
            x_t = emb[:, t, :] # embedding at position t
            h_t_input = x_t # Input to the first layer is the embedding

            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1] # Use the output of the previous layer

                h_t, gate = cell(h_prev2, h_prev1, h_t_input)

                # Store current layer's hidden state as input for the next layer in the next time step
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next time step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1: # Only store gates from the last layer
                    gates.append(gate.unsqueeze(1))

            # After processing all layers for time step t, the output is the hidden state of the last layer
            last_layer_h_t = current_layer_hidden_states[-1]
            all_last_layer_hidden_states.append(last_layer_h_t.unsqueeze(1))

            # Apply attention and head after the last layer's hidden state
            # Causal Attention: attend over hidden states up to current time step t from the last layer
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, t+1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim) - query is the current last layer hidden state

            # Calculate attention scores
            scores = self.attention(last_layer_history) # (b, t+1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, t+1, 1)

            # Apply attention weights
            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            # Combine last layer's current hidden state with context vector before the head
            attended_h_t = last_layer_h_t + context_vector

            logit = self.head(attended_h_t) # predict token at t (or next token)
            logits.append(logit.unsqueeze(1))


        logits = torch.cat(logits, dim=1)          # (b, seq, vocab)
        gates = torch.cat(gates, dim=1)            # (b, seq) # Gates from the last layer
        return logits, gates

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=20, temperature=1.0, top_k=None):
        """
        idx: (batch, current_seq_len) - input context (token ids)
        Returns extended token ids.
        Modified for multiple layers, attention during generation, and optional top-k sampling.
        """
        self.eval()
        bsz, seq_len = idx.shape
        device = next(self.parameters()).device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []

        # Process initial context to get starting hidden states for each layer
        input_ids = idx.clone()
        emb = self.emb(input_ids) # (b, seq_len, emb_dim)

        for t in range(seq_len):
            x_t = emb[:,t,:] # embedding at position t
            h_t_input = x_t

            current_layer_hidden_states = []
            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


        out_ids = input_ids.tolist()

        for _ in range(max_new_tokens):
            # Input for the first layer's ARSCell is the embedding of the last generated token
            last_token_ids = torch.tensor([ids[-1] for ids in out_ids], dtype=torch.long, device=device).unsqueeze(1)
            x_embed = self.emb(last_token_ids).squeeze(1) # (b, emb_dim)

            h_t_input = x_embed
            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next generation step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


            # After processing all layers for the new token, apply attention and head
            last_layer_h_t = current_layer_hidden_states[-1] # Hidden state of the last layer for the new token

            # Causal Attention over all last layer hidden states history
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, current_seq_len + 1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim)

            scores = self.attention(last_layer_history) # (b, current_seq_len + 1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, current_seq_len + 1, 1)

            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            attended_h_t = last_layer_h_t + context_vector

            logits = self.head(attended_h_t) / max(1e-6, temperature)

            # Apply top-k sampling
            if top_k is not None:
                # get top k logits
                top_k = min(max(top_k, 1), logits.size(-1))  # Clamp k to be at least 1 and at most vocab size
                # Remove all tokens with a probability less than the top-k'th token's probability
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')

            # Check if all probabilities are zero after filtering/temperature, if so, fall back to argmax
            if torch.all(logits == -float('Inf')):
                # This means all tokens were filtered out or resulted in 0 probability. Choose a random valid token or a default.
                # For now, let's just pick a generic unk_token or a random one if unk is not available/appropriate
                if self.tokenizer.unk_token_id is not None:
                    next_token = torch.tensor([self.tokenizer.unk_token_id] * bsz, device=device)
                else:
                    # Fallback to random if unk token is not defined
                    next_token = torch.randint(0, self.vocab_size, (bsz,), device=device)
            else:
                probs = F.softmax(logits, dim=-1)
                # Ensure that if probabilities sum to zero (e.g., all logits were -inf before soft max),
                # we handle it gracefully, though the above check should prevent it.
                if torch.sum(probs) == 0: # This case should ideally not happen after the -inf check
                    if self.tokenizer.unk_token_id is not None:
                        next_token = torch.tensor([self.tokenizer.unk_token_id] * bsz, device=device)
                    else:
                        next_token = torch.randint(0, self.vocab_size, (bsz,), device=device)
                else:
                    next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)


            # append the newly generated token
            for i in range(bsz):
                out_ids[i].append(int(next_token[i].item()))


        return out_ids

# ----------------------------
# TextDataset class for structured data handling
# ----------------------------
class TextDataset(Dataset):
    def __init__(self, texts: list, tokenizer):
        self.texts = texts
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]

# Modified to accept and use the tokenizer instance
def collate_batch(batch_texts: List[str], tokenizer):
    # Encode each text, add BOS/EOS, and convert to tensor
    encoded = []
    for t in batch_texts:
        ids = [tokenizer.bos_token_id] + tokenizer.encode(t) + [tokenizer.eos_token_id]
        encoded.append(torch.tensor(ids, dtype=torch.long))

    # Pad to max len within the batch
    max_len = max([x.size(0) for x in encoded])
    padded_batch = torch.full((len(encoded), max_len), tokenizer.pad_token_id, dtype=torch.long)
    for i, x in enumerate(encoded):
        padded_batch[i, :x.size(0)] = x

    # Prepare inputs and targets for language modeling
    inputs = padded_batch[:, :-1]
    targets = padded_batch[:, 1:]

    return inputs, targets


# Modified to use AdvancedTokenizer and handle multiple layers, and DataLoaders
def train_demo():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize configuration
    cfg = Config()

    corpus_file = cfg.corpus_file
    texts = []
    if os.path.exists(corpus_file):
        print(f"Loading corpus from {corpus_file}...")
        with open(corpus_file, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i < cfg.max_corpus_lines:
                    texts.append(line.strip())
                else:
                    break
        print(f"Loaded {len(texts)} lines.")
    else:
        print(f"Corpus file not found at {corpus_file}. Using a small toy corpus instead.")
        texts = [
            "hello world this is ars",
            "the system adapts to its history",
            "benpolyseq demonstrates adaptive sequences",
            "ars can inspire new network protocols",
            "self optimizing systems are possible"
        ]

    tokenizer = AdvancedTokenizer() # Initialize AdvancedTokenizer
    vocab_size = len(tokenizer)

    # Modified Data Splitting Logic
    num_samples = len(texts)
    if num_samples < 3:
        print(f"Warning: Corpus size ({num_samples}) is too small for a proper train/validation/test split. Using all data for training.")
        train_texts = texts
        val_texts = []
        test_texts = []
    else:
        # First split: train_val and test
        train_val_texts, test_texts = train_test_split(texts, test_size=cfg.train_test_split_ratio, random_state=cfg.random_state)
        # Defensive check for test_texts being empty
        if not test_texts and train_val_texts:
            test_texts.append(train_val_texts.pop()) # Move one item to test if test is empty and train_val is not

        # Second split: train and validation from train_val
        if len(train_val_texts) >= 2: # Need at least two samples to split into train and val
            val_split_size = cfg.train_val_split_ratio
            train_texts, val_texts = train_test_split(train_val_texts, test_size=val_split_size, random_state=cfg.random_state)
            # Defensive check for val_texts being empty
            if not val_texts and train_texts:
                val_texts.append(train_texts.pop()) # Move one item to val if val is empty and train is not
        else:
            train_texts = train_val_texts
            val_texts = []

    print(f"Train size: {len(train_texts)}, Validation size: {len(val_texts)}, Test size: {len(test_texts)}")

    # Create datasets
    train_dataset = TextDataset(train_texts, tokenizer)
    val_dataset = TextDataset(val_texts, tokenizer)
    test_dataset = TextDataset(test_texts, tokenizer)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True,
                                collate_fn=lambda b: collate_batch(b, tokenizer))
    val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False,
                              collate_fn=lambda b: collate_batch(b, tokenizer))
    test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False,
                               collate_fn=lambda b: collate_batch(b, tokenizer))

    model = ARSLM(tokenizer, emb_dim=cfg.emb_dim, hidden_dim=cfg.hidden_dim, num_layers=cfg.num_layers, dropout_prob=cfg.dropout_prob).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate)
    scheduler = StepLR(optimizer, step_size=cfg.lr_scheduler_step_size, gamma=cfg.lr_scheduler_gamma)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    n_epochs = cfg.n_epochs
    if len(train_texts) == 0: # Check if train_texts is empty, not inputs.numel()
        print("No training data available. Skipping training.")
        return model, tokenizer

    # Initialize best validation loss for saving the best model
    best_val_loss = float('inf')
    # Create directories for saving if they don't exist
    os.makedirs(cfg.model_save_dir, exist_ok=True)
    os.makedirs(cfg.tokenizer_save_dir, exist_ok=True)

    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        num_batches = len(train_loader)
        if num_batches > 0:
            for batch_idx, (inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                logits, gates = model(inputs)
                b, seq, v = logits.shape
                loss = loss_fn(logits.view(b*seq, v), targets.reshape(b*seq))
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad_norm)
                optimizer.step()
                total_loss += loss.item()
            avg_train_loss = total_loss / num_batches
        else:
            avg_train_loss = float('inf') # Set to inf if no training batches

        scheduler.step()

        if (epoch+1) % 50 == 0 or epoch==0:
            print(f"Epoch {epoch+1}/{n_epochs} — Avg Train Loss: {avg_train_loss:.4f}")
            print(f"Current learning rate: {scheduler.get_last_lr()[0]:.6f}")

            # Validation step with check for empty loader
            if len(val_loader) > 0:
                model.eval()
                total_val_loss = 0
                with torch.no_grad():
                    for inputs_val, targets_val in val_loader:
                        inputs_val, targets_val = inputs_val.to(device), targets_val.to(device)
                        logits_val, _ = model(inputs_val)
                        b_val, seq_val, v_val = logits_val.shape
                        loss_val = loss_fn(logits_val.view(b_val*seq_val, v_val), targets_val.reshape(b_val*seq_val))
                        total_val_loss += loss_val.item()
                avg_val_loss = total_val_loss / len(val_loader)
                print(f"Validation Loss: {avg_val_loss:.4f}")

                # Save the model if this is the best validation loss so far
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    print(f"Saving best model with validation loss: {best_val_loss:.4f}")
                    torch.save(model.state_dict(), os.path.join(cfg.model_save_dir, "best_model.pt"))
                    tokenizer.tokenizer.save_pretrained(cfg.tokenizer_save_dir)
                    # Save the configuration as well
                    torch.save(cfg, cfg.config_save_path)

            else:
                print("No validation data available. Skipping validation.")
                avg_val_loss = float('inf')

    print("\n=== Evaluation ===")
    model.eval()
    total_test_loss = 0
    total_perplexity = 0
    if len(test_loader) > 0:
        with torch.no_grad():
            for inputs_test, targets_test in test_loader:
                inputs_test, targets_test = inputs_test.to(device), targets_test.to(device)
                logits_test, _ = model(inputs_test)
                b_test, seq_test, v_test = logits_test.shape
                loss_test = loss_fn(logits_test.view(b_test*seq_test, v_test), targets_test.reshape(b_test*seq_test))
                total_test_loss += loss_test.item()
                total_perplexity += torch.exp(loss_test).item()

        avg_test_loss = total_test_loss / len(test_loader)
        avg_test_perplexity = total_perplexity / len(test_loader)
        print(f"Test Loss: {avg_test_loss:.4f}")
        print(f"Perplexity on test data: {avg_test_perplexity:.4f}")
    else:
        print("No test data available. Skipping test evaluation.")
        avg_test_loss = float('inf')
        avg_test_perplexity = float('inf')

    # demo generation
    context = cfg.generation_context
    idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(context) ], dtype=torch.long, device=device)
    out_ids = model.generate(idx, max_new_tokens=cfg.max_new_tokens, temperature=cfg.temperature, top_k=cfg.top_k)[0]
    print("\n=== Generated ===")
    print(tokenizer.decode(out_ids))
    return model, tokenizer

def load_model_and_tokenizer(model_path: str, tokenizer_path: str, config_path: str, device):
    """
    Loads a pre-trained ARSLM model and its tokenizer along with its configuration.
    """
    # Add Config class to safe globals for unpickling
    torch.serialization.add_safe_globals([Config])
    # Load the configuration first
    cfg = torch.load(config_path)

    loaded_tokenizer = AdvancedTokenizer(pretrained_model_name_or_path=tokenizer_path)

    # Instantiate the model using the loaded configuration
    loaded_model = ARSLM(loaded_tokenizer,
                         emb_dim=cfg.emb_dim,
                         hidden_dim=cfg.hidden_dim,
                         num_layers=cfg.num_layers,
                         dropout_prob=cfg.dropout_prob).to(device)

    loaded_model.load_state_dict(torch.load(model_path, map_location=device))
    loaded_model.eval() # Set model to evaluation mode

    print(f"Model loaded from {model_path}")
    print(f"Tokenizer loaded from {tokenizer_path}")
    print(f"Configuration loaded from {config_path}")

    return loaded_model, loaded_tokenizer, cfg

def deploy_model(model, tokenizer, config, device, input_text: str):
    """
    Function to deploy the model for inference (text generation).
    """
    model.eval()
    with torch.no_grad():
        idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(input_text) ], dtype=torch.long, device=device)
        out_ids = model.generate(idx, max_new_tokens=config.max_new_tokens, temperature=config.temperature, top_k=config.top_k)[0]
        generated_text = tokenizer.decode(out_ids)
    return generated_text


cfg_instance = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    # Charger le modèle, le tokenizer et la configuration entraînés
    loaded_model, loaded_tokenizer, loaded_config = load_model_and_tokenizer(
        os.path.join(cfg_instance.model_save_dir, "best_model.pt"),
        cfg_instance.tokenizer_save_dir,
        cfg_instance.config_save_path,
        device
    )

    # Créez une nouvelle instance de Config et modifiez max_new_tokens
    short_generation_config = Config()
    short_generation_config.max_new_tokens = 5 # Set to a shorter value

    # Définir le texte d'entrée pour la génération
    input_text = "Le modèle linguistique apprend"

    # Déployer le modèle pour générer du texte en utilisant la configuration modifiée
    generated_output = deploy_model(loaded_model, loaded_tokenizer, short_generation_config, device, input_text)

    print(f"Texte d'entrée: {input_text}")
    print(f"Texte généré (5 tokens): {generated_output}")

except FileNotFoundError:
    print("Erreur: Le modèle, le tokenizer ou la configuration n'ont pas été trouvés. Veuillez vous assurer que l'entraînement a été effectué et que les fichiers ont été sauvegardés.")
except Exception as e:
    print(f"Une erreur inattendue est survenue lors du déploiement : {e}")

In [None]:
cfg_instance = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    # Charger le modèle, le tokenizer et la configuration entraînés
    loaded_model, loaded_tokenizer, loaded_config = load_model_and_tokenizer(
        os.path.join(cfg_instance.model_save_dir, "best_model.pt"),
        cfg_instance.tokenizer_save_dir,
        cfg_instance.config_save_path,
        device
    )

    # Définir le nouveau texte d'entrée pour la génération
    new_input_text = "L'intelligence artificielle est"

    # Déployer le modèle pour générer du texte
    generated_output = deploy_model(loaded_model, loaded_tokenizer, loaded_config, device, new_input_text)

    print(f"Texte d'entrée: {new_input_text}")
    print(f"Texte généré: {generated_output}")

except FileNotFoundError:
    print("Erreur: Le modèle, le tokenizer ou la configuration n'ont pas été trouvés. Veuillez vous assurer que l'entraînement a été effectué et que les fichiers ont été sauvegardés.")
except Exception as e:
    print(f"Une erreur inattendue est survenue lors du déploiement : {e}")

In [None]:
import os
import torch

# Assuming Config class is available in the environment or can be re-imported if necessary.
# For a standalone check, we can re-define a minimal Config class or instantiate the one from the previous cell.
# Given it's a colab environment, we can assume the Config class is already defined.

cfg = Config()

model_path = os.path.join(cfg.model_save_dir, "best_model.pt")
tokenizer_path = cfg.tokenizer_save_dir
config_path = cfg.config_save_path

print("Vérification des fichiers sauvegardés :")

# Check for model file
if os.path.exists(model_path):
    print(f"Le fichier du modèle a été trouvé à : {model_path}")
else:
    print(f"Le fichier du modèle n'a PAS été trouvé à : {model_path}")

# Check for tokenizer directory
if os.path.isdir(tokenizer_path):
    print(f"Le répertoire du tokenizer a été trouvé à : {tokenizer_path}")
else:
    print(f"Le répertoire du tokenizer n'a PAS été trouvé à : {tokenizer_path}")

# Check for config file
if os.path.exists(config_path):
    print(f"Le fichier de configuration a été trouvé à : {config_path}")
else:
    print(f"Le fichier de configuration n'a PAS été trouvé à : {config_path}")

# Additionally, try to load a small part of the config to ensure it's readable
try:
    if os.path.exists(config_path):
        # Need to add Config to safe_globals for unpickling if not already added
        if 'Config' not in torch.serialization.get_safe_globals():
            torch.serialization.add_safe_globals([Config])
        loaded_cfg = torch.load(config_path)
        print(f"La configuration a été chargée avec succès. emb_dim : {loaded_cfg.emb_dim}")
except Exception as e:
    print(f"Erreur lors du chargement de la configuration : {e}")


In [None]:
cfg_instance = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    # Charger le modèle, le tokenizer et la configuration entraînés
    loaded_model, loaded_tokenizer, loaded_config = load_model_and_tokenizer(
        os.path.join(cfg_instance.model_save_dir, "best_model.pt"),
        cfg_instance.tokenizer_save_dir,
        cfg_instance.config_save_path,
        device
    )

    # Définir le texte d'entrée pour la génération
    input_text = "Le modèle linguistique apprend"

    # Déployer le modèle pour générer du texte
    generated_output = deploy_model(loaded_model, loaded_tokenizer, loaded_config, device, input_text)

    print(f"Texte d'entrée: {input_text}")
    print(f"Texte généré: {generated_output}")

except FileNotFoundError:
    print("Erreur: Le modèle, le tokenizer ou la configuration n'ont pas été trouvés. Veuillez vous assurer que l'entraînement a été effectué et que les fichiers ont été sauvegardés.")
except Exception as e:
    print(f"Une erreur inattendue est survenue lors du déploiement : {e}")

## Réexécuter l'Entraînement

### Subtask:
Réexécuter la fonction `train_demo` avec le nouveau corpus. Cela permettra de tester si le modèle bénéficie d'un jeu de données plus riche.


**Reasoning**:
I will execute the `train_demo()` function, which now uses the updated `Config` class pointing to the larger corpus, to initiate the training process and observe its performance with the richer dataset.



In [None]:
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from transformers import BertTokenizer
from torch.optim.lr_scheduler import StepLR
import os
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# ----------------------------
# Configuration Class
# ----------------------------
class Config:
    """Centralized configuration for the ARSLM model and training."""
    def __init__(self):
        # Model Parameters
        self.emb_dim = 64
        self.hidden_dim = 128
        self.num_layers = 2
        self.dropout_prob = 0.1
        self.max_seq_len = 512 # Added max_seq_len

        # Training Parameters
        self.learning_rate = 1e-3
        self.n_epochs = 200
        self.batch_size = 8
        self.lr_scheduler_step_size = 50
        self.lr_scheduler_gamma = 0.5
        self.clip_grad_norm = 1.0

        # Generation Parameters
        self.max_new_tokens = 15
        self.temperature = 1.0
        self.top_k = 50
        self.generation_context = "hello world"

        # Data Parameters
        self.corpus_file = "/tmp/large_corpus.txt"
        self.max_corpus_lines = 10000 # Max lines to load from corpus_file for demo - UPDATED TO 10,000
        self.train_test_split_ratio = 0.2
        self.train_val_split_ratio = 0.5 # 0.1 of total
        self.random_state = 42

        # Saving/Loading Paths
        self.model_save_dir = "./model_checkpoint"
        self.tokenizer_save_dir = "./tokenizer_checkpoint"
        self.config_save_path = os.path.join(self.model_save_dir, "config.pt") # Path to save config

# ----------------------------
# Advanced Tokenizer (using transformers)
# ----------------------------
class AdvancedTokenizer:
    def __init__(self, vocab_file=None, pretrained_model_name_or_path='bert-base-uncased'):
        if vocab_file:
             self.tokenizer = BertTokenizer(vocab_file)
        else:
             # Check if pretrained_model_name_or_path is a directory for a saved tokenizer
             if os.path.isdir(pretrained_model_name_or_path):
                 self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
             else:
                 self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)

        self.vocab = list(self.tokenizer.vocab.keys())
        self.stoi = self.tokenizer.vocab
        self.itos = {i: w for w, i in self.stoi.items()}

        self.pad_token = self.tokenizer.pad_token
        self.unk_token = self.tokenizer.unk_token
        self.bos_token = self.tokenizer.cls_token # Using [CLS] as BOS for BERT-like tokenizers
        self.eos_token = self.tokenizer.sep_token # Using [SEP] as EOS for BERT-like tokenizers

        self.pad_token_id = self.tokenizer.pad_token_id
        self.unk_token_id = self.tokenizer.unk_token_id
        self.bos_token_id = self.tokenizer.cls_token_id
        self.eos_token_id = self.tokenizer.sep_token_id


    def encode(self, text: str) -> List[int]:
        return self.tokenizer.encode(text, add_special_tokens=False)

    def decode(self, ids: List[int]) -> str:
        return self.tokenizer.decode(ids, skip_special_tokens=True)

    def __len__(self):
        return len(self.vocab)

# ----------------------------
# ARSCell: core adaptive cell
# ----------------------------
class ARSCell(nn.Module):
    """
    ARSCell: computes next hidden state from h_{t-2}, h_{t-1}, and input embedding x_{t-1}
    Implements an 'adapt' mechanism inspired by BenPolySeq:
      h_t = h_{t-1} + gate * transform(h_{t-1}, h_{t-2}, x)
    where gate is a function of the difference (h_{t-1} - h_{t-2}) and context.
    """
    def __init__(self, emb_dim: int, hidden_dim: int, dropout_prob: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        # transform candidate from concat([h_{t-1}, h_{t-2}, x])
        self.candidate_mlp = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim)
        )
        # gate network produces scalar gating factor in (0,1)
        self.gate_net = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        # small residual projection
        self.res_proj = nn.Linear(emb_dim, hidden_dim)
        # Dropout layer
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, h_prev2, h_prev1, x_embed):
        # h_prev2, h_prev1: (batch, hidden_dim)
        # x_embed: (batch, emb_dim)
        # difference signal
        diff = h_prev1 - h_prev2    # (batch, hidden_dim)
        # context vector (concat)
        ctx = torch.cat([h_prev1, h_prev2, x_embed], dim=-1)
        candidate = self.candidate_mlp(ctx)                       # (batch, hidden_dim)
        gate = self.gate_net(ctx).squeeze(-1)                     # (batch,)
        # incorporate residual from input embedding
        residual = self.res_proj(x_embed)
        # ARS-style update: adaptive addition of candidate scaled by gate
        h_t = h_prev1 + gate.unsqueeze(-1) * candidate + 0.1 * residual
        # Apply dropout
        h_t = self.dropout(h_t)
        # optional normalization
        h_t = F.layer_norm(h_t, (self.hidden_dim,))
        return h_t, gate

# ----------------------------
# ARSLM model: embed -> ARSCell(s) -> attention -> head
# ----------------------------
class ARSLM(nn.Module):
    # Modified to accept a tokenizer instance, include attention, and support multiple layers
    def __init__(self, tokenizer, emb_dim=64, hidden_dim=128, num_layers=2, dropout_prob=0.1):
        super().__init__()
        self.tokenizer = tokenizer # Store tokenizer instance
        self.vocab_size = len(tokenizer) # Get vocab size from tokenizer
        self.num_layers = num_layers
        self.emb = nn.Embedding(self.vocab_size, emb_dim)
        # Create a list of ARSCells
        # Pass dropout_prob to ARSCell constructor
        self.cells = nn.ModuleList([ARSCell(emb_dim if i == 0 else hidden_dim, hidden_dim, dropout_prob=dropout_prob) for i in range(num_layers)])

        # Simple additive attention mechanism, applied after the last layer's hidden states
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        self.head = nn.Linear(hidden_dim, self.vocab_size)

    def forward(self, input_ids):
        """
        input_ids: (batch, seq_len)
        We compute autoregressively, passing hidden states through layers.
        Returns logits (batch, seq_len, vocab).
        """
        bsz, seq_len = input_ids.shape
        emb = self.emb(input_ids)  # (b, seq, emb)
        device = emb.device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        # We need two previous states for each layer's ARSCell
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []
        logits = []
        gates = [] # Store gates from the last layer

        for t in range(seq_len):
            x_t = emb[:, t, :] # embedding at position t
            h_t_input = x_t # Input to the first layer is the embedding

            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1] # Use the output of the previous layer

                h_t, gate = cell(h_prev2, h_prev1, h_t_input)

                # Store current layer's hidden state as input for the next layer in the next time step
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next time step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1: # Only store gates from the last layer
                    gates.append(gate.unsqueeze(1))

            # After processing all layers for time step t, the output is the hidden state of the last layer
            last_layer_h_t = current_layer_hidden_states[-1]
            all_last_layer_hidden_states.append(last_layer_h_t.unsqueeze(1))

            # Apply attention and head after the last layer's hidden state
            # Causal Attention: attend over hidden states up to current time step t from the last layer
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, t+1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim) - query is the current last layer hidden state

            # Calculate attention scores
            scores = self.attention(last_layer_history) # (b, t+1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, t+1, 1)

            # Apply attention weights
            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            # Combine last layer's current hidden state with context vector before the head
            attended_h_t = last_layer_h_t + context_vector

            logit = self.head(attended_h_t) # predict token at t (or next token)
            logits.append(logit.unsqueeze(1))


        logits = torch.cat(logits, dim=1)          # (b, seq, vocab)
        gates = torch.cat(gates, dim=1)            # (b, seq) # Gates from the last layer
        return logits, gates

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=20, temperature=1.0, top_k=None):
        """
        idx: (batch, current_seq_len) - input context (token ids)
        Returns extended token ids.
        Modified for multiple layers, attention during generation, and optional top-k-sampling.
        """
        self.eval()
        bsz, seq_len = idx.shape
        device = next(self.parameters()).device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []

        # Process initial context to get starting hidden states for each layer
        input_ids = idx.clone()
        emb = self.emb(input_ids) # (b, seq_len, emb_dim)

        for t in range(seq_len):
            x_t = emb[:,t,:] # embedding at position t
            h_t_input = x_t

            current_layer_hidden_states = []
            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


        out_ids = input_ids.tolist()

        for _ in range(max_new_tokens):
            # Input for the first layer's ARSCell is the embedding of the last generated token
            last_token_ids = torch.tensor([ids[-1] for ids in out_ids], dtype=torch.long, device=device).unsqueeze(1)
            x_embed = self.emb(last_token_ids).squeeze(1) # (b, emb_dim)

            h_t_input = x_embed
            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next generation step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


            # After processing all layers for the new token, apply attention and head
            last_layer_h_t = current_layer_hidden_states[-1] # Hidden state of the last layer for the new token

            # Causal Attention over all last layer hidden states history
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, current_seq_len + 1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim)

            scores = self.attention(last_layer_history) # (b, current_seq_len + 1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, current_seq_len + 1, 1)

            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            attended_h_t = last_layer_h_t + context_vector

            logits = self.head(attended_h_t) / max(1e-6, temperature)

            # Apply top-k sampling
            if top_k is not None:
                # get top k logits
                top_k = min(max(top_k, 1), logits.size(-1))  # Clamp k to be at least 1 and at most vocab size
                # Remove all tokens with a probability less than the top-k'th token's probability
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')

            # Check if all probabilities are zero after filtering/temperature, if so, fall back to argmax
            if torch.all(logits == -float('Inf')):
                # This means all tokens were filtered out or resulted in 0 probability. Choose a random valid token or a default.
                # For now, let's just pick a generic unk_token or a random one if unk is not available/appropriate
                if self.tokenizer.unk_token_id is not None:
                    next_token = torch.tensor([self.tokenizer.unk_token_id] * bsz, device=device)
                else:
                    # Fallback to random if unk token is not defined
                    next_token = torch.randint(0, self.vocab_size, (bsz,), device=device)
            else:
                probs = F.softmax(logits, dim=-1)
                # Ensure that if probabilities sum to zero (e.g., all logits were -inf before soft max),
                # we handle it gracefully, though the above check should prevent it.
                if torch.sum(probs) == 0: # This case should ideally not happen after the -inf check
                    if self.tokenizer.unk_token_id is not None:
                        next_token = torch.tensor([self.tokenizer.unk_token_id] * bsz, device=device)
                    else:
                        next_token = torch.randint(0, self.vocab_size, (bsz,), device=device)
                else:
                    next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)


            # append the newly generated token
            for i in range(bsz):
                out_ids[i].append(int(next_token[i].item()))


        return out_ids

# ----------------------------
# TextDataset class for structured data handling
# ----------------------------
class TextDataset(Dataset):
    def __init__(self, texts: list, tokenizer):
        self.texts = texts
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]

# Modified to accept and use the tokenizer instance AND max_seq_len
def collate_batch(batch_texts: List[str], tokenizer, max_seq_len: int):
    # Encode each text, add BOS/EOS, and convert to tensor
    encoded = []
    for t in batch_texts:
        ids = [tokenizer.bos_token_id] + tokenizer.encode(t) + [tokenizer.eos_token_id]
        # Truncate if longer than max_seq_len, ensuring EOS is present
        if len(ids) > max_seq_len:
            ids = ids[:max_seq_len - 1] + [tokenizer.eos_token_id] if max_seq_len > 1 else ids[:max_seq_len]
        encoded.append(torch.tensor(ids, dtype=torch.long))

    # Pad to max_len within the batch, which will be at most max_seq_len due to truncation
    max_len_in_batch = max([x.size(0) for x in encoded])
    padded_batch = torch.full((len(encoded), max_len_in_batch), tokenizer.pad_token_id, dtype=torch.long)
    for i, x in enumerate(encoded):
        padded_batch[i, :x.size(0)] = x

    # Prepare inputs and targets for language modeling
    inputs = padded_batch[:, :-1]
    targets = padded_batch[:, 1:]

    return inputs, targets


# Modified to use AdvancedTokenizer and handle multiple layers, and DataLoaders
def train_demo():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize configuration
    cfg = Config()

    corpus_file = cfg.corpus_file
    texts = []
    if os.path.exists(corpus_file):
        print(f"Loading corpus from {corpus_file}...")
        with open(corpus_file, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i < cfg.max_corpus_lines:
                    texts.append(line.strip())
                else:
                    break
        print(f"Loaded {len(texts)} lines.")
    else:
        print(f"Corpus file not found at {corpus_file}. Using a small toy corpus instead.")
        texts = [
            "hello world this is ars",
            "the system adapts to its history",
            "benpolyseq demonstrates adaptive sequences",
            "ars can inspire new network protocols",
            "self optimizing systems are possible"
        ]

    tokenizer = AdvancedTokenizer() # Initialize AdvancedTokenizer
    vocab_size = len(tokenizer)

    # Modified Data Splitting Logic
    num_samples = len(texts)
    if num_samples < 3:
        print(f"Warning: Corpus size ({num_samples}) is too small for a proper train/validation/test split. Using all data for training.")
        train_texts = texts
        val_texts = []
        test_texts = []
    else:
        # First split: train_val and test
        train_val_texts, test_texts = train_test_split(texts, test_size=cfg.train_test_split_ratio, random_state=cfg.random_state)
        # Defensive check for test_texts being empty
        if not test_texts and train_val_texts:
            test_texts.append(train_val_texts.pop()) # Move one item to test if test is empty and train_val is not

        # Second split: train and validation from train_val
        if len(train_val_texts) >= 2: # Need at least two samples to split into train and val
            val_split_size = cfg.train_val_split_ratio
            train_texts, val_texts = train_test_split(train_val_texts, test_size=val_split_size, random_state=cfg.random_state)
            # Defensive check for val_texts being empty
            if not val_texts and train_texts:
                val_texts.append(train_texts.pop()) # Move one item to val if val is empty and train is not
        else:
            train_texts = train_val_texts
            val_texts = []

    print(f"Train size: {len(train_texts)}, Validation size: {len(val_texts)}, Test size: {len(test_texts)}")

    # Create datasets
    train_dataset = TextDataset(train_texts, tokenizer)
    val_dataset = TextDataset(val_texts, tokenizer)
    test_dataset = TextDataset(test_texts, tokenizer)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True,
                                collate_fn=lambda b: collate_batch(b, tokenizer, cfg.max_seq_len))
    val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False,
                              collate_fn=lambda b: collate_batch(b, tokenizer, cfg.max_seq_len))
    test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False,
                               collate_fn=lambda b: collate_batch(b, tokenizer, cfg.max_seq_len))

    model = ARSLM(tokenizer, emb_dim=cfg.emb_dim, hidden_dim=cfg.hidden_dim, num_layers=cfg.num_layers, dropout_prob=cfg.dropout_prob).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate)
    scheduler = StepLR(optimizer, step_size=cfg.lr_scheduler_step_size, gamma=cfg.lr_scheduler_gamma)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    n_epochs = cfg.n_epochs
    if len(train_texts) == 0: # Check if train_texts is empty, not inputs.numel()
        print("No training data available. Skipping training.")
        return model, tokenizer

    # Initialize best validation loss for saving the best model
    best_val_loss = float('inf')
    # Create directories for saving if they don't exist
    os.makedirs(cfg.model_save_dir, exist_ok=True)
    os.makedirs(cfg.tokenizer_save_dir, exist_ok=True)

    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        num_batches = len(train_loader)
        if num_batches > 0:
            for batch_idx, (inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                logits, gates = model(inputs)
                b, seq, v = logits.shape
                loss = loss_fn(logits.view(b*seq, v), targets.reshape(b*seq))
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad_norm)
                optimizer.step()
                total_loss += loss.item()
            avg_train_loss = total_loss / num_batches
        else:
            avg_train_loss = float('inf') # Set to inf if no training batches

        scheduler.step()

        if (epoch+1) % 50 == 0 or epoch==0:
            print(f"Epoch {epoch+1}/{n_epochs} — Avg Train Loss: {avg_train_loss:.4f}")
            print(f"Current learning rate: {scheduler.get_last_lr()[0]:.6f}")

            # Validation step with check for empty loader
            if len(val_loader) > 0:
                model.eval()
                total_val_loss = 0
                with torch.no_grad():
                    for inputs_val, targets_val in val_loader:
                        inputs_val, targets_val = inputs_val.to(device), targets_val.to(device)
                        logits_val, _ = model(inputs_val)
                        b_val, seq_val, v_val = logits_val.shape
                        loss_val = loss_fn(logits_val.view(b_val*seq_val, v_val), targets_val.reshape(b_val*seq_val))
                        total_val_loss += loss_val.item()
                avg_val_loss = total_val_loss / len(val_loader)
                print(f"Validation Loss: {avg_val_loss:.4f}")

                # Save the model if this is the best validation loss so far
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    print(f"Saving best model with validation loss: {best_val_loss:.4f}")
                    torch.save(model.state_dict(), os.path.join(cfg.model_save_dir, "best_model.pt"))
                    tokenizer.tokenizer.save_pretrained(cfg.tokenizer_save_dir)
                    # Save the configuration as well
                    torch.save(cfg, cfg.config_save_path)

            else:
                print("No validation data available. Skipping validation.")
                avg_val_loss = float('inf')

    print("\n=== Evaluation ===")
    model.eval()
    total_test_loss = 0
    total_perplexity = 0
    if len(test_loader) > 0:
        with torch.no_grad():
            for inputs_test, targets_test in test_loader:
                inputs_test, targets_test = inputs_test.to(device), targets_test.to(device)
                logits_test, _ = model(inputs_test)
                b_test, seq_test, v_test = logits_test.shape
                loss_test = loss_fn(logits_test.view(b_test*seq_test, v_test), targets_test.reshape(b_test*seq_test))
                total_test_loss += loss_test.item()
                total_perplexity += torch.exp(loss_test).item()

        avg_test_loss = total_test_loss / len(test_loader)
        avg_test_perplexity = total_perplexity / len(test_loader)
        print(f"Test Loss: {avg_test_loss:.4f}")
        print(f"Perplexity on test data: {avg_test_perplexity:.4f}")
    else:
        print("No test data available. Skipping test evaluation.")
        avg_test_loss = float('inf')
        avg_test_perplexity = float('inf')

    # demo generation
    context = cfg.generation_context
    idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(context) ], dtype=torch.long, device=device)
    out_ids = model.generate(idx, max_new_tokens=cfg.max_new_tokens, temperature=cfg.temperature, top_k=cfg.top_k)[0]
    print("\n=== Generated ===")
    print(tokenizer.decode(out_ids))
    return model, tokenizer

def load_model_and_tokenizer(model_path: str, tokenizer_path: str, config_path: str, device):
    """
    Loads a pre-trained ARSLM model and its tokenizer along with its configuration.
    """
    # Add Config class to safe globals for unpickling
    torch.serialization.add_safe_globals([Config])
    # Load the configuration first
    cfg = torch.load(config_path)

    loaded_tokenizer = AdvancedTokenizer(pretrained_model_name_or_path=tokenizer_path)

    # Instantiate the model using the loaded configuration
    loaded_model = ARSLM(loaded_tokenizer,
                         emb_dim=cfg.emb_dim,
                         hidden_dim=cfg.hidden_dim,
                         num_layers=cfg.num_layers,
                         dropout_prob=cfg.dropout_prob).to(device)

    loaded_model.load_state_dict(torch.load(model_path, map_location=device))
    loaded_model.eval() # Set model to evaluation mode

    print(f"Model loaded from {model_path}")
    print(f"Tokenizer loaded from {tokenizer_path}")
    print(f"Configuration loaded from {config_path}")

    return loaded_model, loaded_tokenizer, cfg

def deploy_model(model, tokenizer, config, device, input_text: str):
    """
    Function to deploy the model for inference (text generation).
    """
    model.eval()
    with torch.no_grad():
        idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(input_text) ], dtype=torch.long, device=device)
        out_ids = model.generate(idx, max_new_tokens=config.max_new_tokens, temperature=config.temperature, top_k=config.top_k)[0]
        generated_text = tokenizer.decode(out_ids)
    return generated_text

# ----------------------------
# Run demo if main
# ----------------------------
if __name__ == "__main__":
    random.seed(0)
    torch.manual_seed(0)
    model, tokenizer = train_demo()

**Reasoning**:
I will execute the provided code block which contains the entire model definition and the `train_demo()` function. This will re-run the training and evaluation process, now utilizing the larger corpus as configured in the `Config` class, and display the results.



In [None]:
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from transformers import BertTokenizer
from torch.optim.lr_scheduler import StepLR
import os
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# ----------------------------
# Configuration Class
# ----------------------------
class Config:
    """Centralized configuration for the ARSLM model and training."""
    def __init__(self):
        # Model Parameters
        self.emb_dim = 64
        self.hidden_dim = 128
        self.num_layers = 2
        self.dropout_prob = 0.1

        # Training Parameters
        self.learning_rate = 1e-3
        self.n_epochs = 200
        self.batch_size = 8
        self.lr_scheduler_step_size = 50
        self.lr_scheduler_gamma = 0.5
        self.clip_grad_norm = 1.0

        # Generation Parameters
        self.max_new_tokens = 15
        self.temperature = 1.0
        self.top_k = 50
        self.generation_context = "hello world"

        # Data Parameters
        self.corpus_file = "/tmp/large_corpus.txt"
        self.max_corpus_lines = 10000 # Max lines to load from corpus_file for demo - UPDATED TO 10,000
        self.train_test_split_ratio = 0.2
        self.train_val_split_ratio = 0.5 # 0.1 of total
        self.random_state = 42

        # Saving/Loading Paths
        self.model_save_dir = "./model_checkpoint"
        self.tokenizer_save_dir = "./tokenizer_checkpoint"
        self.config_save_path = os.path.join(self.model_save_dir, "config.pt") # Path to save config

# ----------------------------
# Advanced Tokenizer (using transformers)
# ----------------------------
class AdvancedTokenizer:
    def __init__(self, vocab_file=None, pretrained_model_name_or_path='bert-base-uncased'):
        if vocab_file:
             self.tokenizer = BertTokenizer(vocab_file)
        else:
             # Check if pretrained_model_name_or_path is a directory for a saved tokenizer
             if os.path.isdir(pretrained_model_name_or_path):
                 self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
             else:
                 self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)

        self.vocab = list(self.tokenizer.vocab.keys())
        self.stoi = self.tokenizer.vocab
        self.itos = {i: w for w, i in self.stoi.items()}

        self.pad_token = self.tokenizer.pad_token
        self.unk_token = self.tokenizer.unk_token
        self.bos_token = self.tokenizer.cls_token # Using [CLS] as BOS for BERT-like tokenizers
        self.eos_token = self.tokenizer.sep_token # Using [SEP] as EOS for BERT-like tokenizers

        self.pad_token_id = self.tokenizer.pad_token_id
        self.unk_token_id = self.tokenizer.unk_token_id
        self.bos_token_id = self.tokenizer.cls_token_id
        self.eos_token_id = self.tokenizer.sep_token_id


    def encode(self, text: str) -> List[int]:
        return self.tokenizer.encode(text, add_special_tokens=False)

    def decode(self, ids: List[int]) -> str:
        return self.tokenizer.decode(ids, skip_special_tokens=True)

    def __len__(self):
        return len(self.vocab)

# ----------------------------
# ARSCell: core adaptive cell
# ----------------------------
class ARSCell(nn.Module):
    """
    ARSCell: computes next hidden state from h_{t-2}, h_{t-1}, and input embedding x_{t-1}
    Implements an 'adapt' mechanism inspired by BenPolySeq:
      h_t = h_{t-1} + gate * transform(h_{t-1}, h_{t-2}, x)
    where gate is a function of the difference (h_{t-1} - h_{t-2}) and context.
    """
    def __init__(self, emb_dim: int, hidden_dim: int, dropout_prob: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        # transform candidate from concat([h_{t-1}, h_{t-2}, x])
        self.candidate_mlp = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim)
        )
        # gate network produces scalar gating factor in (0,1)
        self.gate_net = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        # small residual projection
        self.res_proj = nn.Linear(emb_dim, hidden_dim)
        # Dropout layer
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, h_prev2, h_prev1, x_embed):
        # h_prev2, h_prev1: (batch, hidden_dim)
        # x_embed: (batch, emb_dim)
        # difference signal
        diff = h_prev1 - h_prev2    # (batch, hidden_dim)
        # context vector (concat)
        ctx = torch.cat([h_prev1, h_prev2, x_embed], dim=-1)
        candidate = self.candidate_mlp(ctx)                       # (batch, hidden_dim)
        gate = self.gate_net(ctx).squeeze(-1)                     # (batch,)
        # incorporate residual from input embedding
        residual = self.res_proj(x_embed)
        # ARS-style update: adaptive addition of candidate scaled by gate
        h_t = h_prev1 + gate.unsqueeze(-1) * candidate + 0.1 * residual
        # Apply dropout
        h_t = self.dropout(h_t)
        # optional normalization
        h_t = F.layer_norm(h_t, (self.hidden_dim,))
        return h_t, gate

# ----------------------------
# ARSLM model: embed -> ARSCell(s) -> attention -> head
# ----------------------------
class ARSLM(nn.Module):
    # Modified to accept a tokenizer instance, include attention, and support multiple layers
    def __init__(self, tokenizer, emb_dim=64, hidden_dim=128, num_layers=2, dropout_prob=0.1):
        super().__init__()
        self.tokenizer = tokenizer # Store tokenizer instance
        self.vocab_size = len(tokenizer) # Get vocab size from tokenizer
        self.num_layers = num_layers
        self.emb = nn.Embedding(self.vocab_size, emb_dim)
        # Create a list of ARSCells
        # Pass dropout_prob to ARSCell constructor
        self.cells = nn.ModuleList([ARSCell(emb_dim if i == 0 else hidden_dim, hidden_dim, dropout_prob=dropout_prob) for i in range(num_layers)])

        # Simple additive attention mechanism, applied after the last layer's hidden states
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        self.head = nn.Linear(hidden_dim, self.vocab_size)

    def forward(self, input_ids):
        """
        input_ids: (batch, seq_len)
        We compute autoregressively, passing hidden states through layers.
        Returns logits (batch, seq_len, vocab).
        """
        bsz, seq_len = input_ids.shape
        emb = self.emb(input_ids)  # (b, seq, emb)
        device = emb.device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        # We need two previous states for each layer's ARSCell
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []
        logits = []
        gates = [] # Store gates from the last layer

        for t in range(seq_len):
            x_t = emb[:, t, :] # embedding at position t
            h_t_input = x_t # Input to the first layer is the embedding

            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                # The input to subsequent layers is the hidden state of the previous layer
                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1] # Use the output of the previous layer

                h_t, gate = cell(h_prev2, h_prev1, h_t_input)

                # Store current layer's hidden state as input for the next layer in the next time step
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next time step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1: # Only store gates from the last layer
                    gates.append(gate.unsqueeze(1))

            # After processing all layers for time step t, the output is the hidden state of the last layer
            last_layer_h_t = current_layer_hidden_states[-1]
            all_last_layer_hidden_states.append(last_layer_h_t.unsqueeze(1))

            # Apply attention and head after the last layer's hidden state
            # Causal Attention: attend over hidden states up to current time step t from the last layer
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, t+1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim) - query is the current last layer hidden state

            # Calculate attention scores
            scores = self.attention(last_layer_history) # (b, t+1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, t+1, 1)

            # Apply attention weights
            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            # Combine last layer's current hidden state with context vector before the head
            attended_h_t = last_layer_h_t + context_vector

            logit = self.head(attended_h_t) # predict token at t (or next token)
            logits.append(logit.unsqueeze(1))


        logits = torch.cat(logits, dim=1)          # (b, seq, vocab)
        gates = torch.cat(gates, dim=1)            # (b, seq) # Gates from the last layer
        return logits, gates

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=20, temperature=1.0, top_k=None):
        """
        idx: (batch, current_seq_len) - input context (token ids)
        Returns extended token ids.
        Modified for multiple layers, attention during generation, and optional top-k sampling.
        """
        self.eval()
        bsz, seq_len = idx.shape
        device = next(self.parameters()).device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []

        # Process initial context to get starting hidden states for each layer
        input_ids = idx.clone()
        emb = self.emb(input_ids) # (b, seq_len, emb_dim)

        for t in range(seq_len):
            x_t = emb[:,t,:] # embedding at position t
            h_t_input = x_t

            current_layer_hidden_states = []
            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


        out_ids = input_ids.tolist()

        for _ in range(max_new_tokens):
            # Input for the first layer's ARSCell is the embedding of the last generated token
            last_token_ids = torch.tensor([ids[-1] for ids in out_ids], dtype=torch.long, device=device).unsqueeze(1)
            x_embed = self.emb(last_token_ids).squeeze(1) # (b, emb_dim)

            h_t_input = x_embed
            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next generation step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


            # After processing all layers for the new token, apply attention and head
            last_layer_h_t = current_layer_hidden_states[-1] # Hidden state of the last layer for the new token

            # Causal Attention over all last layer hidden states history
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, current_seq_len + 1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim)

            scores = self.attention(last_layer_history) # (b, current_seq_len + 1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, current_seq_len + 1, 1)

            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            attended_h_t = last_layer_h_t + context_vector

            logits = self.head(attended_h_t) / max(1e-6, temperature)

            # Apply top-k sampling
            if top_k is not None:
                # get top k logits
                top_k = min(max(top_k, 1), logits.size(-1))  # Clamp k to be at least 1 and at most vocab size
                # Remove all tokens with a probability less than the top-k'th token's probability
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')

            # Check if all probabilities are zero after filtering/temperature, if so, fall back to argmax
            if torch.all(logits == -float('Inf')):
                # This means all tokens were filtered out or resulted in 0 probability. Choose a random valid token or a default.
                # For now, let's just pick a generic unk_token or a random one if unk is not available/appropriate
                if self.tokenizer.unk_token_id is not None:
                    next_token = torch.tensor([self.tokenizer.unk_token_id] * bsz, device=device)
                else:
                    # Fallback to random if unk token is not defined
                    next_token = torch.randint(0, self.vocab_size, (bsz,), device=device)
            else:
                probs = F.softmax(logits, dim=-1)
                # Ensure that if probabilities sum to zero (e.g., all logits were -inf before soft max),
                # we handle it gracefully, though the above check should prevent it.
                if torch.sum(probs) == 0: # This case should ideally not happen after the -inf check
                    if self.tokenizer.unk_token_id is not None:
                        next_token = torch.tensor([self.tokenizer.unk_token_id] * bsz, device=device)
                    else:
                        next_token = torch.randint(0, self.vocab_size, (bsz,), device=device)
                else:
                    next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)


            # append the newly generated token
            for i in range(bsz):
                out_ids[i].append(int(next_token[i].item()))


        return out_ids

# ----------------------------
# TextDataset class for structured data handling
# ----------------------------
class TextDataset(Dataset):
    def __init__(self, texts: list, tokenizer):
        self.texts = texts
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]

# Modified to accept and use the tokenizer instance
def collate_batch(batch_texts: List[str], tokenizer):
    # Encode each text, add BOS/EOS, and convert to tensor
    encoded = []
    for t in batch_texts:
        ids = [tokenizer.bos_token_id] + tokenizer.encode(t) + [tokenizer.eos_token_id]
        encoded.append(torch.tensor(ids, dtype=torch.long))

    # Pad to max len within the batch
    max_len = max([x.size(0) for x in encoded])
    padded_batch = torch.full((len(encoded), max_len), tokenizer.pad_token_id, dtype=torch.long)
    for i, x in enumerate(encoded):
        padded_batch[i, :x.size(0)] = x

    # Prepare inputs and targets for language modeling
    inputs = padded_batch[:, :-1]
    targets = padded_batch[:, 1:]

    return inputs, targets


# Modified to use AdvancedTokenizer and handle multiple layers, and DataLoaders
def train_demo():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize configuration
    cfg = Config()

    corpus_file = cfg.corpus_file
    texts = []
    if os.path.exists(corpus_file):
        print(f"Loading corpus from {corpus_file}...")
        with open(corpus_file, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i < cfg.max_corpus_lines:
                    texts.append(line.strip())
                else:
                    break
        print(f"Loaded {len(texts)} lines.")
    else:
        print(f"Corpus file not found at {corpus_file}. Using a small toy corpus instead.")
        texts = [
            "hello world this is ars",
            "the system adapts to its history",
            "benpolyseq demonstrates adaptive sequences",
            "ars can inspire new network protocols",
            "self optimizing systems are possible"
        ]

    tokenizer = AdvancedTokenizer() # Initialize AdvancedTokenizer
    vocab_size = len(tokenizer)

    # Modified Data Splitting Logic
    num_samples = len(texts)
    if num_samples < 3:
        print(f"Warning: Corpus size ({num_samples}) is too small for a proper train/validation/test split. Using all data for training.")
        train_texts = texts
        val_texts = []
        test_texts = []
    else:
        # First split: train_val and test
        train_val_texts, test_texts = train_test_split(texts, test_size=cfg.train_test_split_ratio, random_state=cfg.random_state)
        # Defensive check for test_texts being empty
        if not test_texts and train_val_texts:
            test_texts.append(train_val_texts.pop()) # Move one item to test if test is empty and train_val is not

        # Second split: train and validation from train_val
        if len(train_val_texts) >= 2: # Need at least two samples to split into train and val
            val_split_size = cfg.train_val_split_ratio
            train_texts, val_texts = train_test_split(train_val_texts, test_size=val_split_size, random_state=cfg.random_state)
            # Defensive check for val_texts being empty
            if not val_texts and train_texts:
                val_texts.append(train_texts.pop()) # Move one item to val if val is empty and train is not
        else:
            train_texts = train_val_texts
            val_texts = []

    print(f"Train size: {len(train_texts)}, Validation size: {len(val_texts)}, Test size: {len(test_texts)}")

    # Create datasets
    train_dataset = TextDataset(train_texts, tokenizer)
    val_dataset = TextDataset(val_texts, tokenizer)
    test_dataset = TextDataset(test_texts, tokenizer)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True,
                                collate_fn=lambda b: collate_batch(b, tokenizer))
    val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False,
                              collate_fn=lambda b: collate_batch(b, tokenizer))
    test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False,
                               collate_fn=lambda b: collate_batch(b, tokenizer))

    model = ARSLM(tokenizer, emb_dim=cfg.emb_dim, hidden_dim=cfg.hidden_dim, num_layers=cfg.num_layers, dropout_prob=cfg.dropout_prob).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate)
    scheduler = StepLR(optimizer, step_size=cfg.lr_scheduler_step_size, gamma=cfg.lr_scheduler_gamma)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    n_epochs = cfg.n_epochs
    if len(train_texts) == 0: # Check if train_texts is empty, not inputs.numel()
        print("No training data available. Skipping training.")
        return model, tokenizer

    # Initialize best validation loss for saving the best model
    best_val_loss = float('inf')
    # Create directories for saving if they don't exist
    os.makedirs(cfg.model_save_dir, exist_ok=True)
    os.makedirs(cfg.tokenizer_save_dir, exist_ok=True)

    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        num_batches = len(train_loader)
        if num_batches > 0:
            for batch_idx, (inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                logits, gates = model(inputs)
                b, seq, v = logits.shape
                loss = loss_fn(logits.view(b*seq, v), targets.reshape(b*seq))
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad_norm)
                optimizer.step()
                total_loss += loss.item()
            avg_train_loss = total_loss / num_batches
        else:
            avg_train_loss = float('inf') # Set to inf if no training batches

        scheduler.step()

        if (epoch+1) % 50 == 0 or epoch==0:
            print(f"Epoch {epoch+1}/{n_epochs} — Avg Train Loss: {avg_train_loss:.4f}")
            print(f"Current learning rate: {scheduler.get_last_lr()[0]:.6f}")

            # Validation step with check for empty loader
            if len(val_loader) > 0:
                model.eval()
                total_val_loss = 0
                with torch.no_grad():
                    for inputs_val, targets_val in val_loader:
                        inputs_val, targets_val = inputs_val.to(device), targets_val.to(device)
                        logits_val, _ = model(inputs_val)
                        b_val, seq_val, v_val = logits_val.shape
                        loss_val = loss_fn(logits_val.view(b_val*seq_val, v_val), targets_val.reshape(b_val*seq_val))
                        total_val_loss += loss_val.item()
                avg_val_loss = total_val_loss / len(val_loader)
                print(f"Validation Loss: {avg_val_loss:.4f}")

                # Save the model if this is the best validation loss so far
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    print(f"Saving best model with validation loss: {best_val_loss:.4f}")
                    torch.save(model.state_dict(), os.path.join(cfg.model_save_dir, "best_model.pt"))
                    tokenizer.tokenizer.save_pretrained(cfg.tokenizer_save_dir)
                    # Save the configuration as well
                    torch.save(cfg, cfg.config_save_path)

            else:
                print("No validation data available. Skipping validation.")
                avg_val_loss = float('inf')

    print("\n=== Evaluation ===")
    model.eval()
    total_test_loss = 0
    total_perplexity = 0
    if len(test_loader) > 0:
        with torch.no_grad():
            for inputs_test, targets_test in test_loader:
                inputs_test, targets_test = inputs_test.to(device), targets_test.to(device)
                logits_test, _ = model(inputs_test)
                b_test, seq_test, v_test = logits_test.shape
                loss_test = loss_fn(logits_test.view(b_test*seq_test, v_test), targets_test.reshape(b_test*seq_test))
                total_test_loss += loss_test.item()
                total_perplexity += torch.exp(loss_test).item()

        avg_test_loss = total_test_loss / len(test_loader)
        avg_test_perplexity = total_perplexity / len(test_loader)
        print(f"Test Loss: {avg_test_loss:.4f}")
        print(f"Perplexity on test data: {avg_test_perplexity:.4f}")
    else:
        print("No test data available. Skipping test evaluation.")
        avg_test_loss = float('inf')
        avg_test_perplexity = float('inf')

    # demo generation
    context = cfg.generation_context
    idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(context) ], dtype=torch.long, device=device)
    out_ids = model.generate(idx, max_new_tokens=cfg.max_new_tokens, temperature=cfg.temperature, top_k=cfg.top_k)[0]
    print("\n=== Generated ===")
    print(tokenizer.decode(out_ids))
    return model, tokenizer

def load_model_and_tokenizer(model_path: str, tokenizer_path: str, config_path: str, device):
    """
    Loads a pre-trained ARSLM model and its tokenizer along with its configuration.
    """
    # Add Config class to safe globals for unpickling
    torch.serialization.add_safe_globals([Config])
    # Load the configuration first
    cfg = torch.load(config_path)

    loaded_tokenizer = AdvancedTokenizer(pretrained_model_name_or_path=tokenizer_path)

    # Instantiate the model using the loaded configuration
    loaded_model = ARSLM(loaded_tokenizer,
                         emb_dim=cfg.emb_dim,
                         hidden_dim=cfg.hidden_dim,
                         num_layers=cfg.num_layers,
                         dropout_prob=cfg.dropout_prob).to(device)

    loaded_model.load_state_dict(torch.load(model_path, map_location=device))
    loaded_model.eval() # Set model to evaluation mode

    print(f"Model loaded from {model_path}")
    print(f"Tokenizer loaded from {tokenizer_path}")
    print(f"Configuration loaded from {config_path}")

    return loaded_model, loaded_tokenizer, cfg

def deploy_model(model, tokenizer, config, device, input_text: str):
    """
    Function to deploy the model for inference (text generation).
    """
    model.eval()
    with torch.no_grad():
        idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(input_text) ], dtype=torch.long, device=device)
        out_ids = model.generate(idx, max_new_tokens=config.max_new_tokens, temperature=config.temperature, top_k=config.top_k)[0]
        generated_text = tokenizer.decode(out_ids)
    return generated_text

# ----------------------------
# Run demo if main
# ----------------------------
if __name__ == "__main__":
    random.seed(0)
    torch.manual_seed(0)
    model, tokenizer = train_demo()

    # Demonstrate loading the saved model and tokenizer
    print("\n=== Demonstrating Model Loading and Deployment ===")
    cfg_instance = Config()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    try:
        loaded_model, loaded_tokenizer, loaded_config = load_model_and_tokenizer(
            os.path.join(cfg_instance.model_save_dir, "best_model.pt"),
            cfg_instance.tokenizer_save_dir,
            cfg_instance.config_save_path,
            device
        )
        # Test generation with loaded model using the deploy_model function
        context = "loaded model generates"
        generated_output = deploy_model(loaded_model, loaded_tokenizer, loaded_config, device, context)
        print("Generated with loaded model:")
        print(generated_output)
    except FileNotFoundError:
        print("No saved model found. Please ensure training completed successfully and saved a model.")


**Reasoning**:
The user wants to execute the `train_demo()` function again. The provided code block in the user's prompt contains the complete and updated Python script, including all necessary class definitions and the `train_demo()` function call within the `if __name__ == "__main__":` block. Executing this will re-run the training process with the larger corpus.



# Task
Ajouter un attribut `max_seq_len` à la classe `Config` avec une valeur par défaut de 512 et modifier la fonction `collate_batch` pour tronquer les séquences à cette longueur maximale lors de l'encodage des textes.

## Ajouter max_seq_len à Config

### Subtask:
Ajouter un attribut `max_seq_len` à la classe `Config` avec une valeur par défaut de 512 (ou une autre valeur appropriée basée sur les contraintes du modèle).


**Reasoning**:
I will update the `Config` class to include `max_seq_len` as requested by the subtask. This centralizes the sequence length setting, making it easier to manage model constraints.



In [None]:
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from transformers import BertTokenizer # Import BertTokenizer here
from torch.optim.lr_scheduler import StepLR # Import StepLR
import os # Import os for file operations
from torch.utils.data import Dataset, DataLoader # Import Dataset and DataLoader
from sklearn.model_selection import train_test_split # Import train_test_split

# ----------------------------
# Configuration Class
# ----------------------------
class Config:
    """Centralized configuration for the ARSLM model and training."""
    def __init__(self):
        # Model Parameters
        self.emb_dim = 64
        self.hidden_dim = 128
        self.num_layers = 2
        self.dropout_prob = 0.1
        self.max_seq_len = 512 # Added max_seq_len

        # Training Parameters
        self.learning_rate = 1e-3
        self.n_epochs = 200
        self.batch_size = 8
        self.lr_scheduler_step_size = 50
        self.lr_scheduler_gamma = 0.5
        self.clip_grad_norm = 1.0

        # Generation Parameters
        self.max_new_tokens = 15
        self.temperature = 1.0
        self.top_k = 50
        self.generation_context = "hello world"

        # Data Parameters
        self.corpus_file = "/tmp/large_corpus.txt"
        self.max_corpus_lines = 1000 # Max lines to load from corpus_file for demo
        self.train_test_split_ratio = 0.2
        self.train_val_split_ratio = 0.5 # 0.1 of total
        self.random_state = 42

        # Saving/Loading Paths
        self.model_save_dir = "./model_checkpoint"
        self.tokenizer_save_dir = "./tokenizer_checkpoint"
        self.config_save_path = os.path.join(self.model_save_dir, "config.pt") # Path to save config

# ----------------------------
# Advanced Tokenizer (using transformers)
# ----------------------------
class AdvancedTokenizer:
    def __init__(self, vocab_file=None, pretrained_model_name_or_path='bert-base-uncased'):
        if vocab_file:
             self.tokenizer = BertTokenizer(vocab_file)
        else:
             # Check if pretrained_model_name_or_path is a directory for a saved tokenizer
             if os.path.isdir(pretrained_model_name_or_path):
                 self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
             else:
                 self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)

        self.vocab = list(self.tokenizer.vocab.keys())
        self.stoi = self.tokenizer.vocab
        self.itos = {i: w for w, i in self.stoi.items()}

        self.pad_token = self.tokenizer.pad_token
        self.unk_token = self.tokenizer.unk_token
        self.bos_token = self.tokenizer.cls_token # Using [CLS] as BOS for BERT-like tokenizers
        self.eos_token = self.tokenizer.sep_token # Using [SEP] as EOS for BERT-like tokenizers

        self.pad_token_id = self.tokenizer.pad_token_id
        self.unk_token_id = self.tokenizer.unk_token_id
        self.bos_token_id = self.tokenizer.cls_token_id
        self.eos_token_id = self.tokenizer.sep_token_id


    def encode(self, text: str) -> List[int]:
        return self.tokenizer.encode(text, add_special_tokens=False)

    def decode(self, ids: List[int]) -> str:
        return self.tokenizer.decode(ids, skip_special_tokens=True)

    def __len__(self):
        return len(self.vocab)

# ----------------------------
# ARSCell: core adaptive cell
# ----------------------------
class ARSCell(nn.Module):
    """
    ARSCell: computes next hidden state from h_{t-2}, h_{t-1}, and input embedding x_{t-1}
    Implements an 'adapt' mechanism inspired by BenPolySeq:
      h_t = h_{t-1} + gate * transform(h_{t-1}, h_{t-2}, x)
    where gate is a function of the difference (h_{t-1} - h_{t-2}) and context.
    """
    def __init__(self, emb_dim: int, hidden_dim: int, dropout_prob: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        # transform candidate from concat([h_{t-1}, h_{t-2}, x])
        self.candidate_mlp = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim)
        )
        # gate network produces scalar gating factor in (0,1)
        self.gate_net = nn.Sequential(
            nn.Linear(hidden_dim*2 + emb_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        # small residual projection
        self.res_proj = nn.Linear(emb_dim, hidden_dim)
        # Dropout layer
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, h_prev2, h_prev1, x_embed):
        # h_prev2, h_prev1: (batch, hidden_dim)
        # x_embed: (batch, emb_dim)
        # difference signal
        diff = h_prev1 - h_prev2    # (batch, hidden_dim)
        # context vector (concat)
        ctx = torch.cat([h_prev1, h_prev2, x_embed], dim=-1)
        candidate = self.candidate_mlp(ctx)                       # (batch, hidden_dim)
        gate = self.gate_net(ctx).squeeze(-1)                     # (batch,)
        # incorporate residual from input embedding
        residual = self.res_proj(x_embed)
        # ARS-style update: adaptive addition of candidate scaled by gate
        h_t = h_prev1 + gate.unsqueeze(-1) * candidate + 0.1 * residual
        # Apply dropout
        h_t = self.dropout(h_t)
        # optional normalization
        h_t = F.layer_norm(h_t, (self.hidden_dim,))
        return h_t, gate

# ----------------------------
# ARSLM model: embed -> ARSCell(s) -> attention -> head
# ----------------------------
class ARSLM(nn.Module):
    # Modified to accept a tokenizer instance, include attention, and support multiple layers
    def __init__(self, tokenizer, emb_dim=64, hidden_dim=128, num_layers=2, dropout_prob=0.1):
        super().__init__()
        self.tokenizer = tokenizer # Store tokenizer instance
        self.vocab_size = len(tokenizer) # Get vocab size from tokenizer
        self.num_layers = num_layers
        self.emb = nn.Embedding(self.vocab_size, emb_dim)
        # Create a list of ARSCells
        # Pass dropout_prob to ARSCell constructor
        self.cells = nn.ModuleList([ARSCell(emb_dim if i == 0 else hidden_dim, hidden_dim, dropout_prob=dropout_prob) for i in range(num_layers)])

        # Simple additive attention mechanism, applied after the last layer's hidden states
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        self.head = nn.Linear(hidden_dim, self.vocab_size)

    def forward(self, input_ids):
        """
        input_ids: (batch, seq_len)
        We compute autoregressively, passing hidden states through layers.
        Returns logits (batch, seq_len, vocab).
        """
        bsz, seq_len = input_ids.shape
        emb = self.emb(input_ids)  # (b, seq, emb)
        device = emb.device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        # We need two previous states for each layer's ARSCell
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []
        logits = []
        gates = [] # Store gates from the last layer

        for t in range(seq_len):
            x_t = emb[:, t, :] # embedding at position t
            h_t_input = x_t # Input to the first layer is the embedding

            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                # The input to subsequent layers is the hidden state of the previous layer
                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1] # Use the output of the previous layer

                h_t, gate = cell(h_prev2, h_prev1, h_t_input)

                # Store current layer's hidden state as input for the next layer in the next time step
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next time step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1: # Only store gates from the last layer
                    gates.append(gate.unsqueeze(1))

            # After processing all layers for time step t, the output is the hidden state of the last layer
            last_layer_h_t = current_layer_hidden_states[-1]
            all_last_layer_hidden_states.append(last_layer_h_t.unsqueeze(1))

            # Apply attention and head after the last layer's hidden state
            # Causal Attention: attend over hidden states up to current time step t from the last layer
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, t+1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim) - query is the current last layer hidden state

            # Calculate attention scores
            scores = self.attention(last_layer_history) # (b, t+1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, t+1, 1)

            # Apply attention weights
            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            # Combine last layer's current hidden state with context vector before the head
            attended_h_t = last_layer_h_t + context_vector

            logit = self.head(attended_h_t) # predict token at t (or next token)
            logits.append(logit.unsqueeze(1))


        logits = torch.cat(logits, dim=1)          # (b, seq, vocab)
        gates = torch.cat(gates, dim=1)            # (b, seq) # Gates from the last layer
        return logits, gates

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=20, temperature=1.0, top_k=None):
        """
        idx: (batch, current_seq_len) - input context (token ids)
        Returns extended token ids.
        Modified for multiple layers, attention during generation, and optional top-k sampling.
        """
        self.eval()
        bsz, seq_len = idx.shape
        device = next(self.parameters()).device
        hidden_dim = self.cells[0].hidden_dim

        # Initialize hidden states for each layer
        h_prev2_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]
        h_prev1_list = [torch.zeros(bsz, hidden_dim, device=device) for _ in range(self.num_layers)]

        # Store hidden states for causal attention (from the last layer)
        all_last_layer_hidden_states = []

        # Process initial context to get starting hidden states for each layer
        input_ids = idx.clone()
        emb = self.emb(input_ids) # (b, seq_len, emb_dim)

        for t in range(seq_len):
            x_t = emb[:,t,:] # embedding at position t
            h_t_input = x_t

            current_layer_hidden_states = []
            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


        out_ids = input_ids.tolist()

        for _ in range(max_new_tokens):
            # Input for the first layer's ARSCell is the embedding of the last generated token
            last_token_ids = torch.tensor([ids[-1] for ids in out_ids], dtype=torch.long, device=device).unsqueeze(1)
            x_embed = self.emb(last_token_ids).squeeze(1) # (b, emb_dim)

            h_t_input = x_embed
            current_layer_hidden_states = []

            for layer in range(self.num_layers):
                cell = self.cells[layer]
                h_prev2 = h_prev2_list[layer]
                h_prev1 = h_prev1_list[layer]

                if layer > 0:
                    h_t_input = current_layer_hidden_states[-1]

                h_t, _ = cell(h_prev2, h_prev1, h_t_input)
                current_layer_hidden_states.append(h_t)

                # Update previous hidden states for the next generation step for this layer
                h_prev2_list[layer] = h_prev1
                h_prev1_list[layer] = h_t

                if layer == self.num_layers - 1:
                    # Add the new last layer hidden state to the history for attention
                    all_last_layer_hidden_states.append(h_t.unsqueeze(1))


            # After processing all layers for the new token, apply attention and head
            last_layer_h_t = current_layer_hidden_states[-1] # Hidden state of the last layer for the new token

            # Causal Attention over all last layer hidden states history
            last_layer_history = torch.cat(all_last_layer_hidden_states, dim=1) # (b, current_seq_len + 1, hidden_dim)
            query = last_layer_h_t.unsqueeze(1) # (b, 1, hidden_dim)

            scores = self.attention(last_layer_history) # (b, current_seq_len + 1, 1)
            attention_weights = F.softmax(scores, dim=1) # (b, current_seq_len + 1, 1)

            context_vector = torch.sum(attention_weights * last_layer_history, dim=1) # (b, hidden_dim)

            attended_h_t = last_layer_h_t + context_vector

            logits = self.head(attended_h_t) / max(1e-6, temperature)

            # Apply top-k sampling
            if top_k is not None:
                # get top k logits
                top_k = min(max(top_k, 1), logits.size(-1))  # Clamp k to be at least 1 and at most vocab size
                # Remove all tokens with a probability less than the top-k'th token's probability
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')

            # Check if all probabilities are zero after filtering/temperature, if so, fall back to argmax
            if torch.all(logits == -float('Inf')):
                # This means all tokens were filtered out or resulted in 0 probability. Choose a random valid token or a default.
                # For now, let's just pick a generic unk_token or a random one if unk is not available/appropriate
                if self.tokenizer.unk_token_id is not None:
                    next_token = torch.tensor([self.tokenizer.unk_token_id] * bsz, device=device)
                else:
                    # Fallback to random if unk token is not defined
                    next_token = torch.randint(0, self.vocab_size, (bsz,), device=device)
            else:
                probs = F.softmax(logits, dim=-1)
                # Ensure that if probabilities sum to zero (e.g., all logits were -inf before soft max),
                # we handle it gracefully, though the above check should prevent it.
                if torch.sum(probs) == 0: # This case should ideally not happen after the -inf check
                    if self.tokenizer.unk_token_id is not None:
                        next_token = torch.tensor([self.tokenizer.unk_token_id] * bsz, device=device)
                    else:
                        next_token = torch.randint(0, self.vocab_size, (bsz,), device=device)
                else:
                    next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)


            # append the newly generated token
            for i in range(bsz):
                out_ids[i].append(int(next_token[i].item()))


        return out_ids

# ----------------------------
# TextDataset class for structured data handling
# ----------------------------
class TextDataset(Dataset):
    def __init__(self, texts: list, tokenizer):
        self.texts = texts
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]

# Modified to accept and use the tokenizer instance
def collate_batch(batch_texts: List[str], tokenizer):
    # Encode each text, add BOS/EOS, and convert to tensor
    encoded = []
    for t in batch_texts:
        ids = [tokenizer.bos_token_id] + tokenizer.encode(t) + [tokenizer.eos_token_id]
        encoded.append(torch.tensor(ids, dtype=torch.long))

    # Pad to max len within the batch
    max_len = max([x.size(0) for x in encoded])
    padded_batch = torch.full((len(encoded), max_len), tokenizer.pad_token_id, dtype=torch.long)
    for i, x in enumerate(encoded):
        padded_batch[i, :x.size(0)] = x

    # Prepare inputs and targets for language modeling
    inputs = padded_batch[:, :-1]
    targets = padded_batch[:, 1:]

    return inputs, targets


# Modified to use AdvancedTokenizer and handle multiple layers, and DataLoaders
def train_demo():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize configuration
    cfg = Config()

    corpus_file = cfg.corpus_file
    texts = []
    if os.path.exists(corpus_file):
        print(f"Loading corpus from {corpus_file}...")
        with open(corpus_file, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i < cfg.max_corpus_lines:
                    texts.append(line.strip())
                else:
                    break
        print(f"Loaded {len(texts)} lines.")
    else:
        print(f"Corpus file not found at {corpus_file}. Using a small toy corpus instead.")
        texts = [
            "hello world this is ars",
            "the system adapts to its history",
            "benpolyseq demonstrates adaptive sequences",
            "ars can inspire new network protocols",
            "self optimizing systems are possible"
        ]

    tokenizer = AdvancedTokenizer() # Initialize AdvancedTokenizer
    vocab_size = len(tokenizer)

    # Modified Data Splitting Logic
    num_samples = len(texts)
    if num_samples < 3:
        print(f"Warning: Corpus size ({num_samples}) is too small for a proper train/validation/test split. Using all data for training.")
        train_texts = texts
        val_texts = []
        test_texts = []
    else:
        # First split: train_val and test
        train_val_texts, test_texts = train_test_split(texts, test_size=cfg.train_test_split_ratio, random_state=cfg.random_state)
        # Defensive check for test_texts being empty
        if not test_texts and train_val_texts:
            test_texts.append(train_val_texts.pop()) # Move one item to test if test is empty and train_val is not

        # Second split: train and validation from train_val
        if len(train_val_texts) >= 2: # Need at least two samples to split into train and val
            val_split_size = cfg.train_val_split_ratio
            train_texts, val_texts = train_test_split(train_val_texts, test_size=val_split_size, random_state=cfg.random_state)
            # Defensive check for val_texts being empty
            if not val_texts and train_texts:
                val_texts.append(train_texts.pop()) # Move one item to val if val is empty and train is not
        else:
            train_texts = train_val_texts
            val_texts = []

    print(f"Train size: {len(train_texts)}, Validation size: {len(val_texts)}, Test size: {len(test_texts)}")

    # Create datasets
    train_dataset = TextDataset(train_texts, tokenizer)
    val_dataset = TextDataset(val_texts, tokenizer)
    test_dataset = TextDataset(test_texts, tokenizer)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True,
                                collate_fn=lambda b: collate_batch(b, tokenizer))
    val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False,
                              collate_fn=lambda b: collate_batch(b, tokenizer))
    test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False,
                               collate_fn=lambda b: collate_batch(b, tokenizer))

    model = ARSLM(tokenizer, emb_dim=cfg.emb_dim, hidden_dim=cfg.hidden_dim, num_layers=cfg.num_layers, dropout_prob=cfg.dropout_prob).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate)
    scheduler = StepLR(optimizer, step_size=cfg.lr_scheduler_step_size, gamma=cfg.lr_scheduler_gamma)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    n_epochs = cfg.n_epochs
    if len(train_texts) == 0: # Check if train_texts is empty, not inputs.numel()
        print("No training data available. Skipping training.")
        return model, tokenizer

    # Initialize best validation loss for saving the best model
    best_val_loss = float('inf')
    # Create directories for saving if they don't exist
    os.makedirs(cfg.model_save_dir, exist_ok=True)
    os.makedirs(cfg.tokenizer_save_dir, exist_ok=True)

    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        num_batches = len(train_loader)
        if num_batches > 0:
            for batch_idx, (inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                logits, gates = model(inputs)
                b, seq, v = logits.shape
                loss = loss_fn(logits.view(b*seq, v), targets.reshape(b*seq))
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad_norm)
                optimizer.step()
                total_loss += loss.item()
            avg_train_loss = total_loss / num_batches
        else:
            avg_train_loss = float('inf') # Set to inf if no training batches

        scheduler.step()

        if (epoch+1) % 50 == 0 or epoch==0:
            print(f"Epoch {epoch+1}/{n_epochs} — Avg Train Loss: {avg_train_loss:.4f}")
            print(f"Current learning rate: {scheduler.get_last_lr()[0]:.6f}")

            # Validation step with check for empty loader
            if len(val_loader) > 0:
                model.eval()
                total_val_loss = 0
                with torch.no_grad():
                    for inputs_val, targets_val in val_loader:
                        inputs_val, targets_val = inputs_val.to(device), targets_val.to(device)
                        logits_val, _ = model(inputs_val)
                        b_val, seq_val, v_val = logits_val.shape
                        loss_val = loss_fn(logits_val.view(b_val*seq_val, v_val), targets_val.reshape(b_val*seq_val))
                        total_val_loss += loss_val.item()
                avg_val_loss = total_val_loss / len(val_loader)
                print(f"Validation Loss: {avg_val_loss:.4f}")

                # Save the model if this is the best validation loss so far
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    print(f"Saving best model with validation loss: {best_val_loss:.4f}")
                    torch.save(model.state_dict(), os.path.join(cfg.model_save_dir, "best_model.pt"))
                    tokenizer.tokenizer.save_pretrained(cfg.tokenizer_save_dir)
                    # Save the configuration as well
                    torch.save(cfg, cfg.config_save_path)

            else:
                print("No validation data available. Skipping validation.")
                avg_val_loss = float('inf')

    print("\n=== Evaluation ===")
    model.eval()
    total_test_loss = 0
    total_perplexity = 0
    if len(test_loader) > 0:
        with torch.no_grad():
            for inputs_test, targets_test in test_loader:
                inputs_test, targets_test = inputs_test.to(device), targets_test.to(device)
                logits_test, _ = model(inputs_test)
                b_test, seq_test, v_test = logits_test.shape
                loss_test = loss_fn(logits_test.view(b_test*seq_test, v_test), targets_test.reshape(b_test*seq_test))
                total_test_loss += loss_test.item()
                total_perplexity += torch.exp(loss_test).item()

        avg_test_loss = total_test_loss / len(test_loader)
        avg_test_perplexity = total_perplexity / len(test_loader)
        print(f"Test Loss: {avg_test_loss:.4f}")
        print(f"Perplexity on test data: {avg_test_perplexity:.4f}")
    else:
        print("No test data available. Skipping test evaluation.")
        avg_test_loss = float('inf')
        avg_test_perplexity = float('inf')

    # demo generation
    context = cfg.generation_context
    idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(context) ], dtype=torch.long, device=device)
    out_ids = model.generate(idx, max_new_tokens=cfg.max_new_tokens, temperature=cfg.temperature, top_k=cfg.top_k)[0]
    print("\n=== Generated ===")
    print(tokenizer.decode(out_ids))
    return model, tokenizer

def load_model_and_tokenizer(model_path: str, tokenizer_path: str, config_path: str, device):
    """
    Loads a pre-trained ARSLM model and its tokenizer along with its configuration.
    """
    # Add Config class to safe globals for unpickling
    torch.serialization.add_safe_globals([Config])
    # Load the configuration first
    cfg = torch.load(config_path)

    loaded_tokenizer = AdvancedTokenizer(pretrained_model_name_or_path=tokenizer_path)

    # Instantiate the model using the loaded configuration
    loaded_model = ARSLM(loaded_tokenizer,
                         emb_dim=cfg.emb_dim,
                         hidden_dim=cfg.hidden_dim,
                         num_layers=cfg.num_layers,
                         dropout_prob=cfg.dropout_prob).to(device)

    loaded_model.load_state_dict(torch.load(model_path, map_location=device))
    loaded_model.eval() # Set model to evaluation mode

    print(f"Model loaded from {model_path}")
    print(f"Tokenizer loaded from {tokenizer_path}")
    print(f"Configuration loaded from {config_path}")

    return loaded_model, loaded_tokenizer, cfg

def deploy_model(model, tokenizer, config, device, input_text: str):
    """
    Function to deploy the model for inference (text generation).
    """
    model.eval()
    with torch.no_grad():
        idx = torch.tensor([ [tokenizer.bos_token_id] + tokenizer.encode(input_text) ], dtype=torch.long, device=device)
        out_ids = model.generate(idx, max_new_tokens=config.max_new_tokens, temperature=config.temperature, top_k=config.top_k)[0]
        generated_text = tokenizer.decode(out_ids)
    return generated_text

# ----------------------------
# Run demo if main
# ----------------------------
if __name__ == "__main__":
    random.seed(0)
    torch.manual_seed(0)
    model, tokenizer = train_demo()

    # Demonstrate loading the saved model and tokenizer
    print("\n=== Demonstrating Model Loading and Deployment ===")
    cfg_instance = Config()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    try:
        loaded_model, loaded_tokenizer, loaded_config = load_model_and_tokenizer(
            os.path.join(cfg_instance.model_save_dir, "best_model.pt"),
            cfg_instance.tokenizer_save_dir,
            cfg_instance.config_save_path,
            device
        )
        # Test generation with loaded model using the deploy_model function
        context = "loaded model generates"
        generated_output = deploy_model(loaded_model, loaded_tokenizer, loaded_config, device, context)
        print("Generated with loaded model:")
        print(generated_output)
    except FileNotFoundError:
        print("No saved model found. Please ensure training completed successfully and saved a model.")