In [83]:
%reload_ext autoreload
%autoreload 2

import socket
import psutil
import sys 
import os
from typing import Any
from functools import partial
import json
from pprint import pprint
from collections import OrderedDict
from datetime import datetime
import glob
import json 

In [89]:
def find_most_recent_config_json(directory, file_name='config.json', max_depth=2):
    """
    Find the most recently written config.json file in a directory tree, up to 2 levels deep.
    Args:
        directory (str): The root directory to search in.
    Returns:
        str: The path to the most recently written config.json file, or None if no such file is found.
    """
    most_recent_file = None
    most_recent_timestamp = 0
    for root, dirs, files in os.walk(directory):
        depth = root.replace(directory, '').count(os.sep)
        if max_depth is not None and depth > max_depth:
            continue
        for file in files:
            if file == file_name:
                file_path = os.path.join(root, file)
                timestamp = os.path.getmtime(file_path)
                if timestamp > most_recent_timestamp:
                    most_recent_timestamp = timestamp
                    most_recent_file = file_path
                    
    if most_recent_file is not None:
        with open(most_recent_file, 'r') as f:
            data = json.load(f)
            return data
        
    return None


def get_trainer_state(output_dir: str):
    checkpoint_dirs = glob.glob(f"{output_dir}/checkpoint-*")
    checkpoints = [int(d.split("/")[-1].split("-")[-1]) for d in checkpoint_dirs]
    if len(checkpoints) == 0:
        return None, None 
    
    checkpoint_latest = max(checkpoints)
    # print(f"{checkpoint_latest=}")

    trainer_state_file = os.path.join(
        output_dir,
        f"checkpoint-{checkpoint_latest}",
        "trainer_state.json"
    )

    with open(trainer_state_file, 'r') as file:
        state = json.load(file)
        
    return checkpoint_latest, state


def get_train_summary(state, config):
    grad_accumulation_steps = config['trainer_args']['gradient_accumulation_steps']
    num_nodes = config['slurm_args']['nodes']
    num_gpus = config['slurm_args']['gpus_per_task']

    mp = config['trainer_args'].get("model_parallel_size", 1)
    cp = config['trainer_args'].get("context_parallel_size", 1)
    pp = config['trainer_args'].get("pipeline_parallel_size", 1)

    num_batches_per_step =  state['logging_steps'] * (num_gpus * num_nodes * grad_accumulation_steps) / (mp * pp * cp)
    num_input_tokens = 0
    num_modality_tokens = 0

    for record in state['log_history'][1:]:
        # skipping the first record, which always report step=1
        # the subsequent record are for step = 10, 20, 30, ... (assuming logging_step = 10)
        num_input_tokens += num_batches_per_step * record['num_input_tokens_per_batch']
        num_modality_tokens += num_batches_per_step * record['num_modality_embs_per_batch']
        
    num_input_tokens, num_modality_tokens
    total_tokens = num_input_tokens + num_modality_tokens

    # calculate VE tokens
    chunk_size = config['trainer_args']['chunk_size']
    patch_size = 14
    cls_token = False
    n_prefix_embs = config['trainer_args']['n_prefix_embs']

    compression_ratio = ((chunk_size/patch_size)**2 + int(cls_token)) / n_prefix_embs

    summary = {
        "num_steps": state['global_step'],
        "billion_text_tokens": int(num_input_tokens/1e9),
        "billion_modality_tokens": int(num_modality_tokens/1e9),
        "billion_llm_tokens": int(total_tokens/1e9),
        "billion_visual_tokens": int(compression_ratio * num_modality_tokens/1e9)
    }

    return summary

In [4]:
base_dir = "/fsx_3/bucket/tranx/checkpoints/perceiver_sizing"
output_dirs = glob.glob(f"{base_dir}/Llama3*")
for d in output_dirs:
    print(d)
    name = d.split("/")[-1]
    checkpoint_latest, state = get_trainer_state(d)
    
    print(f"{name}, {checkpoint_latest=}")

/fsx_3/bucket/tranx/checkpoints/perceiver_sizing/Llama3.1_70B_ViTG_layers18_dim4096_heads32_latents64_bz64_step4
Llama3.1_70B_ViTG_layers18_dim4096_heads32_latents64_bz64_step4, checkpoint_latest=10600
/fsx_3/bucket/tranx/checkpoints/perceiver_sizing/Llama3.1_70B_ViTG_layers22_dim4096_heads32_latents64_bz64_step8
Llama3.1_70B_ViTG_layers22_dim4096_heads32_latents64_bz64_step8, checkpoint_latest=11200
/fsx_3/bucket/tranx/checkpoints/perceiver_sizing/Llama3.1_70B_ViTG_layers22_dim4096_heads32_latents128_bz32_step8
Llama3.1_70B_ViTG_layers22_dim4096_heads32_latents128_bz32_step8, checkpoint_latest=10100
/fsx_3/bucket/tranx/checkpoints/perceiver_sizing/Llama3.1_70B_ViTG_layers26_dim4096_heads32_latents64_bz64_step4
Llama3.1_70B_ViTG_layers26_dim4096_heads32_latents64_bz64_step4, checkpoint_latest=10300
/fsx_3/bucket/tranx/checkpoints/perceiver_sizing/Llama3.1_70B_ViTG_layers14_dim4096_heads32_latents64_bz64_step4
Llama3.1_70B_ViTG_layers14_dim4096_heads32_latents64_bz64_step4, checkpoint_l

