In [83]:
%reload_ext autoreload
%autoreload 2

import socket
import psutil
import sys 
import os
from typing import Any
from functools import partial
import json
from pprint import pprint
from collections import OrderedDict
from datetime import datetime
import glob
import json 

In [89]:
def find_most_recent_config_json(directory, file_name='config.json', max_depth=2):
    """
    Find the most recently written config.json file in a directory tree, up to 2 levels deep.
    Args:
        directory (str): The root directory to search in.
    Returns:
        str: The path to the most recently written config.json file, or None if no such file is found.
    """
    most_recent_file = None
    most_recent_timestamp = 0
    for root, dirs, files in os.walk(directory):
        depth = root.replace(directory, '').count(os.sep)
        if max_depth is not None and depth > max_depth:
            continue
        for file in files:
            if file == file_name:
                file_path = os.path.join(root, file)
                timestamp = os.path.getmtime(file_path)
                if timestamp > most_recent_timestamp:
                    most_recent_timestamp = timestamp
                    most_recent_file = file_path
                    
    if most_recent_file is not None:
        with open(most_recent_file, 'r') as f:
            data = json.load(f)
            return data
        
    return None


def get_trainer_state(output_dir: str):
    checkpoint_dirs = glob.glob(f"{output_dir}/checkpoint-*")
    checkpoints = [int(d.split("/")[-1].split("-")[-1]) for d in checkpoint_dirs]
    if len(checkpoints) == 0:
        return None, None 
    
    checkpoint_latest = max(checkpoints)
    # print(f"{checkpoint_latest=}")

    trainer_state_file = os.path.join(
        output_dir,
        f"checkpoint-{checkpoint_latest}",
        "trainer_state.json"
    )

    with open(trainer_state_file, 'r') as file:
        state = json.load(file)
        
    return checkpoint_latest, state


def get_train_summary(state, config):
    grad_accumulation_steps = config['trainer_args']['gradient_accumulation_steps']
    num_nodes = config['slurm_args']['nodes']
    num_gpus = config['slurm_args']['gpus_per_task']

    mp = config['trainer_args'].get("model_parallel_size", 1)
    cp = config['trainer_args'].get("context_parallel_size", 1)
    pp = config['trainer_args'].get("pipeline_parallel_size", 1)

    num_batches_per_step =  state['logging_steps'] * (num_gpus * num_nodes * grad_accumulation_steps) / (mp * pp * cp)
    num_input_tokens = 0
    num_modality_tokens = 0

    for record in state['log_history'][1:]:
        # skipping the first record, which always report step=1
        # the subsequent record are for step = 10, 20, 30, ... (assuming logging_step = 10)
        num_input_tokens += num_batches_per_step * record['num_input_tokens_per_batch']
        num_modality_tokens += num_batches_per_step * record['num_modality_embs_per_batch']
        
    num_input_tokens, num_modality_tokens
    total_tokens = num_input_tokens + num_modality_tokens

    # calculate VE tokens
    chunk_size = config['trainer_args']['chunk_size']
    patch_size = 14
    cls_token = False
    n_prefix_embs = config['trainer_args']['n_prefix_embs']

    compression_ratio = ((chunk_size/patch_size)**2 + int(cls_token)) / n_prefix_embs

    summary = {
        "num_steps": state['global_step'],
        "billion_text_tokens": int(num_input_tokens/1e9),
        "billion_modality_tokens": int(num_modality_tokens/1e9),
        "billion_llm_tokens": int(total_tokens/1e9),
        "billion_visual_tokens": int(compression_ratio * num_modality_tokens/1e9)
    }

    return summary

In [4]:
base_dir = "/fsx_3/bucket/tranx/checkpoints/perceiver_sizing"
output_dirs = glob.glob(f"{base_dir}/Llama3*")
for d in output_dirs:
    print(d)
    name = d.split("/")[-1]
    checkpoint_latest, state = get_trainer_state(d)
    
    print(f"{name}, {checkpoint_latest=}")

/fsx_3/bucket/tranx/checkpoints/perceiver_sizing/Llama3.1_70B_ViTG_layers18_dim4096_heads32_latents64_bz64_step4
Llama3.1_70B_ViTG_layers18_dim4096_heads32_latents64_bz64_step4, checkpoint_latest=10600
/fsx_3/bucket/tranx/checkpoints/perceiver_sizing/Llama3.1_70B_ViTG_layers22_dim4096_heads32_latents64_bz64_step8
Llama3.1_70B_ViTG_layers22_dim4096_heads32_latents64_bz64_step8, checkpoint_latest=11200
/fsx_3/bucket/tranx/checkpoints/perceiver_sizing/Llama3.1_70B_ViTG_layers22_dim4096_heads32_latents128_bz32_step8
Llama3.1_70B_ViTG_layers22_dim4096_heads32_latents128_bz32_step8, checkpoint_latest=10100
/fsx_3/bucket/tranx/checkpoints/perceiver_sizing/Llama3.1_70B_ViTG_layers26_dim4096_heads32_latents64_bz64_step4
Llama3.1_70B_ViTG_layers26_dim4096_heads32_latents64_bz64_step4, checkpoint_latest=10300
/fsx_3/bucket/tranx/checkpoints/perceiver_sizing/Llama3.1_70B_ViTG_layers14_dim4096_heads32_latents64_bz64_step4
Llama3.1_70B_ViTG_layers14_dim4096_heads32_latents64_bz64_step4, checkpoint_l

In [90]:
output_dir = "/fsx_3/bucket/tranx/checkpoints/perceiver_sizing/Llama3.1_70B_ViTG_layers18_dim4096_heads32_latents64_bz64_step4"

checkpoint_latest, state = get_trainer_state(d)
config = find_most_recent_config_json(output_dir)
summary = get_train_summary(state, config)
summary['last_checkpoint'] = checkpoint_latest
summary

{'num_steps': 10300,
 'billion_text_tokens': 57,
 'billion_modality_tokens': 51,
 'billion_llm_tokens': 109,
 'billion_visual_tokens': 616,
 'last_checkpoint': 10300}

In [None]:
# To-dos:
- count GPU hours total
- count GPU hours per step 
- smoothing loss
- fit losses/evals on log scale