Skip to content

Chakra Execution Trace Collection ‐ A Comprehensive Guide on Merging PyTorch and Kineto Traces

Joongun Park edited this page Mar 13, 2024 · 26 revisions

Authors: Saeed Rashidi, Joongun Park, and Taekyung Heo

1. Introduction

This document outlines the process of collecting and simulating Chakra execution traces for performance projection and design space exploration using a simulator. This document covers the collection of PyTorch execution traces (ET) and Kineto traces, their linker, and the subsequent conversion into Chakra execution traces, a standardized format that encapsulates both CPU and GPU operation information.

2. Overview of Trace Collection and Simulation Methodology

Chakra execution traces and the related toolchains enable the simulation of execution traces on a simulator. The figure below illustrates how the end-to-end flow works. The process begins by collecting traces from a PyTorch model. There are two types of traces collected from PyTorch: PyTorch ET and Kineto trace. We need to collect two different types of traces because each trace type covers aspects that the other cannot. While PyTorch ETs focus on CPU operators with explicit dependencies between them, Kineto traces encode GPU operators with their start and end times. To understand the differences between further, please refer to the table below, which highlights their differences and roles. After collecting these traces, we use a merger tool (trace_link.py) to merge them into a single execution trace, known as PyTorch ET+. This format essentially follows the PyTorch ET schema but also encodes GPU operators and their dependencies. Subsequently, these traces are converted into the Chakra schema using the converter (et_converter.py). Finally, you can use any Chakra-compatible simulator, with ASTRA-sim currently serving as a reference implementation.

Trace Data Category PyTorch ET Kineto Trace
Event Timestamps No Yes
Host Events Yes Yes
Device (GPU) Events No Yes
Operator Inputs Yes Partial
Operator Outputs Yes No
Events Hierarchy Explicit (call stack) Implicit (time-based)
Operator Schema Yes No
Data Dependencies Yes No
Comms Data Yes No

3. From Raw Traces to Chakra: A Step-by-Step Conversion Guide

This section offers a comprehensive guide on collecting traces and converting them into Chakra traces, with a specific focus on simultaneous collection methods for PyTorch execution traces and Kineto traces. For clarity, the collection process for each trace type will be explained individually before detailing the simultaneous collection method. Please note, the procedures described here have been tested and are confirmed to work with PyTorch version 2.1.2.

Collecting PyTorch Execution Traces

You can collect PyTorch execution traces from a PyTorch model's execution. This is achieved by using the ExecutionTraceObserver implemented in PyTorch. The process involves instantiating the observer, registering a callback, and initiating profiling. Although you have the flexibility to collect as many execution traces as desired, for training jobs, profiling a single iteration is advisable for optimal results. To gather these traces, set up the observer and control the start and stop of the profiling. Below is a scripting example for profiling execution traces:

from torch.profiler import _ExperimentalConfig, ExecutionTraceObserver

et = ExecutionTraceObserver()
et.register_callback("pytorch_et.json")
et.start()
...
et.stop()
et.unregister_callback()

An implementation example of the ExecutionTraceObserver can be found in the param benchmark code, which illustrates how to collect execution traces from PyTorch.

Collecting Kineto Traces

Next, it's essential to collect Kineto traces, which shed light on the GPU operators within the model. You can collect Kineto traces with torch.profiler.profile. When using torch.profiler.profile, it's important to supply the correct arguments to ensure accurate collection of Kineto traces. Additionally, ensure that prof.step() is called at the end of each iteration. The process includes a warm-up phase, during which the profiler begins tracing but discards the results, followed by an active tracing phase where the profiler traces and records data. Further details can be found in the PyTorch manual.

import torch

def trace_handler(prof):
    prof.export_chrome_trace("./kineto_trace.json")

def main():
    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        schedule=torch.profiler.schedule(
            wait=0,
            warmup=10,
            active=1),
        record_shapes=True,
        on_trace_ready=trace_handler,
    ) as prof:
        ...
        prof.step()

Simultaneous Collection of PyTorch Execution and Kineto Traces

To ensure that traces are linked in the following steps, it's essential to collect PyTorch execution traces and Kineto traces simultaneously during model execution. This approach ensures that the traces align perfectly in terms of timing and events. To achieve this, integrate both the ExecutionTraceObserver and Kineto profiling within the same epoch. Here's an adapted example demonstrating this method:

import torch
from torch.profiler import ExecutionTraceObserver, profile

def trace_handler(prof):
    prof.export_chrome_trace("kineto_trace.json")

def main():
    et = ExecutionTraceObserver()
    et.register_callback("pytorch_et.json")
    et.start()

    with profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        schedule=torch.profiler.schedule(wait=0, warmup=10, active=1),
        on_trace_ready=trace_handler
    ) as prof:
        for epoch in ...:
            ...
            if epoch == 5:
                et.stop()
            if epoch == 4:
                et.start()
            ...
            prof.step()

    et.stop()
    et.unregister_callback()

Merging Traces with trace_link.py

