In [1]:
import torch
import os
import yaml
import cv2
import numpy as np
import matplotlib.pyplot as plt

from models.model import DPTAffordanceModel
import utils.transform as TF

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fb294186c10>

In [None]:
split_mode = "object" # object or actor

dataset_root_path = ""

resume = ""

In [None]:
affordance = ["openable", "cuttable", "pourable", "containable", "supportable", "holdable"]
num_affordance = len(affordance)
num_object = 12

In [None]:
keypoint_path = os.path.join(dataset_root_path, split_mode, "train_affordance_keypoint.yaml")
with open(keypoint_path, "r") as fb:
    keypoint_dict = yaml.safe_load(fb)

In [None]:
model = DPTAffordanceModel(num_object, num_affordance, use_hf=True).cuda()
ckpt = torch.load(resume, map_location=lambda storage, loc: storage)
model.load_state_dict(
    {
        k.replace("module.", ""): v
        for k, v in ckpt["state_dict"].items()
    }
)
model.eval()

In [None]:
if split_mode == "object":
    mean = [132.2723, 106.8666, 112.8962]
    std = [67.4025, 70.7446, 72.1553]
elif split_mode == "actor":
    mean = [136.5133, 108.5417, 113.0168]
    std = [67.4025, 70.7446, 72.1553]
else:
    raise Exception(f"split_mode: {split_mode} is not supported.")

tf = TF.Compose(
    [
        TF.PILToTensor(),
        TF.ImageNormalizeTensor(mean=mean, std=std),
    ]
)

In [None]:
image_path = ""

In [None]:
file_name = os.path.basename(image_path).split(".")[0]
image = Image.open(image_path)

data = {
    "file_name": file_name,
    "image": image,
    "point_label": keypoint_dict[file_name],
}

In [None]:
data_tf = tf(data)

input = data_tf["image"].unsqueeze(0).cuda()
output_list = model(input)

In [None]:
image = np.array(image)[:, :, ::-1]
for idx, it in enumerate(output_list):
    it = (it > 0).cpu().numpy().astype(np.uint8).squeeze()
    mask = np.zeros((it.shape[0], it.shape[1], 3), dtype=np.uint8)
    mask[it==1] = (255, 255, 0)
    image_mask = cv2.addWeighted(image, 1.0, mask, 0.8, 0.0)
    image_mask = image_mask[:, :, ::-1]

    plt.figure(figsize=(5, 5), dpi=300)
    plt.imshow(image_mask)
    plt.title(f"{affordance[idx]}")
    plt.show()