In [27]:
import torch
from datasets import load_dataset
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import math
import time
import os
import csv
from tqdm import tqdm


In [28]:

class OpenWebTextDataset(Dataset):
    def __init__(self, max_seq_len):
        # Load dataset
        dataset = load_dataset("stas/openwebtext-10k", split="train")

        # Combine all text into a single string
        self.text = " ".join(dataset["text"]).lower().split()
        
        # Create vocabulary
        vocab = sorted(set(self.text))
        self.vocab = {word: idx for idx, word in enumerate(vocab)}
        self.vocab_size = len(self.vocab)
        self.max_seq_len = max_seq_len

        # Convert text to token indices
        self.tokens = [self.vocab[word] for word in self.text]
        self.data = [
            self.tokens[i : i + max_seq_len] for i in range(len(self.tokens) - max_seq_len)
        ]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        x = self.data[idx][:-1]
        y = self.data[idx][1:]
        return torch.tensor(x), torch.tensor(y)



In [29]:
def get_dataloaders(dataset, max_seq_len, batch_size, val_ratio=0.1, test_ratio=0.1):
    #dataset = OpenWebTextDataset(max_seq_len)
    vocab_size = dataset.dataset.vocab_size

    # Déterminer la taille des ensembles
    total_size = len(dataset)
    test_size = int(total_size * test_ratio)
    val_size = int(total_size * val_ratio)
    train_size = total_size - val_size - test_size

    # Split dataset
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

    # Création des DataLoaders avec `shuffle=True` pour le train
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader, vocab_size

In [30]:
def get_device():
    """
    Détecte automatiquement le meilleur device disponible :
    - CUDA pour les GPUs NVIDIA (Windows/Linux)
    - MPS pour les Macs M1/M2
    - CPU par défaut
    """
    if torch.cuda.is_available():
        return torch.device("cuda")  # Windows/Linux NVIDIA
        #return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")  # Mac M1/M2
    else:
        return torch.device("cpu")  # Fallback CPU

        
def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [32]:
class Pattention(nn.Module):
    """Pattention Layer.
    d1 = inputs dimension
    d2 = outpuuts dimension
    n = the number of parameters tokens representing the learnable keys and values
    """

    def __init__(
        self,
        d1,
        d2,
        n,
        param_key_init_method,
        param_value_init_method,
        norm_activation_type,
    ):
        super().__init__()

        self.param_token_num = n
        self.param_key_dim = d1
        self.param_value_dim = d2
        self.norm_activation_type = norm_activation_type
        
        self.key_param_tokens = nn.parameter.Parameter(data=torch.rand((n, d1))) # Kp shape of (n, d1) 
        self.value_param_tokens = nn.parameter.Parameter(data=torch.rand((n, d2))) # Vp shape of (n, d2)
        
        param_key_init_method(self.key_param_tokens)
        param_value_init_method(self.value_param_tokens)
    
    def nonlinear_norm_func(self, inputs, normalize_type, dim=-1):
        if normalize_type == 'softmax': 
            # NOTE: softmax = exp_l1_norm
            # outputs = F.softmax(inputs, dim=dim) * inputs.shape[dim]
            nonlinear_outputs = torch.exp(inputs)
            norm_outputs = nonlinear_outputs / torch.norm(nonlinear_outputs, p=1, dim=dim, keepdim=True) * inputs.shape[dim]
            outputs = norm_outputs
        elif normalize_type == 'gelu_l2_norm':
            nonlinear_outputs = F.gelu(inputs)
            norm_outputs = nonlinear_outputs / torch.norm(nonlinear_outputs, p=2, dim=dim, keepdim=True) * math.sqrt(nonlinear_outputs.shape[dim])
            outputs = norm_outputs
        elif normalize_type == 'l2_norm_gelu':
            norm_outputs = inputs / torch.norm(inputs, p=2, dim=dim, keepdim=True) * math.sqrt(inputs.shape[dim])
            nonlinear_outputs = F.gelu(norm_outputs)
            outputs = nonlinear_outputs
        return outputs

    def forward(self, inputs, dropout_p=0.0, attn_mask=None, scale=None):

        query = inputs
        key, value = self.key_param_tokens, self.value_param_tokens        
        L, S = query.size(-2), key.size(-2)
        scale_factor = 1 if scale is None else scale 
        # just for gelu nonlinear, set torch.zeros for softmax
        attn_bias = torch.ones(L, S, dtype=query.dtype, device=query.device)

        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                # just for gelu nonlinear, set -inf for softmax
                attn_bias.masked_fill_(attn_mask.logical_not(), 0)
            else:
                raise NotImplementedError

        attn_weight = query @ key.transpose(-2, -1) * scale_factor
        # just for gelu nonlinear, set attn_weight += attn_bias for softmax
        attn_weight *= attn_bias
        # modified softmax
        attn_weight = self.nonlinear_norm_func(attn_weight, self.norm_activation_type, dim=-1)
        attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
        output = attn_weight @ value

        return output