In [90]:
output_dir = "/fsx_3/bucket/tranx/checkpoints/perceiver_sizing/Llama3.1_70B_ViTG_layers18_dim4096_heads32_latents64_bz64_step4"

checkpoint_latest, state = get_trainer_state(d)
config = find_most_recent_config_json(output_dir)
summary = get_train_summary(state, config)
summary['last_checkpoint'] = checkpoint_latest
summary

{'num_steps': 10300,
 'billion_text_tokens': 57,
 'billion_modality_tokens': 51,
 'billion_llm_tokens': 109,
 'billion_visual_tokens': 616,
 'last_checkpoint': 10300}

In [None]:
# To-dos:
- count GPU hours total
- count GPU hours per step 
- smoothing loss
- fit losses/evals on log scale

# FLOPS calculator

In [64]:
def flops_metaclip(tokens, depth, width, mlp):
    
    N = depth*(width*(width + 1)*4 + 4*width + 2*(width + 1)*mlp)
    print(f"Number of parameters: {N:e}")
    
    flops = tokens*depth*(4*width*width + 2*tokens*width + 2*mlp*width)
    
    return flops
    
flops = flops_metaclip(
    tokens=(392/14)**2,
    depth=50,
    width=1536,
    mlp=1536*5.833333334
)

print(f"flops: {flops:.2g}")

Number of parameters: 1.849626e+09
flops: 1.5e+12


In [55]:
def calculate_detailed_flops(L, d_in, d_proj, d_kv, d_ff, num_layers):
    # Input Projector FLOPs
    flops_in_projector = 2 * L * d_in * d_proj

    # Attention FLOPs per layer
    flops_q = 2 * L * d_proj * d_proj
    flops_kv = 2 * L * d_proj * d_kv * 2  # Key and Value
    flops_attention_scores = L**2 * d_kv  # QK^T
    flops_attention_output = 2 * L * d_kv * d_proj  # Weighted sum and output projection

    flops_attention_total = flops_q + flops_kv + flops_attention_scores + flops_attention_output

    # Feed-Forward Network (SwiGLU) FLOPs per layer
    flops_ff_expand = 2 * L * d_proj * d_ff  # Expansion to d_ff
    flops_ff_reduce = 2 * L * (d_ff // 2) * d_proj  # Reduction back to d_proj

    flops_ff_total = flops_ff_expand + flops_ff_reduce

    # Total FLOPs per layer
    flops_per_layer = flops_attention_total + flops_ff_total

    # Total FLOPs for all layers
    flops_all_layers = num_layers * flops_per_layer

    # Total FLOPs for the entire architecture
    return flops_in_projector + flops_all_layers

def calculate_simplified_flops(L, d_in, d_proj, d_kv, d_ff, num_layers):
    # Calculate number of parameters (N)
    # Input projector parameters
    params_in_projector = d_in * d_proj

    # Attention mechanism parameters per layer
    params_q = d_proj * d_proj  # Query
    params_kv = 2 * d_proj * d_kv  # Key and Value
    params_output = d_kv * d_proj  # Output projection

    params_attention = params_q + params_kv + params_output

    # Feed-forward network parameters per layer (SwiGLU)
    params_ff_expand = d_proj * d_ff  # Expansion
    params_ff_reduce = (d_ff // 2) * d_proj  # Reduction
    params_ff = params_ff_expand + params_ff_reduce

    # Total parameters per layer
    params_per_layer = params_attention + params_ff

    # Total parameters for all layers
    params_all_layers = num_layers * params_per_layer

    # Total parameters in the model
    total_params = params_in_projector + params_all_layers
    print(f"total_params: {total_params:e}")

    # FLOPs estimation using 2 * N * D
    return 2 * total_params * L

# Example Usage
L = 784  # Sequence length (example)
d_in = 1536  # Input dimension
d_proj = 4096  # Projected dimension
d_kv = 1024  # Key/Value dimension
d_ff = 24576  # Feed-forward expansion
num_layers = 22  # Number of layers

# Calculate FLOPs
flops_detailed = calculate_detailed_flops(L, d_in, d_proj, d_kv, d_ff, num_layers)
flops_simplified = calculate_simplified_flops(L, d_in, d_proj, d_kv, d_ff, num_layers)

# Print results
print(f"Detailed Method FLOPs: {flops_detailed:e}")
print(f"Simplified Method FLOPs: {flops_simplified:e}")


total_params: 3.974103e+09
Detailed Method FLOPs: 6.245241e+12
Simplified Method FLOPs: 6.231394e+12


In [53]:
perceiver_N = 784
perceiver_D = 4*1e9

perceiver_D*perceiver_N /1e12


3.136