In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import torch
import hydra
import numpy as np
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import OmegaConf

from lightning import Trainer
from lightning_modules import PortraitNetModule, PortraitDataModule

from matplotlib import pyplot as plt
%matplotlib inline

### Load config, instantiate model, datamodule and trainer

In [None]:
ckpt_path = ""

with initialize(version_base="1.3", config_path="./configs"):
    cfg = compose(config_name="eval.yaml", overrides=["experiment=v2-eg-full"])
    print(cfg)

def concat_images(images):
    if len(images.shape) != 4:
        raise ValueError("image shape should be (N, C, H, W)")
    image_list = np.split(images, images.shape[0], axis=0)
    image_list = [np.squeeze(image, axis=0) for image in image_list]
    concatenated_image = np.concatenate(image_list, axis=-1).transpose((1, 2, 0))
    return concatenated_image

In [None]:
model = PortraitNetModule(cfg.model)
datamodule = PortraitDataModule(cfg.data)
trainer = Trainer(accelerator="gpu", devices=[7], logger=None)

In [None]:
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)

### Plot images in the dataset

In [None]:
def plot_rawdata_with_titles(image_list, column_titles):
    num_images = len(image_list) 
    num_columns = len(image_list[0])

    fig, axes = plt.subplots(num_images, num_columns, figsize=(12, 3*num_images + 0.1))

    for i, row in enumerate(axes):
        for j, ax in enumerate(row):
            image = image_list[i][j]            
            ax.imshow(image)
            ax.axis('off')
            if i == num_images-1:
                title = column_titles[j]
                ax.set_title(title, y=-0.1)

    plt.tight_layout()
    plt.savefig("dataset.pdf")

datamodule.setup()
train_dataloader = datamodule.test_dataloader()

batch = next(iter(train_dataloader))
input_ori, input_aug, boundary, mask = batch

img_list = []
for i in range(4):
    img_ori = concat_images(model.tensor2image(input_ori[i:i+1]))
    img_aug = concat_images(model.tensor2image(input_aug[i:i+1]))
    img_mask = concat_images(model.label2image(mask[i:i+1]))
    img_boundary = concat_images(model.label2image(boundary[i:i+1]))
    img_list.append([img_ori, img_aug, img_mask, img_boundary])

column_titles = ["Deformation augmentations", "Texture augmentations", "Mask", "Boundary"]

plot_rawdata_with_titles(img_list, column_titles)

### Visualize feature maps

In [None]:
def visualize_feature_map(feature_map_tensor):
    feature_map_array = feature_map_tensor.detach().cpu().numpy()

    batch_size, num_channels, height, width = feature_map_array.shape
    assert batch_size == 1
    feature_map_array = feature_map_array[0]

    grayscale_image = np.zeros((height, width))

    for i in range(num_channels):
        grayscale_image += feature_map_array[i, :, :]

    grayscale_image /= num_channels
    grayscale_image = grayscale_image - np.min(grayscale_image)
    grayscale_image = grayscale_image / np.max(grayscale_image)

    return grayscale_image


def plot_featuremap_with_titles(image_list, column_titles):
    num_images = len(image_list)
    num_columns = len(image_list[0])

    fig, axes = plt.subplots(
        num_images, num_columns, figsize=(3*num_columns, 3*num_images)
    )

    for i, row in enumerate(axes):
        for j, ax in enumerate(row):
            print(i, j)
            image = image_list[i][j]
            ax.imshow(image)
            ax.axis("off")
            title = column_titles[i][j]
            ax.set_title(title)

    plt.tight_layout()
    plt.savefig("feature-map-after-training.pdf")

datamodule.setup()
train_dataloader = datamodule.test_dataloader()

batch = next(iter(train_dataloader))
input_ori, input_aug, boundary, mask = batch

feature2x, feature4x, feature8x, feature16x, feature32x = model.model.encoder(
    input_ori[:1]
)
up16x = model.model.upsample32x(model.model.d_block32x(feature32x))
up8x = model.model.upsample16x(model.model.d_block16x(feature16x + up16x))
up4x = model.model.upsample8x(model.model.d_block8x(feature8x + up8x))
up2x = model.model.upsample4x(model.model.d_block4x(feature4x + up4x))
up1x = model.model.upsample2x(model.model.d_block2x(feature2x + up2x))

mask_logits = model.model.mask_conv(up1x)
img_input = concat_images(model.tensor2image(input_ori[:1]))
img_mask = concat_images(model.logits2image(mask_logits))

encoder_imgs = [
    img_input,
    visualize_feature_map(feature2x),
    visualize_feature_map(feature4x),
    visualize_feature_map(feature8x),
    visualize_feature_map(feature16x),
    visualize_feature_map(feature32x),
]
encoder_titles = [
    "Input image",
    "Encoder-feature 2x",
    "Encoder-feature 4x",
    "Encoder-feature 8x",
    "Encoder-feature 16x",
    "Encoder-feature 32x",
]

decoder_imgs = [
    img_mask,
    visualize_feature_map(up1x),
    visualize_feature_map(up2x),
    visualize_feature_map(up4x),
    visualize_feature_map(up8x),
    visualize_feature_map(up16x),
]
decoder_titles = [
    "Mask prediction",
    "Decoder-feature 1x",
    "Decoder-feature 2x",
    "Decoder-feature 4x",
    "Decoder-feature 8x",
    "Decoder-feature 16x"
]

plot_featuremap_with_titles([encoder_imgs, decoder_imgs], [encoder_titles, decoder_titles])