# DyLoRA-MoE Colab Training Notebook

End-to-end interactive notebook to fine-tune and extend a Dynamic LoRA Mixture-of-Experts (DyLoRA-MoE) model (Gemma 3 270M) with continual skill ingestion and novelty-triggered expert growth.

Key Features:
- Hugging Face Transformers + PEFT LoRA experts
- Dynamic expert creation upon novel skill detection
- CodeAlpaca (Python subset) + MBPP evaluation
- Sequence packing for token efficiency
- W&B logging (losses, routing metrics, expert usage)
- Optional Google Drive persistence

Before running: (optional) set environment variables:
- HF_TOKEN (if model gated)
- WANDB_API_KEY (for Weights & Biases logging)

Run cells in order. Adjust config in the Configuration cell.

In [None]:
# 1. Environment & Hardware Check
import torch, platform, os, subprocess, sys
print(f"Python: {platform.python_version()}")
print(f"Torch: {torch.__version__}")
try:
    import transformers, peft, datasets
    print('Transformers:', transformers.__version__)
    print('PEFT:', peft.__version__)
    print('Datasets:', datasets.__version__)
except Exception as e:
    print('Library import issue:', e)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Device count:", torch.cuda.device_count())
    for i in range(torch.cuda.device_count()):
        print(f"Device {i}:", torch.cuda.get_device_name(i))
    gpu_props = torch.cuda.get_device_properties(0)
    print("Total VRAM (GB):", round(gpu_props.total_memory/1024**3,2))
else:
    print("If you need a GPU: Runtime > Change runtime type > GPU")

In [None]:
# 2. Install / Upgrade Dependencies (restart kernel if upgrading core libs)
!pip -q install --upgrade pip
!pip -q install transformers peft datasets accelerate bitsandbytes sentencepiece wandb einops
!pip -q install trl
import torch, os
print('Torch:', torch.__version__)
## 3. (Optional) Mount Google Drive for persistence
USE_DRIVE = False  # set True to persist outputs across sessions
if USE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    BASE_DIR = '/content/drive/MyDrive/dylora_moe_run'
else:
    BASE_DIR = '/content/dylora_moe_run'
os.makedirs(BASE_DIR, exist_ok=True)
print('Base directory:', BASE_DIR)

In [None]:
# 4. Configuration & Global Setup
import random, json, time, math, numpy as np, torch
cfg = {
    'model_name': 'google/gemma-3-270m',
    'max_seq_len': 512,
    'pack_sequences': True,
    'learning_rate': 1e-4,
    'weight_decay': 0.0,
    'warmup_ratio': 0.1,
    'num_train_epochs': 3,
    'gradient_accumulation_steps': 4,
    'train_batch_size': 1,
    'eval_batch_size': 1,
    'lora_r': 8,
    'lora_alpha': 16,
    'lora_dropout': 0.1,
    'seed': 42,
    'output_dir': BASE_DIR + '/outputs',
    'logging_steps': 10,
    'eval_steps': 50,
    'save_steps': 50,
    'early_stopping_patience': 5,
}

def set_seed(seed:int):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
set_seed(cfg['seed'])
os.makedirs(cfg['output_dir'], exist_ok=True)
print(json.dumps(cfg, indent=2))

# 5. Imports
import os
from datasets import load_dataset, Dataset, concatenate_datasets
from transformers import (AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling, EarlyStoppingCallback)
import wandb
from tqdm import tqdm
print('Imports loaded.')

# 6. Download Datasets (CodeAlpaca subset + MBPP) and filtering
print('Loading datasets...')
code_alpaca = load_dataset('sahil2801/CodeAlpaca-20k', split='train')
code_alpaca = code_alpaca.filter(lambda e: 'python' in e['instruction'].lower())
code_alpaca = code_alpaca.train_test_split(test_size=0.1, seed=cfg['seed'])
mbpp_full = load_dataset('mbpp', split='test')
mbpp_full = mbpp_full.train_test_split(test_size=0.2, seed=cfg['seed'])
mbpp_val_test = mbpp_full['test'].train_test_split(test_size=0.5, seed=cfg['seed'])
mbpp_dataset = {
    'train': mbpp_full['train'],
    'validation': mbpp_val_test['train'],
    'test': mbpp_val_test['test']
}
print('Sizes -> CodeAlpaca train:', len(code_alpaca['train']), 'val:', len(code_alpaca['test']))
print('MBPP train:', len(mbpp_dataset['train']), 'val:', len(mbpp_dataset['validation']), 'test:', len(mbpp_dataset['test']))

