# Import Libraries

In [None]:
import os
import torch
from torchvision import transforms
from src.data_loader.freihand_loader import F_DB
from src.models.baseline_model import BaselineModel
from torch.utils.data import DataLoader
from src.data_loader.utils import convert_2_5D_to_3D
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger
from src.visualization.visualize import plot_hand
from src.constants import MASTER_THESIS_DIR, FREIHAND_DATA
from src.utils import read_json
import matplotlib.pyplot as plt
from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display

# Training configuration 

In [None]:
training_hyper_param = read_json(
        os.path.join(MASTER_THESIS_DIR, "src", "experiments", "training_config.json")
    )

# Dataset 

In [None]:
f_db = F_DB(
    root_dir=os.path.join(FREIHAND_DATA, "training", "rgb"),
    labels_path=os.path.join(FREIHAND_DATA, "training_xyz.json"),
    camera_param_path=os.path.join(FREIHAND_DATA, "training_K.json"),
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            #             transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    ),
)
train_percentage = int(training_hyper_param["train_ratio"] * 100)
train, val = torch.utils.data.random_split(
    f_db,
    [
        len(f_db) * train_percentage // 100,
        len(f_db) - len(f_db) * train_percentage // 100,
    ],
)
train_data_loader = DataLoader(train, batch_size=training_hyper_param["batch_size"])
val_data_loader = DataLoader(val, batch_size=training_hyper_param["batch_size"])

# Logger

In [None]:
comet_logger = CometLogger(
    api_key=os.environ.get("COMET_API_KEY"),
    project_name="master-thesis",
    workspace="dahiyaaneesh",
    save_dir=os.path.join(MASTER_THESIS_DIR, "models"),
)

# Model Training

In [None]:
model = BaselineModel(freeze_resnet=training_hyper_param["resnet_trainable"])
trainer = Trainer( max_epochs=training_hyper_param["epochs"], logger=comet_logger)
trainer.fit(model,train_data_loader,val_data_loader)

# Visualizations

In [None]:
@interact(id=widgets.IntSlider(min=0, max=len(f_db), step=1, value=0))
def visualize_sample(id):
    A = f_db[id]["joints"]
    s = f_db[id]["scale"]
    K = f_db[id]["K"]
    img = transforms.ToPILImage('RGB')(f_db[id]["image"])
    display(img)
    Axy = f_db[id]["joints_3D"]
    fig = plt.figure(figsize=(10,10))
    gs = fig.add_gridspec(5, 5)
    ax1 = fig.add_subplot(gs[1:3,:2])
    ax1.set_title('2d pose')
    plot_hand(ax1, np.array(A))
    ax2 = fig.add_subplot(gs[:,2:], projection='3d')
    ax2.set_title('3D pose')
    plot_hand(ax2, np.array(Axy),plot_3d=True)
    plt.show()