# MLX Distributed Training Notebook

This notebook provides an interactive interface for running MLX distributed training using the existing modules and utilities.

## 1. Setup and Imports

In [None]:
import os
import sys
import json
import shutil
from pathlib import Path
import subprocess
import time
from datetime import datetime

# Add the project root to Python path
project_root = Path.cwd()
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

# Import MLX and project modules
import mlx.core as mx
from mlx_lm import load, generate
from mlx_lm.tuner.trainer import TrainingCallback
from mlx_lm.lora import run

# Import custom utilities
from utils.data_preparation import prepare_data, validate_data_format
from utils.logging import setup_logging, get_logger
from utils.model_cache import ModelCache
from utils.training_callbacks import (
    CheckpointCallback,
    MetricsCallback,
    EarlyStoppingCallback,
    CombinedCallback
)
from utils.training_core import (
    setup_distributed,
    all_reduce_grads,
    DistributedTrainingCallback,
    run_distributed_training
)

print("✅ All modules imported successfully!")
print(f"Project root: {project_root}")
print(f"MLX version: {mx.__version__ if hasattr(mx, '__version__') else 'Unknown'}")

## 2. Configuration

In [None]:
# Training configuration
config = {
    # Model settings
    "model": "mlx-community/gemma-2-2b-it-4bit",
    "adapter_path": "models/fine-tuned",
    
    # Data settings
    "data_path": "data",
    "train_file": "train.jsonl",
    "valid_file": "valid.jsonl",
    
    # Training parameters
    "batch_size": 1,
    "num_epochs": 1,
    "iters": 10,
    "learning_rate": 1e-4,
    "num_layers": 4,
    "max_seq_length": 128,
    
    # LoRA parameters
    "lora_rank": 8,
    "lora_alpha": 16,
    "lora_dropout": 0.0,
    "lora_scale": 10.0,
    
    # Reporting and checkpointing
    "steps_per_report": 2,
    "steps_per_eval": 5,
    "save_every": 5,
    
    # Distributed settings
    "distributed": False,  # Set to True for distributed training
    "hostfile": "hostfile"
}

print("Configuration loaded:")
print(json.dumps(config, indent=2))

## 3. Data Preparation

In [None]:
# Prepare data files
data_path = Path(config["data_path"])

# Copy example files to expected names if they don't exist
if not (data_path / "train.jsonl").exists():
    shutil.copy(data_path / "example_train.jsonl", data_path / "train.jsonl")
    print("✅ Copied example_train.jsonl to train.jsonl")

if not (data_path / "valid.jsonl").exists():
    shutil.copy(data_path / "example_valid.jsonl", data_path / "valid.jsonl")
    print("✅ Copied example_valid.jsonl to valid.jsonl")

# Validate data format
print("\nValidating data files...")
train_valid = validate_data_format(str(data_path / "train.jsonl"))
valid_valid = validate_data_format(str(data_path / "valid.jsonl"))

if train_valid and valid_valid:
    print("✅ Data validation passed!")
else:
    print("❌ Data validation failed!")

# Count samples
with open(data_path / "train.jsonl", 'r') as f:
    train_count = sum(1 for _ in f)
with open(data_path / "valid.jsonl", 'r') as f:
    valid_count = sum(1 for _ in f)

print(f"\nDataset statistics:")
print(f"- Training samples: {train_count}")
print(f"- Validation samples: {valid_count}")

## 4. Model Setup

In [None]:
# Create output directories
os.makedirs("models/fine-tuned", exist_ok=True)
os.makedirs("models/checkpoints", exist_ok=True)
os.makedirs("logs", exist_ok=True)

# Setup logging
logger = setup_logging("logs/training.log")
logger.info("Starting MLX training setup")

# Initialize model cache
model_cache = ModelCache()
print("✅ Model cache initialized")

# Check if model is already cached
if model_cache.is_cached(config["model"]):
    print(f"✅ Model {config['model']} is already cached")
else:
    print(f"ℹ️ Model {config['model']} will be downloaded on first use")

## 5. Training Functions

