In [None]:
# # # Use the magic command without code block formatting
# %pip install mlx

In [None]:
%%bash
python -m ipykernel install --user --name mlx-distributed --display-name "MLX Distributed (arm64)"


In [None]:
%%bash
# Remove existing environment if it exists
conda env remove -n mlx-distributed -y 2>/dev/null || true

# Create fresh environment with Python 3.11 (optimal for MLX)
CONDA_SUBDIR=osx-arm64 conda create -n mlx-distributed python=3.11 -y

# Activate and configure for ARM64

conda activate mlx-distributed
conda config --env --set subdir osx-arm64

echo "Environment created successfully!"
conda info --envs | grep mlx-distributed

In [None]:
%%bash
# Activate environment
source ~/anaconda3/etc/profile.d/conda.sh
conda activate mlx-distributed

# Install OpenMPI via conda (not homebrew!)
conda install -c conda-forge openmpi -y

# Install mpi4py
conda install -c conda-forge mpi4py -y

# Install MLX and MLX-LM
pip install mlx mlx-lm

# Install additional utilities
pip install numpy jupyter ipykernel

# Add kernel to Jupyter
python -m ipykernel install --user --name mlx-distributed --display-name "MLX Distributed"

echo "Installation complete!"

In [None]:
import sys
import platform
import subprocess

print("=== System Information ===")
print(f"Python: {sys.version}")
print(f"Platform: {platform.platform()}")
print(f"Architecture: {platform.machine()}")
print(f"Python executable: {sys.executable}")
print()

print("=== MLX Installation ===")
try:
    import mlx
    import mlx.core as mx
    print(f"✓ MLX version: {mlx.__version__}")
    print(f"✓ Metal available: {mx.metal.is_available()}")
    print(f"✓ Default device: {mx.default_device()}")
except Exception as e:
    print(f"✗ MLX error: {e}")
print()

print("=== MPI Installation ===")
try:
    from mpi4py import MPI
    print(f"✓ mpi4py version: {MPI.Get_version()}")
    print(f"✓ MPI vendor: {MPI.get_vendor()}")
    
    # Check MPI executable
    result = subprocess.run(['which', 'mpirun'], capture_output=True, text=True)
    print(f"✓ mpirun location: {result.stdout.strip()}")
    
    # Check MPI version - fix for f-string issue
    result = subprocess.run(['mpirun', '--version'], capture_output=True, text=True)
    first_line = result.stdout.strip().split('\n')[0]  # Move split outside f-string
    print(f"✓ MPI version: {first_line}")
except Exception as e:
    print(f"✗ MPI error: {e}")
print()

print("=== MLX-LM Installation ===")
try:
    import mlx_lm
    print("✓ mlx_lm installed successfully")
except Exception as e:
    print(f"✗ mlx_lm error: {e}")

In [None]:
import mlx.core as mx
import time

# Set GPU as default device
mx.set_default_device(mx.gpu)

print("=== GPU Test ===")
print(f"Default device: {mx.default_device()}")
print(f"Metal available: {mx.metal.is_available()}")

# Create a large array to test GPU
size = 10000
print(f"\nCreating {size}x{size} matrix multiplication...")

# Time CPU vs GPU
start = time.time()
a = mx.random.uniform(shape=(size, size))
b = mx.random.uniform(shape=(size, size))
c = a @ b
mx.eval(c)  # Force evaluation
gpu_time = time.time() - start

print(f"GPU computation time: {gpu_time:.3f} seconds")
print(f"GPU memory used: {mx.metal.get_active_memory() / 1024**3:.2f} GB")
print(f"GPU memory cache: {mx.metal.get_cache_memory() / 1024**3:.2f} GB")

# Test small model loading
print("\n=== Testing Model Loading ===")
try:
    from mlx_lm import load
    model, tokenizer = load("mlx-community/Llama-3.2-1B-Instruct-4bit")
    print("✓ Model loaded successfully")
    
    # Quick inference test
    prompt = "Hello"
    inputs = tokenizer(prompt, return_tensors="np")
    print(f"✓ Tokenizer works: '{prompt}' -> {inputs['input_ids']}")
