# Explore the dataset


In this notebook, we will perform an EDA (Exploratory Data Analysis) on the processed Waymo dataset (data in the `processed` folder). In the first part, you will create a function to display 

In [None]:
from utils import get_dataset
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
%matplotlib inline

#import matplotlib as mpl
#mpl.use("TkAgg")

In [None]:
dataset = get_dataset("/home/workspace/data/processed/*.tfrecord")

## Write a function to display an image and the bounding boxes

Implement the `display_instances` function below. This function takes a batch as an input and display an image with its corresponding bounding boxes. The only requirement is that the classes should be color coded (eg, vehicles in red, pedestrians in blue, cyclist in green).

In [None]:
def display_instances(batch):
    """
    This function takes a batch from the dataset and display the image with 
    the associated bounding boxes.
    """

    # from download_process.py
    # mapping = {1: 'vehicle', 2: 'pedestrian', 4: 'cyclist'}
    colormap = {1: [1, 0, 0], 2: [0, 1, 0], 4: [0, 0, 1]}

    img = batch["image"]

    f, ax = plt.subplots(figsize=(20, 10))
    ax.imshow(img)

    groundtruth_classes = batch["groundtruth_classes"]
    groundtruth_boxes = batch["groundtruth_boxes"]

    for cl, bb in zip(groundtruth_classes, groundtruth_boxes):
        # 640*640 
        y1, x1, y2, x2 = bb * 640
        rec = Rectangle((x1, y1), x2 - x1, y2-y1, facecolor='none',
                        edgecolor=colormap[cl.numpy()])
        ax.add_patch(rec)
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

## Display 10 images 

Using the dataset created in the second cell and the function you just coded, display 10 random images with the associated bounding boxes. You can use the methods `take` and `shuffle` on the dataset.

In [None]:
## STUDENT SOLUTION HERE
for batch in dataset.shuffle(100).take(10):
    display_instances(batch)


## Additional EDA

In this last part, you are free to perform any additional analysis of the dataset. What else would like to know about the data?
For example, think about data distribution. So far, you have only looked at a single file...

In [None]:
def display_instances_EDA(max_batch,max_num,min_batch,min_num,class_name=None):
    """
    This function takes a batch from the dataset and display the image with 
    the associated bounding boxes.
    """

    # from download_process.py
    mapping = {1: 'vehicle', 2: 'pedestrian', 4: 'cyclist'}
    colormap = {1: [1, 0, 0], 2: [0, 1, 0], 4: [0, 0, 1]}

    max_img = max_batch["image"]
    min_img = min_batch["image"]

    f, ax = plt.subplots(1,2)
    ax[0].imshow(max_img)
    ax[1].imshow(min_img)
    
    if class_name is None:
        ax[0].set_title('max number of all classes is '+max_num)
        ax[1].set_title('min number of all classes is '+min_num)
    else:
        ax[0].set_title('max number of '+mapping[class_name]+' is '+max_num)
        ax[1].set_title('min number of '+mapping[class_name]+' is '+min_num)

    groundtruth_classes = batch["groundtruth_classes"]
    groundtruth_boxes = batch["groundtruth_boxes"]

    for cl, bb in zip(groundtruth_classes, groundtruth_boxes):
        # 640*640 
        y1, x1, y2, x2 = bb * 640
        rec = Rectangle((x1, y1), x2 - x1, y2-y1, facecolor='none',
                    edgecolor=colormap[cl.numpy()])
                        
        ax.add_patch(rec)
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
# mapping = {1: 'vehicle', 2: 'pedestrian', 4: 'cyclist'}
class_vehicle = 0
class_pedestrian = 0
class_cyclist = 0

max_vehicle=0
max_vehicle_batch=None
max_pedestrian=0
max_pedestrian_batch=None
max_cyclist=0
max_cyclist_batch=None
max_class=0
max_class_batch=None


min_vehicle=9999
min_vehicle_batch=None
min_pedestrian=9999
min_pedestrian_batch=None
min_cyclist=9999
min_cyclist_batch=None
min_class=9999
min_class_batch=None


for batch in dataset:
    groundtruth_classes = batch["groundtruth_classes"].numpy()
    if len(groundtruth_classes)>max_class:
        max_class=len(groundtruth_classes)
        max_class_batch=batch

    if len(groundtruth_classes)>min_class:
        min_class=len(groundtruth_classes)
        min_class_batch=batch
    
    num_vehicle=0
    num_pedestrian=0
    num_cyclist=0

    for cl in groundtruth_classes:
        if cl is 1: num_vehicle+=1
        elif cl is 2: num_pedestrian+=1
        elif cl is 4: num_cyclist+=1
        else: raise Exception('unexpected this label'+cl)


    # 
    class_vehicle+=num_vehicle
    class_pedestrian +=num_pedestrian 
    class_cyclist += num_cyclist
    
    # max and min number of vehicle 
    if num_vehicle>max_vehicle:
        max_vehicle=num_vehicle
        max_vehicle_batch=batch

    if num_vehicle<min_vehicle:
        min_vehicle=num_vehicle
        min_vehicle_batch=batch

    # max and min number of pedestrian
    if num_pedestrian>max_pedestrian:
        max_pedestrian=num_pedestrian
        max_pedestrian_batch=batch

    if num_pedestrian<min_pedestrian:
        min_pedestrian=num_pedestrian
        min_pedestrian_batch=batch

    # max and min number of cyclist  
    if num_cyclist>max_cyclist:
        max_cyclist=num_cyclist
        max_cyclist_batch=batch

    if num_cyclist<min_cyclist:
        min_cyclist=num_cyclist
        min_cyclist_batch=batch


In [None]:
display_instances_EDA(max_class_batch,max_class,min_class_batch,min_class)

In [None]:
display_instances_EDA(max_vehicle_batch,max_vehicle,min_vehicle_batch,min_vehicle,1)

In [None]:
display_instances_EDA(max_pedestrian_batch,max_pedestrian,min_pedestrian_batch,min_pedestrian,2)

In [None]:
display_instances_EDA(max_cyclist_batch,max_cyclist,min_cyclist_batch,min_cyclist,4)

In [None]:
data = [class_vehicle,class_pedestrian,class_pedestrian]
labels = ['num_vehicle', 'num_pedestrian', 'num_pedestrian']
plt.title('the total number of labels in all image')
plt.bar(range(len(data)), data,tick_label=labels)
plt.show()