   
   <center><font size="8">Data visualization</font></center>
   

In [None]:
# Import Libraries
import os
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from src.data_loader.data_set import Data_Set
from src.data_loader.utils import get_train_val_split
from src.constants import MASTER_THESIS_DIR, FREIHAND_DATA
from src.utils import read_json
import matplotlib.pyplot as plt
from ipywidgets import interactive,GridspecLayout
import ipywidgets as widgets
from IPython.display import display
import copy
from easydict import EasyDict as edict
from src.utils import read_json
from src.visualization.visualize import plot_hand
import pandas as pd

In [None]:
# Read Data
train_param = edict(
    read_json(f"{MASTER_THESIS_DIR}/src/experiments/config/training_config.json")
)
train_data = Data_Set(
    config=train_param,
    transform=transforms.ToTensor(),
    train_set=True,
    experiment_type="pairwise",
)
val_data = copy.copy(train_data)
val_data.is_training(False)

train_data_loader, val_data_loader = get_train_val_split(
    train_data,
    batch_size=train_param.batch_size,
    num_workers=train_param.num_workers,
)
params = {"ytick.color" : "w",
          "xtick.color" : "w",
          "axes.labelcolor" : "w",
          "axes.edgecolor" : "w",
         "text.color":"w"}
plt.rcParams.update(params)
font_size=5

In [None]:
def visualize(
    idx,
    experiment_type,
    random_crop,
    crop,
    color_jitter,
    cut_out,
    resize,
    color_drop,
    rotate,
    gaussian_blur,
    gaussian_noise,
    sobel_filter,  
):
    train_param.augmentation_flags.random_crop = random_crop
    train_param.augmentation_flags.crop = crop
    train_param.augmentation_flags.color_jitter = color_jitter
    train_param.augmentation_flags.cut_out = cut_out
    train_param.augmentation_flags.resize = resize
    train_param.augmentation_flags.color_drop = color_drop
    train_param.augmentation_flags.gaussian_blur = gaussian_blur
    train_param.augmentation_flags.rotate = rotate
    train_param.augmentation_flags.sobel_filter = sobel_filter
    train_param.augmentation_flags.gaussian_noise = gaussian_noise
#     train_param.augmentation_params.crop_margin_range = [
#         crop_margin_range / 100.0,
#         crop_margin_range / 100.0,
#     ]
#     train_param.augmentation_params.cut_out_fraction = [
#         cut_out_fraction / 100.0,
#         cut_out_fraction / 100.0,
#     ]
#     train_param.augmentation_params.hue_factor_range=[hue_factor_range/100.0,hue_factor_range/100.0]
#     train_param.augmentation_params.sat_factor_range=[sat_factor_range/100.0,sat_factor_range/100.0]
#     train_param.augmentation_params.value_factor_beta_range= [value_factor_beta_range/100.0,value_factor_beta_range/100.0]
#     train_param.augmentation_params.value_factor_alpha_range= [value_factor_alpha_range/100.0 ,value_factor_alpha_range/100.0]
    train_data.update_augmenter(
        train_param.augmentation_params, train_param.augmentation_flags
    )
    if experiment_type == "supervised":
        train_data.experiment_type = "supervised"
        sample = train_data[idx]
        joints = sample["joints"]
        img = torch.flip(sample["image"],(0,))
        fig = plt.figure(figsize=(20, 20))
        ax = fig.add_subplot(121)
        plt.imshow(transforms.ToPILImage()(img))
        plot_hand(ax, joints)
        ax = fig.add_subplot(122, projection='3d')
        ax.set_facecolor('black')
        plot_hand(ax, sample["joints3D"], plot_3d=True,alpha=0.2,linestyle="-", linewidth="5" )
        plot_hand(ax, sample["joints3D_recreated"], plot_3d=True, linestyle=":", )
        display(pd.DataFrame({"Max Error recreated 3D": [torch.max(sample["joints3D"]-sample["joints3D_recreated"]).numpy()]}))
    elif experiment_type == "simclr":
        train_data.experiment_type = "simclr"
        sample = train_data[idx]
        fig = plt.figure(figsize=(20, 20))
        ax = fig.add_subplot(121)
        plt.imshow(transforms.ToPILImage()(torch.flip(sample["transformed_image1"],(0,))))
        ax.set_title("Image 1")
        ax = fig.add_subplot(122)
        plt.imshow(transforms.ToPILImage()(torch.flip(sample["transformed_image2"],(0,))))
        ax.set_title("Image 2")
    elif experiment_type == "pairwise":
        train_data.experiment_type = "pairwise"
        sample = train_data[idx]
        fig = plt.figure(figsize=(20, 20))
        ax = fig.add_subplot(121)
        plt.imshow(transforms.ToPILImage()(torch.flip(sample["transformed_image1"],(0,))))
        ax.set_title("Image 1")
        ax = fig.add_subplot(122)
        plt.imshow(transforms.ToPILImage()(torch.flip(sample["transformed_image2"],(0,))))
        ax.set_title("Image 2")
        display(pd.DataFrame({k:[v.numpy()] for k,v in sample.items() if 'image' not in k and 'joints' not in k}))
    elif experiment_type == "experiment4_pretraining":
        train_data.experiment_type = "experiment4_pretraining"
        sample = train_data[idx]
        fig = plt.figure(figsize=(20, 20))
        ax = fig.add_subplot(121)
        plt.imshow(transforms.ToPILImage()(torch.flip(sample["transformed_image1"],(0,))))
        ax.set_title("Image 1")
        ax = fig.add_subplot(122)
        plt.imshow(transforms.ToPILImage()(torch.flip(sample["transformed_image2"],(0,))))
        ax.set_title("Image 2")
    elif experiment_type == "pairwise_ablative":
        train_data.experiment_type = "pairwise_ablative"
        sample = train_data[idx]
        fig = plt.figure(figsize=(20, 20))
        ax = fig.add_subplot(121)
        plt.imshow(transforms.ToPILImage()(torch.flip(sample["transformed_image1"],(0,))))
        ax.set_title("Image 1")
        ax = fig.add_subplot(122)
        plt.imshow(transforms.ToPILImage()(torch.flip(sample["transformed_image2"],(0,))))
        ax.set_title("Image 2")
        display(pd.DataFrame({k:[v.numpy()] for k,v in sample.items() if 'image' not in k and 'joints' not in k}))
    return fig
