In [None]:
import coiled
import os
from typing import List, Dict
from distributed.client import Client
import subprocess

# coiled.create_software_environment(
#     name="pytorch-llm",
#     conda="pytorch-llm.yaml",
#     gpu_enabled=True,
#     account = "nathan-ballou-gcp"
# )

In [None]:
n_workers = 4

cluster = coiled.Cluster(
    n_workers=n_workers,
    worker_vm_types="g2-standard-4",
    account="nathan-ballou-gcp",
    software="pytorch-llm",
    worker_disk_size=128,
    idle_timeout = "24 hours",
)

client = cluster.get_client()
client.wait_for_workers(n_workers=n_workers)

In [None]:
from dask.distributed import get_worker

def get_worker_info() -> List[Dict]:
    from dask.distributed.diagnostics import nvml
    worker = get_worker()
    return {
        "worker": worker.address,
        "host": worker.ip,
        "local_directory": worker.local_directory,
        "gpus": nvml.device_get_count(),
    }

workers = list(client.scheduler_info()["workers"].keys())
all_workers = client.gather([client.submit(get_worker_info, workers=[worker], pure=False) for worker in workers])
all_workers

In [None]:
gpus_per_worker = all_workers[0]["gpus"]
host = all_workers[0]["host"]
port = "12345"
num_machines = len(all_workers)
num_processes = num_machines * gpus_per_worker

script_path = "sft.py"
config_path = "config.yml"
output_directory = "/scratch/experiments/finetune-mixtral-8x7B"

In [None]:
client.upload_file(script_path, load=False)
client.upload_file(config_path, load=False)

In [None]:
def train(local_directory, config_path, host, port, machine_rank, num_processes, num_machines, script_path, commands_str):
    command = [
        "accelerate",
        "launch",
        "--config_file", f"{local_directory}/{config_path}",
        "--machine_rank", str(machine_rank),
        "--num_processes", str(num_processes),
        "--num_machines", str(num_machines),
    ]
    if num_machines > 1:
        command += [
            "--main_process_ip", str(host),
            "--main_process_port", str(port),
        ]
    command += [f"{local_directory}/{script_path}"] + commands_str.split()
    subprocess.check_call(command)  # Run command using subprocess
    return f"Worker {machine_rank}: Done."

command_str = f"""--model_name mistralai/Mistral-7B-v0.1 
                 --dataset_name trl-lib/ultrachat_200k_chatml
                 --output_dir {output_directory}
                 --batch_size 2 --gradient_accumulation_steps 1
                 --learning_rate 2e-4
                 --save_steps 200_000
                 --use_peft
                 --peft_lora_r 8
                 --peft_lora_alpha 16
                 --target_modules q_proj k_proj v_proj o_proj
                 --load_in_4bit
                 --trust_remote_code
              """

futures = [
    client.submit(
        train,
        local_directory=worker['local_directory'],
        config_path=config_path, 
        host=host,
        port=port,
        machine_rank=rank,
        num_processes=num_processes,
        num_machines=num_machines,
        script_path=script_path,
        commands_str=command_str,
        workers=[worker['worker']],
    )
    for rank, worker in enumerate(all_workers)
]
futures

In [None]:
for future in futures:
    print(future.result())

In [None]:
def list_files(directory_path):
    return [os.path.join(directory_path, f) for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))]

def read_file(file_path):
    with open(file_path, 'rb') as file:
        return file.read(), os.path.basename(file_path)

# Get list of files in the directory
file_list_future = client.submit(list_files, output_directory, workers = [all_workers[0]["worker"]])
file_list = file_list_future.result()

# Local directory to save files
local_directory = 'output'
os.makedirs(local_directory, exist_ok=True)

for file_path in file_list:
    file_future = client.submit(read_file, file_path, workers = [all_workers[0]["worker"]])
    content, filename = file_future.result()
    local_file_path = os.path.join(local_directory, filename)
    with open(local_file_path, 'wb') as local_file:
        local_file.write(content)