# 7. Tokenizer
hf_token = os.environ.get('HF_TOKEN')
from dylo_moe.model import DyLoRA_MoE
from dylo_moe.utils import print_trainable_parameters

tokenizer = AutoTokenizer.from_pretrained(cfg['model_name'], token=hf_token)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 8. Utility: sequence packing
def pack_sequences(tokenizer, texts, max_length=512):
    input_ids_batches, attn_batches = [], []
    cur = []
    for t in texts:
        ids = tokenizer(t, add_special_tokens=False)['input_ids']
        if len(ids) > max_length: ids = ids[:max_length]
        if len(cur) + len(ids) > max_length:
            if cur:
                pad_len = max_length - len(cur)
                input_ids_batches.append(cur + [tokenizer.pad_token_id]*pad_len)
                attn_batches.append([1]*len(cur) + [0]*pad_len)
            cur = []
        cur.extend(ids)
    if cur:
        pad_len = max_length - len(cur)
        input_ids_batches.append(cur + [tokenizer.pad_token_id]*pad_len)
        attn_batches.append([1]*len(cur) + [0]*pad_len)
    return input_ids_batches, attn_batches

# 9. Preprocessing helpers

def preprocess_eval(dataset, tokenizer):
    def tok_fn(examples):
        if 'text' in examples:
            processed = ['\n'.join(x) for x in examples['text']]
        else:
            processed = [f"{ins}\n{out}" for ins, out in zip(examples['instruction'], examples['output'])]
        return tokenizer(processed, padding='max_length', truncation=True, max_length=cfg['max_seq_len'])
    return dataset.map(tok_fn, batched=True, remove_columns=dataset.column_names)

print('Setup complete.')

In [None]:
# 10. Build DyLoRA-MoE Model
model = DyLoRA_MoE(
    cfg['model_name'],
    num_experts=1,
    lora_r=cfg['lora_r'],
    lora_alpha=cfg['lora_alpha'],
    lora_dropout=cfg['lora_dropout'],
    token=hf_token,
)
print_trainable_parameters(model)

# 11. Prepare Evaluation Sets
python_eval = preprocess_eval(code_alpaca['test'], tokenizer)
mbpp_eval = preprocess_eval(mbpp_dataset['validation'], tokenizer)
combined_eval = concatenate_datasets([
    python_eval.add_column('eval_domain', [0]*len(python_eval)),
    mbpp_eval.add_column('eval_domain', [1]*len(mbpp_eval))
])

# 12. Training Arguments
from transformers import TrainingArguments
training_args = TrainingArguments(
    output_dir=cfg['output_dir'],
    num_train_epochs=cfg['num_train_epochs'],
    per_device_train_batch_size=cfg['train_batch_size'],
    per_device_eval_batch_size=cfg['eval_batch_size'],
    gradient_accumulation_steps=cfg['gradient_accumulation_steps'],
    gradient_checkpointing=True,
    learning_rate=cfg['learning_rate'],
    lr_scheduler_type='cosine',
    warmup_ratio=cfg['warmup_ratio'],
    fp16=torch.cuda.is_available(),
    logging_dir=cfg['output_dir'] + '/logs',
    logging_steps=cfg['logging_steps'],
    logging_strategy='steps',
    eval_strategy='steps',
    eval_steps=cfg['eval_steps'],
    save_strategy='steps',
    save_steps=cfg['save_steps'],
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model='eval_loss',
    greater_is_better=False,
    report_to=['wandb'],
    remove_unused_columns=False,
)

# 13. Data Collator
from transformers import DataCollatorForLanguageModeling
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# 14. Trainer Stub (we'll override train_dataset dynamically)
from transformers import Trainer, EarlyStoppingCallback

def compute_metrics(_):
    return {}

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collator,
    eval_dataset=combined_eval,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=cfg['early_stopping_patience'])]
)

# 15. W&B Init
import wandb
if os.environ.get('WANDB_API_KEY'):
    wandb.init(project='dylora-moe-colab')
else:
    os.environ['WANDB_MODE'] = 'offline'
    wandb.init(project='dylora-moe-colab', mode='offline')

