In [None]:
from glob import glob
import math
import os
import sys
sys.path.insert(0, "../src")

import matplotlib.pyplot as plt
import numpy as np
from numpy.random import MT19937, RandomState, SeedSequence
import pandas as pd
from PIL import Image
import torch
from tqdm import tqdm

from config_manager.manager import Params
from data_loaders.dataloader import get_dataloaders
from model.net import loss_fn, ReconstructionModel
from trainers.train_engine import train_evaluate

RNG = RandomState(MT19937(SeedSequence(123456789)))
TRNG = torch.random.manual_seed(42)

plt.style.use('dark_background')

%load_ext tensorboard

In [None]:
param_dict = {
    "data_path": "../output/dataset_20221121-115643",
    "save_path": "../output/experiment",
    "tb_path": "../output/experiment/runs",
    "split": 300,
    "resize": 256,
    "batch_size": 128,
    "num_workers": 1,
    "device": torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu"),
    "pin_memory": torch.cuda.is_available(),
    "epochs": 30,
    "learning_rate": 0.001,
    "filters": [32, 32, 64, 128, 128, 256],
    "kernels": [5, 5, 5, 3, 3, 3],
    "embeddings": 1024,
    "milestones": [15, 30],
    "mse": 1.0,
    "tvl": 0.01,
    "style": 1.0,
    "content": 1.0,
    "style_nodes": ["encoder.0.2", "encoder.1.2", "encoder.2.2"],
    "output_layer": "dconv2.1",
}

params = Params(param_dict)
print(params)

# Read data

In [None]:
def create_data_split(params: Params) -> None:
    """Split and save train and test set
    Args:
        params: Hyperparameters
    """
    tmp = list(filter(lambda x: "front" in x, glob(params.data_path + "/**/*.png", recursive=True)))
    avatars = [os.path.basename(x) for x in tmp]

    ids = RNG.permutation(avatars)

    train_list = ids[:params.split]
    test_list = ids[params.split:]

    df_train = pd.DataFrame(train_list)
    df_train.to_csv(os.path.join(params.data_path, "train.csv"), header=False, index=False)
    df_test = pd.DataFrame(test_list)
    df_test.to_csv(os.path.join(params.data_path, "test.csv"), header=False, index=False)

In [None]:
if not os.path.exists(os.path.join(params.data_path, "train.csv")):
    create_data_split(params)
else:
    print(f"Data split alread exists at {params.data_path}")

# Data loaders

In [None]:
def visualize(outputs: np.ndarray) -> None:
    """Visuzlize results
    Args:
        outputs: Model predictions
    """
    num, height, width = outputs.shape
    nrow = math.ceil(num / 3)
    fig, axi = plt.subplots(nrow, 3, figsize=(12, nrow * 3))
    axi = axi.flatten()
    for i, ax in enumerate(axi):
        if i < num:
            cb_ = ax.imshow(outputs[i], cmap="gray")
            fig.colorbar(cb_, ax=ax)
        else:
            ax.axis("off")
    fig.tight_layout()

In [None]:
dataloaders = get_dataloaders(["train", "test"], params)

In [None]:
tmp_f, tmp_b = next(iter(dataloaders["train"]))

In [None]:
visualize(tmp_f.squeeze().numpy()[:10, ...])

In [None]:
visualize(tmp_b.squeeze().numpy()[:10, ...])

# Model

In [None]:
NODES = params.style_nodes + [params.output_layer]
net = ReconstructionModel(params.filters, params.kernels, params.embeddings, NODES)
net

# Train

In [None]:
criterion = loss_fn
optimizer = torch.optim.Adam(net.parameters(), lr=params.learning_rate)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=params.milestones, gamma=0.1, verbose=True
)

In [None]:
train_loss, test_loss = train_evaluate(net, dataloaders, criterion, optimizer, scheduler, params)

In [None]:
plt.plot(range(params.epochs), train_loss, "o-")
plt.plot(range(params.epochs), test_loss, "o-")
#plt.yscale("log")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend(["Train loss", "Test loss"]);

# Visualize

In [None]:
def get_predictions(net: torch.nn.Module, images: torch.tensor, params: Params) -> np.ndarray:
    """Get output tensors of the model predictions
    Args:
        net: Neural network model
        images: Input batch of images
        params: Params
    Raises:
        Value error if incorrect mode is set
    Returns:
        Numpy array of predictions
    """
    num = images.shape[0]
    min_ = images.view(num, -1).min(1)[0].view(-1, 1, 1, 1)
    max_ = images.view(num, -1).max(1)[0].view(-1, 1, 1, 1)
    net = net.to(params.device)

    net.eval()
    with torch.no_grad():
        images = images.to(params.device)
        outputs = net(images)
    result = outputs[params.output_layer]
    out = (result * (max_ - min_)) + min_
    return out.squeeze(1).cpu().numpy()

In [None]:
imgs, labels = next(iter(dataloaders["test"]))
labels = labels.squeeze().cpu().numpy()

outputs = get_predictions(net, imgs, params)
print(f"min: {outputs.min()}, max: {outputs.max()}")

In [None]:
n, h, w = outputs.shape
combined = np.empty((2 * n, h, w), dtype=outputs.dtype)
combined[0::2, :, :] = outputs
combined[1::2, :, :] = labels

In [None]:
visualize(combined)

# Tensorboard

In [None]:
#%tensorboard --logdir ${params.tb_path}