In [None]:
# install the proper version of pytorch3d

import sys
import torch

pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
version_str="".join([
    f"py3{sys.version_info.minor}_cu",
    torch.version.cuda.replace(".",""),
    f"_pyt{pyt_version_str}"
])
!pip install fvcore
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html


In [None]:
import argparse
import math
import os
import random

import imageio
import numpy as np
import torch
import tqdm
from hydra import compose, initialize
from hydra.utils import instantiate
from omegaconf import OmegaConf
from pytorch3d.renderer import MeshRasterizer, RasterizationSettings
from pytorch3d.structures import Meshes
from pytorch3d.utils import cameras_from_opencv_projection

from tava.datasets.animal_parser import SubjectParser

device = "cuda:0"


In [None]:
# args for data
ARGS_ANIMAL_WOLF = [
    "dataset=animal_wolf", "dataset.root_fp=/home/ruilongli/data/forest_and_friends_rendering/",
    "hydra.run.dir=/home/ruilongli/workspace/TAVA/outputs/release/animal_wolf/Wolf_cub_full_RM_2/narf/",
]
ARGS_ANIMAL_HARE=[
    "dataset=animal_hare", "dataset.root_fp=/home/ruilongli/data/forest_and_friends_rendering/",
    "hydra.run.dir=/home/ruilongli/workspace/TAVA/outputs/release/animal_hare/Hare_male_full_RM/narf/",
]

# args for method
ARGS_TAVA_ANIMAL=["pos_enc=snarf", "loss_bone_w_mult=1.0", "pos_enc.offset_net_enabled=false", "model.shading_mode=null"]
ARGS_NARF=["pos_enc=narf", "model.shading_mode=null"]


In [None]:
# here we set the arguments for ZJU_313 as an example.
overrides = ["resume=True"] + ARGS_ANIMAL_HARE + ARGS_NARF
split = "val_ood"

# create the cfg
with initialize(config_path="../configs"):
    cfg = compose(config_name="mipnerf_dyn", overrides=overrides, return_hydra_config=True)
    OmegaConf.resolve(cfg.hydra)
    save_dir = cfg.hydra.run.dir
    eval_imgs_dir = os.path.join(save_dir, "eval_imgs")

# initialize dataset
dataset = instantiate(
    cfg.dataset, split=split, num_rays=None, cache_n_repeat=None,
)
meta_data_dict = {
    action: dataset.parser.load_meta_data(action)
    for action in dataset.parser.actions
}

In [None]:
# calculate the ground-truth correspondances: pixel -> 3D canonical point

image_size = int(800 * dataset.resize_factor)
rasterizer = MeshRasterizer(
    raster_settings=RasterizationSettings(image_size=image_size)
)

for index in tqdm.tqdm(dataset.index_list):
    action, frame_id, camera_id = index
    meta_id = dataset.encode_meta_id(action, frame_id)

    K, c2w = dataset.parser.load_camera(action, frame_id, camera_id)
    K = torch.from_numpy(K).float().to(device)
    c2w = torch.from_numpy(c2w).float().to(device)

    cameras = cameras_from_opencv_projection(
        R=c2w.inverse()[None, :3, :3], 
        tvec=c2w.inverse()[None, :3, 3], 
        camera_matrix=K[None], 
        image_size=torch.tensor([[image_size, image_size]])
    ).to(device)

    faces = torch.from_numpy(
        meta_data_dict[action]["faces"]).long().to(device)
    verts = torch.from_numpy(
        meta_data_dict[action]["pose_verts"][frame_id]).float().to(device)
    rest_verts = torch.from_numpy(
        meta_data_dict[action]["rest_verts"]).float().to(device)
    meshes = Meshes(verts=[verts], faces=[faces])

    fragments = rasterizer(meshes, cameras=cameras)

    pix_to_face = fragments.pix_to_face
    barycentric = fragments.bary_coords
    pix_to_coord = torch.einsum(
        "bhwnvi,bhwnv->bhwni", rest_verts[faces[pix_to_face]], barycentric)
    masks = (pix_to_face != -1).squeeze(-1).squeeze(0)  # [h, w]
    coords = pix_to_coord.squeeze(-2).squeeze(0) * masks[:, :, None]  # [h, w, 3]

    image_to_save = torch.cat([coords, masks[:, :, None].float()], dim=-1)
    image_path = os.path.join(
        dataset.parser.root_dir, action, "correspondence", camera_id, "%08d.exr" % frame_id
    )
    os.makedirs(os.path.dirname(image_path), exist_ok=True)
    imageio.imwrite(image_path, image_to_save.cpu().numpy())

In [None]:
x, y = torch.meshgrid(
    torch.arange(800).long().to(device),  # X-Axis (columns)
    torch.arange(800).long().to(device),  # Y-Axis (rows)
    indexing="xy",
)


@torch.no_grad()
def matching(warp_src, wrap_dst, chunk_size=8):
    assert warp_src.dim() == wrap_dst.dim() == 2
    errors, indices = [], []
    for i in range(0, warp_src.shape[0], chunk_size):
        warp_src_chunk = warp_src[i: i + chunk_size]
        matching = torch.linalg.norm(
            wrap_dst[:, None, :] - warp_src_chunk[None, :, :], dim=-1
        ).min(dim=0)  # [chunk_size,]
        errors.append(matching.values)
        indices.append(matching.indices)
    errors = torch.cat(errors, dim=0)
    indices = torch.cat(indices, dim=0)
    assert errors.shape[0] == warp_src.shape[0]
    assert indices.shape[0] == indices.shape[0]
    return errors, indices


