In [None]:
from glob import glob
import math
import os
import sys
import time
from typing import List, Tuple
sys.path.insert(0, "../src")

import matplotlib.pyplot as plt
import numpy as np
from numpy.random import MT19937, RandomState, SeedSequence
import pandas as pd
import torch
from torchvision.models.feature_extraction import create_feature_extractor

from config_manager.manager import Params
from data_loaders.dataloader import get_dataloaders
from model.losses import loss_fn
from model.cascade_net import CascadeNet
from trainers.train_engine import train_evaluate

RNG = RandomState(MT19937(SeedSequence(123456789)))
TRNG = torch.random.manual_seed(42)

torch.backends.cudnn.deterministic = True

plt.style.use('dark_background')

%load_ext tensorboard

In [None]:
NUM_CASCADES = 1

param_dict = {
    "data_path": "../../datasets/caesar",
    "save_path": os.path.join("../output/experiment", time.strftime("run_%d_%m_%H%M%S")),
    "resize": 224,
    "batch_size": 8,
    "num_workers": 2,
    "pin_memory": torch.cuda.is_available(),
    "device": "cuda:0" if torch.cuda.is_available() else "cpu",
    "epochs": 10,
    "learning_rate": 0.001,
    "num_encodings": 6,
    "temperature": 1.0,
    "out_channels": [1 for _ in range(NUM_CASCADES)],
    "filters": [[64, 64, 128, 256, 512, 512, 512, 512] for _ in range(NUM_CASCADES)],
    "kernels": [[3, 3, 3, 3, 3, 3, 3, 3] for _ in range(NUM_CASCADES)],
    "milestones": [5],
    "person": 0.0,
    "bg": 1.0,
    "tvl": 0.1,
    "style": 10.0,
    "perceptual": 0.05,
    "exp": 1.0,
    "grad": 0.1,
    "disc": 0.1,
    "style_layers": [f"net_stack.{NUM_CASCADES - 1}.encoder.{i + 1}.activation" for i in range(3)],
    "vgg_style": ["features.6", "features.13", "features.26"],
    "cascade_layers": [f"net_stack.{i}.out_conv.activation" for i in range(NUM_CASCADES - 1)],
    "embedding_layer": f"net_stack.{NUM_CASCADES - 1}.encoder.7.activation",
    "output_layer": f"net_stack.{NUM_CASCADES - 1}.out_conv.activation",
}

params = Params(param_dict)
print(params)

# Read data

In [None]:
def create_data_split(data_path: str) -> None:
    """Split and save train and test set
    Args:
        data_path: Path to dataset
    """
    tmp = list(filter(lambda x: "front" in x, glob(data_path + "/**/*.png", recursive=True)))
    avatars = [os.path.basename(x) for x in tmp]

    ids = RNG.permutation(avatars)
    num = int(len(ids) * 0.95)
    train_list = ids[:num]
    test_list = ids[num:]

    df_train = pd.DataFrame(train_list)
    df_train.to_csv(os.path.join(data_path, "train.csv"), header=False, index=False)
    df_test = pd.DataFrame(test_list)
    df_test.to_csv(os.path.join(data_path, "test.csv"), header=False, index=False)

In [None]:
if not os.path.exists(os.path.join(params.data_path, "train.csv")):
    create_data_split(params.data_path)
else:
    print(f"Data split alread exists at {params.data_path}")

# Data loaders

In [None]:
def visualize(outputs: np.ndarray, row_titles: List[str], ncols: int = 3) -> None:
    """Visuzlize results
    Args:
        outputs: Model predictions
        row_titles: Titles of each figure in a row
        ncols: Number of columns
    """
    num = outputs.shape[0]
    nrow = math.ceil(num / ncols)
    if len(row_titles) == ncols:
        row_titles *= nrow 
    else:
        row_titles *= num
    fig, axi = plt.subplots(nrow, ncols, figsize=(15, nrow * 3))
    axi = axi.flatten()
    for i, ax in enumerate(axi):
        if i < num:
            cb_ = ax.imshow(outputs[i], cmap="gray")
            ax.set_title(row_titles[i])
            ax.grid()
            fig.colorbar(cb_, ax=ax)
        else:
            ax.axis("off")
    fig.tight_layout()