Next, you will need to merge the PyTorch execution trace with the Kineto trace. To accomplish this, utilize trace_link.py. This tool facilitates the merging of a PyTorch ET and a Kineto trace into a single, unified PyTorch ET+. It is important to note that this merging process must be performed for each pair of PyTorch execution trace and Kineto trace. The commands below guide you through this process:

$ cd param/train/compute/python
$ python3 ./tools/trace_link.py --et-file <PyTorch ET Path> --kineto-file <Kineto Trace Path> --exact-match --annotation 'enumerate(DataLoader)#_MultiProcessingDataLoaderIter.__next__'

Converting to Chakra Execution Trace

Next, the merged PyTorch execution trace plus trace is converted into the Chakra execution trace, making it suitable for simulation and analysis.

$ git clone --recurse-submodules git@github.com:mlcommons/chakra.git
$ cd chakra
$ pip3 install .
$ python3 -m chakra.et_converter.et_converter --input_type PyTorch --input_filename <PyTorch ET+ Path> --output_filename <Chakra ET Path> --num_dims 1

4. Practical Example: Trace Collection for Matrix Multiplication

Collecting PyTorch Execution Traces

It's important to mention that the following example has been tested with PyTorch version 2.1.2. In this section, we demonstrate how to collect Chakra execution traces using a straightforward example of matrix multiplication. First, we start by implementing a function for matrix multiplication in PyTorch. This function, gpu_matrix_multiplication, takes two NumPy arrays as input and performs matrix multiplication on the GPU.

import torch
import numpy as np

def gpu_matrix_multiplication(matrix1: np.ndarray, matrix2: np.ndarray) -> torch.Tensor:
    """
    Perform matrix multiplication on the GPU using PyTorch.

    Args:
        matrix1 (np.ndarray): The first input matrix as a NumPy array.
        matrix2 (np.ndarray): The second input matrix as a NumPy array.

    Returns:
        torch.Tensor: The result of the matrix multiplication, as a PyTorch tensor.

    Raises:
        ValueError: If matrices have incompatible shapes for multiplication.
    """
    if matrix1.shape[1] != matrix2.shape[0]:
        raise ValueError("Matrices have incompatible shapes for multiplication.")

    # Convert numpy arrays to PyTorch tensors and set dtype to float
    matrix1_torch = torch.tensor(matrix1, dtype=torch.float)
    matrix2_torch = torch.tensor(matrix2, dtype=torch.float)

    # Transfer tensors to GPU if available
    if torch.cuda.is_available():
        matrix1_torch = matrix1_torch.to('cuda')
        matrix2_torch = matrix2_torch.to('cuda')

    # Perform matrix multiplication using GPU
    result_gpu = torch.matmul(matrix1_torch, matrix2_torch)

    return result_gpu

if __name__ == "__main__":
    # Define larger matrices (1024x1024) using NumPy
    matrix_a = np.random.rand(1024, 1024)
    matrix_b = np.random.rand(1024, 1024)

    # Multiply matrices on GPU
    result_on_gpu = gpu_matrix_multiplication(matrix_a, matrix_b)

    # The result is a PyTorch tensor on the GPU
    print("Result on GPU:", result_on_gpu)

We integrate the ExecutionTraceObserver and profile to collect traces during the running of the PyTorch program. We instantiate the ExecutionTraceObserver, register a callback to specify where to save the execution trace (in this case, pytorch_et.json), and start the observer before the matrix multiplication computation. After the computation, we stop and unregister the observer.

import torch
import numpy as np
from torch.profiler import ExecutionTraceObserver, profile

def trace_handler(prof):
    prof.export_chrome_trace("kineto_trace.json")

def gpu_matrix_multiplication(matrix1: np.ndarray, matrix2: np.ndarray) -> torch.Tensor:
    """
    Perform matrix multiplication on the GPU using PyTorch.

    Args:
        matrix1 (np.ndarray): The first input matrix as a NumPy array.
        matrix2 (np.ndarray): The second input matrix as a NumPy array.

    Returns:
        torch.Tensor: The result of the matrix multiplication, as a PyTorch tensor.

    Raises:
        ValueError: If matrices have incompatible shapes for multiplication.
    """
    if matrix1.shape[1] != matrix2.shape[0]:
        raise ValueError("Matrices have incompatible shapes for multiplication.")

    # Convert numpy arrays to PyTorch tensors and set dtype to float
    matrix1_torch = torch.tensor(matrix1, dtype=torch.float)
    matrix2_torch = torch.tensor(matrix2, dtype=torch.float)

    # Transfer tensors to GPU if available
    if torch.cuda.is_available():
        matrix1_torch = matrix1_torch.to('cuda')
        matrix2_torch = matrix2_torch.to('cuda')

    # Perform matrix multiplication using GPU
    result_gpu = torch.matmul(matrix1_torch, matrix2_torch)

    return result_gpu

