### NOTE : This is a sample notebook. Please make a copy of it for yourself and try it out.

<a id='top'></a>
This notebook is a follow up tutorial. Please make sure to go through the [DataSetManagement-Basic](./DataSetManagement-Basic.ipynb) before trying out this notebook. 

This tutorial shows how to filter out unwanted images and labels at dataset creation stage, or while  applying transformations and creating a new version, or when creating a split to train the models.To show how the filtering is working, we are going to use a smaller sample of [cifar_10 dataset](https://en.wikipedia.org/wiki/CIFAR-10), which has pictures of animals, airplanes, ships and so on.


This tutorial covers the following:
- [Filtering Out Data When Creating Dataset](#filter_at_create)
- [Filtering Out Data While Applying Transformations](#filter_at_trans)
- [Filtering Out Data When Creating Dataset Training/Test/Validation Splits](#filter_at_split)



In [None]:
from sbrain.dataset import DataSetImageClassification,DataSetVersion,DataSetSplit
from sbrain.dataset import DataSetStatus,JobStatus,DataSetSplitStatus,DataSetVersionStatus
from sbrain.dataset import Transformation,TransformationSet
import numpy as np
import cv2
import uuid
import time
from IPython.display import clear_output

#### Please set the username you used to log into sbrain ui in the following cell

In [None]:
user_name = "admin"

In [None]:
import time
def unique_id():
    return str(int(time.time()))

The sample dataset has images belonging to following classes.

In [None]:
classes = {
                'airplane': 0,
                'automobile':1,
                'bird': 2,
                'cat': 3,
                'deer': 4,
                'dog': 5,
                'frog': 6,
                'horse': 7,
                'ship': 8,
                'truck': 9
            }

Following code lists out the files in the cifar10_small dataset we are going to use.

In [None]:
import glob
files = glob.glob("../demo-data/cifar10_small/*.*")
files = [f.split("/")[-1:][0] for f in files]
files = "\n".join(files)
print(files)

<a id='filter_at_create'></a>
## Filtering Out Data When Creating Dataset

For filtering data out when creating dataset, the image_iterator and/or label_iterator functions can be written in a way to filter out unwanted data as shown below.

Following example  we are using the image iterator to filter out airplane images, 
and label iterator to filter out images belonging to class 2, i.e. the bird images. 


NOTE: the example shows sample how to use image_iterator and label_iterator to filter in one call. 
You don't have to use both. You can use just image_iterator to filter out both airplane and bird images, or just use the label iterator to do the same.
<div align="right"><a href="#top">BackToTheTop</a></div>

In [None]:
def iterator_images(data_root_path):
    import glob
    result = []
    files = glob.glob("{}/*.*".format(data_root_path))
    for f in files:
        # Filtering out images with "_airplane" in name
        if not "_airplane" in f:
            yield f

def iterator_labels(data_root_path):
    import glob
    import time
    files = glob.glob("{}/*.*".format(data_root_path))
    classes = {
                'airplane': 0,
                'automobile':1,
                'bird': 2,
                'cat': 3,
                'deer': 4,
                'dog': 5,
                'frog': 6,
                'horse': 7,
                'ship': 8,
                'truck': 9
            }
    for f in files:
        img_name =  f.split('/')[-1:][0]
        lbl_str = img_name[img_name.index('_')+1:img_name.index('.')]
        lbl_id = classes[lbl_str]
        # Filtering out images with label 2 i.e. class 'bird'
        if lbl_id != 2:
            yield (img_name, lbl_id)    
    

In [None]:
dataset_name = "cifar10-demo-{}".format(unique_id())

job = DataSetImageClassification(name=dataset_name).create(
    description="cifar 10 dataset",
    source_archive_path="shared-dir/sample-notebooks/demo-data/cifar10_small",
    classes=classes,
    collection_date="06-19-2018",
    image_iterator=iterator_images,
    label_iterator=iterator_labels
)

while job.status != JobStatus.COMPLETE.value and job.status != JobStatus.FAILED.value:
    clear_output(wait=True)
    job = job.get_status()

In [None]:
DataSetImageClassification.search(name=dataset_name)
ds = DataSetImageClassification.lookup(dataset_name)
ds.search_versions(version_name="v1")

In [None]:
ds_get = DataSetImageClassification.lookup(dataset_name)
ds_version = ds_get.version("v1")

Now that the dataset is created, lets see if the newly created dataset has the airplane and bird images filtered out.

In [None]:
version_iterator = ds_version.get_iterator()
result = version_iterator.get_all()
print("image_name, label\n")
for k,v in result:
    k = k.split("/")[-1]
    print("{}, {}".format(k,v))

<a id='filter_at_trans'></a>
## Filtering Out Data While Applying Transformations
<div align="right"><a href="#top">BackToTheTop</a></div>

Following is a simple transformation which just flips an image. 

In [None]:
class Flip(Transformation):
    def __init__(self, name):
        super().__init__(name)

    def process(self, arr_in):
        rotated_image = cv2.flip(arr_in, 1)
        return rotated_image


In [None]:
flip = Flip(name="flip-{}".format(unique_id())).create(author="jasmine",description="flip images")

Filtering out data when running transformation job, can be done using either of the following 2 ways:
 
 1. **data_exclude_function** : which filters out data based on the image file. The function gets absolute path to image as input and returns Boolean to say whether to exclude this image or not.
 
 2. **label_exclude_function** : which filters out data based on the label of the image. The function has "label" as the input and should return a Boolean to say whether to exclude this image or not.
 
 
In the following example, we will use the label_exclude_func to filter out images with label 5 , i.e. class 'dog', and use the image_exclude_func to filter out images with "_cat" in their name, i.e. class "cat".
 
NOTE : Following example shows how to use both, you can use either one of them.
 

In [None]:
def label_exclude_func(label):
    # Filter out label 5, i.e. dog
    return int(label) == 5

In [None]:
def image_exclude_func(img_path):
    # Filter out cat
    return "_cat" in img_path

In [None]:
version_flipped_1_name = "flipped-{}".format(unique_id())
tj = ds_version.transform(flip).run(target_version=version_flipped_1_name, 
                                    num_workers=2, 
                                    data_exclude_function=image_exclude_func, 
                                    label_exclude_function=label_exclude_func)

#Check job status
status = tj.get_status().lower()
while status.lower() != 'complete':
    clear_output(wait=True)
    status = tj.get_status().lower()
    time.sleep(2)

In [None]:
ds.search_versions(version_name=version_flipped_1_name)
version_flipped_1 = ds_get.version( version_name =  version_flipped_1_name)

Now lets look at the images in the "flipped" version we just created. 

It should not have any cat or dog images.

In [None]:
version_flipped_1_iterator = version_flipped_1.get_iterator()
result = version_flipped_1_iterator.get_all()
print("image_name, label\n")
for k,v in result:
    k = k.split("/")[-1]
    print("{}, {}".format(k,v))

<a id='filter_at_split'></a>
## Filtering Out Data When Creating Dataset Training/Test/Validation Splits
<div align="right"><a href="#top">BackToTheTop</a></div>

For filtering out data, when creating dataset splits, we can use the similar data_exclude_function and/or label_exclude_function, that we used above to filter when running transformation job.

In the following example, we will use the label_exclude_func to filter out images with label 9 , i.e. class 'truck', and use the image_exclude_func to filter out images with "_ship" in their name, i.e. class "ship".
 
NOTE : Following example shows how to use both, you can use either one of them.
 

In [None]:
def label_exclude_func_2(label):
    # Filtering out truck images
    return int(label) == 9

def image_exclude_func_2(img_path):
    # Filtering out ship images
    return "_ship" in img_path

In [None]:
split_name = "flipped-split-60-30-10--{}".format(unique_id())
split_job = version_flipped_1.create_data_split(split_name=split_name,
                                                split_percentages=[60,30,10],
                                                data_exclude_function=image_exclude_func_2,
                                                label_exclude_function=label_exclude_func_2,
                                                description="example split with filter functions"
                                            )


#Check job status
while split_job.status != JobStatus.COMPLETE.value and split_job.status != JobStatus.FAILED.value:
    clear_output(wait=True)
    split_job = split_job.get_status()
    time.sleep(2)

In [None]:
version_flipped_1.search_splits(split_name=split_name)

In [None]:
cifar_10_split = DataSetSplit.lookup(dataset_name=dataset_name, 
                                     dataset_version_name=version_flipped_1_name, 
                                     split_name=split_name)

The split creates 3 Tensorflow's TFRecordDataset files, one each for train,test and validation.
Following code shows how to list out files in the tfrecord files.
"ship" or "truck" images should not be present in any of the tfrecord files.


In [None]:
import tensorflow as tf
import json
sess = tf.Session()

dataset = cifar_10_split.train_tfrecord(all_fields=True)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(50):
    try:
        data,lbl,name,h,w,d = sess.run(next_element)
        name = name.decode('utf-8')
        label = int(lbl.decode('utf-8'))
        print("{} : {}".format(name, label))
    except Exception as ex:
        break

In [None]:
import tensorflow as tf
import json
sess = tf.Session()

dataset = cifar_10_split.test_tfrecord(all_fields=True)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(50):
    try:
        data,lbl,name,h,w,d = sess.run(next_element)
        name = name.decode('utf-8')
        label = int(lbl.decode('utf-8'))
        print("{} : {}".format(name, label))
    except Exception as ex:
        break

In [None]:
import tensorflow as tf
import json
sess = tf.Session()

dataset = cifar_10_split.eval_tfrecord(all_fields=True)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(50):
    try:
        data,lbl,name,h,w,d = sess.run(next_element)
        name = name.decode('utf-8')
        label = int(lbl.decode('utf-8'))
        print("{} : {}".format(name, label))
    except Exception as ex:
        break

## **_<font color="green">Congratulations !!! You completed the tutorial successfully.</font>_**