# Import libs

In [None]:
import os
import json
import random
import pickle
import numpy as np
import pandas as pd
from itertools import product
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from datetime import datetime

import utils
from VisionModels import CustomCNN, CustomEfficientNetB3
from VisionDatasets import ContactDataset

%load_ext autoreload
%autoreload 2

# Pre-settings

In [None]:
parameters_json = os.path.join(os.getcwd(), 'settings/parameters.json')
paths_json = os.path.join(os.getcwd(), 'settings/paths.json')

In [None]:
with open(parameters_json, 'r') as json_file:
    params = json.load(json_file)

with open(paths_json, 'r') as json_file:
    paths = json.load(json_file)

In [None]:
random_seed = params["random_seed"]

os.environ['PYTHONHASHSEED'] = str(random_seed)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
train_set = ["M1_L_V1_1", "M1_L_V1_2", "M1_R_V1_1", "M1_R_V1_2",
             "M3_L_V1_1", "M3_L_V1_2", "M3_R_V1_1", "M3_R_V1_2",
             "M5_L_V1_1", "M5_L_V1_2", "M5_R_V1_1", "M5_R_V1_2"]

test_set = ["M1_L_V2_1", "M1_L_V2_2", "M1_R_V2_1", "M1_R_V2_2",
            "M3_L_V2_1", "M3_L_V2_2", "M3_R_V2_1", "M3_R_V2_2",
            "M5_L_V2_1", "M5_L_V2_2", "M5_R_V2_1", "M5_R_V2_2",
            "M2_L_V1_1", "M2_L_V1_2", "M2_R_V1_1", "M2_R_V1_2",
            "M4_L_V1_1", "M4_L_V1_2", "M4_R_V1_1", "M4_R_V1_2",
            "M2_L_V2_1", "M2_L_V2_2", "M2_R_V2_1", "M2_R_V2_2",
            "M4_L_V2_1", "M4_L_V2_2", "M4_R_V2_1", "M4_R_V2_2"]

In [None]:
data_set = 'realistic'
model_set = ['CustomCNN', 'EfficientNet']
label_set = ['GT', 'MTurk']
model_combinations = list(product(model_set, label_set))

# Import Data

In [None]:
data = pd.read_csv(paths[data_set]['labels'], header=0)

In [None]:
train_data = data[data['dataset'].isin(train_set)]
test_data = data[data['dataset'].isin(test_set)]

In [None]:
# add data to Datasets
train_dataloader = {}
val_dataloader = {}
loss_fn = {}

for label_name in label_set:

    image_col = 'original filepath'
    
    if label_name == 'GT':
        label_col = 'relabGT'
    elif label_name == 'MTurk':
        label_col = 'Label'

    train_dataset = ContactDataset(
        images=train_data[image_col].tolist(),
        labels=train_data[label_col].to_numpy(),
        coords=list(zip(
            train_data['x'].astype(int),
            train_data['y'].astype(int))),
        jitter=True)

    # create DataLoader with existed Datasets
    train_dataloader[label_name] = DataLoader(
        dataset=train_dataset,
        batch_size=params['batch_size'],
        num_workers=(16 if os.cpu_count() > 16 else os.cpu_count()),
        pin_memory=True,
        shuffle=True)
    
    weights = train_dataset.getWeights().to(device)
    loss_fn[label_name] = nn.CrossEntropyLoss(weight=weights)

# Training

In [None]:
for model_name, label_name in model_combinations:
    # select the model
    if model_name == 'CustomCNN':
        model = CustomCNN()
    elif model_name == 'EfficientNet':
        model = CustomEfficientNetB3()

    # set up the optimizer (hyper-parameters)
    optimizer = optim.Adam(
        model.parameters(),
        lr=params[model_name]['learning_rate'],
        weight_decay=params[model_name]['weight_decay'])
    
    # load pre-trained model
    utils.load_state_dict(model, model_name, label_name,
                          load_path='/home/sxy841/ERIE/silicone/nn-contact/models')

    # train and retrieve the metrics
    utils.train(
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn[label_name],
        dataloader=train_dataloader[label_name],
        device=device,
        use_tqdm=False,
        epochs=params['epochs'])

    utils.save_metrics(model, model_name, label_name)
    utils.save_state_dict(model, model_name, label_name)

