# Overview

This notebook illustrates how to customize the attributes parsed by the Trace Parser when writing analyzers using the `hta` library.

Due to the extensive number of trace attributes included in the PyTorch traces, many trace analyzers may never make use of the complete set. To optimize memory usage when analyzing the traces of distributed training jobs, `hta` only retains a subset of these attributes. However, the default subset of attributes may not meet the needs of all trace analyzers. Rather than save all the attributes, `hta` allow users to configure which set of attributes to be saved into the trace dataframes.


# The APIs for Trace Parser Customization

The class `ParserConfig` provides the interface for customizing the trace parser's behavior. At current version, `hta` mainly supports configuring the set of argument attributes. However, the `ParserConfig` can be extended to support further customization when it is necessary.

## The default configuration

To be compatibility with existing trace analyzers, we use a global `ParserConfig` object to hold the default configuration. There are two class methods to access this global object:

+ `ParserConfig.get_default_cfg()`
+ `ParserConfig.set_default_cfg(cfg: ParserConfig)`

If the default configuration already meets the needs of your trace analyzer, then there is no need for the analyzer's code to interact with any `ParserConfig`.

In case an analyzer needs to add some new attributes, the analyzer needs to insert the following lines before the `Trace` objects are parsed.

```
from hta.configs.parser_config import ParserConfig, AVAILABLE_ARGS

new_config = ParserConfig.get_default_cfg()
new_config.add_args(ParserConfig.ARGS_INPUT_SHAPE + [AVAILABLE_ARGS["sm::occupancy"]])
ParserConfig.set_default_cfg(new_config)
```

# Examples

In [1]:
from hta.configs.parser_config import ParserConfig
from hta.configs.config import HtaConfig
from hta.common.trace import Trace

In [2]:
def check_trace(trace_dir: str):
    t = Trace(trace_dir=trace_dir)
    t.parse_traces(use_multiprocessing=False)
    rank = next(iter(t.traces))
    df = t.get_trace(rank)
    df.info()

# Select a trace folder
test_trace_dir = HtaConfig.get_test_data_path("h100")

# Use the default configuration
# ParserConfig.set_default_cfg(ParserConfig())
check_trace(test_trace_dir)


<class 'pandas.core.frame.DataFrame'>
Index: 273916 entries, 0 to 393636
Data columns (total 15 columns):
 #   Column                             Non-Null Count   Dtype  
---  ------                             --------------   -----  
 0   index                              273916 non-null  int32  
 1   cat                                273916 non-null  int64  
 2   name                               273916 non-null  int64  
 3   pid                                273916 non-null  object 
 4   tid                                273916 non-null  object 
 5   ts                                 273916 non-null  int64  
 6   dur                                273916 non-null  float64
 7   stream                             273916 non-null  int8   
 8   correlation                        273916 non-null  int32  
 9   bytes                              273916 non-null  int32  
 10  memory_bw_gbps                     273916 non-null  float64
 11  wait_on_stream                     273916 no

In [3]:
# Add two input attributes
cfg = ParserConfig.get_default_cfg()
cfg.add_args(ParserConfig.ARGS_INPUT_SHAPE)
ParserConfig.set_default_cfg(cfg)

check_trace(test_trace_dir)

<class 'pandas.core.frame.DataFrame'>
Index: 273916 entries, 0 to 393636
Data columns (total 17 columns):
 #   Column                             Non-Null Count   Dtype  
---  ------                             --------------   -----  
 0   index                              273916 non-null  int32  
 1   cat                                273916 non-null  int64  
 2   name                               273916 non-null  int64  
 3   pid                                273916 non-null  object 
 4   tid                                273916 non-null  object 
 5   ts                                 273916 non-null  int64  
 6   dur                                273916 non-null  float64
 7   stream                             273916 non-null  int8   
 8   correlation                        273916 non-null  int32  
 9   bytes                              273916 non-null  int32  
 10  memory_bw_gbps                     273916 non-null  float64
 11  wait_on_stream                     273916 no