# DIMBA Training on Google Colab/Kaggle

This notebook trains DIMBA (Diffusion-based Mamba) on GPU with:
- mamba2-ssm optimized library
- BPE tokenizer
- HuggingFace datasets

Works on Google Colab and Kaggle with T4/A100 GPUs.

## Setup & Installation

In [None]:
# Install dependencies
!pip install -q torch pytorch-lightning transformers datasets
!pip install -q tokenizers  # For BPE tokenizer
!pip install -q causal-conv1d  # Required for mamba-ssm
!pip install -q mamba-ssm  # Optimized Mamba library

print("✓ Dependencies installed")

In [None]:
# Clone or mount repo
import os
from pathlib import Path

# For Colab: mount Google Drive
try:
    from google.colab import drive
    drive.mount('/content/drive')
    repo_path = '/content/drive/MyDrive/dimba-lib-exp'  # Adjust path as needed
    IS_COLAB = True
except ImportError:
    # Kaggle or local
    repo_path = '/kaggle/input/dimba-lib-exp' if os.path.exists('/kaggle') else '.'
    IS_COLAB = False

os.chdir(repo_path)
print(f"Working directory: {os.getcwd()}")
print(f"Is Colab: {IS_COLAB}")

In [None]:
# Add src to path and import
import sys
sys.path.insert(0, 'src')

import torch
import yaml
from dimba import DIMBA
from dimba.tokenizers import BPETokenizer, SimpleCharacterTokenizer

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Configuration & Setup

In [None]:
# Load config
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Config loaded:")
print(f"  Model: d_model={config['model']['d_model']}, layers={config['model']['num_denoiser_layers']}")
print(f"  Data: {config['data']['type']}")
print(f"  Tokenizer: {config['tokenizer']['type']}, vocab_size={config['tokenizer']['vocab_size']}")
print(f"  Training: lr={config['training']['learning_rate']}, epochs={config['training']['num_epochs']}")

In [None]:
# Setup device
device_config = config.get('device', {})
use_gpu = device_config.get('use_gpu', True) and torch.cuda.is_available()
device = 'cuda' if use_gpu else 'cpu'

if use_gpu:
    torch.cuda.set_per_process_memory_fraction(0.8)  # Use 80% of GPU memory
    if device_config.get('benchmark', True):
        torch.backends.cudnn.benchmark = True

print(f"Device: {device}")
if use_gpu:
    print(f"  Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

## Create Tokenizer

In [None]:
tokenizer_type = config['tokenizer']['type']
vocab_size = config['tokenizer']['vocab_size']

# Create appropriate tokenizer
if tokenizer_type == 'bpe':
    print(f"Creating BPE tokenizer (vocab_size={vocab_size})...")
    try:
        # Try to load existing tokenizer
        tokenizer = BPETokenizer(vocab_size=vocab_size)
        print("  Note: Using untrained tokenizer. For real training, train on your dataset first.")
    except Exception as e:
        print(f"  Creating new BPE tokenizer: {e}")
        tokenizer = BPETokenizer(vocab_size=vocab_size)
else:
    print(f"Creating SimpleCharacterTokenizer (vocab_size={vocab_size})...")
    tokenizer = SimpleCharacterTokenizer(vocab_size=vocab_size)

print(f"✓ Tokenizer created")
print(f"  vocab_size: {tokenizer.vocab_size}")
print(f"  pad_token_id: {tokenizer.pad_token_id}")
print(f"  unk_token_id: {tokenizer.unk_token_id}")

## Create Model

In [None]:
# Create DIMBA model
model_config = config['model']
model = DIMBA(
    vocab_size=vocab_size,
    d_model=model_config['d_model'],
    d_prompt=model_config['d_prompt'],
    num_diffusion_steps=model_config['num_diffusion_steps'],
    num_denoiser_layers=model_config['num_denoiser_layers'],
    d_state=model_config['d_state'],
    d_conv=model_config['d_conv'],
    expand=model_config['expand'],
    conditioning_type=model_config['conditioning_type'],
    dropout=model_config['dropout'],
    use_weight_tying=model_config['use_weight_tying'],
    use_simple_mamba=model_config.get('use_simple_mamba', False),  # Use mamba-ssm on GPU
)

model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"✓ Model created")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
if use_gpu:
    print(f"  GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

## Create Data

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset # Directly import load_dataset
import importlib
import dimba.data.dataset # For collate_fn if needed
from dimba.data import DummyDataset, collate_fn # Keep collate_fn, potentially DummyDataset

data_config = config['data']
data_type = data_config['type']
batch_size = data_config['batch_size']
max_length = data_config['max_length']

print(f"Creating {data_type} dataset...")

if data_type == 'dummy':
    train_dataset = DummyDataset(
        size=data_config.get('num_examples', 1000),
        vocab_size=vocab_size,
        seq_length=max_length,
    )
    val_dataset = DummyDataset(
        size=data_config.get('num_examples', 1000) // 10,
        vocab_size=vocab_size,
        seq_length=max_length,
    )
elif data_type == 'huggingface':
    dataset_name_from_config = data_config.get('dataset_name', 'wikitext')
    dataset_config_from_config = data_config.get('dataset_config', 'wikitext-2-raw-v1') # Default to wikitext-2-raw-v1

    print(f"Loading Hugging Face dataset: {dataset_name_from_config} with config: {dataset_config_from_config}")

    # Load raw datasets directly using datasets.load_dataset
    raw_train_dataset = load_dataset(dataset_name_from_config, name=dataset_config_from_config, split='train', streaming=False)
    raw_val_dataset = load_dataset(dataset_name_from_config, name=dataset_config_from_config, split='validation', streaming=False)

    # Limit number of examples if specified
    num_train_examples = data_config.get('num_examples', 10000)
    if num_train_examples:
        raw_train_dataset = raw_train_dataset.select(range(num_train_examples))

    num_val_examples = 1000
    if num_val_examples:
        raw_val_dataset = raw_val_dataset.select(range(num_val_examples))

    # Define a simple wrapper for tokenization and compatibility with DataLoader
    class CustomHuggingFaceDataset(Dataset):
        def __init__(self, hf_dataset, tokenizer, max_length, text_column='text'):
            self.hf_dataset = hf_dataset
            self.tokenizer = tokenizer
            self.max_length = max_length
            self.text_column = text_column

        def __len__(self):
            return len(self.hf_dataset)

        def __getitem__(self, idx):
            example = self.hf_dataset[idx]
            text = example.get(self.text_column)
            
            # Handle potential list of lists for wikitext-2-raw-v1 and None values
            if isinstance(text, list):
                text = ' '.join([str(item) for item in text if item is not None])
            elif text is None:
                text = "" # Ensure text is a string for tokenizer

            tokenized = self.tokenizer.encode(
                text,
                padding=False, # Let collate_fn handle padding
                truncation=True,
                max_length=self.max_length,
                add_special_tokens=True
            )
            # Handle both list and BatchEncoding returns
            if isinstance(tokenized, list):
                input_ids = tokenized
            else:
                input_ids = tokenized['input_ids']
            
            return {'input_ids': input_ids}

    # Instantiate custom datasets
    train_dataset = CustomHuggingFaceDataset(raw_train_dataset, tokenizer, max_length)
    val_dataset = CustomHuggingFaceDataset(raw_val_dataset, tokenizer, max_length)

# DataLoader creation remains the same, as long as train_dataset and val_dataset are Dataset objects
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=use_gpu,
    collate_fn=collate_fn,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    num_workers=0,
    pin_memory=use_gpu,
    collate_fn=collate_fn,
)