# Predict

In [None]:
results = {}

In [None]:
for set_name in tqdm(test_set):
    # concat paths
    label_path = os.path.join(
        paths[data_set]['image_set'], f'output_{set_name}', 'labels_PSM2.txt')
    images_path = os.path.join(
        paths[data_set]['image_set'], f'output_{set_name}')
    coordinates_path = os.path.join(
        paths[data_set]['keypoints'],
        f"{set_name}_L_h264{paths[data_set]['keypoints_model']}.h5")
    
    # load data files
    test_data = pd.read_csv(label_path, header=None).iloc[:, 1:4].to_numpy()
    coordinates = pd.read_hdf(coordinates_path).loc[:, [
        (paths[data_set]['keypoints_model'], 'Mid_1', 'x'),
        (paths[data_set]['keypoints_model'], 'Mid_1', 'y')]].to_numpy()

    test_images = []
    test_laebls = []

    # add data to list
    force_threshold = 0.2
    for index, row in enumerate(test_data):
        test_images.append(os.path.join(images_path, f'img_{index}.jpg'))
        test_laebls.append(1 if np.sqrt(row.dot(row)) > force_threshold else 0)

    # create dataset and dataloader
    test_dataset = ContactDataset(
        images=test_images,
        labels=test_laebls,
        coords=coordinates.astype(int).tolist(),
        jitter=False)

    test_dataloader = DataLoader(
        dataset=test_dataset,
        batch_size=512,
        num_workers=(16 if os.cpu_count() > 16 else os.cpu_count()),
        pin_memory=True,
        shuffle=True)

    # predict for each model
    for model_name, label_name in model_combinations:
        # select the model
        if model_name == 'CustomCNN':
            model = CustomCNN()
        elif model_name == 'EfficientNet':
            model = CustomEfficientNetB3()

        utils.load_state_dict(model, model_name, label_name)
        predictions, ground_truth = utils.predict(
            model=model,
            dataloader=test_dataloader,
            device=device)

        results[(model_name, label_name, set_name)] = {
            "Prediction": predictions,
            "Ground Truth": ground_truth
        }

In [None]:
with open(f'labels/{data_set}_noPreTrain_{datetime.now().strftime("%Y%m%d_%H%M%S")}.pkl', 'wb') as file:
    pickle.dump(results, file)

# Test

In [None]:
model_name = 'EfficientNet'
label_name = 'MTurk'
model = CustomEfficientNetB3()

In [None]:
utils.load_state_dict(model, model_name, label_name)

In [None]:
from sklearn.metrics import classification_report
print(data_set)
for model_name, label_name in model_combinations:
    binary_predictions = []
    y = []
    for ts in test_set:
        pred, gt = results[(model_name, label_name, ts)].values()
        y.extend(gt)
        binary_predictions.extend((pred > 0.5).astype(int))
    print((model_name, label_name),
          classification_report(y, binary_predictions, output_dict=True))

In [None]:
# set up the optimizer (hyper-parameters)
optimizer = optim.Adam(
    model.parameters(),
    lr=params[model_name]['learning_rate'],
    weight_decay=params[model_name]['weight_decay'])

# train and retrieve the metrics
utils.train(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn[label_name],
    dataloader=train_dataloader[label_name],
    val_dataloader=val_dataloader[label_name],
    device=device,
    use_tqdm=True,
    epochs=10)

utils.save_metrics(model, model_name, label_name)
utils.save_state_dict(model, model_name, label_name)

In [None]:
from PIL import Image, ImageDraw
import torchvision
width, height = 936, 702
merged_image = Image.new('RGB', (width, height))

draw = ImageDraw.Draw(merged_image)

small_image_width, small_image_height = 234, 234
x, y = 0, 0

for images, labels in test_dataloader:
    for image in images:
        image = torchvision.transforms.ToPILImage()(image)
        merged_image.paste(image, (x, y))
    
        x += small_image_width
        
        if x + small_image_width > width:
            x = 0
            y += small_image_height

merged_image.show()