# Import Libraries

In [None]:
import os
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from src.data_loader.data_set import Data_Set
from src.constants import MASTER_THESIS_DIR, FREIHAND_DATA
from src.utils import read_json
import matplotlib.pyplot as plt
from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display
import copy
from easydict import EasyDict as edict
from src.utils import read_json
from src.visualization.visualize import plot_hand

# Read Data

In [None]:
train_param = edict(read_json(f"{MASTER_THESIS_DIR}/src/experiments/training_config.json"))
print(train_param)

In [None]:
train_param.augmentation_flags.random_crop =True
train_param.augmentation_flags.crop =True
train_param.augmentation_flags.color_jitter =True
train_param.augmentation_flags.cut_out =True
train_param.augmentation_flags.resize =True
train_param.augmentation_flags.color_drop = True
train_param.augmentation_flags.gaussian_blur =True
train_param.augmentation_params.crop_margin_range =[0.8 ,1.5]
train_param.augmentation_params.cut_out_fraction = [0.05,0.2]

In [None]:
train_data = Data_Set(
    config=train_param,
    transform=None,
#     transform=transforms.Compose(
#         [transforms.ToTensor()]
#     ),
    train_set=True,
)
val_data = copy.copy(train_data)
val_data.is_training(False)

train_data_loader = DataLoader(
    train_data,
    batch_size=train_param.batch_size,
    num_workers=train_param.num_workers,
)
val_data_loader = DataLoader(
    val_data, batch_size=train_param.batch_size, num_workers=train_param.num_workers
)

In [None]:
train_data[0]["image"].shape

# Data visualization

## Visualize sample

## Visualize sample in batch

In [None]:
@interact(
    idx=widgets.IntSlider(min=1, max=100, step=5, value=3),
    experiment_type=widgets.Dropdown(
        options=["supervised", "simclr"],
        value="supervised",
        description="Experiment type:",
        disabled=False,
    ),
)
def visualize(idx, experiment_type):
    if experiment_type == "supervised":
        train_data.experiment_type = "supervised"
        sample = train_data[idx]
        joints = sample["joints"]
        img = sample["image"]
        fig = plt.figure(figsize=(5, 5))
        ax = fig.add_subplot(111)
        plt.imshow(transforms.ToPILImage()(img))
        plot_hand(ax, joints)
    elif experiment_type == "simclr":
        train_data.experiment_type = "simclr"
        sample = train_data[idx]
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(121)
        plt.imshow(transforms.ToPILImage()(sample["transformed_image2"]))
        ax.set_title("Image 1")
        ax = fig.add_subplot(122)
        plt.imshow(transforms.ToPILImage()(sample["transformed_image1"]))
        ax.set_title("Image 2")

In [None]:
@interact(
    idx=widgets.IntSlider(min=0, max=31, step=1, value=10),
    experiment_type=widgets.Dropdown(
        options=["supervised", "simclr"],
        value="supervised",
        description="Experiment type:",
        disabled=False,
    )
)
def vis(idx, experiment_type):
    train_data.experiment_type = experiment_type
    for i, elem in enumerate(train_data_loader):
        if experiment_type == "supervised":
            sample = elem
            joints = sample["joints"][idx]
            img = sample["image"][idx]
            fig = plt.figure(figsize=(5, 5))
            ax = fig.add_subplot(111)
            plt.imshow(transforms.ToPILImage()(img.numpy()))
            plot_hand(ax, joints)
            plt.show()
        elif experiment_type == "simclr":
            sample = elem
            fig = plt.figure(figsize=(10, 10))
            ax = fig.add_subplot(121)
            plt.imshow(transforms.ToPILImage()(sample["transformed_image2"][idx].numpy()))
            ax.set_title("Image 1")
            ax = fig.add_subplot(122)
            plt.imshow(transforms.ToPILImage()(sample["transformed_image1"][idx].numpy()))
            ax.set_title("Image 2")
        break