### README



**Model Architecture and Training for ListOps using BigBird**

**Overview**  
This project fine-tunes a [BigBird](https://arxiv.org/abs/2007.14062) model on the ListOps dataset. The ListOps task involves evaluating hierarchical list expressions containing operations like `MIN`, `MAX`, `MED`, and `SM` over sequences of digits. Our goal is to predict the correct integer output (0-9) for each given expression.

**Dataset**  
The dataset is derived from the Long-Range-Arena (LRA) benchmark's ListOps task. It consists of expressions and their corresponding integer results. We use three splits:
- **Training Set**
- **Validation Set**
- **Test Set**

Each expression is tokenized into a custom vocabulary including special tokens (e.g. `[MIN`, `[MAX`, etc.), parentheses, and digits `0-9`. We pad these sequences to the maximum length in the batch and create an attention mask for the model.

**Model**  
We use a `BigBirdForSequenceClassification` model from the Hugging Face `transformers` library. BigBird employs sparse attention mechanisms with random and block-sparse attention patterns, allowing it to handle longer sequences efficiently compared to standard Transformers.

- **Vocabulary Size:** Determined by the set of operators, digits, and special tokens in the dataset.
- **Model Configuration:**
  - **Hidden Size:** 256  
  - **Number of Layers:** 4  
  - **Number of Attention Heads:** 4  
  - **Intermediate Size:** 1024  
  - **Attention Type:** Block-sparse with random blocks  
  - **Max Sequence Length:** Dynamically set to accommodate the longest expression in the dataset.

These parameters can be adjusted depending on available compute resources and desired accuracy.

**Training Procedure**  
We train the model using the AdamW optimizer with a learning rate of `1e-4` for several epochs (e.g., 3). During training:
- We feed batches of tokenized expressions and targets into the model.
- The model outputs classification logits, and the loss is computed against the true target class.
- We backpropagate the loss to update model parameters.

We monitor training and validation loss and accuracy at the end of each epoch to ensure the model generalizes well.

**Evaluation**  
After training, we evaluate the model on the test set. The final accuracy on the test set gives an indication of how well the model has learned to solve the ListOps task.

**Checkpoints**  
Trained model weights are saved in the `model_checkpoints` directory. They can be reloaded later for further analysis or fine-tuning.

---

**Running the Code**  
1. Place the dataset files (`basic_train.tsv`, `basic_val.tsv`, `basic_test.tsv`) in the `data_dir` location specified in the script.
2. Run the dataset loading script to generate `train_loader`, `val_loader`, and `test_loader`.
3. Run the training script. If a GPU is available, the model will train on GPU.



### code:

## Importing librairies

In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
import sys
import numpy as np
import os
sys.path.append(os.path.abspath('../data'))
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence, pad_sequence
import re
import torch.optim as optim
from torch.nn.functional import cross_entropy
from transformers import BigBirdConfig, BigBirdForSequenceClassification
from torch.cuda.amp import autocast, GradScaler



## Dataset loading

In [2]:

def tokenize(expression):
    """Convert expression string to tokens, preserving operators."""
    # Replace parentheses with spaces
    expr = expression.replace('(', ' ').replace(')', ' ')

    # Add spaces around brackets that aren't part of operators
    expr = re.sub(r'\[(?!(MIN|MAX|MED|SM))', ' [ ', expr)
    expr = expr.replace(']', ' ] ')

    # Split and filter empty strings
    return [token for token in expr.split() if token]

class ListOpsDataset(Dataset):
    def __init__(self, X, y):
        """
        Args:
            X: Array of source expressions
            y: Array of target values
        """
        self.X = X
        self.y = y

        # Create vocabulary from operators and digits
        self.vocab = {
            'PAD': 0,  # Padding token
            '[MIN': 1,
            '[MAX': 2,
            '[MED': 3,
            '[SM': 4,
            ']': 5,
            '(': 6,
            ')': 7
        }
        # Add digits 0-9
        for i in range(10):
            self.vocab[str(i)] = i + 8

    def __len__(self):
        return len(self.X)

    def tokenize(self, expr):
        """Convert expression to token IDs."""
        tokens = tokenize(expr)  # Using our previous tokenize function
        return [self.vocab.get(token, 0) for token in tokens]

    def __getitem__(self, idx):
        expr = self.X[idx]
        target = self.y[idx]

        # Convert to token IDs without padding or truncating
        token_ids = self.tokenize(expr)

        return {
            'input_ids': torch.tensor(token_ids, dtype=torch.long),
            'target': torch.tensor(target, dtype=torch.long)
        }

In [3]:

# Define the data directory and file paths
data_dir = '/content/drive/MyDrive/LongListOps/data/output_dir'
train_file = os.path.join(data_dir, 'basic_train.tsv')
val_file = os.path.join(data_dir, 'basic_val.tsv')
test_file = os.path.join(data_dir, 'basic_test.tsv')

def load_listops_data(file_path, max_rows=None):
    """
    Load ListOps data from TSV file.

    Args:
        file_path: Path to the TSV file
        max_rows: Maximum number of rows to load (for testing)

    Returns:
        sources: Array of source expressions
        targets: Array of target values (0-9)
    """
    sources = []
    targets = []

    with open(file_path, 'r', encoding='utf-8') as f:
        next(f)  # Skip header (Source, Target)
        for i, line in enumerate(f):
            if max_rows and i >= max_rows:
                break
            if not line.strip():  # Skip empty lines
                continue
            parts = line.strip().split('\t')
            if len(parts) != 2:
                continue  # Skip lines that don't have exactly two columns
            source, target = parts
            sources.append(source)
            targets.append(int(target))  # Target is always 0-9

    # Convert to numpy arrays
    source_array = np.array(sources, dtype=object)  # Keep expressions as strings
    target_array = np.array(targets, dtype=np.int32)  # Targets are integers

    return source_array, target_array

try:
    # Load training data
    print("Loading training data...")
    X_train, y_train = load_listops_data(train_file)

    # Load validation data
    print("Loading validation data...")
    X_val, y_val = load_listops_data(val_file)

    # Load test data
    print("Loading test data...")
    X_test, y_test = load_listops_data(test_file)

    # Print dataset statistics
    print("\nDataset sizes:")
    print(f"Training: {len(X_train)} examples")
    print(f"Validation: {len(X_val)} examples")
    print(f"Test: {len(X_test)} examples")

except Exception as e:
    print(f"Error occurred: {type(e).__name__}: {str(e)}")

Loading training data...
Loading validation data...
Loading test data...

Dataset sizes:
Training: 96000 examples
Validation: 2000 examples
Test: 2000 examples


In [4]:

def collate_fn(batch):
    # Separate sequences and targets
    sequences = [item['input_ids'] for item in batch]
    targets = [item['target'] for item in batch]

    # Get lengths of each sequence
    lengths = torch.tensor([len(seq) for seq in sequences], dtype=torch.long, device=sequences[0].device)

    # Sort sequences by length in descending order for pack_padded_sequence
    lengths, sort_idx = lengths.sort(descending=True)
    sequences = [sequences[i] for i in sort_idx]
    targets = [targets[i] for i in sort_idx]

    # Pad sequences
    padded_sequences = pad_sequence(sequences, batch_first=True)

    # Convert targets to tensor
    targets = torch.stack(targets)

    return {
        'input_ids': padded_sequences,
        'target': targets,
        'lengths': lengths
    }

# Create datasets
train_dataset = ListOpsDataset(X_train, y_train)
val_dataset = ListOpsDataset(X_val, y_val)
test_dataset = ListOpsDataset(X_test, y_test)

# Create dataloaders with collate_fn
batch_size = 32
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    collate_fn=collate_fn
)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    collate_fn=collate_fn
)