except Exception as e:
    print(f"✗ Model loading error: {e}")
    print("This is okay for now - we'll use a different model for distributed tests")

In [None]:
import subprocess
import os

hosts = ["mm@mm1.local", "mm@mm2.local"]

print("=== Testing SSH Connectivity ===")
for host in hosts:
    print(f"\nTesting {host}...")
    
    # Test basic SSH
    result = subprocess.run(
        ["ssh", "-o", "BatchMode=yes", "-o", "ConnectTimeout=5", host, "echo 'SSH OK'"],
        capture_output=True, text=True
    )
    
    if result.returncode == 0:
        print(f"✓ SSH connection successful")
    else:
        print(f"✗ SSH connection failed: {result.stderr}")
        print(f"  Fix: Run 'ssh-copy-id {host}' in terminal")

# Create SSH config for faster connections
ssh_config = """
Host mm1.local
    User mm
    HostName mm1.local
    ForwardAgent yes
    ServerAliveInterval 60

Host mm2.local
    User mm
    HostName mm2.local
    ForwardAgent yes
    ServerAliveInterval 60

Host *
    AddKeysToAgent yes
    UseKeychain yes
    IdentityFile ~/.ssh/id_rsa
"""

print("\n=== Recommended SSH Config ===")
print("Add this to ~/.ssh/config:")
print(ssh_config)

In [None]:
import sys

# Check if mpi4py is installed in current environment
try:
    import mpi4py
    print(f"✓ mpi4py is installed in current Python: {mpi4py.__file__}")
    print(f"  mpi4py version: {mpi4py.__version__}")
except ImportError:
    print("✗ mpi4py not found in current Python")

# Check which Python we're using
print(f"\nCurrent Python: {sys.executable}")
print(f"Python version: {sys.version}")

# Better way to check installed packages
try:
    import pkg_resources
    installed_packages = [d.project_name for d in pkg_resources.working_set]
    if 'mpi4py' in installed_packages:
        version = pkg_resources.get_distribution('mpi4py').version
        print(f"\n✓ mpi4py {version} is installed via pip")
    else:
        print("\n✗ mpi4py not found in pip packages")
except:
    # Alternative method
    import importlib.metadata
    try:
        version = importlib.metadata.version('mpi4py')
        print(f"\n✓ mpi4py {version} is installed")
    except:
        print("\n✗ mpi4py not installed")

# Check conda list instead
import subprocess
result = subprocess.run(['conda', 'list', 'mpi4py'], capture_output=True, text=True)
print(f"\nConda list output:\n{result.stdout}")

In [None]:
import subprocess
import sys
import os

print("=== Solution: Use Homebrew MPI ===")

# First, uninstall the broken mpi4py
print("1. Removing broken mpi4py...")
subprocess.run([sys.executable, '-m', 'pip', 'uninstall', 'mpi4py', '-y'])

# Install mpi4py compiled against Homebrew's MPI
print("\n2. Installing mpi4py with Homebrew MPI...")
env = os.environ.copy()
env['MPICC'] = '/opt/homebrew/bin/mpicc'
env['CC'] = '/opt/homebrew/bin/mpicc'

result = subprocess.run(
    [sys.executable, '-m', 'pip', 'install', 'mpi4py', '--no-cache-dir', '--no-binary', 'mpi4py'],
    capture_output=True, text=True, env=env
)

if result.returncode == 0:
    print("✓ mpi4py installed successfully")
else:
    print(f"Installation output: {result.stdout}")
    print(f"Errors: {result.stderr}")

# Test the installation
print("\n3. Testing MPI...")
test_script = """
import sys
print(f"Python: {sys.executable}")

from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

print(f"Rank {rank}/{size}: MPI is working!")

if rank == 0 and size > 1:
    comm.send("Hello from rank 0", dest=1)
elif rank == 1:
    msg = comm.recv(source=0)
    print(f"Rank 1 received: {msg}")
"""

