In [None]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


In [None]:
import sys
import io
import PIL.Image
import numpy as np
from skimage import measure
import math
import k3d
import os.path
import torch as t
import dataclasses

if "./src" not in sys.path:
    sys.path.append("./src")
from corenet.data import batched_example
from corenet import tf_model as tf_model_lib
from corenet import super_resolution as super_resolution_lib
from corenet.geometry import transformations as transformations_lib

color_palette = [
    0x00ffff, 0xffff00, 0xff00ff, 0x00ff80, 0xff0000, 0x0000ff, 0x80ff00, 0x0080ff, 0x8000ff, 0xff8000, 
    0x00ff00, 0xff0080, 0xffff80, 0x80ffff, 0xff80ff, 0x80ff80, 0x8080ff, 0xff8080 ]

def plot_result(src_image, grid, u=1):
    c, d, h, w = grid.shape
    plot = k3d.plot()
    obj_idx = 0
    for obj_grid in grid:
        if u > 1:
            kernel = grid.new_ones([1, 1, u, u, u], dtype=t.float32) / u**3
            obj_grid = t.nn.functional.pad(obj_grid[None, None], [u//2, u//2-1, u//2, u//2-1, u//2, u//2-1])
            obj_grid = t.conv3d(obj_grid, kernel)[0, 0]
        if obj_grid.max() <= 0.5:
            continue
        obj_grid = t.nn.functional.pad(obj_grid, [1, 1, 1, 1, 1, 1]).cpu().numpy()
        vertices, faces, _, _ = measure.marching_cubes(obj_grid, level=0.5)
        vertices = (vertices - 1) / (w, h, d)
        zz, yy, xx = [vertices[..., i] for i in range(3)]
        vertices = np.stack([xx, zz, yy], -1) - 0.5
        faces = np.flip(faces, [-1])
        plot += k3d.mesh(vertices.astype(np.float32), faces.astype(np.uint32), color=color_palette[obj_idx])
        obj_idx += 1
    PIL.Image.fromarray(example_image.numpy()).save(buf := io.BytesIO(), format="png")
    s=2.1
    plot += k3d.texture(
        binary=buf.getvalue(),file_format="png",
        model_matrix=[[s, 0, 0, 0], [0, 0, 1, 0.5], [0, -s, 0, 0], [0, 0, 0, 1.0]])
    plot += k3d.texture(binary=buf.getvalue(),file_format="png",
                        model_matrix=[[0.5, 0, 0, -1.5], [0, 0, 1, 0.5], [0, -0.5, 0, 0], [0, 0, 0, 1.0]])
    plot.camera = [0, -1.3666666, 0, 0, 0, 0, 0, 0, 1]
    plot.camera_auto_fit = False
    plot.grid_visible = False
    plot.camera_mode = "orbit"

    return plot
   
pretrained_models = {
    "m7": "example_pair_m7_bd2d7cd9ebcc03f0691ac421b36085911c29f420d288bdf3d8533dcfe74a414f.webp",
    "h5": "example_single_h5_a7862633361da9b8cb9e2ad0cb23c7b490c324686e185a53045910819a07f824.webp",
    "h7": "example_single_h7_a7862633361da9b8cb9e2ad0cb23c7b490c324686e185a53045910819a07f824.webp",
    "y1": "example_single_y1_6976ae6754872044ce00af95aafeb09fae4b733983da5fa25377fd5b9cf7dee8.webp",
    "m9": "example_triplet_m9_efe6cd70d5f1a2636c39834fcff598d1fa4525f8d73b115d5eb2f2210527f047.webp",
}

camera_transform = t.as_tensor([[
    [ 1.7320507 ,  0.        ,  0.        , -0.8660253 ],
    [ 0.        , -1.7320507 ,  0.        ,  0.8660253 ],
    [ 0.        ,  0.        ,  1.00002   ,  0.8664839 ],
    [ 0.        ,  0.        ,  1.        ,  0.86666656]]])


In [None]:
import importlib
importlib.reload(tf_model_lib)
importlib.reload(super_resolution_lib)

model_name = "h7"  # The model to show, must be one of the keys of `pretrained_models`
data_dir = "data/paper_tf_models"  # Root directory containing the model and the example image
output_resolution = (128,)*3
# output_resolution = (32,)*3  # Uncomment to display at native resolution for model y1

example_image = PIL.Image.open(os.path.join(data_dir, pretrained_models[model_name]))
example_image = t.as_tensor(np.array(example_image))

inference_fn = tf_model_lib.super_resolution_from_tf_model(os.path.join(data_dir, model_name + ".pb"))
input_image=example_image.permute([2, 0, 1])[None].cuda()
camera_transform=camera_transform.cuda()
w2x_transform=transformations_lib.scale(output_resolution)[None].cuda()
grid_sampling_offset = t.ones([1, 3], device="cuda") * 0.5
grid_pmf = inference_fn(
    input_image, camera_transform, w2x_transform, grid_sampling_offset, output_resolution)[0, 1:]
plot_result(example_image, grid_pmf, inference_fn.get_resolution_multiplier(output_resolution))
