# Trace Filter Overview

A trace filter is a callable object which extracts a set of events that match a set of criteria from a Trace Data
Frame. Trace filters are essential to trace analysis as they effectively narrows down the data set to enable more 
efficient and targeted analysis.

A trace filter object operates similarly to a customizable function that is invoked via a consistent interface. 
Users can use or define a Trace Filter class to specify which trace data should be captured or ignored.

The basic use pattern of trace filters is as follows:

```
from hta.common.trace_filter import Filter, IterationIndexFilter, NameFilter

# get df and trace_symbol_table from the parsed tarces. 
df = ...
trace_symbol_table = ...

# extract trace events in the first and scond iterations.
filter_func1 = IterationIndexFilter(iteration_index=[0, 1])
filtered_df1 = filter_func1(df)

# extract trace events whose name starts with "nccl".
filter_func2 = NameFilter(name_pattern=r"^nccl")
filtered_df2 = filter_func2(df, trace_symbol_table)
``` 

# Code Examples

## Load a trace collection

In [None]:
from hta.common.trace import Trace
import hta
import os
from pathlib import Path

# Load the traces in a folder
base_data_dir = str(Path(hta.__file__).parent.parent.joinpath("tests/data"))
trace_dir: str = os.path.join(base_data_dir, "trace_filter")
t = Trace(trace_dir=os.path.join(base_data_dir, "trace_filter"))
t.parse_traces()

# Decode the symbol columns (i.e., `name` and `cat`) 
t.decode_symbol_ids()

df = t.get_trace(0)
symbol_table = t.symbol_table

## Filter events by name patterns

In [None]:
from hta.common.trace_filter import NameFilter

name_filter = NameFilter("aten::.*mm")
selected_df = name_filter(df, symbol_table)
print(f"Found {len(selected_df)} matching events.")
selected_df.head()[["iteration", "index", "ts", "dur", "stream", "s_cat", "s_name"]]

Found 378 matching events.


Unnamed: 0,iteration,index,ts,dur,stream,s_cat,s_name
13,551,13,1682725898237042,60.0,-1,cpu_op,aten::mm
14,551,14,1682725898237110,33.0,-1,cpu_op,aten::mm
74,551,74,1682725898239892,95.0,-1,cpu_op,aten::mm
77,551,77,1682725898240570,51.0,-1,cpu_op,aten::mm
78,551,78,1682725898240628,43.0,-1,cpu_op,aten::mm


## Filter events by a time range

In [None]:
from hta.common.trace_filter import TimeRangeFilter

start_time = 1682725898237042
end_time = 1682725898240570
time_filter = TimeRangeFilter((start_time, end_time))
selected_df = time_filter(df)
print(f"Found {len(selected_df)} matching events.")
selected_df.head()[["iteration", "index", "ts", "dur", "stream", "s_cat", "s_name"]]


Found 93 matching events.


Unnamed: 0,iteration,index,ts,dur,stream,s_cat,s_name
13,551,13,1682725898237042,60.0,-1,cpu_op,aten::mm
14,551,14,1682725898237110,33.0,-1,cpu_op,aten::mm
15,551,15,1682725898237194,57.0,-1,cpu_op,SoftmaxBackward0
16,551,16,1682725898237196,51.0,-1,cpu_op,SoftmaxBackward0
17,551,17,1682725898237201,45.0,-1,cpu_op,aten::_softmax_backward_data


## Filter events by a sequence of filters

You can also compose a new composite filter by combining several trace filter objects.  

In [None]:
from hta.common.trace_filter import GPUKernelFilter, IterationIndexFilter, NameFilter, CompositeFilter

nccl_kernels_in_first_iteration_filter = CompositeFilter([
    IterationIndexFilter(0), 
    GPUKernelFilter(), 
    NameFilter("^nccl", symbol_table=symbol_table)])
selected_df = nccl_kernels_in_first_iteration_filter(df)
print(f"Found {len(selected_df)} matching events.")
selected_df.head()[["iteration", "index", "ts", "dur", "stream", "s_cat", "s_name"]]

Found 5 matching events.


Unnamed: 0,iteration,index,ts,dur,stream,s_cat,s_name
2512,551,2512,1682725898094377,30669.0,84,kernel,ncclKernel_SendRecv_RING_SIMPLE_Sum_int8_t
2570,551,2570,1682725898125052,62783.0,84,kernel,ncclKernel_SendRecv_RING_SIMPLE_Sum_int8_t
2962,551,2962,1682725898295662,11727.0,84,kernel,ncclKernel_SendRecv_RING_SIMPLE_Sum_int8_t
4554,551,4554,1682725898348496,47652.0,84,kernel,ncclKernel_SendRecv_RING_SIMPLE_Sum_int8_t
4794,551,4794,1682725898585321,42496.0,203,kernel,ncclKernel_SendRecv_RING_SIMPLE_Sum_int8_t