with open('test_mpi_final.py', 'w') as f:
    f.write(test_script)

# Run with Homebrew's mpirun
result = subprocess.run(
    ['/opt/homebrew/bin/mpirun', '-np', '2', sys.executable, 'test_mpi_final.py'],
    capture_output=True, text=True
)

print("\nOutput:")
print(result.stdout)
if result.stderr:
    print("Errors:", result.stderr)

os.remove('test_mpi_final.py')

In [None]:
# Create configuration for using Homebrew MPI
config_content = f"""#!/bin/bash
# MLX Distributed Configuration

# Use Homebrew MPI
export PATH="/opt/homebrew/bin:$PATH"
export MPICC=/opt/homebrew/bin/mpicc
export MPIRUN=/opt/homebrew/bin/mpirun

# Python from conda environment
export PYTHON={sys.executable}

# Function to run distributed MLX
run_mlx_dist() {{
    /opt/homebrew/bin/mpirun "$@"
}}

echo "MLX Distributed configured with:"
echo "  MPI: Homebrew OpenMPI 5.0.7"
echo "  Python: Conda environment (mlx-distributed)"
echo ""
echo "Usage: run_mlx_dist -np 4 python your_script.py"
"""

with open('mlx_dist_config.sh', 'w') as f:
    f.write(config_content)

os.chmod('mlx_dist_config.sh', 0o755)

print("\n=== Configuration Created ===")
print("Source this before running distributed jobs:")
print("  source mlx_dist_config.sh")

In [None]:
import subprocess
import os

mlx_launch = '/Users/zz/anaconda3/envs/mlx-distributed/bin/mlx.launch'

print("=== Creating Final Working Scripts ===")

# Local run script with MPI backend
local_script = f"""#!/bin/bash
# Run MLX distributed locally with MPI backend

NP="${{1:-2}}"
SCRIPT="${{2:-test_mlx_dist.py}}"

echo "Running MLX locally with $NP processes (MPI backend)..."
echo "Script: $SCRIPT"
echo ""

{mlx_launch} --backend mpi --hosts localhost -n "$NP" "$SCRIPT"
"""

with open('run_mlx_local.sh', 'w') as f:
    f.write(local_script)
os.chmod('run_mlx_local.sh', 0o755)

# Distributed run script for your cluster
distributed_script = f"""#!/bin/bash
# Run MLX distributed across your Mac cluster

SCRIPT="${{1:-test_mlx_dist.py}}"
PROCESSES_PER_HOST="${{2:-2}}"

echo "Running MLX distributed (MPI backend)"
echo "Hosts: mbp.local, mm1.local, mm2.local"
echo "Processes per host: $PROCESSES_PER_HOST"
echo "Script: $SCRIPT"
echo ""

{mlx_launch} --backend mpi \\
    --hosts mbp.local,mm1.local,mm2.local \\
    -n "$PROCESSES_PER_HOST" \\
    "$SCRIPT"
"""

with open('run_mlx_distributed.sh', 'w') as f:
    f.write(distributed_script)
os.chmod('run_mlx_distributed.sh', 0o755)

# Create hostfile for MPI backend
hostfile_content = """mbp.local
mbp.local
mm1.local
mm1.local
mm2.local
mm2.local
"""

with open('mlx_hostfile.txt', 'w') as f:
    f.write(hostfile_content)

# Hostfile version
hostfile_script = f"""#!/bin/bash
# Run MLX using hostfile (MPI backend)

SCRIPT="${{1:-test_mlx_dist.py}}"
HOSTFILE="${{2:-mlx_hostfile.txt}}"

echo "Running MLX with hostfile (MPI backend)"
echo "Hostfile: $HOSTFILE"
echo "Script: $SCRIPT"
echo ""

{mlx_launch} --backend mpi --hostfile "$HOSTFILE" "$SCRIPT"
"""

with open('run_mlx_hostfile.sh', 'w') as f:
    f.write(hostfile_script)
os.chmod('run_mlx_hostfile.sh', 0o755)