print(f"✓ Datasets created")
print(f"  Train samples: {len(train_dataset)}")
print(f"  Val samples: {len(val_dataset)}")
print(f"  Batch size: {batch_size}")

## Setup Training

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR

train_config = config['training']

# Create optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=train_config['learning_rate'],
    weight_decay=train_config['weight_decay'],
)

# Create scheduler
num_epochs = train_config['num_epochs']
total_steps = num_epochs * len(train_loader)
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps)

print(f"✓ Optimizer & scheduler created")
print(f"  Learning rate: {train_config['learning_rate']}")
print(f"  Total training steps: {total_steps}")

## Training Loop

In [None]:
import numpy as np
from tqdm.notebook import tqdm

def train_epoch(model, train_loader, optimizer, scheduler, device, epoch, log_interval=100):
    """Train for one epoch."""
    model.train()
    losses = []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    for batch_idx, batch in enumerate(pbar):
        input_ids = batch['input_ids'].to(device)
        
        # Sample random timesteps
        t = torch.randint(0, model.num_diffusion_steps, (input_ids.shape[0],), device=device)
        
        # Forward pass
        x_pred, noise = model(input_ids, t)
        
        # Get embeddings
        x_0 = model.token_embed(input_ids)
        
        # MSE loss
        loss = torch.nn.functional.mse_loss(x_pred, x_0)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        if train_config.get('gradient_clip'):
            torch.nn.utils.clip_grad_norm_(model.parameters(), train_config['gradient_clip'])
        optimizer.step()
        scheduler.step()
        
        losses.append(loss.item())
        
        if batch_idx % log_interval == 0:
            avg_loss = np.mean(losses[-log_interval:])
            pbar.set_postfix({'loss': f'{avg_loss:.4f}'})
    
    return np.mean(losses)

def validate(model, val_loader, device):
    """Validate model."""
    model.eval()
    losses = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            input_ids = batch['input_ids'].to(device)
            
            t = torch.randint(0, model.num_diffusion_steps, (input_ids.shape[0],), device=device)
            x_pred, noise = model(input_ids, t)
            x_0 = model.token_embed(input_ids)
            loss = torch.nn.functional.mse_loss(x_pred, x_0)
            losses.append(loss.item())
    
    return np.mean(losses)

print("Training functions defined")

In [None]:
# Train
num_epochs = train_config['num_epochs']
best_val_loss = float('inf')

for epoch in range(num_epochs):
    train_loss = train_epoch(
        model, train_loader, optimizer, scheduler, device, epoch,
        log_interval=train_config.get('log_interval', 100)
    )
    
    val_loss = validate(model, val_loader, device)
    
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"  Train loss: {train_loss:.4f}")
    print(f"  Val loss: {val_loss:.4f}")
    
    # Save checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint_path = f"./checkpoints/dimba-epoch={epoch:02d}-val_loss={val_loss:.4f}.ckpt"
        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
        torch.save(model.state_dict(), checkpoint_path)
        print(f"  ✓ Saved checkpoint: {checkpoint_path}")

## Save Final Model

In [None]:
# Save final model
final_checkpoint = f"./checkpoints/dimba-final-val_loss={best_val_loss:.4f}.ckpt"
os.makedirs(os.path.dirname(final_checkpoint), exist_ok=True)
torch.save(model.state_dict(), final_checkpoint)
print(f"✓ Saved final model: {final_checkpoint}")

# Save tokenizer
tokenizer_path = "./checkpoints/tokenizer.json"
tokenizer.save(tokenizer_path)
print(f"✓ Saved tokenizer: {tokenizer_path}")

if IS_COLAB:
    print("\nTo download checkpoints:")
    print(f"  1. Go to Google Drive/dimba-lib-exp/checkpoints/")
    print(f"  2. Download .ckpt files")