In [1]:
import coiled
import os
from typing import List, Dict
from distributed.client import Client
import subprocess

# coiled.create_software_environment(
#     name="pytorch-llm",
#     conda="pytorch-llm.yaml",
#     gpu_enabled=True,
#     account = "nathan-ballou-gcp"
# )

In [3]:
n_workers = 4

cluster = coiled.Cluster(
    n_workers=n_workers,
    worker_vm_types="g2-standard-4",
    # worker_vm_types="a2-highgpu-1g",
    # worker_vm_types="a2-ultragpu-1g",
    # scheduler_vm_types="g2-standard-4",
    account="nathan-ballou-gcp",
    software="pytorch-llm",
    worker_disk_size=256,
    idle_timeout = "2 hours",
    # region = "us-west1"
)

client = cluster.get_client()
client

client.wait_for_workers(n_workers=n_workers)

INFO:coiled:Creating Cluster (name: nathan-ballou-gcp-4d454d0f, https://cloud.coiled.io/clusters/361325?account=nathan-ballou-gcp ). This usually takes 1-2 minutes...


In [4]:
def _get_worker_info(client: Client) -> List[Dict]:
    """
    returns a list of workers (sorted), and the DNS name for the master host
    The master is the 0th worker's host
    """
    workers = client.scheduler_info()["workers"]
    worker_keys = sorted(workers.keys())
    workers_by_host: Dict[str, List[str]] = {}
    for key in worker_keys:
        worker = workers[key]
        host = worker["host"]
        workers_by_host.setdefault(host, []).append(key)
    host = workers[worker_keys[0]]["host"]
    all_workers = []
    global_rank = 0
    for host in sorted(workers_by_host.keys()):
        for worker in workers_by_host[host]:
            worker_info = workers[worker]
            all_workers.append(
                dict(
                    worker=worker,
                    global_rank=global_rank,
                    host=host,
                    local_directory=worker_info.get("local_directory", None),
                )
            )
            global_rank += 1
    return all_workers

In [5]:
all_workers = _get_worker_info(client)

In [6]:
all_workers

[{'worker': 'tls://10.0.48.101:40393',
  'global_rank': 0,
  'host': '10.0.48.101',
  'local_directory': '/scratch/dask-scratch-space/worker-aqwy00u3'},
 {'worker': 'tls://10.0.48.102:41283',
  'global_rank': 1,
  'host': '10.0.48.102',
  'local_directory': '/scratch/dask-scratch-space/worker-a0xa_svn'},
 {'worker': 'tls://10.0.48.103:44893',
  'global_rank': 2,
  'host': '10.0.48.103',
  'local_directory': '/scratch/dask-scratch-space/worker-o59r50pe'},
 {'worker': 'tls://10.0.48.104:36759',
  'global_rank': 3,
  'host': '10.0.48.104',
  'local_directory': '/scratch/dask-scratch-space/worker-ibs481ar'},
 {'worker': 'tls://10.0.48.105:40639',
  'global_rank': 4,
  'host': '10.0.48.105',
  'local_directory': '/scratch/dask-scratch-space/worker-gyaxl6p1'},
 {'worker': 'tls://10.0.48.27:43935',
  'global_rank': 5,
  'host': '10.0.48.27',
  'local_directory': '/scratch/dask-scratch-space/worker-rhpxiy0b'},
 {'worker': 'tls://10.0.48.28:43761',
  'global_rank': 6,
  'host': '10.0.48.28',
  

In [7]:
from distributed.diagnostics.plugin import UploadFile

client.register_plugin(UploadFile("sft.py", load = False))
client.register_plugin(UploadFile("config.yml", load = False))

{'tls://10.0.48.101:40393': {'status': 'OK'},
 'tls://10.0.48.102:41283': {'status': 'OK'},
 'tls://10.0.48.103:44893': {'status': 'OK'},
 'tls://10.0.48.104:36759': {'status': 'OK'},
 'tls://10.0.48.105:40639': {'status': 'OK'},
 'tls://10.0.48.27:43935': {'status': 'OK'},
 'tls://10.0.48.28:43761': {'status': 'OK'},
 'tls://10.0.48.30:38081': {'status': 'OK'},
 'tls://10.0.48.32:36903': {'status': 'OK'},
 'tls://10.0.48.33:37337': {'status': 'OK'},
 'tls://10.0.48.35:44745': {'status': 'OK'},
 'tls://10.0.48.46:39473': {'status': 'OK'},
 'tls://10.0.48.48:45537': {'status': 'OK'},
 'tls://10.0.48.50:36967': {'status': 'OK'},
 'tls://10.0.48.90:33579': {'status': 'OK'},
 'tls://10.0.48.97:36181': {'status': 'OK'}}

In [8]:
host = all_workers[0]["host"]
port = 12345
num_machines = len(all_workers)
gpus_per_worker = 1
output_directory = "/scratch/experiments/finetune-mixtral-8x7B"

In [9]:
def run_subprocess(
        host_ip, 
        host_port, 
        machine_rank, 
        num_processes, 
        num_machines,
        local_directory,
        output_directory
        ):
    
    command = [
        "accelerate", "launch",
        "--config_file", f"{local_directory}/config.yml",
        "--main_process_ip", host_ip,
        "--main_process_port", str(host_port),
        "--machine_rank", str(machine_rank),
        "--num_processes", str(num_processes),
        "--num_machines", str(num_machines),
        f"{local_directory}/sft.py",
        '--model_name', 'mistralai/Mistral-7B-v0.1',
        '--dataset_name', "trl-lib/ultrachat_200k_chatml",
        '--batch_size', '2',
        '--gradient_accumulation_steps', '1',
        '--learning_rate', '2e-4',
        '--save_steps', '200_000',
        '--use_peft',
        '--peft_lora_r', '8',
        '--peft_lora_alpha', '16',
        '--target_modules', "q_proj", "k_proj", "v_proj", "o_proj",
        '--load_in_4bit',
        '--output_dir', output_directory,
        # '--gradient_checkpointing',
        '--trust_remote_code',
        ]

    with subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as process:
        for line in process.stdout:
            print(line.decode('utf8'))

    return f"Worker {machine_rank}: Done."

In [10]:
futures = []
for worker in all_workers:
    fut = client.submit(
        run_subprocess,
        workers = [worker['worker']],
        host_ip = host,
        host_port = port, 
        machine_rank= worker['global_rank'], 
        num_processes = num_machines * gpus_per_worker, 
        num_machines = num_machines,
        local_directory = worker['local_directory'],
        output_directory = output_directory
    )
    futures.append(fut)
futures

[<Future: pending, key: run_subprocess-4b34f93a75c9577cd5084eea5b0be532>,
 <Future: pending, key: run_subprocess-35b041bd595d62e0af62b81017ee2a94>,
 <Future: pending, key: run_subprocess-4ff268f2d2bd77591b8ff01b9bf88cd2>,
 <Future: pending, key: run_subprocess-be19e2ca5482f759cef86ffd32216502>,
 <Future: pending, key: run_subprocess-9a4285cfa5de4aab598b656ad0422881>,
 <Future: pending, key: run_subprocess-e2815619e40f42ae9fbc2a6d6015e001>,
 <Future: pending, key: run_subprocess-e78ddf333f7eedbc2f9bc88607aaca7c>,
 <Future: pending, key: run_subprocess-e632a15d14ee80507da26434fd8f7b02>,
 <Future: pending, key: run_subprocess-41dd1289a05404f83a664459a2135137>,
 <Future: pending, key: run_subprocess-1ba5ad780c3529a35229dab4706b448b>,
 <Future: pending, key: run_subprocess-4f5a48c520b0ac522be91bbcb1d67d68>,
 <Future: pending, key: run_subprocess-d302e47040f8acbc0cd308e5bc25574e>,
 <Future: pending, key: run_subprocess-dfc197b9549733003cb3ac75aa6a1403>,
 <Future: pending, key: run_subprocess

In [11]:
def foo():
   from distributed.diagnostics import nvml
   return nvml.real_time()
client.run(foo)

{'tls://10.0.48.101:40393': {'utilization': 100, 'memory-used': 23855366144},
 'tls://10.0.48.102:41283': {'utilization': 100, 'memory-used': 23855366144},
 'tls://10.0.48.103:44893': {'utilization': 100, 'memory-used': 23855366144},
 'tls://10.0.48.104:36759': {'utilization': 100, 'memory-used': 23855366144},
 'tls://10.0.48.105:40639': {'utilization': 100, 'memory-used': 23855366144},
 'tls://10.0.48.27:43935': {'utilization': 100, 'memory-used': 23855366144},
 'tls://10.0.48.28:43761': {'utilization': 100, 'memory-used': 23855366144},
 'tls://10.0.48.30:38081': {'utilization': 100, 'memory-used': 23855366144},
 'tls://10.0.48.32:36903': {'utilization': 100, 'memory-used': 23855366144},
 'tls://10.0.48.33:37337': {'utilization': 100, 'memory-used': 23855366144},
 'tls://10.0.48.35:44745': {'utilization': 100, 'memory-used': 23855366144},
 'tls://10.0.48.46:39473': {'utilization': 100, 'memory-used': 23855366144},
 'tls://10.0.48.48:45537': {'utilization': 100, 'memory-used': 23855366

In [None]:
for future in futures:
    print(future.result())

Worker 0: Done.
Worker 1: Done.
Worker 2: Done.
Worker 3: Done.
Worker 4: Done.
Worker 5: Done.
Worker 6: Done.
Worker 7: Done.


In [None]:
for future in futures:
    future.cancel()

In [None]:
def list_files(directory_path):
    return [os.path.join(directory_path, f) for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))]

def read_file(file_path):
    with open(file_path, 'rb') as file:  # Use 'rb' for binary mode
        return file.read(), os.path.basename(file_path)

# Get list of files in the directory
file_list_future = client.submit(list_files, output_directory, workers = [all_workers[0]["worker"]])
file_list = file_list_future.result()

# Local directory to save files
local_directory = 'output'
os.makedirs(local_directory, exist_ok=True)

for file_path in file_list:
    file_future = client.submit(read_file, file_path, workers = [all_workers[0]["worker"]])
    content, filename = file_future.result()
    local_file_path = os.path.join(local_directory, filename)
    with open(local_file_path, 'wb') as local_file:
        local_file.write(content)


In [None]:
cluster.shutdown()

In [None]:
# client.restart()

0,1
Connection method: Cluster object,Cluster type: coiled.Cluster
Dashboard: https://cluster-nklbs.dask.host/oy2KD2OSE-xXOMlW/status,

0,1
Dashboard: https://cluster-nklbs.dask.host/oy2KD2OSE-xXOMlW/status,Workers: 1
Total threads: 12,Total memory: 82.20 GiB

0,1
Comm: tls://10.0.64.3:8786,Workers: 1
Dashboard: http://10.0.64.3:8787/status,Total threads: 12
Started: 2 hours ago,Total memory: 82.20 GiB

0,1
Comm: tls://10.0.64.2:39095,Total threads: 12
Dashboard: http://10.0.64.2:8787/status,Memory: 82.20 GiB
Nanny: tls://10.0.64.2:43467,
Local directory: /scratch/dask-scratch-space/worker-ik_4c7cf,Local directory: /scratch/dask-scratch-space/worker-ik_4c7cf


In [None]:
# # Function to search for 'train.py' in the worker's file system
# def find_train_py():
#     for root, dirs, files in os.walk('/scratch'):
#         if 'config.yml' in files:
#             return os.path.join(root, 'config.yml')
#     return "train.py not found"

# # Run the function on all workers
# futures = client.run(find_train_py)

# # Collect results
# results = client.gather(futures)

# # Print the results
# for worker, path in results.items():
#     print(f"Worker {worker} found 'config.yml' at: {path}")