In [33]:
class SelfAttention(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, attention_dropout=0.1,token_num=10):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.head_dim = hidden_size // num_attention_heads
        #print(hidden_size)
        #print(num_attention_heads)

        assert hidden_size % num_attention_heads == 0, "hidden_size must be divisible by num_attention_heads"

        # Query, Key, and Value projections
        self.query = Pattention(
            d1=hidden_size,
            d2=hidden_size,
            #n=hidden_size // num_attention_heads,
            n=token_num,
            param_key_init_method=torch.nn.init.xavier_uniform_,
            param_value_init_method=torch.nn.init.xavier_uniform_,
            norm_activation_type="l2_norm_gelu"
        )
        self.key = Pattention(
            d1=hidden_size,
            d2=hidden_size,
            #n=hidden_size // num_attention_heads,
            n=token_num,
            param_key_init_method=torch.nn.init.xavier_uniform_,
            param_value_init_method=torch.nn.init.xavier_uniform_,
            norm_activation_type="l2_norm_gelu"
        )
        self.value = Pattention(
            d1=hidden_size,
            d2=hidden_size,
            #n=hidden_size // num_attention_heads,
            n=token_num,
            param_key_init_method=torch.nn.init.xavier_uniform_,
            param_value_init_method=torch.nn.init.xavier_uniform_,
            norm_activation_type="l2_norm_gelu"
        )
        self.out_proj = Pattention(
            d1=hidden_size,
            d2=hidden_size,
            #n=hidden_size,
            n=token_num,
            param_key_init_method=torch.nn.init.xavier_uniform_,
            param_value_init_method=torch.nn.init.xavier_uniform_,
            norm_activation_type="l2_norm_gelu"
        )

        self.attention_dropout = nn.Dropout(attention_dropout)
        self.norm_factor = math.sqrt(self.head_dim)

    def forward(self, hidden_states, attention_mask=None):

        # Proceed with attention mechanism
        batch_size, seq_len, _ = hidden_states.size()

        query_layer = self.query(hidden_states).view(
            batch_size, seq_len, self.num_attention_heads, self.head_dim
        )
        #print("qshape: ", query_layer.shape)
        key_layer = self.key(hidden_states).view(
            batch_size, seq_len, self.num_attention_heads, self.head_dim
        )
        #print("kshape: ", key_layer.shape)
        value_layer = self.value(hidden_states).view(
            batch_size, seq_len, self.num_attention_heads, self.head_dim
        )
        #print("vshape: ", value_layer.shape)
        query_layer = query_layer.transpose(1, 2)
        key_layer = key_layer.transpose(1, 2)
        value_layer = value_layer.transpose(1, 2)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores /= self.norm_factor

        if attention_mask is not None:
            attention_scores += attention_mask

        attention_probs = torch.softmax(attention_scores, dim=-1)
        attention_probs = self.attention_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.hidden_size
        )

        output = self.out_proj(context_layer)
        return output