def plot_hist(
    data: List[np.ndarray], bins: int, titles: List[str], figsize: Tuple[int, int]
) -> None:
    """Plot histogram of data
    Args:
        data: Data to plot
        bins: Bins for the histogram
        titles: List of titles for each histogram
        figsize: Width and height of plot
    """
    num = len(data)
    fig, ax = plt.subplots(1, num, figsize=figsize, sharey=True)
    for i, axi in enumerate(ax):
        counts, vals = np.histogram(data[i], bins=bins)
        axi.stairs(counts, vals, fill=True)
        axi.set_title(titles[i])
        axi.set_ylabel("Log( counts )")
        axi.set_xlabel("Values")
        axi.grid()
        axi.semilogy()
    fig.tight_layout()

In [None]:
dataloaders = get_dataloaders(["train", "test"], params)

In [None]:
tmp_f, tmp_b, tmp_m = next(iter(dataloaders["train"]))

In [None]:
print(f"Front min: {tmp_f.min()}, max: {tmp_f.max()}")
print(f"Back min: {tmp_b.min()}, max: {tmp_b.max()}")
print(f"Back mask min: {tmp_m.min()}, max: {tmp_m.max()}")

In [None]:
visualize(tmp_f.permute(0, 2, 3, 1).numpy()[:10, ...], ["Front"])

In [None]:
visualize(tmp_b.permute(0, 2, 3, 1).numpy()[:10, ...], ["Front"])

In [None]:
visualize(tmp_m.permute(0, 2, 3, 1).numpy()[:10, ...], ["Front"])

In [None]:
plot_hist(
    [tmp_f.squeeze().numpy(), tmp_b.squeeze().numpy()],
    30,
    ["Front", "Back"],
    (15, 5)
)

# Model

In [None]:
base_model = CascadeNet(params)

In [None]:
NODES = params.style_layers + params.cascade_layers + [params.embedding_layer, params.output_layer]
net = create_feature_extractor(base_model, return_nodes=NODES)
net

# Train

In [None]:
criterion = loss_fn
optimizer = torch.optim.Adam(net.parameters(), lr=params.learning_rate)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=params.milestones, gamma=0.1, verbose=True
)

In [None]:
train_evaluate(base_model, net, dataloaders, criterion, optimizer, scheduler, params)

# Visualize

In [None]:
def get_predictions(
    net: torch.nn.Module,
    images: torch.tensor,
    mask: torch.tensor,
    params: Params,
) -> np.ndarray:
    """Get output tensors of the model predictions
    Args:
        net: Neural network model
        images: Input batch of images
        mask: Input mask of back
        params: Params
    Returns:
        Numpy array of predictions
    """
    net = net.to(torch.device(params.device))

    net.eval()
    with torch.no_grad():
        images = images.to(torch.device(params.device))
        mask = mask.to(torch.device(params.device))
        outputs = net(images, mask.to(torch.float32))
    result = outputs[params.output_layer]

    return result.cpu().numpy()

In [None]:
imgs, labels, mask = next(iter(dataloaders["test"]))
labels = labels.cpu().numpy()

outputs = get_predictions(net, imgs, mask, params)
print(f"Outputs min: {outputs.min()}, max: {outputs.max()}")
print(f"Labels min: {labels.min()}, max: {labels.max()}")

In [None]:
p = 4
n, c, h, w = outputs.shape
combined = np.empty((p * n, h, w, c), dtype=outputs.dtype)

combined[p - 4::p, :, :] = imgs.cpu().permute(0, 2, 3, 1).numpy()
combined[p - 3::p, ...] = np.moveaxis(outputs, 1, -1)
combined[p - 2::p, ...] = np.moveaxis(labels, 1, -1)
combined[p - 1::p, ...] = np.abs(np.moveaxis(outputs - labels, 1, -1))

In [None]:
visualize(combined, ["Input", "Pred", "Label", "Error"], p)

In [None]:
plot_hist([labels, outputs], 30, ["GT", "Pred"], (15, 5))

# Tensorboard

In [None]:
#os.environ["TENSORBOARD_BINARY"] = "/opt/conda/envs/cv3d-env/bin/tensorboard"

In [None]:
#%tensorboard --logdir $params.tb_path