# Verify the data
print("Dataset sizes:")
print(f"Train: {len(train_dataset)}")
print(f"Val: {len(val_dataset)}")
print(f"Test: {len(test_dataset)}")

# Check first batch
batch = next(iter(train_loader))
print("\nFirst batch shape:")
print(f"Input IDs: {batch['input_ids'].shape}")
print(f"Targets: {batch['target'].shape}")
print(f"Sequence lengths: {batch['lengths']}")

Dataset sizes:
Train: 96000
Val: 2000
Test: 2000

First batch shape:
Input IDs: torch.Size([32, 1875])
Targets: torch.Size([32])
Sequence lengths: tensor([1875, 1778, 1722, 1612, 1455, 1445, 1407, 1345, 1218, 1168, 1000,  988,
         974,  974,  944,  939,  927,  913,  886,  846,  828,  825,  739,  693,
         692,  679,  644,  633,  612,  552,  526,  524])


## Model definition and training

In [5]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

############################################
# Preprocessing: Truncate Sequences to speed up training
max_length = 512  # reduce this if original sequences are very long

def truncate_batch(batch):
    # Truncate sequences longer than max_length
    batch_input = batch['input_ids']
    if batch_input.size(1) > max_length:
        batch['input_ids'] = batch_input[:, :max_length]
    return batch

