In [1]:
# # # Use the magic command without code block formatting
# %pip install mlx

In [2]:
%%bash
python -m ipykernel install --user --name mlx-distributed --display-name "MLX Distributed (arm64)"


Installed kernelspec mlx-distributed in /Users/zz/Library/Jupyter/kernels/mlx-distributed


In [3]:
%%bash
# Remove existing environment if it exists
conda env remove -n mlx-distributed -y 2>/dev/null || true

# Create fresh environment with Python 3.11 (optimal for MLX)
CONDA_SUBDIR=osx-arm64 conda create -n mlx-distributed python=3.11 -y

# Activate and configure for ARM64

conda activate mlx-distributed
conda config --env --set subdir osx-arm64

echo "Environment created successfully!"
conda info --envs | grep mlx-distributed

Channels:
 - defaults
Platform: osx-arm64
Collecting package metadata (repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /Users/zz/anaconda3/envs/mlx-distributed

  added / updated specs:
    - python=3.11


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    ca-certificates-2025.2.25  |       hca03da5_0         131 KB
    openssl-3.0.16             |       h02f6b3c_0         4.3 MB
    ------------------------------------------------------------
                                           Total:         4.4 MB

The following NEW packages will be INSTALLED:

  bzip2              pkgs/main/osx-arm64::bzip2-1.0.8-h80987f9_6 
  ca-certificates    pkgs/main/osx-arm64::ca-certificates-2025.2.25-hca03da5_0 
  expat              pkgs/main/osx-arm64::expat-2.7.1-h313beb8_0 
  libcxx             pkgs/main/osx-arm64::libcxx-17.0.6-he5c5206_4 
  libffi          


CondaError: Run 'conda init' before 'conda activate'



Environment created successfully!
mlx-distributed      * /Users/zz/anaconda3/envs/mlx-distributed


In [4]:
%%bash
# Activate environment
source ~/anaconda3/etc/profile.d/conda.sh
conda activate mlx-distributed

# Install OpenMPI via conda (not homebrew!)
conda install -c conda-forge openmpi -y

# Install mpi4py
conda install -c conda-forge mpi4py -y

# Install MLX and MLX-LM
pip install mlx mlx-lm

# Install additional utilities
pip install numpy jupyter ipykernel

# Add kernel to Jupyter
python -m ipykernel install --user --name mlx-distributed --display-name "MLX Distributed"

echo "Installation complete!"

Channels:
 - conda-forge
 - defaults
Platform: osx-arm64
Collecting package metadata (repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /Users/zz/anaconda3/envs/mlx-distributed

  added / updated specs:
    - openmpi


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    openmpi-4.1.3              |       h8b79891_4           9 KB  conda-forge
    ------------------------------------------------------------
                                           Total:           9 KB

The following NEW packages will be INSTALLED:

  mpi                conda-forge/osx-arm64::mpi-1.0-openmpi 
  openmpi            conda-forge/noarch::openmpi-4.1.3-h8b79891_4 

The following packages will be UPDATED:

  ca-certificates    pkgs/main/osx-arm64::ca-certificates-~ --> conda-forge/noarch::ca-certificates-2025.6.15-hbd8a1cb_0 
  openssl              pkgs/main::openssl-3.0.1

In [5]:
import sys
import platform
import subprocess

print("=== System Information ===")
print(f"Python: {sys.version}")
print(f"Platform: {platform.platform()}")
print(f"Architecture: {platform.machine()}")
print(f"Python executable: {sys.executable}")
print()

print("=== MLX Installation ===")
try:
    import mlx
    import mlx.core as mx
    print(f"✓ MLX version: {mlx.__version__}")
    print(f"✓ Metal available: {mx.metal.is_available()}")
    print(f"✓ Default device: {mx.default_device()}")
except Exception as e:
    print(f"✗ MLX error: {e}")
print()

print("=== MPI Installation ===")
try:
    from mpi4py import MPI
    print(f"✓ mpi4py version: {MPI.Get_version()}")
    print(f"✓ MPI vendor: {MPI.get_vendor()}")
    
    # Check MPI executable
    result = subprocess.run(['which', 'mpirun'], capture_output=True, text=True)
    print(f"✓ mpirun location: {result.stdout.strip()}")
    
    # Check MPI version - fix for f-string issue
    result = subprocess.run(['mpirun', '--version'], capture_output=True, text=True)
    first_line = result.stdout.strip().split('\n')[0]  # Move split outside f-string
    print(f"✓ MPI version: {first_line}")
except Exception as e:
    print(f"✗ MPI error: {e}")
print()

print("=== MLX-LM Installation ===")
try:
    import mlx_lm
    print("✓ mlx_lm installed successfully")
except Exception as e:
    print(f"✗ mlx_lm error: {e}")

=== System Information ===
Python: 3.11.13 (main, Jun  5 2025, 08:21:08) [Clang 14.0.6 ]
Platform: macOS-15.5-arm64-arm-64bit
Architecture: arm64
Python executable: /Users/zz/anaconda3/envs/mlx-distributed/bin/python

=== MLX Installation ===
✗ MLX error: module 'mlx' has no attribute '__version__'

=== MPI Installation ===
✗ MPI error: dlopen(/Users/zz/anaconda3/envs/mlx-distributed/lib/python3.11/site-packages/mpi4py/MPI.cpython-311-darwin.so, 0x0002): Library not loaded: @rpath/libmpi.40.dylib
  Referenced from: <A853210E-CFB8-34D9-8C29-289BC747DD98> /Users/zz/anaconda3/envs/mlx-distributed/lib/python3.11/site-packages/mpi4py/MPI.cpython-311-darwin.so
  Reason: tried: '/Users/zz/anaconda3/envs/mlx-distributed/lib/python3.11/site-packages/mpi4py/../../../libmpi.40.dylib' (no such file), '/Users/zz/anaconda3/envs/mlx-distributed/lib/python3.11/site-packages/mpi4py/../../../libmpi.40.dylib' (no such file), '/Users/zz/anaconda3/envs/mlx-distributed/bin/../lib/libmpi.40.dylib' (no such f

In [6]:
import mlx.core as mx
import time

# Set GPU as default device
mx.set_default_device(mx.gpu)

print("=== GPU Test ===")
print(f"Default device: {mx.default_device()}")
print(f"Metal available: {mx.metal.is_available()}")

# Create a large array to test GPU
size = 10000
print(f"\nCreating {size}x{size} matrix multiplication...")

# Time CPU vs GPU
start = time.time()
a = mx.random.uniform(shape=(size, size))
b = mx.random.uniform(shape=(size, size))
c = a @ b
mx.eval(c)  # Force evaluation
gpu_time = time.time() - start

print(f"GPU computation time: {gpu_time:.3f} seconds")
print(f"GPU memory used: {mx.metal.get_active_memory() / 1024**3:.2f} GB")
print(f"GPU memory cache: {mx.metal.get_cache_memory() / 1024**3:.2f} GB")

# Test small model loading
print("\n=== Testing Model Loading ===")
try:
    from mlx_lm import load
    model, tokenizer = load("mlx-community/Llama-3.2-1B-Instruct-4bit")
    print("✓ Model loaded successfully")
    
    # Quick inference test
    prompt = "Hello"
    inputs = tokenizer(prompt, return_tensors="np")
    print(f"✓ Tokenizer works: '{prompt}' -> {inputs['input_ids']}")
except Exception as e:
    print(f"✗ Model loading error: {e}")
    print("This is okay for now - we'll use a different model for distributed tests")

=== GPU Test ===
Default device: Device(gpu, 0)
Metal available: True

Creating 10000x10000 matrix multiplication...
GPU computation time: 0.306 seconds
GPU memory used: 1.12 GB
GPU memory cache: 0.00 GB

=== Testing Model Loading ===


mx.metal.get_active_memory is deprecated and will be removed in a future version. Use mx.get_active_memory instead.
mx.metal.get_cache_memory is deprecated and will be removed in a future version. Use mx.get_cache_memory instead.


Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

✓ Model loaded successfully
✗ Model loading error: 'TokenizerWrapper' object is not callable
This is okay for now - we'll use a different model for distributed tests


In [7]:
import subprocess
import os

hosts = ["mm@mm1.local", "mm@mm2.local"]

print("=== Testing SSH Connectivity ===")
for host in hosts:
    print(f"\nTesting {host}...")
    
    # Test basic SSH
    result = subprocess.run(
        ["ssh", "-o", "BatchMode=yes", "-o", "ConnectTimeout=5", host, "echo 'SSH OK'"],
        capture_output=True, text=True
    )
    
    if result.returncode == 0:
        print(f"✓ SSH connection successful")
    else:
        print(f"✗ SSH connection failed: {result.stderr}")
        print(f"  Fix: Run 'ssh-copy-id {host}' in terminal")

# Create SSH config for faster connections
ssh_config = """
Host mm1.local
    User mm
    HostName mm1.local
    ForwardAgent yes
    ServerAliveInterval 60

Host mm2.local
    User mm
    HostName mm2.local
    ForwardAgent yes
    ServerAliveInterval 60

Host *
    AddKeysToAgent yes
    UseKeychain yes
    IdentityFile ~/.ssh/id_rsa
"""

print("\n=== Recommended SSH Config ===")
print("Add this to ~/.ssh/config:")
print(ssh_config)

=== Testing SSH Connectivity ===

Testing mm@mm1.local...
✓ SSH connection successful

Testing mm@mm2.local...
✓ SSH connection successful

=== Recommended SSH Config ===
Add this to ~/.ssh/config:

Host mm1.local
    User mm
    HostName mm1.local
    ForwardAgent yes
    ServerAliveInterval 60

Host mm2.local
    User mm
    HostName mm2.local
    ForwardAgent yes
    ServerAliveInterval 60

Host *
    AddKeysToAgent yes
    UseKeychain yes
    IdentityFile ~/.ssh/id_rsa



In [8]:
# Create a test MPI script
test_script = """
import os
import socket
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

hostname = socket.gethostname()
pid = os.getpid()

print(f"Rank {rank}/{size} on {hostname} (PID: {pid})")

# Test communication
if rank == 0:
    data = {'message': 'Hello from rank 0!'}
    comm.send(data, dest=1)
    print(f"Rank 0: Sent message to rank 1")
elif rank == 1:
    data = comm.recv(source=0)
    print(f"Rank 1: Received: {data['message']}")

comm.Barrier()
if rank == 0:
    print("All processes synchronized successfully!")
"""

# Write test script
with open('test_mpi_local.py', 'w') as f:
    f.write(test_script)

print("=== Testing Local MPI (2 processes) ===")
result = subprocess.run(
    ['mpirun', '-np', '2', sys.executable, 'test_mpi_local.py'],
    capture_output=True, text=True
)

print("Output:")
print(result.stdout)
if result.stderr:
    print("Errors:")
    print(result.stderr)

# Clean up
os.remove('test_mpi_local.py')

=== Testing Local MPI (2 processes) ===
Output:

Errors:
Traceback (most recent call last):
  File "/Users/zz/Documents/GitHub/mlx-dist-setup/test_mpi_local.py", line 4, in <module>
Traceback (most recent call last):
  File "/Users/zz/Documents/GitHub/mlx-dist-setup/test_mpi_local.py", line 4, in <module>
    from mpi4py import MPI
    from mpi4py import MPI
ImportError: dlopen(/Users/zz/anaconda3/envs/mlx-distributed/lib/python3.11/site-packages/mpi4py/MPI.cpython-311-darwin.so, 0x0002): Library not loaded: @rpath/libmpi.40.dylib
  Referenced from: <A853210E-CFB8-34D9-8C29-289BC747DD98> /Users/zz/anaconda3/envs/mlx-distributed/lib/python3.11/site-packages/mpi4py/MPI.cpython-311-darwin.so
  Reason: tried: '/Users/zz/anaconda3/envs/mlx-distributed/lib/python3.11/site-packages/mpi4py/../../../libmpi.40.dylib' (no such file), '/Users/zz/anaconda3/envs/mlx-distributed/lib/python3.11/site-packages/mpi4py/../../../libmpi.40.dylib' (no such file), '/Users/zz/anaconda3/envs/mlx-distributed/bin

In [9]:
import sys

# Check if mpi4py is installed in current environment
try:
    import mpi4py
    print(f"✓ mpi4py is installed in current Python: {mpi4py.__file__}")
    print(f"  mpi4py version: {mpi4py.__version__}")
except ImportError:
    print("✗ mpi4py not found in current Python")

# Check which Python we're using
print(f"\nCurrent Python: {sys.executable}")
print(f"Python version: {sys.version}")

# Better way to check installed packages
try:
    import pkg_resources
    installed_packages = [d.project_name for d in pkg_resources.working_set]
    if 'mpi4py' in installed_packages:
        version = pkg_resources.get_distribution('mpi4py').version
        print(f"\n✓ mpi4py {version} is installed via pip")
    else:
        print("\n✗ mpi4py not found in pip packages")
except:
    # Alternative method
    import importlib.metadata
    try:
        version = importlib.metadata.version('mpi4py')
        print(f"\n✓ mpi4py {version} is installed")
    except:
        print("\n✗ mpi4py not installed")

# Check conda list instead
import subprocess
result = subprocess.run(['conda', 'list', 'mpi4py'], capture_output=True, text=True)
print(f"\nConda list output:\n{result.stdout}")

✓ mpi4py is installed in current Python: /Users/zz/anaconda3/envs/mlx-distributed/lib/python3.11/site-packages/mpi4py/__init__.py
  mpi4py version: 3.1.4

Current Python: /Users/zz/anaconda3/envs/mlx-distributed/bin/python
Python version: 3.11.13 (main, Jun  5 2025, 08:21:08) [Clang 14.0.6 ]

✓ mpi4py 3.1.4 is installed via pip


  import pkg_resources



Conda list output:
# packages in environment at /Users/zz/anaconda3/envs/mlx-distributed:
#
# Name                     Version          Build            Channel
mpi4py                     3.1.4            py311he4f2fd2_0



In [10]:
import os
import sys
import subprocess

# Use the -m flag to ensure Python uses the right module path
test_code = '''
import sys
print(f"Python: {sys.executable}")
try:
    import mpi4py
    print(f"mpi4py location: {mpi4py.__file__}")
    from mpi4py import MPI
    print(f"Rank {MPI.COMM_WORLD.rank}: Success!")
except Exception as e:
    print(f"Error: {e}")
'''

# Save test script
with open('simple_mpi_test.py', 'w') as f:
    f.write(test_code)

# Run with python -m mpi4py
print("=== Running with python -m mpi4py ===")
result = subprocess.run(
    [sys.executable, '-m', 'mpi4py', '-n', '2', 'simple_mpi_test.py'],
    capture_output=True, text=True
)

print("Output:", result.stdout)
if result.stderr:
    print("Errors:", result.stderr)

os.remove('simple_mpi_test.py')

=== Running with python -m mpi4py ===
Output: 
Errors: Unknown option: -n
usage: python -m mpi4py [options] <pyfile> [arg] ...
   or: python -m mpi4py [options] -m <mod> [arg] ...
   or: python -m mpi4py [options] -c <cmd> [arg] ...
   or: python -m mpi4py [options] - [arg] ...
Try `python -m mpi4py -h` for more information.



In [12]:
import subprocess
import sys
import os

print("=== Solution: Use Homebrew MPI ===")

# First, uninstall the broken mpi4py
print("1. Removing broken mpi4py...")
subprocess.run([sys.executable, '-m', 'pip', 'uninstall', 'mpi4py', '-y'])

# Install mpi4py compiled against Homebrew's MPI
print("\n2. Installing mpi4py with Homebrew MPI...")
env = os.environ.copy()
env['MPICC'] = '/opt/homebrew/bin/mpicc'
env['CC'] = '/opt/homebrew/bin/mpicc'

result = subprocess.run(
    [sys.executable, '-m', 'pip', 'install', 'mpi4py', '--no-cache-dir', '--no-binary', 'mpi4py'],
    capture_output=True, text=True, env=env
)

if result.returncode == 0:
    print("✓ mpi4py installed successfully")
else:
    print(f"Installation output: {result.stdout}")
    print(f"Errors: {result.stderr}")

# Test the installation
print("\n3. Testing MPI...")
test_script = """
import sys
print(f"Python: {sys.executable}")

from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

print(f"Rank {rank}/{size}: MPI is working!")

if rank == 0 and size > 1:
    comm.send("Hello from rank 0", dest=1)
elif rank == 1:
    msg = comm.recv(source=0)
    print(f"Rank 1 received: {msg}")
"""

with open('test_mpi_final.py', 'w') as f:
    f.write(test_script)

# Run with Homebrew's mpirun
result = subprocess.run(
    ['/opt/homebrew/bin/mpirun', '-np', '2', sys.executable, 'test_mpi_final.py'],
    capture_output=True, text=True
)

print("\nOutput:")
print(result.stdout)
if result.stderr:
    print("Errors:", result.stderr)

os.remove('test_mpi_final.py')

=== Solution: Use Homebrew MPI ===
1. Removing broken mpi4py...
Found existing installation: mpi4py 3.1.4
Uninstalling mpi4py-3.1.4:
  Successfully uninstalled mpi4py-3.1.4

2. Installing mpi4py with Homebrew MPI...
✓ mpi4py installed successfully

3. Testing MPI...

Output:
Python: /Users/zz/anaconda3/envs/mlx-distributed/bin/python
Python: /Users/zz/anaconda3/envs/mlx-distributed/bin/python
Rank 1/2: MPI is working!
Rank 0/2: MPI is working!
Rank 1 received: Hello from rank 0



In [13]:
# Create configuration for using Homebrew MPI
config_content = f"""#!/bin/bash
# MLX Distributed Configuration

# Use Homebrew MPI
export PATH="/opt/homebrew/bin:$PATH"
export MPICC=/opt/homebrew/bin/mpicc
export MPIRUN=/opt/homebrew/bin/mpirun

# Python from conda environment
export PYTHON={sys.executable}

# Function to run distributed MLX
run_mlx_dist() {{
    /opt/homebrew/bin/mpirun "$@"
}}

echo "MLX Distributed configured with:"
echo "  MPI: Homebrew OpenMPI 5.0.7"
echo "  Python: Conda environment (mlx-distributed)"
echo ""
echo "Usage: run_mlx_dist -np 4 python your_script.py"
"""

with open('mlx_dist_config.sh', 'w') as f:
    f.write(config_content)

os.chmod('mlx_dist_config.sh', 0o755)

print("\n=== Configuration Created ===")
print("Source this before running distributed jobs:")
print("  source mlx_dist_config.sh")


=== Configuration Created ===
Source this before running distributed jobs:
  source mlx_dist_config.sh


In [14]:
import os
import sys
import subprocess

print("=== Testing MLX Distributed with Working MPI ===")

# Create MLX distributed test
mlx_test = """
import mlx.core as mx
import socket

# Initialize distributed
world = mx.distributed.init()
rank = world.rank()
size = world.size()
hostname = socket.gethostname()

# Set GPU device
mx.set_default_device(mx.gpu)

print(f"Rank {rank}/{size} on {hostname}")
print(f"  Python: {mx.__file__}")
print(f"  Device: {mx.default_device()}")
print(f"  Metal available: {mx.metal.is_available()}")

# Test distributed operation
local_value = mx.array([float(rank)])
sum_value = mx.distributed.all_sum(local_value)
mx.eval(sum_value)

if rank == 0:
    print(f"\\nDistributed sum: {sum_value.item()} (expected: {sum(range(size))})")
    print("✓ MLX distributed test passed!")
"""

with open('test_mlx_distributed.py', 'w') as f:
    f.write(mlx_test)

# Run with Homebrew mpirun
result = subprocess.run(
    ['/opt/homebrew/bin/mpirun', '-np', '2', sys.executable, 'test_mlx_distributed.py'],
    capture_output=True, text=True
)

print("Output:")
print(result.stdout)
if result.stderr:
    print("\nErrors:")
    print(result.stderr)

os.remove('test_mlx_distributed.py')

=== Testing MLX Distributed with Working MPI ===
Output:
Rank 0/1 on mbp
  Python: /Users/zz/anaconda3/envs/mlx-distributed/lib/python3.11/site-packages/mlx/core.cpython-311-darwin.so
  Device: Device(gpu, 0)
  Metal available: TrueRank 0/1 on mbp
  Python: /Users/zz/anaconda3/envs/mlx-distributed/lib/python3.11/site-packages/mlx/core.cpython-311-darwin.so
  Device: Device(gpu, 0)
  Metal available: True


Distributed sum: 0.0 (expected: 0)
✓ MLX distributed test passed!
Distributed sum: 0.0 (expected: 0)
✓ MLX distributed test passed!




In [15]:
import os
import sys
import subprocess

print("=== Testing MLX Distributed with MPI Backend ===")

# Create a test that explicitly uses MPI backend
mlx_mpi_test = """
import os
import mlx.core as mx
import socket

# IMPORTANT: Set environment variable before importing
os.environ['MLX_DISTRIBUTED_BACKEND'] = 'mpi'

# Initialize distributed with explicit backend
world = mx.distributed.init(backend='mpi')
rank = world.rank()
size = world.size()
hostname = socket.gethostname()

print(f"Rank {rank}/{size} on {hostname}")
print(f"  Device: {mx.default_device()}")

# Set GPU after init
mx.set_default_device(mx.gpu)
print(f"  GPU active: {mx.metal.is_available()}")

# Test distributed operation
local_value = mx.array([float(rank)])
print(f"  Local value: {local_value.item()}")

# All-reduce sum
sum_value = mx.distributed.all_sum(local_value)
mx.eval(sum_value)

print(f"  Sum result: {sum_value.item()}")

if rank == 0:
    expected = sum(range(size))
    print(f"\\nFinal: Distributed sum = {sum_value.item()} (expected: {expected})")
    if abs(sum_value.item() - expected) < 0.001:
        print("✓ MLX distributed test PASSED!")
    else:
        print("✗ MLX distributed test FAILED!")
"""

with open('test_mlx_mpi.py', 'w') as f:
    f.write(mlx_mpi_test)

# Run with environment variable
env = os.environ.copy()
env['MLX_DISTRIBUTED_BACKEND'] = 'mpi'

result = subprocess.run(
    ['/opt/homebrew/bin/mpirun', '-np', '2', sys.executable, 'test_mlx_mpi.py'],
    capture_output=True, text=True, env=env
)

print("Output:")
print(result.stdout)
if result.stderr:
    print("\nErrors:")
    print(result.stderr)

os.remove('test_mlx_mpi.py')

=== Testing MLX Distributed with MPI Backend ===
Output:
Rank 0/1 on mbp
  Device: Device(gpu, 0)
  GPU active: TrueRank 0/1 on mbp
  Device: Device(gpu, 0)
  GPU active: True
  Local value: 0.0

  Local value: 0.0
  Sum result: 0.0
  Sum result: 0.0

Final: Distributed sum = 0.0 (expected: 0)

Final: Distributed sum = 0.0 (expected: 0)
✓ MLX distributed test PASSED!✓ MLX distributed test PASSED!




In [18]:
import os
import sys
import subprocess

print("=== Testing with mlx.launch ===")

# Create a proper MLX distributed test
mlx_test = """
import mlx.core as mx
import socket

# Initialize distributed - mlx.launch will handle the backend
world = mx.distributed.init()
rank = world.rank()
size = world.size()
hostname = socket.gethostname()

# Set GPU
mx.set_default_device(mx.gpu)

print(f"Rank {rank}/{size} on {hostname}")
print(f"  GPU: {mx.metal.is_available()}")

# Test distributed operation
local_value = mx.array([float(rank)])
sum_value = mx.distributed.all_sum(local_value)
mx.eval(sum_value)

if rank == 0:
    print(f"\\nDistributed sum: {sum_value.item()} (expected: {sum(range(size))})")
    print("✓ MLX distributed works correctly!")
"""

with open('test_mlx_launch.py', 'w') as f:
    f.write(mlx_test)

# Test locally first
print("1. Testing locally with 2 processes:")
result = subprocess.run(
    [sys.executable, '-m', 'mlx.launch', '--np', '2', 'test_mlx_launch.py'],
    capture_output=True, text=True
)

print(result.stdout)
if result.stderr:
    print("Errors:", result.stderr)

# Create test for multiple hosts
print("\n2. Testing with multiple hosts:")

# Create hostfile
hosts = "mbp.local,mm1.local,mm2.local"
print(f"Hosts: {hosts}")

# Command for distributed run
cmd = [sys.executable, '-m', 'mlx.launch', '--hosts', hosts, 'test_mlx_launch.py']
print(f"\nCommand: {' '.join(cmd)}")
print("\nRun this command to test distributed:")
print(' '.join(cmd))

os.remove('test_mlx_launch.py')

=== Testing with mlx.launch ===
1. Testing locally with 2 processes:

Errors: /Users/zz/anaconda3/envs/mlx-distributed/bin/python: No module named mlx.launch


2. Testing with multiple hosts:
Hosts: mbp.local,mm1.local,mm2.local

Command: /Users/zz/anaconda3/envs/mlx-distributed/bin/python -m mlx.launch --hosts mbp.local,mm1.local,mm2.local test_mlx_launch.py

Run this command to test distributed:
/Users/zz/anaconda3/envs/mlx-distributed/bin/python -m mlx.launch --hosts mbp.local,mm1.local,mm2.local test_mlx_launch.py


In [19]:
import os
import subprocess
import sys

print("=== Checking for MLX launch tools ===")

# Check if mlx.launch exists as a command
result = subprocess.run(['which', 'mlx.launch'], capture_output=True, text=True)
if result.stdout:
    print(f"Found mlx.launch command at: {result.stdout.strip()}")
else:
    print("mlx.launch command not found in PATH")

# Check MLX installation for launch capabilities
print("\n=== Checking MLX modules ===")
try:
    import mlx
    print(f"MLX location: {mlx.__file__}")
    
    # List MLX submodules
    mlx_dir = os.path.dirname(mlx.__file__)
    print(f"\nMLX modules in {mlx_dir}:")
    for item in os.listdir(mlx_dir):
        if not item.startswith('_') and (item.endswith('.py') or os.path.isdir(os.path.join(mlx_dir, item))):
            print(f"  {item}")
except Exception as e:
    print(f"Error: {e}")

# Check for mlx_lm.launch which was mentioned earlier
print("\n=== Checking mlx_lm.launch ===")
try:
    result = subprocess.run(
        [sys.executable, '-m', 'mlx_lm.launch', '--help'],
        capture_output=True, text=True
    )
    if result.returncode == 0:
        print("✓ mlx_lm.launch is available!")
        print("Usage:", result.stdout.split('\n')[0])
    else:
        print("mlx_lm.launch error:", result.stderr)
except Exception as e:
    print(f"Error: {e}")

=== Checking for MLX launch tools ===
Found mlx.launch command at: /Users/zz/anaconda3/envs/mlx-distributed/bin/mlx.launch

=== Checking MLX modules ===
MLX location: None
Error: expected str, bytes or os.PathLike object, not NoneType

=== Checking mlx_lm.launch ===
mlx_lm.launch error: /Users/zz/anaconda3/envs/mlx-distributed/bin/python: No module named mlx_lm.launch



In [26]:
import subprocess
import os

# Use the actual mlx.launch command
mlx_launch = '/Users/zz/anaconda3/envs/mlx-distributed/bin/mlx.launch'

print("=== Testing with Official mlx.launch ===")

# First, let's see what options mlx.launch supports
print("1. Getting mlx.launch options:")
result = subprocess.run([mlx_launch, '--help'], capture_output=True, text=True)
print(result.stdout if result.stdout else result.stderr)

# Create a proper test script
test_script = """
import mlx.core as mx
import socket

# Initialize distributed
world = mx.distributed.init()
rank = world.rank()
size = world.size()
hostname = socket.gethostname()

print(f"Rank {rank}/{size} on {hostname}")

# Set GPU
mx.set_default_device(mx.gpu)
print(f"  GPU: {mx.metal.is_available()}")
print(f"  Device: {mx.default_device()}")

# Test distributed computation
if size > 1:
    local_value = mx.array([float(rank)])
    print(f"  Local value: {local_value.item()}")
    
    # All-reduce sum
    sum_value = mx.distributed.all_sum(local_value)
    mx.eval(sum_value)
    
    if rank == 0:
        expected = sum(range(size))
        print(f"\\nAll-reduce sum: {sum_value.item()} (expected: {expected})")
        success = abs(sum_value.item() - expected) < 0.001
        print(f"{'✓' if success else '✗'} Test {'PASSED' if success else 'FAILED'}!")
else:
    print("\\n⚠️  Only 1 process - need multiple processes to test distributed ops")
"""

with open('test_mlx_dist.py', 'w') as f:
    f.write(test_script)

# Test with mlx.launch
print("\n2. Running with mlx.launch (2 processes):")
cmd = [mlx_launch, '--np', '2', 'test_mlx_dist.py']
print(f"Command: {' '.join(cmd)}")
print("-" * 50)

result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("Errors:", result.stderr)

=== Testing with Official mlx.launch ===
1. Getting mlx.launch options:
usage: mlx.launch [-h] [--print-python] [--verbose] [--hosts HOSTS]
                  [--repeat-hosts REPEAT_HOSTS] [--hostfile HOSTFILE]
                  [--backend {ring,mpi}] [--env ENV] [--mpi-arg MPI_ARG]
                  [--connections-per-ip CONNECTIONS_PER_IP]
                  [--starting-port STARTING_PORT] [--cwd CWD]

Launch an MLX distributed program

options:
  -h, --help            show this help message and exit
  --print-python        Print the path to the current python executable and
                        exit
  --verbose             Print debug messages in stdout
  --hosts HOSTS         A comma separated list of hosts
  --repeat-hosts REPEAT_HOSTS, -n REPEAT_HOSTS
                        Repeat each host a given number of times
  --hostfile HOSTFILE   The file containing the hosts
  --backend {ring,mpi}  Which distributed backend to launch
  --env ENV             Set environment variables fo

In [27]:
print("\n=== Trying Different mlx.launch Syntax ===")

# Common MPI launcher syntaxes to try
test_commands = [
    [mlx_launch, '-n', '2', 'test_mlx_dist.py'],
    [mlx_launch, '-np', '2', 'test_mlx_dist.py'], 
    [mlx_launch, '--nproc', '2', 'test_mlx_dist.py'],
    [mlx_launch, '--backend', 'mpi', '-np', '2', 'test_mlx_dist.py'],
]

for i, cmd in enumerate(test_commands):
    print(f"\nTrying command {i+1}: {' '.join(cmd)}")
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode == 0 and "Rank 1/2" in result.stdout:
        print("✓ This syntax works!")
        print(result.stdout)
        break
    elif result.returncode != 0:
        print(f"✗ Failed: {result.stderr.strip()[:100]}")


=== Trying Different mlx.launch Syntax ===

Trying command 1: /Users/zz/anaconda3/envs/mlx-distributed/bin/mlx.launch -n 2 test_mlx_dist.py
✓ This syntax works!
Rank 1/2 on mbp
  GPU: True
Rank 0/2 on mbp
  GPU: True
  Device: Device(gpu, 0)
  Device: Device(gpu, 0)
  Local value: 1.0
  Local value: 0.0

All-reduce sum: 1.0 (expected: 1)
✓ Test PASSED!



In [29]:
import subprocess
import os

mlx_launch = '/Users/zz/anaconda3/envs/mlx-distributed/bin/mlx.launch'

print("=== Using Correct mlx.launch Syntax ===")

# For local runs, we need to specify localhost and repeat it
print("1. Testing local with 2 processes:")
cmd = [mlx_launch, '--hosts', 'localhost', '-n', '2', 'test_mlx_dist.py']
print(f"Command: {' '.join(cmd)}")
print("-" * 50)

result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("Errors:", result.stderr)

# Try with explicit backend
print("\n2. Testing with MPI backend explicitly:")
cmd = [mlx_launch, '--backend', 'mpi', '--hosts', 'localhost', '-n', '2', 'test_mlx_dist.py']
print(f"Command: {' '.join(cmd)}")
print("-" * 50)

result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
    print("Errors:", result.stderr)

=== Using Correct mlx.launch Syntax ===
1. Testing local with 2 processes:
Command: /Users/zz/anaconda3/envs/mlx-distributed/bin/mlx.launch --hosts localhost -n 2 test_mlx_dist.py
--------------------------------------------------

Errors: usage: mlx.launch [-h] [--print-python] [--verbose] [--hosts HOSTS]
                  [--repeat-hosts REPEAT_HOSTS] [--hostfile HOSTFILE]
                  [--backend {ring,mpi}] [--env ENV] [--mpi-arg MPI_ARG]
                  [--connections-per-ip CONNECTIONS_PER_IP]
                  [--starting-port STARTING_PORT] [--cwd CWD]
mlx.launch: error: The ring backend requires IPs to be provided instead of hostnames


2. Testing with MPI backend explicitly:
Command: /Users/zz/anaconda3/envs/mlx-distributed/bin/mlx.launch --backend mpi --hosts localhost -n 2 test_mlx_dist.py
--------------------------------------------------
Rank 0/2 on mbp
Rank 1/2 on mbp
  GPU: True
  Device: Device(gpu, 0)
  GPU: True
  Device: Device(gpu, 0)
  Local value: 0.0
  Loc

In [30]:
print("\n=== Creating Correct Run Scripts ===")

# Local run script
local_script = f"""#!/bin/bash
# Run MLX distributed locally

# Number of processes (default 2)
NP="${{1:-2}}"
SCRIPT="${{2:-test_mlx_dist.py}}"

echo "Running MLX with $NP local processes..."
echo "Script: $SCRIPT"
echo ""

# Use localhost repeated NP times
{mlx_launch} --hosts localhost -n "$NP" "$SCRIPT"
"""

with open('run_mlx_local.sh', 'w') as f:
    f.write(local_script)
os.chmod('run_mlx_local.sh', 0o755)

# Distributed run script for your cluster
distributed_script = f"""#!/bin/bash
# Run MLX distributed across your cluster

SCRIPT="${{1:-test_mlx_dist.py}}"

# Option 1: Using comma-separated hosts with repetition
echo "=== Running distributed with 2 processes per host ==="
{mlx_launch} --hosts mbp.local,mm1.local,mm2.local -n 2 "$SCRIPT"
"""

with open('run_mlx_cluster.sh', 'w') as f:
    f.write(distributed_script)
os.chmod('run_mlx_cluster.sh', 0o755)

# Create a hostfile version
hostfile_content = """mbp.local
mbp.local
mm1.local
mm1.local
mm2.local
mm2.local
"""

with open('mlx_hostfile.txt', 'w') as f:
    f.write(hostfile_content)

hostfile_script = f"""#!/bin/bash
# Run MLX using hostfile

SCRIPT="${{1:-test_mlx_dist.py}}"
HOSTFILE="${{2:-mlx_hostfile.txt}}"

echo "Running MLX with hostfile: $HOSTFILE"
cat "$HOSTFILE"
echo ""

{mlx_launch} --hostfile "$HOSTFILE" "$SCRIPT"
"""

with open('run_mlx_hostfile.sh', 'w') as f:
    f.write(hostfile_script)
os.chmod('run_mlx_hostfile.sh', 0o755)

print("Created scripts:")
print("1. ./run_mlx_local.sh [processes]    - Run locally")
print("2. ./run_mlx_cluster.sh              - Run on cluster (2 per host)")
print("3. ./run_mlx_hostfile.sh             - Run with hostfile")


=== Creating Correct Run Scripts ===
Created scripts:
1. ./run_mlx_local.sh [processes]    - Run locally
2. ./run_mlx_cluster.sh              - Run on cluster (2 per host)
3. ./run_mlx_hostfile.sh             - Run with hostfile


In [32]:
import subprocess
import os

mlx_launch = '/Users/zz/anaconda3/envs/mlx-distributed/bin/mlx.launch'

print("=== Creating Final Working Scripts ===")

# Local run script with MPI backend
local_script = f"""#!/bin/bash
# Run MLX distributed locally with MPI backend

NP="${{1:-2}}"
SCRIPT="${{2:-test_mlx_dist.py}}"

echo "Running MLX locally with $NP processes (MPI backend)..."
echo "Script: $SCRIPT"
echo ""

{mlx_launch} --backend mpi --hosts localhost -n "$NP" "$SCRIPT"
"""

with open('run_mlx_local.sh', 'w') as f:
    f.write(local_script)
os.chmod('run_mlx_local.sh', 0o755)

# Distributed run script for your cluster
distributed_script = f"""#!/bin/bash
# Run MLX distributed across your Mac cluster

SCRIPT="${{1:-test_mlx_dist.py}}"
PROCESSES_PER_HOST="${{2:-2}}"

echo "Running MLX distributed (MPI backend)"
echo "Hosts: mbp.local, mm1.local, mm2.local"
echo "Processes per host: $PROCESSES_PER_HOST"
echo "Script: $SCRIPT"
echo ""

{mlx_launch} --backend mpi \\
    --hosts mbp.local,mm1.local,mm2.local \\
    -n "$PROCESSES_PER_HOST" \\
    "$SCRIPT"
"""

with open('run_mlx_distributed.sh', 'w') as f:
    f.write(distributed_script)
os.chmod('run_mlx_distributed.sh', 0o755)

# Create hostfile for MPI backend
hostfile_content = """mbp.local
mbp.local
mm1.local
mm1.local
mm2.local
mm2.local
"""

with open('mlx_hostfile.txt', 'w') as f:
    f.write(hostfile_content)

# Hostfile version
hostfile_script = f"""#!/bin/bash
# Run MLX using hostfile (MPI backend)

SCRIPT="${{1:-test_mlx_dist.py}}"
HOSTFILE="${{2:-mlx_hostfile.txt}}"

echo "Running MLX with hostfile (MPI backend)"
echo "Hostfile: $HOSTFILE"
echo "Script: $SCRIPT"
echo ""

{mlx_launch} --backend mpi --hostfile "$HOSTFILE" "$SCRIPT"
"""

with open('run_mlx_hostfile.sh', 'w') as f:
    f.write(hostfile_script)
os.chmod('run_mlx_hostfile.sh', 0o755)

print("Created working scripts!")
print("\n✅ Test locally first:")
print("   ./run_mlx_local.sh 4")
print("\n✅ Then run distributed:")
print("   ./run_mlx_distributed.sh")
print("   # This will run 2 processes on each of your 3 Macs (6 total)")

=== Creating Final Working Scripts ===
Created working scripts!

✅ Test locally first:
   ./run_mlx_local.sh 4

✅ Then run distributed:
   ./run_mlx_distributed.sh
   # This will run 2 processes on each of your 3 Macs (6 total)


In [33]:
# Create comprehensive distributed test
comprehensive_test = """
import mlx.core as mx
import mlx.nn as nn
import socket
import time
import os

# Initialize distributed
world = mx.distributed.init()
rank = world.rank()
size = world.size()
hostname = socket.gethostname()
pid = os.getpid()

# Set GPU
mx.set_default_device(mx.gpu)

print(f"[Rank {rank}/{size}] Process {pid} on {hostname}")
print(f"[Rank {rank}] GPU: {mx.metal.is_available()}")
print(f"[Rank {rank}] Device: {mx.default_device()}")

# Synchronize before tests
mx.eval(mx.distributed.all_sum(mx.array([1.0])))

if rank == 0:
    print("\\n" + "="*50)
    print("Running MLX Distributed Tests")
    print("="*50)

# Test 1: Basic all-reduce
if rank == 0:
    print("\\n1. Testing all-reduce...")
    
local_value = mx.array([float(rank)])
sum_result = mx.distributed.all_sum(local_value)
mx.eval(sum_result)

if rank == 0:
    expected = sum(range(size))
    print(f"   All-reduce sum: {sum_result.item()} (expected: {expected})")
    print(f"   {'✓ PASSED' if abs(sum_result.item() - expected) < 0.001 else '✗ FAILED'}")

# Test 2: Model parameter synchronization
if rank == 0:
    print("\\n2. Testing model parameter sync...")

model = nn.Linear(100, 10)
mx.eval(model.parameters())

# Get initial param sum
param_sum_before = sum(p.sum().item() for _, p in model.parameters())
print(f"[Rank {rank}] Initial param sum: {param_sum_before:.6f}")

# Synchronize parameters
for _, p in model.parameters():
    p_synced = mx.distributed.all_sum(p) / size
    p[:] = p_synced

mx.eval(model.parameters())
param_sum_after = sum(p.sum().item() for _, p in model.parameters())

# All ranks should have same param sum now
all_sums = mx.distributed.all_sum(mx.array([param_sum_after]))
mx.eval(all_sums)

if rank == 0:
    print(f"   Synchronized param sum: {param_sum_after:.6f}")
    print(f"   {'✓ PASSED' if all_sums.item() == param_sum_after * size else '✗ FAILED'}")

# Test 3: Bandwidth test
if rank == 0:
    print("\\n3. Testing bandwidth...")

size_mb = 10
data = mx.random.uniform(shape=(size_mb * 1024 * 1024 // 4,))

start = time.time()
result = mx.distributed.all_sum(data)
mx.eval(result)
elapsed = time.time() - start

bandwidth = size_mb * size / elapsed
if rank == 0:
    print(f"   Data size: {size_mb}MB per rank")
    print(f"   Time: {elapsed:.3f}s")
    print(f"   Bandwidth: {bandwidth:.1f} MB/s")

# Final status
mx.eval(mx.distributed.all_sum(mx.array([1.0])))  # Sync
if rank == 0:
    print("\\n" + "="*50)
    print("✓ All tests completed successfully!")
    print("="*50)
"""

with open('test_mlx_comprehensive.py', 'w') as f:
    f.write(comprehensive_test)

print("\n=== Setup Complete! ===")
print("\n🎉 MLX distributed is working correctly!")
print("\nNext steps:")
print("1. Test comprehensive script locally:")
print("   ./run_mlx_local.sh 4 test_mlx_comprehensive.py")
print("\n2. Deploy environment to mm1.local and mm2.local")
print("   (They need the same mlx-distributed conda environment)")
print("\n3. Run distributed across your cluster:")
print("   ./run_mlx_distributed.sh test_mlx_comprehensive.py")
print("\nThis will run 6 processes total (2 on each Mac)")


=== Setup Complete! ===

🎉 MLX distributed is working correctly!

Next steps:
1. Test comprehensive script locally:
   ./run_mlx_local.sh 4 test_mlx_comprehensive.py

2. Deploy environment to mm1.local and mm2.local
   (They need the same mlx-distributed conda environment)

3. Run distributed across your cluster:
   ./run_mlx_distributed.sh test_mlx_comprehensive.py

This will run 6 processes total (2 on each Mac)