visualization_panel = interactive(visualize,
    idx=widgets.IntSlider(min=1, max=3000, step=5, value=3),
    experiment_type=widgets.Dropdown(
        options=["supervised", "simclr", "pairwise","experiment4_pretraining", "pairwise_ablative"],
        value="pairwise",
        description=f"<font size='{font_size}'>Model</font>",
        disabled=False,
    ),
    random_crop=widgets.Checkbox(value=False, description=f"<font size='{font_size}'>Random crop</font>"),
    crop=widgets.Checkbox(value=True, description=f"<font size='{font_size}'>Crop</font>"),
    color_jitter=widgets.Checkbox(value=False, description=f"<font size='{font_size}'>Color jitter</font>"),
    cut_out=widgets.Checkbox(value=False, description=f"<font size='{font_size}'>Cut out</font>"),
    resize=widgets.Checkbox(value=True, description=f"<font size='{font_size}'>Resize</font>"),
    color_drop=widgets.Checkbox(value=False, description=f"<font size='{font_size}'>Color drop</font>"),
    rotate=widgets.Checkbox(value=False, description=f"<font size='{font_size}'>Rotate</font>"),
    gaussian_blur=widgets.Checkbox(value=False, description=f"<font size='{font_size}'>Blur(gaussian)</font>"),
    gaussian_noise=widgets.Checkbox(value=False, description=f"<font size='{font_size}'>Noise(gaussian)</font>"),
    sobel_filter=widgets.Checkbox(value=False, description=f"<font size='{font_size}'>sobel_filter</font>"),
)
augmentation_checkboxes = visualization_panel.children[:-1]
grid = GridspecLayout(22,6,)
grid[0,1]= visualization_panel.children[0]
visualization_panel.children[0].description = f"<font size='{font_size}'>Index</font>"
grid[0,3]= visualization_panel.children[1]
count= 2
for i in range(2,4):
    for j in range(5):
        grid[i,j] = augmentation_checkboxes[count]
        count+=1
grid[4:,1:]= visualization_panel.children[-1]
display(grid)

In [None]:
## Visualize sample in batch
# @interact(
#     idx=widgets.IntSlider(min=0, max=31, step=1, value=10),
#     experiment_type=widgets.Dropdown(
#         options=["supervised", "simclr", "pairwise"],
#         value="pairwise",
#         description="Experiment type:",
#         disabled=False,
#     ),
# )
# def vis(idx, experiment_type):
#     train_data.experiment_type = experiment_type
#     for i, elem in enumerate(train_data_loader):
#         if experiment_type == "supervised":
#             sample = elem
#             joints = sample["joints"][idx]
#             img = sample["image"][idx]
#             fig = plt.figure(figsize=(5, 5))
#             ax = fig.add_subplot(111)
#             plt.imshow(transforms.ToPILImage()(img))
#             plot_hand(ax, joints)
#             plt.show()
#         elif experiment_type == "simclr":
#             sample = elem
#             fig = plt.figure(figsize=(10, 10))
#             ax = fig.add_subplot(121)
#             plt.imshow(transforms.ToPILImage()(sample["transformed_image1"][idx]))
#             ax.set_title("Image 1")
#             ax = fig.add_subplot(122)
#             plt.imshow(transforms.ToPILImage()(sample["transformed_image2"][idx]))
#             ax.set_title("Image 2")
#         elif experiment_type == "pairwise":
#             train_data.experiment_type = "pairwise"
#             sample = train_data[idx]
#             fig = plt.figure(figsize=(10, 10))
#             ax = fig.add_subplot(121)
#             plt.imshow(transforms.ToPILImage()(sample["transformed_image1"]))
#             ax.set_title("Image 1")
#             ax = fig.add_subplot(122)
#             plt.imshow(transforms.ToPILImage()(sample["transformed_image2"]))
#             ax.set_title("Image 2")
#         break