############################################
# Recreate DataLoaders with truncated sequences

batch_size = 16  # Slightly larger than before, but still small enough for memory.

def faster_collate_fn(batch):
    # Same as collate_fn, but now we truncate after padding
    sequences = [item['input_ids'] for item in batch]
    targets = [item['target'] for item in batch]
    lengths = torch.tensor([len(seq) for seq in sequences], dtype=torch.long, device=sequences[0].device)

    # Sort by length (descending)
    lengths, sort_idx = lengths.sort(descending=True)
    sequences = [sequences[i] for i in sort_idx]
    targets = [targets[i] for i in sort_idx]

    padded_sequences = nn.utils.rnn.pad_sequence(sequences, batch_first=True)
    targets = torch.stack(targets)

    out_batch = {
        'input_ids': padded_sequences,
        'target': targets,
        'lengths': lengths
    }

    # Truncate here before returning
    out_batch = truncate_batch(out_batch)
    return out_batch

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=faster_collate_fn,
    pin_memory=True,
    num_workers=0
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    collate_fn=faster_collate_fn,
    pin_memory=True,
    num_workers=0
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    collate_fn=faster_collate_fn,
    pin_memory=True,
    num_workers=0
)

############################################
# Configure a much smaller BigBird model

vocab_size = len(train_dataset.vocab)

config = BigBirdConfig(
    vocab_size=vocab_size,
    hidden_size=64,            # Much smaller hidden size
    num_hidden_layers=2,       # Fewer layers
    num_attention_heads=2,     # Fewer heads
    intermediate_size=256,     # Smaller intermediate size
    max_position_embeddings=max_length,
    num_labels=10,
    attention_type="block_sparse",
    block_size=64,             # Larger block size, potentially fewer blocks to process
    num_random_blocks=1        # Fewer random blocks reduces complexity
)

model = BigBirdForSequenceClassification(config)
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler()

############################################
# Training and Evaluation Functions

def train_epoch(model, dataloader, optimizer):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(device, non_blocking=True)
        targets = batch['target'].to(device, non_blocking=True)
        attention_mask = (input_ids != train_dataset.vocab['PAD']).long()

        optimizer.zero_grad(set_to_none=True)
        with autocast():
            outputs = model(input_ids, attention_mask=attention_mask, labels=targets)
            loss = outputs.loss
            logits = outputs.logits

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item() * input_ids.size(0)
        preds = logits.argmax(dim=-1)
        total_correct += (preds == targets).sum().item()
        total_samples += input_ids.size(0)

        # Free memory
        del input_ids, targets, attention_mask, outputs, loss, logits, preds
        torch.cuda.empty_cache()

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    return avg_loss, accuracy


