In [9]:
import json
from statistics import mean

def load_trace(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

def get_formatted_durations(trace_data):
    forward_durations = []
    backward_durations = []
    
    for event in trace_data['traceEvents']:
        if 'name' in event and 'dur' in event:
            if event['name'] == "nanotron/parallel/pipeline_parallel/engine.py(26): forward":
                forward_durations.append(event['dur'])
            elif event['name'] == "nanotron/parallel/pipeline_parallel/engine.py(67): backward":
                backward_durations.append(event['dur'])
    
    def format_duration(duration):
        ms = duration // 1000
        us = duration % 1000
        return f"{ms}ms {us}μs"
    
    return {
        "forward": format_duration(int(mean(forward_durations))),
        "backward": format_duration(int(mean(backward_durations)))
    }

# Load and analyze the trace
trace_file_path = '/fsx/ferdinandmom/ferdinand-hf/bench_cluster/results/llama-1B/8_GPUS/dp-1_tp-1_pp-1_mbz-1/20240624-092049/ip-26-0-163-220_1055475.1719220860453569001.pt.trace.json'
trace_data = load_trace(trace_file_path)
durations = get_formatted_durations(trace_data)

print("Average Forward duration:", durations["forward"])
print("Average Backward duration:", durations["backward"])

Average Forward duration: 72ms 853μs
Average Backward duration: 87ms 61μs


In [10]:
def analyze_communication_vs_computation(trace_data):
    comm_time = 0
    comp_time = 0
    comm_events = [
        'ncclRecv', 'ncclSend', 
        'ncclDevKernel_Send', 'ncclDevKernel_Recv',
        'ncclAllReduce', 'ncclReduceScatter', 'ncclAllGather',
        'ncclBroadcast', 'ncclReduce',
        'cudaMemcpyAsync', 'cudaMemcpy',  # Include memory transfers
        'nccl', 'comm'  # Generic communication keywords
    ]
    
    for event in trace_data['traceEvents']:
        if 'dur' in event and 'name' in event:
            duration = event['dur']
            name = event['name'].lower()  # Convert to lowercase for case-insensitive matching
            
            if any(comm_event.lower() in name for comm_event in comm_events):
                comm_time += duration
            else:
                comp_time += duration
    
    total_time = comm_time + comp_time
    comm_ratio = comm_time / total_time if total_time > 0 else 0
    comp_ratio = comp_time / total_time if total_time > 0 else 0
    
    return {
        'communication_time': comm_time,
        'computation_time': comp_time,
        'communication_ratio': comm_ratio,
        'computation_ratio': comp_ratio
    }

# Load and analyze the trace
trace_file_path = '/fsx/ferdinandmom/ferdinand-hf/bench_cluster/results/llama-1B/8_GPUS/dp-1_tp-2_pp-2_mbz-256/20240624-095924/ip-26-0-163-220_1136603.1719223186534915479.pt.trace.json'
trace_data = load_trace(trace_file_path)
results = analyze_communication_vs_computation(trace_data)

print(f"Communication time: {results['communication_time']} microseconds")
print(f"Computation time: {results['computation_time']} microseconds")
print(f"Communication ratio: {results['communication_ratio']:.2%}")
print(f"Computation ratio: {results['computation_ratio']:.2%}")
print(f"Communication to Computation ratio: {results['communication_time'] / results['computation_time']:.2f}")

Communication time: 563010 microseconds
Computation time: 110993888 microseconds
Communication ratio: 0.50%
Computation ratio: 99.50%
Communication to Computation ratio: 0.01
