In [None]:
import glob
import os
import sys
from typing import List, Tuple
sys.path.insert(0, "../src")

import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d
from PIL import Image
import torch
import torchvision.transforms as tvt

from config_manager.manager import Params
import model.layers as layer
from model.unet import UNet
from utils.utils import read_json


plt.style.use('dark_background')

In [None]:
param_dict = {
    "data_path": "../output/dataset_20230117-152714",
    "save_path": "../output/experiment",
    "height": 720,
    "width": 1280,
    "resize": 224,
    "scale": 2.0,
    "filters": [32, 64, 128, 256],
    "kernels": [5, 3, 3, 3],
    "style_nodes": ["encoder.0.activation", "encoder.1.activation", "encoder.2.activation"],
    "output_layer": "out_conv.activation",
}

params = Params(param_dict)
print(params)

# Read images

In [None]:
img_list = glob.glob(params.data_path + "/**/*.png", recursive=True)
img_list[0]

In [None]:
def load_img(img_path: str, params: Params) -> Tuple[np.ndarray, torch.Tensor]:
    """Load image and transform to get a tensor for inference"""
    img = np.asarray(Image.open(img_path))
    img_tensor = torch.tensor(img, dtype=torch.float32)

    result = ((img_tensor - img.min()) / (img.max() - img.min())).unsqueeze(0)

    result = tvt.functional.affine(result, shear=0.0, scale=params.scale, translate=(0, 0), angle=0.0)
    rsz = tvt.Resize((params.resize, params.resize))
    result = rsz(result)

    return img, result.unsqueeze(0)

In [None]:
img_np, img_tensor = load_img(img_list[0], params)

In [None]:
print(img_np.max())
print(img_np.min())
print(img_np.dtype)
print()
print(img_tensor.max())
print(img_tensor.min())
print(img_tensor.dtype)

In [None]:
plt.imshow(img_np, cmap="gray")
plt.colorbar();

In [None]:
plt.imshow(img_tensor.squeeze().numpy(), cmap="gray")
plt.colorbar();

# Load model

In [None]:
base_model = UNet(params.filters, params.kernels)
NODES = params.style_nodes + [params.output_layer]
net = layer.ReconstructionModel(base_model, nodes=NODES)
net.load_state_dict(
    torch.load(os.path.join(params.save_path, "best_net.pt"), map_location=torch.device('cpu'))
)

# Evaluate

In [None]:
net.eval()
with torch.no_grad():
    pred = net(img_tensor)
out = pred[params.output_layer]
print(out.max())
print(out.min())
print(out.dtype)

In [None]:
plt.imshow(out.squeeze().numpy(), cmap="gray")
plt.colorbar();

# Un-normalize

In [None]:
def unnormalize_prediction(img: torch.Tensor, max_: int, min_: int, params: Params) -> np.ndarray:
    """Unnormalize prediction to get original scale"""
    tmp = img.clip(0.0, 1.0)
    rsz = tvt.Resize((params.height, params.width))
    result = rsz(tmp)
    result = tvt.functional.affine(
        result, shear=0.0, scale=1/params.scale, translate=(0, 0), angle=0.0
    )
    result = (result * (max_ - min_)) + min_
    return result.squeeze().numpy().astype("int16")

In [None]:
res = unnormalize_prediction(out, img_np.max(), img_np.min(), params)
print(res.max())
print(res.min())
print(res.dtype)

In [None]:
plt.imshow(res, cmap="gray")
plt.colorbar();

# View 3D

In [None]:
json_files = glob.glob(params.data_path + "/**/*.json", recursive=True)

intrinsic = o3d.camera.PinholeCameraIntrinsic(
    **read_json(list(filter(lambda x: "intrinsic" in x, json_files))[0])
)
extrinsic = read_json(list(filter(lambda x: "extrinsic" in x, json_files))[0])["extrinsics"]
scale = read_json(list(filter(lambda x: "scale" in x, json_files))[0])["depth_scale"]

In [None]:
pcd = o3d.geometry.PointCloud.create_from_depth_image(
    o3d.geometry.Image(res),
    #o3d.geometry.Image(img_np.astype("uint16")),
    intrinsic,
    extrinsic,
    depth_scale=scale,
    depth_trunc=4000.0
)

o3d.visualization.draw_geometries([pcd])