# Tutorial 6:  TensorFilter - imbalanced training

When dataset is imbalanced, training data needs to maintain certain distribution to make sure minority classes are not ommitted during the training. In FastEstimator, `TensorFilter` is designed for that purpose.

`TensorFilter` is a Tensor Operator that is used in `Pipeline` along with other tensor operators such as `MinMax` and `Resize`.

There is only two differences between `TensorFilter` and `TensorOp`: 
1. `TensorFilter` does not have outputs
2. the forward function of `TensorFilter` produce a boolean value which indicates whether to keep the data.

## Prepare data (same as tutorial 1)

In [None]:
import numpy as np
import tensorflow as tf
import fastestimator as fe
from fastestimator.pipeline.processing import Minmax

(x_train, y_train), (x_eval, y_eval) = tf.keras.datasets.mnist.load_data()
train_data = {"x": np.expand_dims(x_train, -1), "y": y_train}
eval_data = {"x": np.expand_dims(x_eval, -1), "y": y_eval}
data = {"train": train_data, "eval": eval_data}

## Customize your own Filter
In this example, let's get rid of all image that has label less than 5

In [3]:
from fastestimator.pipeline.processing import TensorFilter

class MyFilter(TensorFilter):
    def forward(self, data, state):
        pass_filter = data >= 5
        return pass_filter

pipeline = fe.Pipeline(batch_size=32, data=data, ops=[MyFilter(inputs="y"), Minmax(inputs="x", outputs="x")])

In [7]:
results = pipeline.show_results()
print("filtering out all data with label less than 5, the label of current batch is:")
print(results[0]["y"])

filtering out all data with label less than 5, the label of current batch is:
tf.Tensor([9 9 6 9 7 9 6 8 9 5 7 9 5 5 9 9 9 6 9 8 8 6 6 8 9 5 5 6 7 8 5 5], shape=(32,), dtype=uint8)


## Using pre-built ScalarFilter

In FastEstimator, if user needs to filter on scalar value with probability, one can use pre-built filter `ScalarFilter`. Let's filter out even number label with 50% probility:

In [8]:
from fastestimator.pipeline.processing import ScalarFilter

pipeline = fe.Pipeline(batch_size=32, 
                       data=data, 
                       ops=[ScalarFilter(inputs="y", filter_value=[0, 2, 4, 6, 8], keep_prob=[0.5, 0.5, 0.5, 0.5, 0.5]), 
                            Minmax(inputs="x", outputs="x")])

In [18]:
results = pipeline.show_results(num_steps=10)
for idx in range(10):
    batch_label = results[idx]["y"].numpy()
    even_count = 0
    odd_count = 0
    for elem in batch_label:
        if elem % 2 == 0:
            even_count += 1
        else:
            odd_count += 1
    print("in batch number {}, there are {} odd labels and {} even labels".format(idx, odd_count, even_count))

in batch number 0, there are 20 odd labels and 12 even labels
in batch number 1, there are 26 odd labels and 6 even labels
in batch number 2, there are 25 odd labels and 7 even labels
in batch number 3, there are 27 odd labels and 5 even labels
in batch number 4, there are 19 odd labels and 13 even labels
in batch number 5, there are 19 odd labels and 13 even labels
in batch number 6, there are 21 odd labels and 11 even labels
in batch number 7, there are 17 odd labels and 15 even labels
in batch number 8, there are 20 odd labels and 12 even labels
in batch number 9, there are 21 odd labels and 11 even labels