In [None]:
def run_local_training(config):
    """Run training on a single node"""
    print("Starting local training...")
    
    # Create training arguments
    from types import SimpleNamespace
    args = SimpleNamespace(
        model=config["model"],
        train=True,
        data=config["data_path"],
        adapter_path=config["adapter_path"],
        batch_size=config["batch_size"],
        iters=config["iters"],
        num_layers=config["num_layers"],
        max_seq_length=config["max_seq_length"],
        learning_rate=config["learning_rate"],
        steps_per_report=config["steps_per_report"],
        steps_per_eval=config["steps_per_eval"],
        save_every=config["save_every"],
        lora_parameters={
            "rank": config["lora_rank"],
            "alpha": config["lora_alpha"],
            "dropout": config["lora_dropout"],
            "scale": config["lora_scale"]
        },
        fine_tune_type="lora",
        val_batches=5,
        seed=42,
        resume_adapter_file=None,
        test=False,
        test_batches=100,
        grad_checkpoint=False,
        lr_schedule=None,
        wandb=None,
        optimizer="adam",
        optimizer_config={"adam": {}, "adamw": {"weight_decay": 0.01}}
    )
    
    # Create callbacks
    checkpoint_callback = CheckpointCallback(
        checkpoint_dir="models/checkpoints",
        save_every=config["save_every"]
    )
    
    metrics_callback = MetricsCallback()
    
    combined_callback = CombinedCallback([
        checkpoint_callback,
        metrics_callback
    ])
    
    # Run training
    start_time = time.time()
    
    try:
        run(args, training_callback=combined_callback)
        
        end_time = time.time()
        print(f"\n✅ Training completed in {end_time - start_time:.2f} seconds")
        
        # Print final metrics
        if metrics_callback.train_losses:
            print(f"Final training loss: {metrics_callback.train_losses[-1][1]:.4f}")
        if metrics_callback.val_losses:
            print(f"Final validation loss: {metrics_callback.val_losses[-1][1]:.4f}")
            
        return True
        
    except Exception as e:
        print(f"❌ Training failed with error: {e}")
        logger.error(f"Training failed: {e}")
        return False


def run_distributed_training_notebook(config):
    """Run distributed training across multiple nodes"""
    print("Starting distributed training...")
    
    # Check if hostfile exists
    if not os.path.exists(config["hostfile"]):
        print(f"❌ Hostfile not found: {config['hostfile']}")
        return False
    
    # Run the distributed training script
    cmd = "./run_distributed_fine_tune_fixed.sh"
    
    try:
        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        print(result.stdout)
        if result.stderr:
            print("Errors:", result.stderr)
        return result.returncode == 0
    except Exception as e:
        print(f"❌ Failed to run distributed training: {e}")
        return False


print("✅ Training functions defined")

## 6. Run Training

In [None]:
# Choose training mode
if config["distributed"]:
    print("🚀 Running distributed training...")
    success = run_distributed_training_notebook(config)
else:
    print("🚀 Running local training...")
    success = run_local_training(config)

if success:
    print("\n✅ Training completed successfully!")
else:
    print("\n❌ Training failed!")

## 7. Test the Fine-tuned Model

In [None]:
# Test the fine-tuned model
if success and os.path.exists("models/fine-tuned/adapters.safetensors"):
    print("Loading fine-tuned model for testing...")
    
    try:
        # Load model with adapters
        model, tokenizer = load(
            config["model"],
            adapter_path=config["adapter_path"]
        )
        
        print("✅ Model loaded successfully!")
        
        # Test prompts
        test_prompts = [
            "turn on the living room lights",
            "what's the temperature in the bedroom",
            "play music in the kitchen"
        ]
        
        print("\nTesting model with sample prompts:\n")
        
        for prompt in test_prompts:
            print(f"Prompt: {prompt}")
            
            response = generate(
                model,
                tokenizer,
                prompt=prompt,
                max_tokens=100,
                temperature=0.7
            )
            
            print(f"Response: {response}")
            print("-" * 50)
            
    except Exception as e:
        print(f"❌ Failed to test model: {e}")
else:
    print("⚠️ No fine-tuned model found to test")

## 8. Model Information and Statistics

In [None]:
# Display model information
adapter_path = Path(config["adapter_path"])

if adapter_path.exists():
    print("📊 Model Information:\n")
    
    # List adapter files
    adapter_files = list(adapter_path.glob("*.safetensors"))
    print(f"Adapter files ({len(adapter_files)}):")
    for file in adapter_files:
        size_mb = file.stat().st_size / (1024 * 1024)
        print(f"  - {file.name}: {size_mb:.2f} MB")
    
    # Load and display adapter config
    config_file = adapter_path / "adapter_config.json"
    if config_file.exists():
        with open(config_file, 'r') as f:
            adapter_config = json.load(f)
        print("\nAdapter Configuration:")
        print(json.dumps(adapter_config, indent=2))
    
    # Display training logs if available
    log_file = Path("logs/training.log")
    if log_file.exists():
        print("\n📝 Recent training logs:")
        with open(log_file, 'r') as f:
            lines = f.readlines()
            # Show last 10 lines
            for line in lines[-10:]:
                print(line.strip())
else:
    print("⚠️ No model files found")

## 9. Cleanup (Optional)

In [None]:
# Optional: Clean up temporary files
# Uncomment the following lines if you want to clean up

# # Remove copied data files
# if (data_path / "train.jsonl").exists():
#     os.remove(data_path / "train.jsonl")
#     print("Removed train.jsonl")
    
# if (data_path / "valid.jsonl").exists():
#     os.remove(data_path / "valid.jsonl")
#     print("Removed valid.jsonl")

print("\n✅ Notebook execution complete!")