# MLX Distributed Training Tutorial

This notebook provides a comprehensive guide to using the MLX Distributed Training framework for training large language models across multiple Apple Silicon devices.

## What You'll Learn
1. Setting up distributed training environment
2. Configuring model and training parameters
3. Implementing efficient data loading
4. Managing memory and performance
5. Monitoring training progress
6. Handling distributed communication

## Prerequisites
- macOS Sonoma 14.0+
- Python 3.12+
- MLX 0.20.0+
- High-speed network connection (10Gbps recommended)
- Multiple Apple Silicon devices

## 1. Environment Setup

First, let's import required modules and verify our environment:

In [None]:
import mlx.core as mx
from src.models import UnifiedModel
from src.distributed import DistributedTrainer
from src.monitoring import PerformanceDashboard
from src.utils.network_utils import DistributedCommunicator
from src.utils.memory_utils import AdvancedMemoryManager
import psutil
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)

# Verify MLX and Metal availability
print(f"MLX Version: {mx.__version__}")
print(f"Metal Available: {mx.metal.is_available()}")
print(f"Available Memory: {mx.metal.get_memory_limit() / (1024**3):.2f} GB")

## 2. Network Verification

Before starting distributed training, let's verify network connectivity:

In [None]:
# Initialize network communicator
network_config = NetworkConfig(
    primary_host="localhost",
    primary_port=29500
)
communicator = DistributedCommunicator(network_config)

# Verify connection
if communicator.verify_connection():
    print("Network verification successful")
    print(f"World Size: {communicator.world.size}")
    print(f"Current Rank: {communicator.world.rank}")
else:
    raise RuntimeError("Network verification failed")

## 3. Memory Management Setup

Configure memory management for optimal performance:

In [None]:
# Initialize memory manager
memory_config = MemoryConfig(
    max_memory_gb=mx.metal.get_memory_limit() / (1024**3) * 0.9  # Use 90% of available memory
)
memory_manager = AdvancedMemoryManager(memory_config)

# Monitor current memory usage
def print_memory_stats():
    metal_used = mx.metal.get_active_memory() / (1024**3)
    metal_total = mx.metal.get_memory_limit() / (1024**3)
    ram_used = psutil.Process().memory_info().rss / (1024**3)
    
    print(f"Metal Memory Used: {metal_used:.2f} GB / {metal_total:.2f} GB")
    print(f"RAM Used: {ram_used:.2f} GB")

print_memory_stats()

## 4. Model Configuration

Set up the model and training configuration:

In [None]:
config = {
    "model": {
        "num_layers": 24,
        "dims": 1024,
        "num_heads": 16,
        "vocab_size": 50257,
        "max_seq_length": 2048
    },
    "training": {
        "batch_size": {
            "primary": 32,    # Larger device (e.g., Mac Studio)
            "secondary": 16   # Smaller device (e.g., MacBook)
        },
        "gradient_accumulation_steps": 8,
        "learning_rate": 1e-4,
        "warmup_steps": 1000,
        "max_steps": 100000,
        "eval_frequency": 500,
        "save_frequency": 1000
    },
    "distributed": {
        "world_size": 2,
        "backend": "mpi",
        "sync_weights_every": 100  # Synchronize weights every N steps
    },
    "monitoring": {
        "enable_ui": True,
        "port": 8050,
        "alert_thresholds": {
            "loss": 10.0,
            "gpu_utilization": 0.95,
            "memory_usage": 0.9
        }
    }
}

## 5. Initialize Training Components

Set up the trainer, model, and monitoring:

In [None]:
# Initialize trainer
trainer = DistributedTrainer(config)

# Create and optimize model
model = UnifiedModel(config["model"])
model = memory_manager.optimize_memory_layout(model)

# Setup monitoring dashboard
dashboard = PerformanceDashboard(config["monitoring"])

print("Training components initialized successfully")
print_memory_stats()

## 6. Data Loading and Processing

Implement efficient data loading with streaming:

In [None]:
from datasets import load_dataset
from src.training.data_utils import DataManager, DataConfig

# Configure data loading
data_config = DataConfig(
    streaming=True,  # Enable streaming for large datasets
    cache_dir="./cache",
    prefetch_batches=2
)

# Load and prepare dataset
data_manager = DataManager(data_config)
dataset = load_dataset("openwebtext", split="train", streaming=True)
processed_dataset = data_manager.prepare_dataset(dataset)

# Create data loader
dataloader = data_manager.create_loader(
    processed_dataset,
    batch_size=config["training"]["batch_size"]["primary" if trainer.world.rank == 0 else "secondary"]
)

## 7. Training Loop with Monitoring

Run the training loop with comprehensive monitoring:

In [None]:
async def train():
    try:
        # Training loop
        for epoch in range(config["training"]["max_epochs"]):
            # Train one epoch
            epoch_metrics = await trainer.train_epoch(dataloader, epoch)
            
            # Log metrics
            print(f"Epoch {epoch} - Loss: {epoch_metrics['avg_loss']:.4f}, "
                  f"LR: {epoch_metrics['learning_rate']:.6f}")
            
            # Check early stopping
            if trainer.scheduler.should_stop(epoch_metrics['avg_loss'], epoch):
                print("Early stopping triggered")
                break
                
    except KeyboardInterrupt:
        print("Training interrupted")
    finally:
        # Clean shutdown
        trainer.shutdown()
        dashboard.shutdown()

# Start training
await train()

## 8. Analyzing Results

Review training metrics and performance:

In [None]:
# Get training summary
metrics_summary = dashboard.get_summary()

print("\nTraining Summary:")
for metric, stats in metrics_summary.items():
    print(f"\n{metric}:")
    for stat_name, value in stats.items():
        print(f"  {stat_name}: {value:.4f}")

# Create and display final plots
training_plot = dashboard.create_training_plot()
system_plot = dashboard.create_system_plot()

training_plot.show()
system_plot.show()

## Next Steps

1. Experiment with different model configurations
2. Optimize batch sizes for your devices
3. Implement custom monitoring metrics
4. Explore advanced features:
   - Gradient accumulation
   - Dynamic batch sizing
   - Custom evaluation metrics

For more details, check out:
- [Performance Tuning Guide](../docs/performance_tuning.md)
- [API Documentation](../docs/api/)
- [Best Practices](../docs/best_practices.md)