__<font color='red'>
Note: This feature is currently under development so it will not be fully functional yet. 
Also, the APIs here are experimental and subject to change.
</font>__

## Setup and loading traces
To run this demo notebook on your laptop
1. Clone the repo `git clone https://github.com/facebookresearch/HolisticTraceAnalysis.git`
1. [Optional and recommended] Setup a conda environment. See README for details.
1. Set the `trace_dir` parameter in the next cell to the location of `tests/data/critical_path/simple_add` folder in your local `HolisticTraceAnalysis` installation.

In [1]:
from hta.trace_analysis import TraceAnalysis

trace_dir = "<UPDATE THIS PATH>/tests/data/critical_path/simple_add/"
analyzer = TraceAnalysis(trace_dir=trace_dir)

2023-10-09 15:27:39,303 - hta - trace.py:L414 - INFO - /Users/bcoutinho/Work/hta/HolisticTraceAnalysis2/tests/data/critical_path/simple_add
2023-10-09 15:27:39,333 - hta - trace_file.py:L94 - INFO - Rank to trace file map:
{0: '/Users/bcoutinho/Work/hta/HolisticTraceAnalysis2/tests/data/critical_path/simple_add/benchmark_result_493459_1694039959_trace.json.gz'}
2023-10-09 15:27:39,337 - hta - trace.py:L560 - INFO - ranks=[0]
2023-10-09 15:27:39,358 - hta - trace.py:L142 - INFO - Parsed /Users/bcoutinho/Work/hta/HolisticTraceAnalysis2/tests/data/critical_path/simple_add/benchmark_result_493459_1694039959_trace.json.gz time = 0.02 seconds 


# Critical Path Analysis

Critical path analysis is a commonly applied technique in HPC and AI/ML optimization. It can be leveraged in two ways:
1. **Efficiency opportunities** Operations/kernels on critical path should be the target of performance analysis and optimizations. They can provide the “best bang for buck” for performance improvements.This is not limited to just CPU/GPU kernels. Delays in launching or executing CUDA kernels can constitute a significant portion of the critical path as well. This could be optimized by operator fusion and CUDA graphs for example.
2. **Simulating improvements / gains:** After identifying critical path we can estimate improvements to the execution time without actually running anything. This helps us estimate the possible gains from various performance optimization.

## Approach
In a nutshell, computing the critical path involves 1) constructing a weighted DAG connecting all the operations, 2) finding the longest path in this DAG. The challenging part is constructing the DAG here. 

**Nodes**: The Nodes in the critical path graph represent points in time. Each operator/kernel thus has two nodes viz. a begin and an end node. In case of nested operators we also link the nodes in the order they appear in the callstack. An example of stacked operators is shown below-
```
          |----------------------- Op A ----------------------|
                    |--- Op B ---|        |-- Op C--|
Critical Path Graph
         (OpA.b)--->(ObB.b)----->(OpB.e)->(OpC.b)-->(OpC.e)->(OpA.e)
```

**Edges** in this DAG can be one of two types
1. Timing edges (weight = time): include durations for the operators/kernels as well as delays to launch operators between CPU and GPU.
1. Dependency edges (weight = 0): do not have a time component but show a dependency between operations themselves. This includes data dependencies and synchronization between CPU and GPU.

## Lightweight Critical Path
Our initial implementation of Critical Path analysis does not consider dependency between PyTorch operators, this requires combining the trace with Chakra Execution Traces and processing Tensor information.

The key idea here is to simplify the dependency analysis between PyTorch operators
- We assume that all CPU operators are dependent serially on the last operator that ran on a CPU thread. This frees us from attempting to correlate with Execution Traces.
- The operator dependency part can be added back later.




## Using this notebook

We can demonstrate critical path analysis on most PyTorch traces. However, currently traces are missing
information regarding CUDA synchronization. 

