# Import libs

In [None]:
import os
import json
import random
import numpy as np
import pandas as pd
from itertools import product
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from PIL import Image
from datetime import datetime

import utils
from VisionModels import CustomCNN, CustomEfficientNetB3
from VisionDatasets import ContactDataset

%load_ext autoreload
%autoreload 2

# Pre-settings

In [None]:
parameters_json = os.path.join(os.getcwd(), 'settings/parameters.json')
paths_json = os.path.join(os.getcwd(), 'settings/paths.json')

In [None]:
with open(parameters_json, 'r') as json_file:
    params = json.load(json_file)

with open(paths_json, 'r') as json_file:
    paths = json.load(json_file)

In [None]:
random_seed = params["random_seed"]

os.environ['PYTHONHASHSEED'] = str(random_seed)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
train_set = ["M1_L_V1_1", "M1_L_V1_2", "M1_R_V1_1", "M1_R_V1_2",
             "M3_L_V1_1", "M3_L_V1_2", "M3_R_V1_1", "M3_R_V1_2",
             "M5_L_V1_1", "M5_L_V1_2", "M5_R_V1_1", "M5_R_V1_2"]

test_set = ["M1_L_V2_1", "M1_L_V2_2", "M1_R_V2_1", "M1_R_V2_2",
            "M3_L_V2_1", "M3_L_V2_2", "M3_R_V2_1", "M3_R_V2_2",
            "M5_L_V2_1", "M5_L_V2_2", "M5_R_V2_1", "M5_R_V2_2",
            "M2_L_V1_1", "M2_L_V1_2", "M2_R_V1_1", "M2_R_V1_2",
            "M4_L_V1_1", "M4_L_V1_2", "M4_R_V1_1", "M4_R_V1_2",
            "M2_L_V2_1", "M2_L_V2_2", "M2_R_V2_1", "M2_R_V2_2",
            "M4_L_V2_1", "M4_L_V2_2", "M4_R_V2_1", "M4_R_V2_2"]

In [None]:
model_combinations = list(product(['CustomCNN', 'EfficientNet'],
                                  ['GT']))

# Import Data

In [None]:
data = pd.read_csv(paths['realistic']['labels'], header=0)

In [None]:
# filter the data rows by the index of sets
train_data = data[data['dataset'].isin(train_set)]
test_data = data[data['dataset'].isin(test_set)]

# Training

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# add data to Datasets
train_dataset = ContactDataset(
    images=[Image.open(image)
            for image in train_data['original filepath'].tolist()],
    labels=train_data['relabGT'].to_numpy(),
    transform=None,
    coords=list(zip(
        train_data['x'].astype(int),
        train_data['y'].astype(int))),
    jitter=True)

# create DataLoader with existed Datasets
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=params['CustomCNN']['batch_size'],
    pin_memory=True,
    num_workers=(4 if os.cpu_count() > 4 else os.cpu_count()))

# set up loss function
weights = train_dataset.getWeights().to(device)
loss_fn = nn.CrossEntropyLoss(weight=weights)

In [None]:
for model_name, label_name in model_combinations:
    # select the model
    if model_name == 'CustomCNN':
        model = CustomCNN()
    elif model_name == 'EfficientNet':
        model = CustomEfficientNetB3()

    # set up the optimizer (hyper-parameters)
    optimizer = optim.Adam(
        model.parameters(),
        lr=params[model_name]['learning_rate'],
        weight_decay=params[model_name]['weight_decay'])

    # train and retrieve the metrics
    metrics = utils.train(
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        dataloader=train_dataloader,
        device=device,
        use_tqdm=True,
        epochs=params['epochs'])

    timestamp = datetime.now().strftime("%H%M%S")
    with open(f'results/metrics_{model_name}_{timestamp}.json', 'w') as f:
        json.dump(metrics, f)

# Temp

In [None]:
model = CustomCNN()
model_name = 'CustomCNN'

optimizer = optim.Adam(
    model.parameters(),
    lr=params[model_name]['learning_rate'],
    weight_decay=params[model_name]['weight_decay'])

metrics = utils.train(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    dataloader=train_dataloader,
    device=device,
    use_tqdm=True,
    epochs=params['epochs'])

# Get the current time and format it as a string
timestamp = datetime.now().strftime("%H%M%S")

with open(f'results/metrics_{timestamp}.json', 'w') as f:
    json.dump(metrics, f)

In [None]:
# Get the current time and format it as a string
now = datetime.now()
timestamp = now.strftime("%Y%m%d_%H%M%S")

with open(f'results/metrics_{timestamp}.json', 'w') as f:
    json.dump(metrics, f)

# Test

In [None]:
pred = utils.predict(model=model,
                     dataloader=train_dataloader,
                     device=device)

In [None]:
print(train_data['relabGT'].to_numpy())
print(pred)