In [None]:
import coiled
import os
from typing import List, Dict
from distributed.client import Client
import subprocess

# coiled.create_software_environment(
#     name="pytorch-llm",
#     conda="pytorch-llm.yaml",
#     gpu_enabled=True,
#     account = "nathan-ballou-gcp"
# )

In [None]:
n_workers = 4
vm_worker_type = "g2-standard-4"

In [None]:
cluster = coiled.Cluster(
    n_workers=n_workers,
    worker_vm_types=vm_worker_type,
    account="nathan-ballou-gcp",
    software="pytorch-llm",
    worker_disk_size=128,
    idle_timeout = "24 hours",
)

client = cluster.get_client()
client

client.wait_for_workers(n_workers=n_workers)

In [None]:
def get_worker_info(self) -> List[Dict]:
    workers = self.client.scheduler_info()["workers"]
    worker_keys = sorted(workers.keys())
    workers_by_host: Dict[str, List[str]] = {}
    for key in worker_keys:
        worker = workers[key]
        host = worker["host"]
        workers_by_host.setdefault(host, []).append(key)
    all_workers = []
    global_rank = 0
    for host in sorted(workers_by_host.keys()):
        for worker in workers_by_host[host]:
            worker_info = workers[worker]
            all_workers.append(
                dict(
                    worker=worker,
                    global_rank=global_rank,
                    host=host,
                    local_directory=worker_info.get("local_directory", None),
                )
            )
            global_rank += 1
    return all_workers

def detect_gpus() -> Dict:
    """Detects the number of NVIDIA GPUs using nvidia-smi."""
    try:
        output = subprocess.check_output(["nvidia-smi", "--list-gpus"], text=True)
        return len(output.strip().split('\n'))  # Count lines of output
    except Exception as e:
        print(f"Error detecting GPUs: {e}")
        return 0
    
def run_subprocess(local_directory, config_path, host, port, machine_rank, num_processes, num_machines, script_path, commands_str):
    additional_args = commands_str.split()

    if num_machines == 1:
        command = [
        "accelerate", "launch",
        "--config_file", f"{local_directory}/{config_path}",
        "--machine_rank", str(machine_rank),
        "--num_processes", str(num_processes),
        "--num_machines", str(num_machines),
        f"{local_directory}/{script_path}"
        ] + additional_args
    else:
        command = [
            "accelerate", "launch",
            "--config_file", f"{local_directory}/{config_path}",
            "--main_process_ip", str(host),
            "--main_process_port", str(port),
            "--machine_rank", str(machine_rank),
            "--num_processes", str(num_processes),
            "--num_machines", str(num_machines),
            f"{local_directory}/{script_path}"
        ] + additional_args

    # Execute command using subprocess
    with subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as process:
        for line in process.stdout:
            decoded_line = line.decode('utf8')
            print(decoded_line)
    return f"Worker {machine_rank}: Done."

In [None]:
worker_gpus = client.run(detect_gpus)
all_workers = get_worker_info()
gpus_per_worker = worker_gpus[self.all_workers[0]['worker']]
host = all_workers[0]["host"]
port = "12345"
num_machines = len(all_workers)
num_processes = num_machines * gpus_per_worker

script_path = "sft.py"
config_path = "config.yml"
output_directory = "/scratch/experiments/finetune-mixtral-8x7B"

In [None]:
client.upload_file(script_path)
client.upload_file(config_path)

In [None]:
command_str = f"""--model_name mistralai/Mistral-7B-v0.1 
                 --dataset_name trl-lib/ultrachat_200k_chatml
                 --output_path = {output_directory}
                 --batch_size 2 --gradient_accumulation_steps 1
                 --learning_rate 2e-4
                 --save_steps 200_000
                 --use_peft
                 --peft_lora_r 8
                 --peft_lora_alpha 16
                 --target_modules q_proj k_proj v_proj o_proj
                 --load_in_4bit
                 --trust_remote_code
              """


futures = []
for worker in all_workers:
    fut = client.submit(
        run_subprocess,
        local_directory = worker['local_directory'],
        config_path = config_path, 
        host = host,
        port = port,
        machine_rank = worker['global_rank'],
        num_processes = num_processes,
        num_machines = num_machines,
        script_path = script_path,
        commands_str = command_str,
        workers=[worker['worker']]
    )
    futures.append(fut)
futures

In [None]:
for future in futures:
    print(future.result())

In [None]:
def list_files(directory_path):
    return [os.path.join(directory_path, f) for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))]

def read_file(file_path):
    with open(file_path, 'rb') as file:
        return file.read(), os.path.basename(file_path)

# Get list of files in the directory
file_list_future = client.submit(list_files, output_directory, workers = [all_workers[0]["worker"]])
file_list = file_list_future.result()

# Local directory to save files
local_directory = 'output'
os.makedirs(local_directory, exist_ok=True)

for file_path in file_list:
    file_future = client.submit(read_file, file_path, workers = [all_workers[0]["worker"]])
    content, filename = file_future.result()
    local_file_path = os.path.join(local_directory, filename)
    with open(local_file_path, 'wb') as local_file:
        local_file.write(content)