The above is fixed in this [PR](https://github.com/pytorch/pytorch/pull/105187) but not enabled by default. Please follow the documentation in this [PR](https://github.com/pytorch/pytorch/pull/105187) to enable CUDA synchronization events to get best results from this analysis.

In [2]:
analyzer.critical_path_analysis?

[0;31mSignature:[0m
[0manalyzer[0m[0;34m.[0m[0mcritical_path_analysis[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mrank[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mannotation[0m[0;34m:[0m [0mstr[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0minstance_id[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0mTuple[0m[0;34m[[0m[0mhta[0m[0;34m.[0m[0manalyzers[0m[0;34m.[0m[0mcritical_path_analysis[0m[0;34m.[0m[0mCPGraph[0m[0;34m,[0m [0mbool[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Perform critical path analysis for trace events within a rank.
We further reduce the region of interest by selecting
a trace annotation and instance id. This will
limit the analysis to events within the time range of that annoation.
This will include GPU kernels launched by the cpu operators in that
time duration.
For example, you can use this to limit the a

In [3]:
annotation = "[param|pytorch.model.alex_net|0|0|0|measure|forward]"
instance_id = 1

In [4]:
cp_graph, success = analyzer.critical_path_analysis(
    rank = 0, annotation=annotation, instance_id=instance_id)
success                                             

2023-10-09 15:27:39,473 - hta - critical_path_analysis.py:L493 - INFO - Looking up events under 1 instance of '[param|pytorch.model.alex_net|0|0|0|measure|forward]' annotation
2023-10-09 15:27:39,483 - hta - critical_path_analysis.py:L520 - INFO - Clipped dataframe has 212 events
2023-10-09 15:27:39,518 - hta - critical_path_analysis.py:L321 - INFO - CUDA Sync event eid = 975.0, name = Context Sync
2023-10-09 15:27:39,519 - hta - critical_path_analysis.py:L321 - INFO - CUDA Sync event eid = 1029.0, name = Stream Wait Event
2023-10-09 15:27:39,520 - hta - critical_path_analysis.py:L321 - INFO - CUDA Sync event eid = 1075.0, name = Stream Wait Event
2023-10-09 15:27:39,525 - hta - critical_path_analysis.py:L321 - INFO - CUDA Sync event eid = 1279.0, name = Context Sync


True

## Overlay and visualize the Critical Path
The `overlay_critical_path_analysis()` function exposes the critical path on the original trace file.
There are two modes for the output:
1. When `only_show_critical_events=True` (default value) the output trace only contains CPU operators and GPU events on the critical path. One can compare it with the original trace to contrast the critical path identified by the algorithm.
1. When `only_show_critical_events=True` in the output trace file search for "critical" to highlight events on the critical path.

Edges in the critical path graph will be shown using arrows or flow events. Critical edges are marked using the "critical" flag in the args property.

The category names of flow events begin with "critical_path_". They have the following meaning:
* `critical_path_dependency`: an inter operator dependency.
* `critical_path_operator`: an edge between start and end of a CPU operator or GPU kernel.
* `critical_path_kernel_launch_delay`: delay to launch a GPU kernel that is likely to be on the critical path.
* `critical_path_kernel_kernel_delay`: delay between running successive GPU kernels.
* `critical_path_sync_dependency`: these edges denote synchronization or control dependencies between events such as GPU->CPU, GPU->GPU synchronization.

In [5]:
analyzer.overlay_critical_path_analysis?

[0;31mSignature:[0m
[0manalyzer[0m[0;34m.[0m[0moverlay_critical_path_analysis[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mrank[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcritical_path_graph[0m[0;34m:[0m [0mhta[0m[0;34m.[0m[0manalyzers[0m[0;34m.[0m[0mcritical_path_analysis[0m[0;34m.[0m[0mCPGraph[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0moutput_dir[0m[0;34m:[0m [0mstr[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0monly_show_critical_events[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mshow_all_edges[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0mstr[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Overlay the identified critical path on top of the trace file
for visualization.

Args:
    rank (int): rank to generate the time series for.
    critical_path_graph: Critical Path Graph object generated 

In [6]:
# Make sure '~/HolisticTraceAnalysis/tests/data/critical_path/simple_add/overlaid' exists or add a different directory
# to write your trace too.
analyzer.overlay_critical_path_analysis(
    0, cp_graph, output_dir='~/HolisticTraceAnalysis/tests/data/critical_path/simple_add/overlaid', show_all_edges=True)

2023-10-09 15:28:45,797 - hta - trace.py:L142 - INFO - Parsed /Users/bcoutinho/Work/hta/HolisticTraceAnalysis2/tests/data/critical_path/simple_add/benchmark_result_493459_1694039959_trace.json.gz time = 0.01 seconds 


'/Users/bcoutinho/Work/hta/critical_path/simple_add/overlaid/overlaid_critical_path_benchmark_result_493459_1694039959_trace.json.gz'

## More details on the CPGraph object

In [7]:
from hta.analyzers.critical_path_analysis import CPGraph
CPGraph?

[0;31mInit signature:[0m [0mCPGraph[0m[0;34m([0m[0mt[0m[0;34m:[0m [0;34m'Trace'[0m[0;34m,[0m [0mrank[0m[0;34m:[0m [0mint[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Critical path analysis graph representation for trace from one rank.
This object constructs a graph that can be analyzed using networkx library.

We maintain a mapping between node ids -> CPNode objects
and use the integer as a node in the networkx graph datastructure.
Edges are directly used as the type is hashable.

Attributes:
    trace_df (pd.DataFrame): dataframe of trace events used to construct this graph.
    symbol_table (TraceSymbolTable): a symbol table used to encode the symbols in the trace.
    node_list (List[int]): list of critical path node objects, index in this list is always the node id..
    critical_path_nodes (List[int]): list of node ids on the critical path.
    critical_path_events_set (Set[int]): set of event ids correspondi