In [36]:
%reload_ext autoreload
%autoreload 2

import socket
import psutil
import sys 
import os
from typing import Any
from functools import partial
import json
from pprint import pprint
from collections import OrderedDict
from datetime import datetime
import os
import pytz
# lib_path = ['/fsx_0/user/tranx/experiments']
import_paths = [
    "/fsx_0/user/tranx",
    # "/fsx_0/user/tranx/rsync",
    # "/fsx_0/user/tranx/experiments/lib"
]
for path in import_paths:
    if path not in sys.path:
        sys.path.append(path)
from rsync.llm_mm_aligner.experiments.aws.launch_job import run_job as rsync_run_job
from moe.llm_mm_aligner.experiments.aws.launch_job import run_job as moe_run_job
# import utils
    
hostname = socket.gethostname()
print("Host name:", hostname)
num_cpus = psutil.cpu_count()
print("Number of CPUs:", num_cpus)
total_memory = psutil.virtual_memory().total / (1024 ** 3)
print("Total memory (GB):", round(total_memory, 2))

import torch
from torch import nn 
from torch.nn import functional as F


# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")

Host name: submit-1
Number of CPUs: 32
Total memory (GB): 247.74


In [58]:
class Launcher:
    def __init__(self, launch_fn, config_base_dir: str, run_log_file: str):
        self.launch_fn = launch_fn
        self.config_base_dir = config_base_dir 
        self.run_log_file = run_log_file
        self.run_log = self._read_run_log()
        
    def run(
        self,
        config: str,
        nodes: int = 1,
        overrides_dict = None,
        qos: str = "ar-ai-hipri",
        conda_env: str = "/fsx_0/user/ahmadyan/.conda/envs/aligner_20240822",
        note = None
    ):
        config_file = os.path.join(self.config_base_dir, config)
        
        overrides = [[
            ("slurm_args.qos", qos),
            ("slurm_args.account", qos),
            ("slurm_args.nodes", nodes)
        ]]
        
        if overrides_dict is not None:
            for k, v in overrides_dict.items():
                overrides[0].append((
                    k, v
                ))

        timestamp = datetime.now(pytz.timezone(zone='America/New_York'))
        timestamp = timestamp.strftime('%Y-%m-%d %H:%M:%S')
    
        info = OrderedDict([
            ("job_id", None),
            ("nodes", nodes),
            ("config", config),
            ("config_base", self.config_base_dir),
            ("overrides", overrides_dict),
            ("conda_env", conda_env),
            ("timestamp", timestamp),
            ("note", note)
        ])
            
        pprint(info)

        job_id = self.launch_fn(
            config_file = config_file,
            conda_env = conda_env,
            overrides = overrides
        )
        
        print(int(job_id))        
        if job_id is not None:
            info.update({"job_id": int(job_id)})
            
            # insert job to top of the list
            self.run_log["jobs"] = [info] + self.run_log["jobs"]
            self._save_run_log()
        
    def _read_run_log(self):
        try:
            with open(self.run_log_file, 'r') as f:
                run_log = json.load(f)
        except Exception:
            run_log = {"jobs": []}
        
        return run_log
    
    def _save_run_log(self):
        with open(self.run_log_file, 'w') as f:
            json.dump(self.run_log, f, indent=4)
            
rsync_launcher = Launcher(
    launch_fn=rsync_run_job,
    config_base_dir="/fsx_0/user/tranx/rsync/llm_mm_aligner/experiments/aws",
    run_log_file="/fsx_0/user/tranx/experiments/run_log/run_log_rsync.json"
)

moe_launcher = Launcher(
    launch_fn=moe_run_job,
    config_base_dir="/fsx_0/user/tranx/moe/llm_mm_aligner/experiments/aws",
    run_log_file="/fsx_0/user/tranx/experiments/run_log/run_log_moe.json"
)

In [59]:
rsync_launcher.run(
    config="mm10.1/stage1/MH22final_70B_ViTH_336px_R1_recap_20241024_resume.json",
    nodes=64
)



OrderedDict([('job_id', None),
             ('nodes', 64),
             ('config',
              'mm10.1/stage1/MH22final_70B_ViTH_336px_R1_recap_20241024_resume.json'),
             ('config_base',
              '/fsx_0/user/tranx/rsync/llm_mm_aligner/experiments/aws'),
             ('overrides', None),
             ('conda_env', '/fsx_0/user/ahmadyan/.conda/envs/aligner_20240822'),
             ('timestamp', '2024-10-27 22:18:40'),
             ('note', None)])
61131


In [17]:
rsync_launcher.run_log
rsync_launcher.run_log["jobs"].extend(["1", 2, 3])
rsync_launcher._save_run_log()

In [9]:
def run(
    config: str,
    overrides_dict: None,
    nodes: int = 1,
    config_base: str = "/fsx_0/user/tranx/rsync/llm_mm_aligner/experiments/aws",
    qos: str = "ar-ai-hipri",
    conda_env: str = "/fsx_0/user/ahmadyan/.conda/envs/aligner_20240822",
    launcher = launch_job_rsync
):
    config_file = os.path.join(config_base, config)
    
    overrides = [[
        ("slurm_args.qos", qos),
        ("slurm_args.account", qos),
        ("slurm_args.nodes", nodes)
    ]]
    
    if overrides_dict is not None:
        for k, v in overrides_dict.items():
            overrides[0].append((
                k, v
            ))
    
    launcher.run_job(
        config_file = config_file,
        conda_env = conda_env,
        overrides = overrides
    )
    
run_moe = partial(run, launcher=launch_job_moe)
    

functools.partial(<function run at 0x7f7a4a2653a0>, launcher=<module 'moe.llm_mm_aligner.experiments.aws.launch_job' from '/fsx_0/user/tranx/moe/llm_mm_aligner/experiments/aws/launch_job.py'>)

# Stage 1: Test limit on batch_size

In [37]:
run(
    config="mm10.1_moe/moe/MH22final_70B_ViTH_336px_R1_moe_22x8x2_bz_8x1.json",
    nodes=8,
    overrides_dict={
        "trainer_args.output_dir": "/fsx_0/checkpoints/tranx/moe/70B_moe_22x8x2_n8_bz_8x1"
    }
)



In [40]:
base_config = "mm10.1_moe/moe/MH22final_70B_ViTH_336px_R1_moe_22x8x2_bz_8x1.json"

In [41]:
run(
    config=base_config,
    nodes=8,
    overrides_dict={
        "trainer_args.gradient_accumulation_steps": 2,
        "trainer_args.output_dir": "/fsx_0/checkpoints/tranx/moe/70B_moe_22x8x2_n8_bz_8x2"
    }
)



In [None]:
for bz in [16, 32]:
    run(
        config=f"mm10.1_moe/moe/MH22final_70B_ViTH_336px_R1_moe_22x8x2_bz_{bz}x1.json",
        nodes=8,
        overrides_dict={
            "trainer_args.output_dir": "/fsx_0/checkpoints/tranx/moe/70B_moe_22x8x2_n8_bz_8x1"
        }
    )