# Import Libraries

In [None]:
from src.models.pairwise_model import PairwiseModel
import os
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from src.data_loader.data_set import Data_Set
from src.data_loader.utils import get_train_val_split
from src.constants import MASTER_THESIS_DIR, FREIHAND_DATA, PAIRWISE_CONFIG, DATA_PATH
from src.utils import read_json
import matplotlib.pyplot as plt
from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display
import copy
from easydict import EasyDict as edict
from src.utils import read_json, get_console_logger
from src.visualization.visualize import plot_hand
from pytorch_lightning import Trainer, seed_everything

from src.models.callbacks.upload_comet_logs import UploadCometLogs
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.callbacks import LearningRateMonitor

In [None]:
train_param = edict(read_json(f"{MASTER_THESIS_DIR}/src/experiments/config/training_config.json"))
train_data = Data_Set(
    config=train_param,
    transform=transforms.ToTensor(),
    train_set=True,
    experiment_type='pairwise'
)
val_data = copy.copy(train_data)
val_data.is_training(False)

train_data_loader, val_data_loader = get_train_val_split(train_data,
    batch_size=train_param.batch_size,
    num_workers=train_param.num_workers,
)


In [None]:
model_param = edict(read_json(PAIRWISE_CONFIG))
model_param.num_samples = len(train_data)
model = PairwiseModel(model_param)
model_param.warmup_epochs=1

In [None]:
comet_logger = CometLogger(
    api_key=os.environ.get("COMET_API_KEY"),
    project_name="master-thesis",
    workspace="dahiyaaneesh",
    save_dir=os.path.join(DATA_PATH, "models"),
)
upload_comet_logs = UploadCometLogs(
    "step", get_console_logger("callback"), "pairwise"
)
lr_monitor = LearningRateMonitor(logging_interval="step")
trainer = Trainer(precision=16,
                  logger = comet_logger,
    callbacks=[upload_comet_logs, lr_monitor],
    gpus="1",
    max_epochs=train_param.epochs,
)

In [None]:
trainer.fit(model,train_data_loader, val_data_loader)