# Advanced Tutorial 4: Trace

## Overview
In this tutorial, we will talk on:
* **Customizing Trace**
    * Example
* **More about Trace**
    * inputs, outputs and mode
    * data
    * system
* **Trace communication**
* **Other Trace usages**    
    * Debugging/Monitoring

Let's create a function to generate pipeline, model and network to be used for the tutorial.

In [1]:
import fastestimator as fe
from fastestimator.architecture.tensorflow import LeNet
from fastestimator.dataset.data import mnist
from fastestimator.op.numpyop.univariate import ExpandDims, Minmax
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp


def get_pipeline_model_network(model_name="LeNet", batch_size=32):
    train_data, eval_data = mnist.load_data()
    test_data = eval_data.split(0.5)
    
    pipeline = fe.Pipeline(train_data=train_data,
                           eval_data=eval_data,
                           test_data=test_data,
                           batch_size=batch_size,
                           ops=[ExpandDims(inputs="x", outputs="x"), 
                                Minmax(inputs="x", outputs="x")])

    model = fe.build(model_fn=LeNet, optimizer_fn="adam", model_names=model_name)

    network = fe.Network(ops=[
        ModelOp(model=model, inputs="x", outputs="y_pred"),
        CrossEntropy(inputs=("y_pred", "y"), outputs="ce"),
        UpdateOp(model=model, loss_name="ce")
    ])

    return pipeline, model, network