print("Created working scripts!")
print("\n✅ Test locally first:")
print("   ./run_mlx_local.sh 4")
print("\n✅ Then run distributed:")
print("   ./run_mlx_distributed.sh")
print("   # This will run 2 processes on each of your 3 Macs (6 total)")

In [None]:
# Create comprehensive distributed test
comprehensive_test = """
import mlx.core as mx
import mlx.nn as nn
import socket
import time
import os

# Initialize distributed
world = mx.distributed.init()
rank = world.rank()
size = world.size()
hostname = socket.gethostname()
pid = os.getpid()

# Set GPU
mx.set_default_device(mx.gpu)

print(f"[Rank {rank}/{size}] Process {pid} on {hostname}")
print(f"[Rank {rank}] GPU: {mx.metal.is_available()}")
print(f"[Rank {rank}] Device: {mx.default_device()}")

# Synchronize before tests
mx.eval(mx.distributed.all_sum(mx.array([1.0])))

if rank == 0:
    print("\\n" + "="*50)
    print("Running MLX Distributed Tests")
    print("="*50)

# Test 1: Basic all-reduce
if rank == 0:
    print("\\n1. Testing all-reduce...")
    
local_value = mx.array([float(rank)])
sum_result = mx.distributed.all_sum(local_value)
mx.eval(sum_result)

if rank == 0:
    expected = sum(range(size))
    print(f"   All-reduce sum: {sum_result.item()} (expected: {expected})")
    print(f"   {'✓ PASSED' if abs(sum_result.item() - expected) < 0.001 else '✗ FAILED'}")

# Test 2: Model parameter synchronization
if rank == 0:
    print("\\n2. Testing model parameter sync...")

model = nn.Linear(100, 10)
mx.eval(model.parameters())

# Get initial param sum
param_sum_before = sum(p.sum().item() for _, p in model.parameters())
print(f"[Rank {rank}] Initial param sum: {param_sum_before:.6f}")

# Synchronize parameters
for _, p in model.parameters():
    p_synced = mx.distributed.all_sum(p) / size
    p[:] = p_synced

mx.eval(model.parameters())
param_sum_after = sum(p.sum().item() for _, p in model.parameters())

# All ranks should have same param sum now
all_sums = mx.distributed.all_sum(mx.array([param_sum_after]))
mx.eval(all_sums)

if rank == 0:
    print(f"   Synchronized param sum: {param_sum_after:.6f}")
    print(f"   {'✓ PASSED' if all_sums.item() == param_sum_after * size else '✗ FAILED'}")

# Test 3: Bandwidth test
if rank == 0:
    print("\\n3. Testing bandwidth...")

size_mb = 10
data = mx.random.uniform(shape=(size_mb * 1024 * 1024 // 4,))

start = time.time()
result = mx.distributed.all_sum(data)
mx.eval(result)
elapsed = time.time() - start

bandwidth = size_mb * size / elapsed
if rank == 0:
    print(f"   Data size: {size_mb}MB per rank")
    print(f"   Time: {elapsed:.3f}s")
    print(f"   Bandwidth: {bandwidth:.1f} MB/s")

# Final status
mx.eval(mx.distributed.all_sum(mx.array([1.0])))  # Sync
if rank == 0:
    print("\\n" + "="*50)
    print("✓ All tests completed successfully!")
    print("="*50)
"""

with open('test_mlx_comprehensive.py', 'w') as f:
    f.write(comprehensive_test)

print("\n=== Setup Complete! ===")
print("\n🎉 MLX distributed is working correctly!")
print("\nNext steps:")
print("1. Test comprehensive script locally:")
print("   ./run_mlx_local.sh 4 test_mlx_comprehensive.py")
print("\n2. Deploy environment to mm1.local and mm2.local")
print("   (They need the same mlx-distributed conda environment)")
print("\n3. Run distributed across your cluster:")
print("   ./run_mlx_distributed.sh test_mlx_comprehensive.py")
print("\nThis will run 6 processes total (2 on each Mac)")