@torch.no_grad()
def matching_pairs(map_src, mask_src, map_dst, mask_dst, thre=1e-4):
    warp_src = map_src[mask_src]
    warp_dst = map_dst[mask_dst]
    errors, indices = matching(warp_src, warp_dst)
    selector = errors < thre
    x_src = x[mask_src][selector]
    y_src = y[mask_src][selector]
    x_dst = x[mask_dst][indices][selector]
    y_dst = y[mask_dst][indices][selector]
    coord_src = torch.stack([x_src, y_src], dim=-1)
    coord_dst = torch.stack([x_dst, y_dst], dim=-1)
    return coord_src, coord_dst

In [None]:
# calculate the ground-truth correspondances: pixel -> pixel

render_every = math.ceil(len(dataset) / 100)
index_list = dataset.index_list[::render_every]
print ("index list", len(index_list))

pair_list = []
for _, index_src in enumerate(index_list):
    for _, index_dst in enumerate(index_list):
        if index_src == index_dst: continue
        pair_list.append([
            dataset.index_list.index(index_src), index_src, 
            dataset.index_list.index(index_dst), index_dst
        ])
print ("pair list", len(pair_list))

random.seed(42)
random.shuffle(pair_list)
for id1, index1, id2, index2 in tqdm.tqdm(pair_list[:2000]):
    action1, frame_id1, camera_id1 = index1
    action2, frame_id2, camera_id2 = index2

    gt_map1 = torch.from_numpy(imageio.imread(os.path.join(
        dataset.parser.root_dir, action1, "correspondence", camera_id1, "%08d.exr" % frame_id1
    ))).to(device)
    gt_map2 = torch.from_numpy(imageio.imread(os.path.join(
        dataset.parser.root_dir, action2, "correspondence", camera_id2, "%08d.exr" % frame_id2
    ))).to(device)

    map_src = gt_map1[:, :, 0:3]
    mask_src = gt_map1[:, :, 3] > 0.5
    map_dst = gt_map2[:, :, 0:3]
    mask_dst = gt_map2[:, :, 3] > 0.5
    coord_src, coord_dst = matching_pairs(map_src, mask_src, map_dst, mask_dst, thre=1e-4)

    cache_path = os.path.join(
        "/tmp", "tava_corr", cfg.dataset.subject_id, split, "gt",
        f"{action1}_{frame_id1}_{camera_id1}___{action2}_{frame_id2}_{camera_id2}.npz"
    )
    os.makedirs(os.path.dirname(cache_path), exist_ok=True)
    np.savez(
        cache_path, 
        coord_src = coord_src.cpu().numpy(), 
        coord_dst = coord_dst.cpu().numpy(), 
    )

In [None]:
metric = {"p2p": []}
for id1, index1, id2, index2 in tqdm.tqdm(pair_list[:100]):
    action1, frame_id1, camera_id1 = index1
    action2, frame_id2, camera_id2 = index2

    # get groundtruth
    cache_path = os.path.join(
        "/tmp", "tava_corr", cfg.dataset.subject_id, split, "gt",
        f"{action1}_{frame_id1}_{camera_id1}___{action2}_{frame_id2}_{camera_id2}.npz"
    )
    gt_data = np.load(cache_path, allow_pickle=True)
    gt_coord_src = torch.from_numpy(gt_data["coord_src"]).to(device)
    gt_coord_dst = torch.from_numpy(gt_data["coord_dst"]).to(device)
    if gt_coord_src.numel() == 0:
        continue

    # get predict map
    meta_id1 = dataset.encode_meta_id(action1, frame_id1)
    meta_id2 = dataset.encode_meta_id(action2, frame_id2)
    # if args.method == "ours":
    #     pred_map1 = torch.from_numpy(imageio.imread(
    #         os.path.join(image_dir, f"{id1:04d}_{sid1}_{fid1}_{cid1}.exr"
    #     ))).to(device)
    #     pred_map2 = torch.from_numpy(imageio.imread(
    #         os.path.join(image_dir, f"{id2:04d}_{sid2}_{fid2}_{cid2}.exr"
    #     ))).to(device)
    # else:
    pred_map1 = torch.from_numpy(np.load(
        os.path.join(
            eval_imgs_dir, split, 
            f"{id1:04d}_{cfg.dataset.subject_id}_{meta_id1}_{camera_id1}.npy"
        )
    )).to(device)
    pred_map2 = torch.from_numpy(np.load(
        os.path.join(
            eval_imgs_dir, split, 
            f"{id2:04d}_{cfg.dataset.subject_id}_{meta_id2}_{camera_id2}.npy"
        )
    )).to(device)

    pred_map_src = pred_map1
    pred_mask_src = torch.zeros_like(pred_map1[..., 0]).bool()
    pred_mask_src[gt_coord_src[:, 1], gt_coord_src[:, 0]] = True
    pred_map_dst = pred_map2
    pred_mask_dst = pred_map2.sum(dim=-1) != 0
    pred_coord_src, pred_coord_dst = matching_pairs(
        pred_map_src, pred_mask_src, pred_map_dst, pred_mask_dst, thre=1e10
    )
    assert torch.allclose(pred_coord_src, gt_coord_src)
    p2p = torch.linalg.norm(
        pred_coord_dst.float() - gt_coord_dst.float(), dim=-1
    ).mean()
    metric["p2p"].append(p2p)

    cache_path = os.path.join(
        "/tmp", "tava_corr", cfg.dataset.subject_id, split, "pred",
        f"{action1}_{frame_id1}_{camera_id1}___{action2}_{frame_id2}_{camera_id2}.npz"
    )
    os.makedirs(os.path.dirname(cache_path), exist_ok=True)
    np.savez(
        cache_path, 
        {
            "coord_src": pred_coord_src.cpu().numpy(), 
            "coord_dst": pred_coord_dst.cpu().numpy(), 
        }
    )

for key, value in metric.items():
    metric[key] = sum(value) / len(value)

print (metric)    