# 16. Skill Streams (initial + synthetic domains)
python_skill = [ex['output'] for ex in code_alpaca['train'].select(range(min(800, len(code_alpaca['train']))))]
requests_skill = [
    "import requests\nresp = requests.get('https://httpbin.org/get')\nprint(resp.status_code)",
    "import requests\nresp = requests.post('https://httpbin.org/post', data={'a':1})\nprint(resp.json())"
]
flask_skill = [
    "from flask import Flask\napp=Flask(__name__)\n@app.route('/')\ndef home():\n    return 'hi'",
    "from flask import Flask, request\napp=Flask(__name__)\n@app.route('/echo', methods=['POST'])\ndef echo():\n    return request.data.decode()"
]
skills = [python_skill, requests_skill, flask_skill]

# 17. Preprocessing function for training skills
def preprocess_train(texts):
    if cfg['pack_sequences']:
        input_ids, attn = pack_sequences(tokenizer, texts, max_length=cfg['max_seq_len'])
        return Dataset.from_dict({'input_ids': input_ids, 'attention_mask': attn, 'labels': input_ids})
    tok = tokenizer(texts, padding=True, truncation=True, max_length=cfg['max_seq_len'])
    return Dataset.from_dict({'input_ids': tok['input_ids'], 'attention_mask': tok['attention_mask'], 'labels': tok['input_ids']})

# 18. Initial Evaluation
print('Initial MBPP validation eval:')
initial_mbpp = trainer.evaluate(mbpp_eval, metric_key_prefix='eval_mbpp')
print(initial_mbpp)
wandb.log({'initial_mbpp_loss': initial_mbpp['eval_mbpp_loss']})

# 19. Continual Learning Loop
last_log_time = time.time(); tokens_processed = 0
for idx, skill in enumerate(skills):
    print(f"\n=== Skill {idx+1}/{len(skills)} ===")
    dataset = preprocess_train(skill)
    # Padding fraction
    pad_tokens = sum(row.count(tokenizer.pad_token_id) for row in dataset['input_ids'])
    total_tokens = len(dataset['input_ids'])*cfg['max_seq_len']
    wandb.log({'padding_fraction': pad_tokens/total_tokens if total_tokens else 0})

    # Novelty detection: feed batches to add_new_skill
    device = trainer.args.device
    is_novel = False
    for row in dataset['input_ids']:
        batch = torch.tensor(row).unsqueeze(0).to(device)
        if model.add_new_skill(batch):
            is_novel = True
    if is_novel:
        print('Novel skill detected -> training expert')
        trainer.train_dataset = dataset
        trainer.train()
        model.router.set_expert_maturity(model.expert_manager.num_experts - 1, 1)
        wandb.log({'num_experts': model.expert_manager.num_experts})
    else:
        print('Skill not novel; skipping training step')

    # Routing metrics (if >1 expert)
    if model.router.num_experts > 1:
        sample = torch.tensor(dataset[0]['input_ids']).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = model.foundation_model(sample, attention_mask=(sample != tokenizer.pad_token_id), output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
            routing_weights = model.router(hidden_states)
            entropy = -(routing_weights * (routing_weights.clamp(min=1e-8).log())).sum(-1).mean().item()
            expert_usage = routing_weights.mean(dim=(0,1)).detach().cpu().tolist()
        wandb.log({'routing_entropy': entropy, **{f'expert_usage_{i}':v for i,v in enumerate(expert_usage)}})

    # Tokens/sec approximate
    for mask in dataset['attention_mask']:
        tokens_processed += sum(mask)
    now = time.time()
    if now - last_log_time >= 30:
        tps = tokens_processed/(now - last_log_time)
        wandb.log({'tokens_per_second': tps})
        tokens_processed = 0; last_log_time = now

    # Per-domain eval each skill
    py_loss = trainer.evaluate(python_eval, metric_key_prefix='eval_python')['eval_python_loss']
    mbpp_loss = trainer.evaluate(mbpp_eval, metric_key_prefix='eval_mbpp')['eval_mbpp_loss']
    wandb.log({'eval_python_loss': py_loss, 'eval_mbpp_loss': mbpp_loss})

# 20. Final Evaluation on MBPP Test
mbpp_test = preprocess_eval(mbpp_dataset['test'], tokenizer)
final_mbpp = trainer.evaluate(mbpp_test, metric_key_prefix='final_mbpp')
print('Final MBPP test metrics:', final_mbpp)
wandb.log({'final_mbpp_loss': final_mbpp['final_mbpp_loss']})

# 21. Save Model
trainer.save_model(cfg['output_dir'] + '/best_model')
tokenizer.save_pretrained(cfg['output_dir'] + '/best_model')
print('Saved model and tokenizer to', cfg['output_dir'] + '/best_model')

# 22. Finish
wandb.finish(); print('Done.')