def evaluate(model, dataloader):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    with torch.no_grad(), autocast():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            targets = batch['target'].to(device, non_blocking=True)
            attention_mask = (input_ids != train_dataset.vocab['PAD']).long()

            outputs = model(input_ids, attention_mask=attention_mask, labels=targets)
            loss = outputs.loss
            logits = outputs.logits

            total_loss += loss.item() * input_ids.size(0)
            preds = logits.argmax(dim=-1)
            total_correct += (preds == targets).sum().item()
            total_samples += input_ids.size(0)

            del input_ids, targets, attention_mask, outputs, loss, logits, preds
            torch.cuda.empty_cache()

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    return avg_loss, accuracy

############################################
# Training Loop

num_epochs = 8  # Try fewer epochs to check speed first.

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer)
    val_loss, val_acc = evaluate(model, val_loader)
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss:   {val_loss:.4f}, Val Acc:   {val_acc:.4f}")

test_loss, test_acc = evaluate(model, test_loader)
print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

output_dir = './model_checkpoints'
os.makedirs(output_dir, exist_ok=True)
model.save_pretrained(output_dir)


Using device: cuda


  scaler = GradScaler()


Epoch 1/8


  with autocast():
Training: 100%|██████████| 6000/6000 [06:20<00:00, 15.79it/s]
  with torch.no_grad(), autocast():
Evaluating: 100%|██████████| 125/125 [00:02<00:00, 43.13it/s]


Train Loss: 1.7564, Train Acc: 0.3582
Val Loss:   1.7261, Val Acc:   0.3600
Epoch 2/8


Training: 100%|██████████| 6000/6000 [06:14<00:00, 16.02it/s]
Evaluating: 100%|██████████| 125/125 [00:02<00:00, 42.72it/s]


Train Loss: 1.7332, Train Acc: 0.3600
Val Loss:   1.7284, Val Acc:   0.3650
Epoch 3/8


Training: 100%|██████████| 6000/6000 [06:13<00:00, 16.08it/s]
Evaluating: 100%|██████████| 125/125 [00:03<00:00, 33.35it/s]


Train Loss: 1.7331, Train Acc: 0.3586
Val Loss:   1.7246, Val Acc:   0.3630
Epoch 4/8


Training: 100%|██████████| 6000/6000 [06:20<00:00, 15.79it/s]
Evaluating: 100%|██████████| 125/125 [00:02<00:00, 44.98it/s]


Train Loss: 1.7244, Train Acc: 0.3594
Val Loss:   1.7175, Val Acc:   0.3560
Epoch 5/8


Training: 100%|██████████| 6000/6000 [06:14<00:00, 16.02it/s]
Evaluating: 100%|██████████| 125/125 [00:03<00:00, 35.99it/s]


Train Loss: 1.7114, Train Acc: 0.3609
Val Loss:   1.7034, Val Acc:   0.3650
Epoch 6/8


Training: 100%|██████████| 6000/6000 [06:19<00:00, 15.82it/s]
Evaluating: 100%|██████████| 125/125 [00:02<00:00, 44.45it/s]


Train Loss: 1.6993, Train Acc: 0.3592
Val Loss:   1.6918, Val Acc:   0.3650
Epoch 7/8


Training: 100%|██████████| 6000/6000 [06:13<00:00, 16.07it/s]
Evaluating: 100%|██████████| 125/125 [00:02<00:00, 41.95it/s]


Train Loss: 1.6929, Train Acc: 0.3599
Val Loss:   1.6844, Val Acc:   0.3620
Epoch 8/8


Training: 100%|██████████| 6000/6000 [06:06<00:00, 16.35it/s]
Evaluating: 100%|██████████| 125/125 [00:02<00:00, 45.70it/s]


Train Loss: 1.6887, Train Acc: 0.3619
Val Loss:   1.6820, Val Acc:   0.3615


Evaluating: 100%|██████████| 125/125 [00:03<00:00, 34.73it/s]

Test Loss: 1.6577, Test Acc: 0.3700