## Customizing Trace
In [Tutorial 7](https://github.com/TortoiseHam/fastestimator/blob/tutorials/summary/tutorial/beginner/t07_estimator.ipynb) in the beginner section, we talked about the basic concept and structure of trace and used few Traces from Fastestimator. We can also customize a Trace to suit our needs. Let's look at an example of a custom trace implementation.

### Example
We can utilize traces to calculate any custom metric needed for mintoring or controlling training. Below, we implement a trace for calculating F-beta score of our model.

In [2]:
from fastestimator.backend import to_number
from fastestimator.trace import Trace
from sklearn.metrics import fbeta_score
import numpy as np

class FBetaScore(Trace):
    def __init__(self, true_key, pred_key, beta=2, output_name="f_beta_score", mode=["eval", "test"]):
        super().__init__(inputs=(true_key, pred_key), outputs=output_name, mode=mode)
        self.true_key = true_key
        self.pred_key = pred_key
        self.beta = beta
        self.y_true = []
        self.y_pred = []
        
    def on_epoch_begin(self, data):
        self.y_true = []
        self.y_pred = []
        
    def on_batch_end(self, data):
        y_true, y_pred = to_number(data[self.true_key]), to_number(data[self.pred_key])
        y_pred = np.argmax(y_pred, axis=-1)
        self.y_pred.extend(y_pred.ravel())
        self.y_true.extend(y_true.ravel())
        
    def on_epoch_end(self, data):
        score = fbeta_score(self.y_true, self.y_pred, beta=self.beta, average="weighted")
        data.write_with_log(self.outputs[0], score)

Here we will calculate f2-score. f2-score gives more importance to recall.

In [3]:
pipeline, model, network = get_pipeline_model_network()

traces = FBetaScore(true_key="y", pred_key="y_pred", beta=2, output_name="f2_score", mode="eval")
estimator = fe.Estimator(pipeline=pipeline, network=network, epochs=4, traces=traces, log_steps=1000)

estimator.fit()

FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved.
    ______           __  ______     __  _                 __            
   / ____/___ ______/ /_/ ____/____/ /_(_)___ ___  ____ _/ /_____  _____
  / /_  / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
 / __/ / /_/ (__  ) /_/ /___(__  ) /_/ / / / / / / /_/ / /_/ /_/ / /    
/_/    \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/     
                                                                        

FastEstimator-Start: step: 1; LeNet_lr: 0.001; 
FastEstimator-Train: step: 1; ce: 2.3143477; 
FastEstimator-Train: step: 1000; ce: 0.015297858; steps/sec: 254.94; 
FastEstimator-Train: step: 1875; epoch: 1; epoch_time: 10.92 sec; 
FastEstimator-Eval: step: 1875; epoch: 1; ce: 0.043765396; min_ce: 0.043765396; since_best: 0; f2_score: 0.9861824320950051; 
FastEstimator-Train: step: 2000; ce: 0.04872855; steps/sec: 229.85; 
FastEstimator-Train: step: 3000; ce: 0.044155814; steps/sec

## More about Trace
As we have now seen a custom Trace implementaion, let's delve deeper into the structure of Trace.

### Inputs, Outputs and Mode
These Trace arguments are similar to the Operator. To recap, the keys from the data dictionary which are required by the Trace can be specified using the `inputs` argument. The `outputs` argument is used to specify the keys which the Trace wants to write into the system buffer. `mode` is used to specify the mode(s) for trace execution.

### Data
Through the data argument, Trace has access to current data dictionary. You can use the keys passed through the `inputs` argument to access information from the data dictionary. 
We can write the outputs into the `Data` dictionary with or without logging using `write_with_log` and `write_without_log` methods respectively.

### System
Trace has access to the current `System` instance which has information on network and training. The information provided by System is listed below:
* global_step
* num_devices
* log_steps
* total_epochs
* epoch_idx
* batch_idx
* stop_training
* network
* max_steps_per_epoch
* summary
* experiment_time

We will showcase `System` usage in **Other Trace Usage** section in this tutorial. 

## Trace communication
We can have multiple traces in a network where the output of one trace is utilized by the other as depicted below: 

<img src="../resources/t04_advanced_trace_communication.png" alt="drawing" width="500"/>

Below, we demonstrate an example where we utilize the outputs of Precision and Recall traces to generate f1-score

In [4]:
from fastestimator.trace.metric import Precision, Recall

class CustomF1Score(Trace):
    def __init__(self, precision_key, recall_key, mode=["eval", "test"], output_name="f1_score"):
        super().__init__(inputs=(precision_key, recall_key), outputs=output_name, mode=mode)
        self.precision_key = precision_key
        self.recall_key = recall_key
        
    def on_epoch_end(self, data):
        precision = data[self.precision_key]
        recall = data[self.recall_key]
        score = 2*(precision*recall)/(precision+recall)
        data.write_with_log(self.outputs[0], score)
        

pipeline, model, network = get_pipeline_model_network()

traces = [
    Precision(true_key="y", pred_key="y_pred", mode=["eval", "test"], output_name="precision"),
    Recall(true_key="y", pred_key="y_pred", mode=["eval", "test"], output_name="recall"),
    CustomF1Score(precision_key="precision", recall_key="recall", mode=["eval", "test"], output_name="f1_score")
]
estimator = fe.Estimator(pipeline=pipeline, network=network, epochs=2, traces=traces, log_steps=1000)

FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved.


In [5]:
estimator.fit()

    ______           __  ______     __  _                 __            
   / ____/___ ______/ /_/ ____/____/ /_(_)___ ___  ____ _/ /_____  _____
  / /_  / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
 / __/ / /_/ (__  ) /_/ /___(__  ) /_/ / / / / / / /_/ / /_/ /_/ / /    
/_/    \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/     
                                                                        

FastEstimator-Start: step: 1; LeNet_lr: 0.001; 
FastEstimator-Train: step: 1; ce: 2.2954829; 
FastEstimator-Train: step: 1000; ce: 0.02208437; steps/sec: 261.23; 
FastEstimator-Train: step: 1875; epoch: 1; epoch_time: 7.6 sec; 
FastEstimator-Eval: step: 1875; epoch: 1; ce: 0.04879124; min_ce: 0.04879124; since_best: 0; 
precision:
[0.982     ,0.98813559,0.98452611,0.99804305,0.98807157,0.97550111,
 0.98951782,0.98571429,0.96781116,0.99396378];
recall:
[0.99392713,0.99828767,0.97884615,0.96958175,0.994     ,0.98871332,
 0.98128898,0.98773006,0.98471616,0.9

`Note:` precision, recall and f1-score are displayed for each class

## Other Trace usages 

### Debugging/Monitoring
Here, we will implement a custom trace to monitor the predictions. Using this, any discrepancy from the expected behavior can be checked and the relevant corrections can be made. 

In [6]:
class MonitorPred(Trace):
    def __init__(self, true_key, pred_key, mode="train"):
        super().__init__(inputs=(true_key, pred_key), mode=mode)
        self.true_key = true_key
        self.pred_key = pred_key
        
    def on_batch_end(self, data):
        print("Global Step Index: ", self.system.global_step)
        print("Batch Index: ", self.system.batch_idx)
        print("Epoch: ", self.system.epoch_idx)
        print("Batch data has following keys: ", list(data.keys()))
        print("Batch true labels: ", data[self.true_key])
        print("Batch predictictions: ", data[self.pred_key])

pipeline, model, network = get_pipeline_model_network(batch_size=4)

traces = MonitorPred(true_key="y", pred_key="y_pred")
estimator = fe.Estimator(pipeline=pipeline, network=network, epochs=2, traces=traces, max_steps_per_epoch=2, log_steps=None)

FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved.


In [7]:
estimator.fit()

    ______           __  ______     __  _                 __            
   / ____/___ ______/ /_/ ____/____/ /_(_)___ ___  ____ _/ /_____  _____
  / /_  / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
 / __/ / /_/ (__  ) /_/ /___(__  ) /_/ / / / / / / /_/ / /_/ /_/ / /    
/_/    \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/     
                                                                        

Global Step Index:  1
Batch Index:  1
Epoch:  1
Batch data has following keys:  ['ce', 'y_pred', 'y', 'x']
Batch true labels:  tf.Tensor([1 0 6 1], shape=(4,), dtype=uint8)
Batch predictictions:  tf.Tensor(
[[0.10357114 0.10485268 0.1035656  0.09670787 0.089131   0.10154787
  0.1033091  0.10102929 0.09534299 0.10094254]
 [0.09291946 0.09567764 0.11450008 0.0905508  0.08430735 0.09910616
  0.11450168 0.10635708 0.09766228 0.10441744]
 [0.09475411 0.10670866 0.11174129 0.08387841 0.0786593  0.09865108
  0.12429795 0.10309361 0.09057966 0.10763595]
 [0.10064

As you can see, we can visualize information like the Global step, batch number, epoch, keys in the data dictionary, true labels, predictions at batch level etc. using our trace.