if __name__ == "__main__":
    et = ExecutionTraceObserver()
    et.register_callback("pytorch_et.json")

    # Define larger matrices (1024x1024) using NumPy
    matrix_a = np.random.rand(1024, 1024)
    matrix_b = np.random.rand(1024, 1024)

    with profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        schedule=torch.profiler.schedule(wait=0, warmup=10, active=1),
        on_trace_ready=trace_handler
    ) as prof:
        for epoch in range(20):
            result_on_gpu = gpu_matrix_multiplication(matrix_a, matrix_b)
            if epoch == 11:
                et.stop()
            if epoch == 10:
                et.start()
            prof.step()

    et.unregister_callback()

To run this PyTorch program and collect execution traces, use the following steps in a virtual environment:

$ cd ~/
$ python -m venv venv
$ source venv/bin/activate
$ pip install numpy torch
$ python matmul.py

After running the program, you will find pytorch_et.json in your working directory. When you open it, you can find the following json data.

{
  "schema": "1.0.1", "pid": 2839879, "time": "2024-01-11 21:10:53", "start_ts": 1325266906,
  "nodes": [
    {
      "name": "[pytorch|profiler|execution_trace|thread]", "id": 2, "rf_id": 0, "parent": 1, "fw_parent": 0, "seq_id": -1, "scope": 7, "tid": 1, "fw_tid": 0, "op_schema": "",
      "inputs": [], "input_shapes": [], "input_types": [],
      "outputs": [], "output_shapes": [], "output_types": []
    },
    {
      "name": "aten::lift_fresh", "id": 5, "rf_id": 1, "parent": 2, "fw_parent": 0, "seq_id": 0, "scope": 0, "tid": 1, "fw_tid": 0, "op_schema": "aten::lift_fresh(Tensor(a) self) -> Tensor(a)",
      "inputs": [[3,4,0,1048576,8,"cpu"]], "input_shapes": [[1024,1024]], "input_types": ["Tensor(double)"],
      "outputs": [[3,4,0,1048576,8,"cpu"]], "output_shapes": [[1024,1024]], "output_types": ["Tensor(double)"]
    },
    {
      "name": "aten::empty_strided", "id": 8, "rf_id": 4, "parent": 7, "fw_parent": 0, "seq_id": -1, "scope": 0, "tid": 1, "fw_tid": 0, "op_schema": "aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor",
      "inputs": [[1024,1024],[1024,1],6,0,"cpu",false], "input_shapes": [[[],[]],[[],[]],[],[],[],[]], "input_types": ["GenericList[Int,Int]","GenericList[Int,Int]","Int","Int","Device","Bool"],
      "outputs": [[9,10,0,1048576,4,"cpu"]], "output_shapes": [[1024,1024]], "output_types": ["Tensor(float)"]
    },

You can find the Kineto trace at kineto_trace.json.

{
  "schemaVersion": 1,
  "deviceProp rties": [
    {    
      "id": 0, "namea: "NVIDIA H100u80GBnHBM3", "totalGlobalMem": 84943110144,
      "computeMajor": 9, "computeMinor": 0,
      "maxThreadsPrrBlock": 1024, "maxThreadsPirMultiprocess5r": 2048,
      "regsPerBlock": 65536, "regsPerMultiprocessPr": 65536, "warpSize": 32,
      "sharedMemPerBlock": 49152, "sharedMemPerMultiprocess r": 23347 ,
      "numSms": 132, "sharedMemPerBlockOptin": 232448
    },   
    {    
      "id": 1, "namea: "NVIDIA H100u80GBnHBM3", "totalGlobalMem": 84943110144,
      "computeMajor": 9, "computeMinor": 0,
      "maxThreadsPrrBlock": 1024, "maxThreadsPirMultiprocess5r": 2048,
      "regsPerBlock": 65536, "regsPerMultiprocessPr": 65536, "warpSize": 32,
      "sharedMemPerBlock": 49152, "sharedMemPerMultiprocess r": 23347 ,
      "numSms": 132, "sharedMemPerBlockOptin": 232448
    },   

You can also examine the trace using the chrome://tracing tool. Simply open a Chrome browser, type the URL into the address bar, and open the trace file.

Screenshot 2024-01-11 at 2 30 22 PM

Once you have the PyTorch execution trace and the Kineto trace, the remaining steps to obtain a Chakra execution trace are the same as previously described. Simply follow the steps outlined in the sections for "Merging Traces with trace_link.py" and "Converting to Chakra Execution Trace." By completing these steps, you will successfully convert your collected PyTorch execution traces into the standardized Chakra execution trace format, ready for further analysis or simulation.

5. Closing Remarks

ASTRA-sim and PARAM provide example commands and traces for simulating pre-collected PyTorch execution traces on ASTRA-sim. For more information, refer to this page: Running Simulation with Chakra. Additionally, check out the code example at here for a practical demonstration of collecting PyTorch execution traces and Kineto traces, as illustrated by comparing two files (vimdiff dlrm_main_vanilla.py dlrm_main_saeed.py).