In [None]:
import time
print(f"Starting notebook: {time.time()}")
from accelerate import notebook_launcher
import torch


In [None]:
time_record = {}

In [13]:

def train_fn():
    start_time = time.time()
    # print(f"Initializing process group: {time.time()}")
    torch.distributed.init_process_group(backend='nccl')
    end_time = time.time()
    rank = torch.distributed.get_rank()
    def print_rank(message):
        print(f"[rank {rank}] {message}")
    # print_rank(f"Time taken to initialize process group: {end_time - start_time} seconds")
    print_rank(f"Running on rank {torch.distributed.get_rank()}")

    torch.cuda.set_device(rank)

    # Start benchmarking linear layers with TP.
    import torch.nn as nn
    tp = torch.distributed.get_world_size()
    head_dim = 128
    num_qo_heads = 32 // tp
    num_kv_heads = 8 // tp

    hidden_dim = head_dim * num_qo_heads
    l1 = nn.Linear(hidden_dim, hidden_dim * 4, bias=False)
    # print(l1.weight.shape)
    
    # test the forward and backward time of l1
    K = 1024
    ctx_length = 16 * K
    x = torch.randn(ctx_length, hidden_dim).requires_grad_(True)
    z = torch.randn(ctx_length, hidden_dim * 4).requires_grad_(False)

    import numpy as np

    fw_times = []
    bw_times = []

    for _ in range(10):
        fw_start_evt = torch.cuda.Event(enable_timing=True)
        fw_end_evt = torch.cuda.Event(enable_timing=True)
        bw_start_evt = torch.cuda.Event(enable_timing=True)
        bw_end_evt = torch.cuda.Event(enable_timing=True)

        fw_start_evt.record()
        y = l1(x)
        fw_end_evt.record()

        bw_start_evt.record()
        y.backward(z)
        bw_end_evt.record()
        
        torch.cuda.synchronize()
        fw_time = fw_start_evt.elapsed_time(fw_end_evt)
        bw_time = bw_start_evt.elapsed_time(bw_end_evt)
        
        fw_times.append(fw_time)
        bw_times.append(bw_time)

    fw_avg = np.mean(fw_times)
    fw_std = np.std(fw_times)
    bw_avg = np.mean(bw_times)
    bw_std = np.std(bw_times)

    print_rank(f"fw_time={fw_avg:.2f} ± {fw_std:.2f}, bw_time={bw_avg:.2f} ± {bw_std:.2f}")


# print(f"Starting notebook launcher: {time.time()}")
notebook_launcher(train_fn, num_processes=4)

Launching training on 4 GPUs.
[rank 1] Running on rank 1
[rank 0] Running on rank 0
[rank 2] Running on rank 2
[rank 3] Running on rank 3
[rank 1] fw_time=186.25 ± 48.66, bw_time=667.63 ± 66.66
[rank 0] fw_time=190.75 ± 47.66, bw_time=699.66 ± 100.15
[rank 2] fw_time=174.03 ± 37.51, bw_time=711.10 ± 228.03
[rank 3] fw_time=208.06 ± 59.54, bw_time=674.61 ± 274.36


In [104]:
import time
import torch
import torch.nn as nn
import torch.profiler