In [34]:
class TokenformerLayer(nn.Module):
    """A single Tokenformer layer implementing token-token and token-parameter interactions."""

    def __init__(
        self,
        hidden_size,
        vocab_size,
        num_attention_heads,
        max_seq_len,
        attention_dropout=0.1,
        hidden_dropout=0.1,
        token_num=10
    ):
        """
        Args:
            hidden_size (int): The size of the hidden dimension.
            num_attention_heads (int): Number of attention heads for multi-head attention.
            num_param_tokens (int): Number of parameter tokens for the feed-forward Pattention layer.
            attention_dropout (float): Dropout probability for attention weights.
            hidden_dropout (float): Dropout probability for residual connections.
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.head_dim = hidden_size // num_attention_heads

        assert (
            hidden_size % num_attention_heads == 0
        ), "hidden_size must be divisible by num_attention_heads"

        # Layer normalizations
        self.input_layernorm = nn.LayerNorm(hidden_size)
        self.post_attention_layernorm = nn.LayerNorm(hidden_size)

        # Token and positional embeddings
        self.token_embedding = nn.Embedding(vocab_size, hidden_size)
        self.position_embedding = nn.Embedding(max_seq_len, hidden_size)

        # Self-attention using Pattention
        self.attention = self.attention = SelfAttention(
            #vocab_size=30522,  # Provide a valid vocab_size if needed
            hidden_size=hidden_size,
            num_attention_heads=num_attention_heads,
            #max_seq_len=max_seq_len,
            attention_dropout=attention_dropout,
            token_num=token_num
        )

        # Feed-forward network using Pattention
        self.mlp = Pattention(
            d1=hidden_size,
            d2=hidden_size,
            n=token_num,
            param_key_init_method=torch.nn.init.xavier_uniform_,
            param_value_init_method=torch.nn.init.xavier_uniform_,
            norm_activation_type="l2_norm_gelu"
        )

        self.hidden_dropout = hidden_dropout
        self.dropout = nn.Dropout(hidden_dropout)

        # to obtain logits (before softmax) for the vocabulary
        self.lm_head = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, attention_mask=None):
        """
        Forward pass for the Tokenformer layer.
        
        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size].
            attention_mask (torch.Tensor, optional): Attention mask for self-attention.

        Returns:
            torch.Tensor: Output tensor of the same shape as the input.
        """
        #print(f"Input shape before LayerNorm: {x.shape}")
        # Residual connection and pre-normalization for attention
        # Word embedding
        x = self.token_embedding(x)  # [batch_size, seq_len, hidden_size]
        #print("token emb: ", x.shape)
        # Positional embedding
        seq_len = x.size(1)
        position_ids = torch.arange(seq_len, device=x.device).unsqueeze(0)  # [1, seq_len]
        #print("Position: ", position_ids.shape)
        x += self.position_embedding(position_ids)  # Add positional embeddings
        #print( "emb + pos", x.shape)

        residual = x
        normed_input = self.input_layernorm(x)
        #print("norme: ", normed_input.shape)

        # Self-attention
        attention_output = self.attention(normed_input, attention_mask)
        #print("attention: ", attention_output.shape)
        attention_output = self.dropout(attention_output) + residual

        # Residual connection and pre-normalization for feed-forward
        residual = attention_output
        normed_attention_output = self.post_attention_layernorm(attention_output)

        # Feed-forward network (Pattention)
        mlp_output = self.mlp(normed_attention_output)
        #print("feed: ",  mlp_output.shape)
        output = self.dropout(mlp_output) + residual

        # Linear layer for logits
        logits = self.lm_head(output)

        return logits

In [35]:
def estimate_perplexity(model, dataloader, criterion, device):
    """
    Compute the perplexity of a language model.
    """
    model.eval()  # Set the model to evaluation mode.
    total_loss = 0
    total_tokens = 0

    # Disable gradient computation for evaluation.
    with torch.no_grad():
        for x_batch, y_batch in dataloader:
            # Move the batch data to the specified device.
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            # Forward pass through the model.
            outputs = model(x_batch)  # Shape: (batch_size, seq_len, vocab_size)
            outputs = outputs.view(-1, outputs.shape[-1])  # Reshape to (batch_size * seq_len, vocab_size)
            y_batch = y_batch.view(-1)  # Flatten targets to shape: (batch_size * seq_len)

            # Compute the loss for the current batch.
            loss = criterion(outputs, y_batch)
            total_loss += loss.item() * y_batch.size(0)  # Multiply by the number of tokens.
            total_tokens += y_batch.size(0)

    # Calculate the average loss per token.
    avg_loss = total_loss / total_tokens
    # Compute the perplexity by exponentiating the average loss.
    perplexity = torch.exp(torch.tensor(avg_loss))
    return avg_loss, perplexity.item()



In [36]:
def train(use_metrics=False, subset_size = 100_000):
    # Configurations
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")  # use M1 GPU
    else:
        device = torch.device("cpu")  # use CPU if no M1 GPU
    print('Device:', device)

    hidden_dim = 32
    num_heads = 4
    max_seq_len = 16  # Réduction de la taille des séquences
    batch_size = 8  # Batch réduit
    num_epochs = 5  # Moins d'époques pour tester plus vite
    learning_rate = 0.001
    token_num = 32

    # Préparer les données
    dataset = OpenWebTextDataset(max_seq_len=max_seq_len)

    # Réduction du dataset
    dataset = Subset(dataset, range(min(subset_size, len(dataset))))

    vocab_size = dataset.dataset.vocab_size  # Mise à jour du vocab_size

    # Chargement des données
    train_loader, val_loader, test_loader, vocab_size = get_dataloaders(
        dataset, max_seq_len, batch_size, val_ratio, test_ratio
    )

    # Instancier le modèle
    model = TokenformerLayer(hidden_dim, vocab_size, num_heads, max_seq_len, token_num=token_num)
    model = model.to(device)
    print("Model parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

    # Définir la fonction de perte et l'optimiseur
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    print("Début de l'entraînement")
    total_training_time = 0.0
    overall_start_time = time.time()

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        epoch_train_start = time.time()

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

        for x_batch, y_batch in progress_bar:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            outputs = model(x_batch)
            outputs = outputs.view(-1, vocab_size)
            y_batch = y_batch.view(-1)

            loss = criterion(outputs, y_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")

        epoch_training_time = time.time() - epoch_train_start
        total_training_time += epoch_training_time

        if use_metrics:
            eval_start = time.time()
            avg_loss, perplexity = estimate_perplexity(model, val_loader, criterion, device)
            eval_time = time.time() - eval_start
            print(f"Epoch {epoch + 1}/{num_epochs}, "
                  f"Train Loss: {epoch_loss / len(train_loader):.4f}, "
                  f"Eval Loss: {avg_loss:.4f}, Perplexity: {perplexity:.4f}, "
                  f"Train Time: {epoch_training_time:.2f}s, Eval Time: {eval_time:.2f}s")
        else:
            print(f"Epoch {epoch + 1}/{num_epochs}, "
                  f"Train Loss: {epoch_loss / len(train_loader):.4f}, "
                  f"Train Time: {epoch_training_time:.2f}s")

    print(f"Total training time (excluding evaluation): {total_training_time:.2f} seconds")
    # Sauvegarde du modèle après entraînement
    model_save_path = "tokenformer_model.pth"
    torch.save(model.state_dict(), model_save_path)
    print(f"Modèle sauvegardé dans : {model_save_path}")
    return model
    
hidden_dim = 32
num_heads = 4
max_seq_len = 16  # Réduction de la taille des séquences
batch_size = 8  # Batch réduit
num_epochs = 3  # Moins d'époques pour tester plus vite
learning_rate = 0.001
token_num = 32

In [39]:
def train_with_scaling(
    initial_token_num,
    scaling_steps,
    new_tokens_per_step,
    hidden_dim,
    num_heads,
    max_seq_len,
    batch_size,
    num_epochs,
    learning_rate=0.001,
    val_ratio=0.1,
    test_ratio=0.1,
    model_base_name="tokenformer_scaled",
    subset_size = 1000
):
    """Entraîne un modèle TokenFormer en augmentant token_num progressivement tout en conservant les poids."""

    device = get_device()
    print(f"Training on {device}")

    # Préparer les données
    dataset = OpenWebTextDataset(max_seq_len=max_seq_len)

    # Réduction du dataset
    dataset = Subset(dataset, range(min(subset_size, len(dataset))))

    #vocab_size = dataset.dataset.vocab_size  # Mise à jour du vocab_size

    # Chargement des données
    train_loader, val_loader, test_loader, vocab_size = get_dataloaders(
        dataset, max_seq_len, batch_size, val_ratio, test_ratio
    )
    # Chargement des données
    #train_loader, val_loader, test_loader, vocab_size = get_dataloaders(
    #    file_path, max_seq_len, batch_size, val_ratio, test_ratio
    #)
    #print(f"Vocab Size: {vocab_size}")

    # Création des répertoires
    checkpoint_dir = "Saved_Models_Checkpoints"
    result_dir = "Saved_Models_Results"
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(result_dir, exist_ok=True)

    # Initialisation du modèle avec `token_num = initial_token_num`
    model = TokenformerLayer(
        hidden_size=hidden_dim,
        vocab_size=vocab_size,
        num_attention_heads=num_heads,
        max_seq_len=max_seq_len,
        attention_dropout=0.1,
        hidden_dropout=0.1,
        token_num=initial_token_num
    ).to(device)

    print(f"Initial Token Num: {initial_token_num}")
    print(f"Trainable Parameters: {count_trainable_params(model):_}")

    # Définition des composants d'entraînement
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    total_training_time = 0.0

    for step in range(scaling_steps + 1):  # +1 pour inclure l'entraînement initial
        step_results_path = os.path.join(result_dir, f"{model_base_name}_step_{step}.csv")

        # Initialisation du fichier CSV pour ce scaling step
        with open(step_results_path, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["epoch", "train_loss", "val_loss", "val_perplexity", "test_loss", "test_perplexity", "epoch_time", "total_time"])

        if step > 0:
            # Scale token_num et préserve les poids appris
            new_token_num = model.token_num + new_tokens_per_step[step - 1]
            model.scale_token_num(new_token_num)
            print(f"🔼 Scaling step {step}: token_num = {new_token_num}")

        for epoch in range(num_epochs):
            model.train()
            epoch_loss = 0.0
            start_time = time.time()

            # Barre de progression
            progress_bar = tqdm(train_loader, desc=f"Step {step}, Epoch {epoch+1}/{num_epochs}", unit="batch", leave=False)

            for x_batch, y_batch in progress_bar:
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                optimizer.zero_grad()
                outputs = model(x_batch)

                outputs = outputs.view(-1, vocab_size)
                y_batch = y_batch.view(-1)

                loss = criterion(outputs, y_batch)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()

                # Mise à jour de tqdm
                progress_bar.set_postfix(loss=loss.item())

            epoch_time = time.time() - start_time
            total_training_time += epoch_time

            # Évaluation sur validation
            val_loss, val_perplexity = estimate_perplexity(model, val_loader, criterion, device)

            print(f"Step {step} | Epoch {epoch+1}/{num_epochs} | Train Loss: {epoch_loss / len(train_loader):.4f} | "
                  f"Val Loss: {val_loss:.4f} | Val Perplexity: {val_perplexity:.4f} | Time: {epoch_time:.2f}s")

            # Sauvegarde des métriques pour ce scaling step
            with open(step_results_path, mode='a', newline='') as file:
                writer = csv.writer(file)
                writer.writerow([epoch+1, epoch_loss / len(train_loader), val_loss, val_perplexity, None, None, epoch_time, total_training_time])

        # Évaluation finale après le scaling step
        print(f"\n🔍 Evaluation finale après Scaling Step {step}...")
        test_loss, test_perplexity = estimate_perplexity(model, test_loader, criterion, device)

        print(f"Final Evaluation for Step {step} | Test Loss: {test_loss:.4f} | Test Perplexity: {test_perplexity:.4f}")

        # Sauvegarde du modèle après chaque scaling step
        model_checkpoint_path = os.path.join(checkpoint_dir, f"{model_base_name}_step_{step}.pth")
        torch.save(model.state_dict(), model_checkpoint_path)
        print(f"✅ Modèle sauvegardé : {model_checkpoint_path}")

        # Ajout des résultats de test dans le fichier CSV
        with open(step_results_path, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["FINAL", None, None, None, test_loss, test_perplexity, None, total_training_time])

    print(f"Total Training Time: {total_training_time:.2f}s")
    print(f"Final results saved in: {result_dir}")

In [42]:
token_num_init = 10

train_with_scaling(
    initial_token_num=token_num_init,
    scaling_steps=4,  # On va jusqu'à 16384 tokens
    new_tokens_per_step=[token_num_init + 3, token_num_init + 4, token_num_init + 8, token_num_init + 16],  
    hidden_dim=16,
    num_heads=1,
    max_seq_len=32,
    batch_size=32,
    num_epochs=2,
    learning_rate=0.001,
    val_ratio=0.1,
    test_ratio=0.1,
)

Training on mps
Initial Token Num: 10
Trainable Parameters: 13_529_437


                                                            

RuntimeError: MPS backend out of memory (MPS allocated: 6.11 GB, other allocations: 1.52 GB, max allowed: 9.07 GB). Tried to allocate 1.51 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).