In [None]:
import coiled
import os
from typing import List, Dict
from distributed.client import Client
import subprocess

# coiled.create_software_environment(
#     name="pytorch-llm",
#     conda="pytorch-llm.yaml",
#     gpu_enabled=True,
#     account = "nathan-ballou-gcp"
# )

In [None]:
n_workers = 4

cluster = coiled.Cluster(
    n_workers=n_workers,
    worker_vm_types="g2-standard-4",
    # worker_vm_types="a2-highgpu-1g",
    # worker_vm_types="a2-ultragpu-1g",
    # scheduler_vm_types="g2-standard-4",
    account="nathan-ballou-gcp",
    software="pytorch-llm",
    worker_disk_size=256,
    idle_timeout = "2 hours",
    # region = "us-west1"
)

client = cluster.get_client()
client

client.wait_for_workers(n_workers=n_workers)

In [None]:
def _get_worker_info(client: Client) -> List[Dict]:
    """
    returns a list of workers (sorted), and the DNS name for the master host
    The master is the 0th worker's host
    """
    workers = client.scheduler_info()["workers"]
    worker_keys = sorted(workers.keys())
    workers_by_host: Dict[str, List[str]] = {}
    for key in worker_keys:
        worker = workers[key]
        host = worker["host"]
        workers_by_host.setdefault(host, []).append(key)
    host = workers[worker_keys[0]]["host"]
    all_workers = []
    global_rank = 0
    for host in sorted(workers_by_host.keys()):
        for worker in workers_by_host[host]:
            worker_info = workers[worker]
            all_workers.append(
                dict(
                    worker=worker,
                    global_rank=global_rank,
                    host=host,
                    local_directory=worker_info.get("local_directory", None),
                )
            )
            global_rank += 1
    return all_workers

In [None]:
all_workers = _get_worker_info(client)

In [None]:
all_workers

In [None]:
from distributed.diagnostics.plugin import UploadFile

client.register_plugin(UploadFile("sft.py", load = False))
client.register_plugin(UploadFile("config.yml", load = False))

In [None]:
host = all_workers[0]["host"]
port = 12345
num_machines = len(all_workers)
gpus_per_worker = 1
output_directory = "/scratch/experiments/finetune-mixtral-8x7B"

In [None]:
def run_subprocess(
        host_ip, 
        host_port, 
        machine_rank, 
        num_processes, 
        num_machines,
        local_directory,
        output_directory
        ):
    
    command = [
        "accelerate", "launch",
        "--config_file", f"{local_directory}/config.yml",
        "--main_process_ip", host_ip,
        "--main_process_port", str(host_port),
        "--machine_rank", str(machine_rank),
        "--num_processes", str(num_processes),
        "--num_machines", str(num_machines),
        f"{local_directory}/sft.py",
        '--model_name', 'mistralai/Mistral-7B-v0.1',
        '--dataset_name', "trl-lib/ultrachat_200k_chatml",
        '--batch_size', '2',
        '--gradient_accumulation_steps', '1',
        '--learning_rate', '2e-4',
        '--save_steps', '200_000',
        '--use_peft',
        '--peft_lora_r', '8',
        '--peft_lora_alpha', '16',
        '--target_modules', "q_proj", "k_proj", "v_proj", "o_proj",
        '--load_in_4bit',
        '--output_dir', output_directory,
        # '--gradient_checkpointing',
        '--trust_remote_code',
        ]

    with subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as process:
        for line in process.stdout:
            print(line.decode('utf8'))

    return f"Worker {machine_rank}: Done."

In [None]:
futures = []
for worker in all_workers:
    fut = client.submit(
        run_subprocess,
        workers = [worker['worker']],
        host_ip = host,
        host_port = port, 
        machine_rank= worker['global_rank'], 
        num_processes = num_machines * gpus_per_worker, 
        num_machines = num_machines,
        local_directory = worker['local_directory'],
        output_directory = output_directory
    )
    futures.append(fut)
futures

In [None]:
def foo():
   from distributed.diagnostics import nvml
   return nvml.real_time()
client.run(foo)

In [None]:
for future in futures:
    print(future.result())

In [None]:
for future in futures:
    future.cancel()

In [None]:
def list_files(directory_path):
    return [os.path.join(directory_path, f) for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))]

def read_file(file_path):
    with open(file_path, 'rb') as file:  # Use 'rb' for binary mode
        return file.read(), os.path.basename(file_path)

# Get list of files in the directory
file_list_future = client.submit(list_files, output_directory, workers = [all_workers[0]["worker"]])
file_list = file_list_future.result()

# Local directory to save files
local_directory = 'output'
os.makedirs(local_directory, exist_ok=True)

for file_path in file_list:
    file_future = client.submit(read_file, file_path, workers = [all_workers[0]["worker"]])
    content, filename = file_future.result()
    local_file_path = os.path.join(local_directory, filename)
    with open(local_file_path, 'wb') as local_file:
        local_file.write(content)


In [None]:
cluster.shutdown()

In [None]:
# client.restart()

In [None]:
# # Function to search for 'train.py' in the worker's file system
# def find_train_py():
#     for root, dirs, files in os.walk('/scratch'):
#         if 'config.yml' in files:
#             return os.path.join(root, 'config.yml')
#     return "train.py not found"

# # Run the function on all workers
# futures = client.run(find_train_py)

# # Collect results
# results = client.gather(futures)

# # Print the results
# for worker, path in results.items():
#     print(f"Worker {worker} found 'config.yml' at: {path}")