def profile_linear(
    device: int = 0,
    input_dim: int = 128 * (32 // 1),
    output_dim: int = 128 * (32 // 1) * 4,
    seq_len: int = 1024,
    warmup: int = 2,
    active: int = 10,
    log_dir: str = "./profiler_logs/single_gpu",
):
    """
    Profiles a single nn.Linear(hidden_dim, hidden_dim*4) forward+backward on one GPU.

    Args:
      device:       CUDA device index.
      hidden_dim:   Input feature size.
      seq_len:      Sequence length (number of rows).
      warmup:       Number of warm‑up iterations (unrecorded).
      active:       Number of profiled iterations.
      log_dir:      Directory where TensorBoard/Chrome‑trace logs go.
    """
    # 1) Set up device & model
    torch.cuda.set_device(device)
    model = torch.compile(nn.Linear(input_dim, output_dim, bias=False, device=device))

    # 2) Dummy inputs
    x = torch.randn(seq_len, input_dim, device=device, requires_grad=True)
    z = torch.randn(seq_len, output_dim, device=device)

    # 3) Profiler schedule & handler
    schedule = torch.profiler.schedule(
        wait=warmup,
        warmup=warmup,
        active=active,
        repeat=1,
    )
    tb_handler = torch.profiler.tensorboard_trace_handler(
        dir_name=log_dir
    )
    from torch.cuda import nvtx

    fw_times = []
    bw_times = []

    with torch.profiler.profile(
        schedule=schedule,
        on_trace_ready=tb_handler,
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
        with_flops=True,
        # use_cuda=True,
    ) as prof:
        total_steps = warmup + active
        for step in range(total_steps):
            # Use NVTX for marking regions
            nvtx.range_push(f"Step {step}")

            # Timing for forward pass
            forward_start_event = torch.cuda.Event(enable_timing=True)
            forward_end_event = torch.cuda.Event(enable_timing=True)

            forward_start_event.record()
            nvtx.range_push("Forward Pass")
            y = model(x)
            nvtx.range_pop()
            forward_end_event.record()

            torch.cuda.synchronize()
            forward_time = forward_start_event.elapsed_time(forward_end_event)

            # Timing for backward pass
            backward_start_event = torch.cuda.Event(enable_timing=True)
            backward_end_event = torch.cuda.Event(enable_timing=True)

            backward_start_event.record()
            nvtx.range_push("Backward Pass")
            y.backward(z)
            nvtx.range_pop()
            backward_end_event.record()

            torch.cuda.synchronize()
            backward_time = backward_start_event.elapsed_time(backward_end_event)

            # if step == warmup - 1:
            #     print(f"[device {device}] warm‑up done, starting profiling")
            # if (step >= warmup):
            #     total_time = forward_time + backward_time
            #     print(f"[device {device}] prof step {step-warmup}/{active}, forward: {forward_time:.2f} ms, backward: {backward_time:.2f} ms, total: {total_time:.2f} ms")

            nvtx.range_pop()
            prof.step()
            if (step >= warmup):
                fw_times.append(forward_time)
                bw_times.append(backward_time)
    
    # print(f"[device {device}] profiling complete. Logs in: {log_dir}")
    return fw_times, bw_times

In [110]:

import numpy as np

# Initialize lists to store forward and backward times for each seq_len
all_fw_times = []
all_bw_times = []

# Loop over different sequence lengths
tp = 1
K = 1024
for seq_len_factor in range(1, 10):
    seq_len = K * (2 ** seq_len_factor)
    fw_times, bw_times = profile_linear(
        device=0,
        input_dim=128 * (32 // tp),
        output_dim=128 * (32 // tp) * 4,
        seq_len=seq_len,
        # warmup=10,
        warmup=5,
        active=8,
        log_dir=f"./my_logs_seq_len_{seq_len}",
    )
    all_fw_times.append(fw_times)
    all_bw_times.append(bw_times)

    fw_avg = np.mean(fw_times)
    fw_std = np.std(fw_times)
    bw_avg = np.mean(bw_times)
    bw_std = np.std(bw_times)

    print(f"seq_len=2 ** {seq_len_factor}K, fw_time={fw_avg:.2f} ± {fw_std:.2f}, bw_time={bw_avg:.2f} ± {bw_std:.2f}")


  with torch.profiler.profile(


seq_len=2 ** 1K, fw_time=14.80 ± 0.51, bw_time=29.52 ± 0.07
seq_len=2 ** 2K, fw_time=29.15 ± 0.12, bw_time=58.10 ± 0.19
seq_len=2 ** 3K, fw_time=58.08 ± 0.10, bw_time=115.63 ± 0.09


  with torch.profiler.profile(


seq_len=2 ** 4K, fw_time=116.00 ± 0.16, bw_time=232.47 ± 0.13
seq_len=2 ** 5K, fw_time=231.76 ± 0.15, bw_time=461.06 ± 0.21
seq_len=2 ** 6K, fw_time=462.61 ± 0.21, bw_time=920.40 ± 0.21
seq_len=2 ** 7K, fw_time=924.86 ± 0.49, bw_time=1838.88 ± 0.35
seq_len=2 ** 8K, fw_time=1851.89 ± 0.14, bw_time=3755.85 ± 1.12


OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 GiB. GPU 0 has a total capacity of 79.14 GiB of which 30.35 GiB is free. Including non-PyTorch memory, this process has 48.78 GiB memory in use. Of the allocated memory 40.27 GiB is allocated by PyTorch, and 8.00 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [107]:
torch.cuda.empty_cache()

In [108]:
"""
seq_len=2 ** 1K, fw_time=14.80 ± 0.51, bw_time=29.52 ± 0.07
seq_len=2 ** 2K, fw_time=29.15 ± 0.12, bw_time=58.10 ± 0.19
seq_len=2 ** 3K, fw_time=58.08 ± 0.10, bw_time=115.63 ± 0.09
seq_len=2 ** 4K, fw_time=116.00 ± 0.16, bw_time=232.47 ± 0.13
seq_len=2 ** 5K, fw_time=231.76 ± 0.15, bw_time=461.06 ± 0.21
seq_len=2 ** 6K, fw_time=462.61 ± 0.21, bw_time=920.40 ± 0.21
seq_len=2 ** 7K, fw_time=924.86 ± 0.49, bw_time=1838.88 ± 0.35
seq_len=2 ** 8K, fw_time=1851.89 ± 0.14, bw_time=3755.85 ± 1.12
"""

'\nseq_len=2 ** 1, fw_time=7.40 ± 0.09, bw_time=15.54 ± 0.05\nseq_len=2 ** 2, fw_time=14.55 ± 0.07, bw_time=29.48 ± 0.04\nseq_len=2 ** 3, fw_time=21.76 ± 0.09, bw_time=44.28 ± 0.07\nseq_len=2 ** 4, fw_time=29.09 ± 0.08, bw_time=58.08 ± 0.06\nseq_len=2 ** 5, fw_time=36.10 ± 0.09, bw_time=72.75 ± 0.05\nseq_len=2 ** 6, fw_time=43.57 ± 0.08, bw_time=87.49 ± 0.06\nseq_len=2 ** 7, fw_time=50.46 ± 0.08, bw_time=101.72 ± 0.04\nseq_len=2 ** 8, fw_time=58.03 ± 0.08, bw_time=115.67 ± 0.07\nseq_len=2 ** 9, fw_time=65.68 ± 0.09, bw_time=131.83 ± 0.17\n'