In [1]:
import torch

# Show details
print(f"PyTorch version: {torch.__version__}, CUDA version: {torch.version.cuda}, GPU available: {torch.cuda.is_available()}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PyTorch version: 2.6.0+cu124, CUDA version: 12.4, GPU available: True


In [15]:
import clip
import copy
import json
import kaolin as kal
import kaolin.ops.mesh as mesh
import kaolin.ops.conversions as conversions
import trimesh
import numpy as np
import os
import random
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import open3d as o3d

from datetime import datetime
from itertools import permutations, product
from kaolin.ops.mesh import face_normals
from Normalization import MeshNormalizer
from mesh import Mesh
from pathlib import Path
from render import Renderer
from tqdm import tqdm
from torch.autograd import grad
from torchvision import transforms
from utils import color_mesh
import pickle
from scipy.spatial import cKDTree
from utilities.dataset import load_dataset, get_coordinates, get_affordance_classes, get_affordance_label, is_affordance_present, split_dataset
from utilities.point_cloud import pointcloud_to_voxel_mesh, project_vertex_scores_to_pointcloud, visualize_affordance_pointcloud

In [3]:
class NeuralHighlighter(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=256, output_dim=2, num_layers=6):
        """
        Args:
            input_dim: usually 3 (x, y, z)
            hidden_dim: size of hidden layers
            output_dim: 2 for [highlight, gray]
            num_layers: total number of linear layers
        """
        super(NeuralHighlighter, self).__init__()

        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.LayerNorm(hidden_dim)]

        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(hidden_dim))

        layers.append(nn.Linear(hidden_dim, output_dim))
        layers.append(nn.Softmax(dim=1))  # 2-class output

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


def get_clip_model(clipmodel='ViT-L/14', jit=False):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load(clipmodel, device=device, jit=jit)
    print(f"Loaded CLIP model: {clipmodel} on {device} (jit={jit})")
    return model, preprocess


# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=1,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=background)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name='final_render.jpg')

def clip_loss(rendered_images, text_prompt, clip_transform, clip_model, tokenizer, device, aug_transform=None, n_augs=0):
    """
    """

    loss = 0.0

    # Encode text
    text_tokens = tokenizer([text_prompt]).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text_tokens).float()
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)  # L2 norm

    if n_augs == 0:
        clip_image = clip_transform(rendered_images)
        image_features = clip_model.encode_image(clip_image).float()
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # Cosine similarity
        loss = -torch.mean(torch.cosine_similarity(image_features, text_features))

    else:
        for _ in range(n_augs):
          aug_image = aug_transform(rendered_images)
          image_encoded = clip_model.encode_image(aug_image)
          loss -= torch.mean(torch.cosine_similarity(image_encoded, text_features))

        loss =  loss / n_augs

    return loss


def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))

In [4]:
# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)

# ==== Set Seed for Determinism ====
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [5]:
def load_vertices(data):
  if type(data) == str:
    mesh = o3d.io.read_triangle_mesh(data)
    mesh.compute_vertex_normals()
    vertices = mesh.sample_points_uniformly(number_of_points=4096)
    return torch.tensor(np.asarray(vertices.points), dtype=torch.float32).to(device)
  else:
    return torch.tensor(data, dtype=torch.float32).to(device)

In [6]:
# ==== Hyperparameters and Settings ====
render_res = 224
learning_rate = 0.00005
n_iter = 2200
n_augs = 1
output_dir = './output/'
clip_model_name = 'ViT-B/16'

In [7]:
# ==== Device ====
render = Renderer(dim=(render_res, render_res))

# ==== CLIP ====
clip_model, preprocess = get_clip_model(clip_model_name)
tokenizer = clip.tokenize

# ==== Normalization and Augmentation ====
clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

clip_transform = transforms.Compose([
    transforms.Resize((render_res, render_res)),
    clip_normalizer
])

augment_transform = transforms.Compose([
    transforms.RandomResizedCrop(render_res, scale=(1, 1)),
    transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
    clip_normalizer
])

# ==== Colors and Other Constants ====
colors = torch.tensor([[204/255, 1., 0.], [180/255, 180/255, 180/255]]).to(device)
background = torch.tensor((1., 1., 1.)).to(device)
n_views = 7
losses = []

Loaded CLIP model: ViT-B/16 on cuda (jit=False)


In [8]:
# Load dataset
original_dataset = load_dataset("data/full-shape/full_shape_train_data.pkl")

Loaded train_data


In [9]:
affordances = ['grasp', 'wrap grasp', 'pull']

val_set, test_set = split_dataset(original_dataset, val_ratio=0.01, seed=42)
print(f"Validation set size: {len(val_set)}, Test set size: {len(test_set)}")

Validation set size: 160, Test set size: 160


In [10]:
val_counts = {}
test_counts = {}

for item in val_set:
    cls = item["semantic class"]
    if cls not in val_counts:
        val_counts[cls] = 0
    val_counts[cls] += 1

for item in test_set:
    cls = item["semantic class"]
    if cls not in test_counts:
        test_counts[cls] = 0
    test_counts[cls] += 1

sorted_val = sorted(val_counts.items(), key=lambda x: x[1], reverse=True)
sorted_test = sorted(test_counts.items(), key=lambda x: x[1], reverse=True)

print("Validation Set Semantic Class Counts:")
for cls, count in sorted_val:
    print(f"{cls}: {count}")

print("\nTest Set Semantic Class Counts:")
for cls, count in sorted_test:
    print(f"{cls}: {count}")

Validation Set Semantic Class Counts:
Table: 52
Chair: 39
Vase: 14
Bottle: 10
StorageFurniture: 7
Refrigerator: 6
Faucet: 4
TrashCan: 4
Display: 3
Door: 3
Knife: 3
Bed: 3
Clock: 3
Dishwasher: 2
Hat: 2
Keyboard: 1
Bowl: 1
Mug: 1
Laptop: 1
Earphone: 1

Test Set Semantic Class Counts:
Table: 62
Chair: 41
StorageFurniture: 18
Display: 6
Clock: 5
Vase: 4
TrashCan: 3
Laptop: 3
Refrigerator: 2
Microwave: 2
Faucet: 2
Mug: 2
Dishwasher: 2
Keyboard: 1
Bed: 1
Bag: 1
Hat: 1
Door: 1
Scissors: 1
Bowl: 1
Knife: 1


In [11]:
def build_prompt(semantic_class, affordance):
  temp_prompt = "a gray " + str(semantic_class).lower() + " with highlighted " + str(affordance).lower() + " region"
  return temp_prompt

In [16]:
def get_vertex_scores(pred_class: torch.Tensor, positive_class: int = 1):
    """
    Returns vertex-wise confidence scores for the positive class.

    Args:
        pred_class (torch.Tensor): shape [N, 2], softmax logits
        positive_class (int): which class index should be interpreted as affordance (1 or 0)

    Returns:
        torch.Tensor: shape [N], probabilities
    """
    probs = F.softmax(pred_class, dim=1)
    return probs[:, positive_class]

def compute_mIoU(pred_labels: torch.Tensor, gt_labels: torch.Tensor):
    """
    Computes binary mean Intersection over Union.

    Args:
        pred_labels (torch.Tensor): shape [N], binary 0/1
        gt_labels (torch.Tensor): shape [N], binary 0/1

    Returns:
        float: IoU score
    """
    pred = pred_labels.bool()
    gt = gt_labels.bool()
    intersection = (pred & gt).sum().float()
    union = (pred | gt).sum().float()
    return (intersection / union).item() if union > 0 else float('nan')

def optimize_mIoU_threshold(projected_scores, gt_labels, thresholds=None, gt_threshold=0.0):
    """
    Computes IoU over different thresholds on predicted scores.

    Args:
        projected_scores (torch.Tensor): shape [N], soft prediction per point
        gt_labels (torch.Tensor): shape [N], soft or binary GT labels
        thresholds (list or tensor): list of thresholds to test (default: 0.1 to 0.9)
        gt_threshold (float): threshold to binarize ground truth

    Returns:
        (float, float): best threshold, best IoU
    """
    if thresholds is None:
        thresholds = torch.linspace(0.1, 0.9, steps=9)

    gt_binary = (gt_labels > gt_threshold).long()

    best_iou = -1
    best_thresh = 0.5

    for t in thresholds:
        pred_binary = (projected_scores > t).long()
        iou = compute_mIoU(pred_binary, gt_binary)
        print(f"Threshold {t:.2f} → IoU: {iou:.4f}")
        if iou > best_iou:
            best_iou = iou
            best_thresh = float(t)

    return best_thresh, best_iou


In [17]:
def optimize_highlighting(sample, affordance, mlp, optimizer, render, clip_model, tokenizer, clip_transform, augment_transform, n_augs, n_iter, colors, background, output_dir, device):
    """
    Optimizes the highlighting process for a given sample and affordances.

    Args:
        sample (dict): The input sample containing mesh and semantic class information.
        affordance: Affordance to be highlighted.
        mlp (nn.Module): Neural network model for highlighting.
        optimizer (torch.optim.Optimizer): Optimizer for training the model.
        render (Renderer): Renderer for generating views of the mesh.
        clip_model (nn.Module): CLIP model for computing loss.
        tokenizer (function): Tokenizer for text prompts.
        clip_transform (transforms.Compose): Transformations for CLIP input images.
        augment_transform (transforms.Compose): Augmentation transformations for images.
        n_augs (int): Number of augmentations for CLIP loss.        n_iter (int): Number of optimization iterations.
        colors (torch.Tensor): Tensor of colors for highlighting.
        background (torch.Tensor): Background color tensor.
        output_dir (str): Directory to save results.
        device (torch.device): Device to run computations on.

    Returns:
        torch.Tensor: Predicted class probabilities for vertices.
    """
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    export_path = os.path.join(output_dir, f"run_{timestamp}")
    os.makedirs(os.path.join(export_path, "renders"), exist_ok=True)
    temp_obj_path = "data/outputDemo.obj"

    if "semantic class" not in sample or not isinstance(sample["semantic class"], str):
        raise ValueError(f"Error: Missing or invalid 'semantic class' field in sample: {sample}")

    prompt = build_prompt(sample["semantic class"], affordance)

    points = get_coordinates(sample, device)
    trimesh_mesh = pointcloud_to_voxel_mesh(
        points,  # sampled point cloud from Open3D
        resolution=16,
        threshold=0.5,
        export_path=temp_obj_path
    )


    sampled_mesh = Mesh(temp_obj_path)
    MeshNormalizer(sampled_mesh)()
    vertices = sampled_mesh.vertices.clone().detach().to(device).float()

    losses = []
    pred_class = None

    # Optimization loop
    for i in tqdm(range(n_iter)):
        optimizer.zero_grad()
        pred_class = mlp(vertices)  # Predict highlight probabilities
        color_mesh(pred_class, sampled_mesh, colors)  # Color mesh

        # Render and compute loss
        rendered_images, elev, azim = render.render_views(
            sampled_mesh,
            num_views=n_views,
            show=False,
            center_azim=0,
            center_elev=0,
            std=1,
            return_views=True,
            lighting=True,
            background=background
        )
        loss = clip_loss(rendered_images, prompt, clip_transform, clip_model, tokenizer, device, augment_transform, n_augs)
        loss.backward(retain_graph=True)
        optimizer.step()

        with torch.no_grad():
            # Save the loss for logging
            losses.append(loss.item())

        # Log and save intermediate results
        if i % 100 == 0:
            print(f"Last 100 CLIP score: {np.mean(losses[-100:])}")
            save_renders(export_path, i, rendered_images)
            with open(os.path.join(export_path, "training_info.txt"), "a") as f:
                f.write(f"Iter {i}: Prompt: {prompt}, Avg CLIP score: {np.mean(losses[-100:])}, CLIP score: {loss.item()}\n")

    # Final save and cleanup
    save_final_results(export_path, sample["semantic class"], sampled_mesh, mlp, vertices, colors, render, background)
    with open(os.path.join(export_path, "prompt.txt"), "w") as f:
        f.write(prompt)

    if os.path.exists(temp_obj_path):
        os.remove(temp_obj_path)

    return trimesh_mesh, pred_class, export_path

In [None]:
for sample in val_set:
    # ==== Neural Highlighter ====
    mlp = NeuralHighlighter().to(device)
    optim = torch.optim.Adam(mlp.parameters(), learning_rate)
    print(f"Training Neural Highlighter for {sample['semantic class']}")
    affordances = []

    for elem in get_affordance_classes(sample):
        if is_affordance_present(sample, elem):
            affordances.append(elem)

    if len(affordances) == 0:
        print(f"Warning: No affordances found for sample {sample['semantic class']}")
        continue

    random.seed(seed)
    affordance = random.choice(affordances)
    print(f"Selected affordance: {affordance} of {affordances}")


    trimesh, pred_class, export_path = optimize_highlighting(
        sample=sample,
        affordance=affordance,
        mlp=mlp,
        optimizer=optim,
        render=render,
        clip_model=clip_model,
        tokenizer=tokenizer,
        clip_transform=clip_transform,
        augment_transform=augment_transform,
        n_augs=n_augs,
        n_iter=n_iter,
        colors=colors,
        background=background,
        output_dir=output_dir,
        device=device
    )

    # Project vertex scores to point cloud
    pointcloud = get_coordinates(sample, device)
    gt_labels = get_affordance_label(sample, affordance, device)
    vertex_scores = get_vertex_scores(pred_class, positive_class=0)  # class 0 = affordance
    projected_scores = project_vertex_scores_to_pointcloud(trimesh, vertex_scores, pointcloud, device)

    # Optimize IoU threshold
    best_thresh, best_iou = optimize_mIoU_threshold(projected_scores, gt_labels)
    print(f"Best IoU threshold: {best_thresh}, Best IoU: {best_iou}")

    # Save results
    with open(os.path.join(export_path, "IoU.txt"), "a") as f:
        f.write(f"Best IoU threshold: {best_thresh}, Best IoU: {best_iou}\n")
    # Save projected scores
    with open(os.path.join(export_path, "projected_scores.pkl"), "wb") as f:
        pickle.dump(projected_scores.detach().cpu().numpy(), f)
    # Save ground truth labels
    with open(os.path.join(export_path, "gt_labels.pkl"), "wb") as f:
        pickle.dump(gt_labels.detach().cpu().numpy(), f)
    # Save point cloud
    with open(os.path.join(export_path, "pointcloud.pkl"), "wb") as f:
        pickle.dump(pointcloud.detach().cpu().numpy(), f)


In [19]:
pointcloud = get_coordinates(val_set[0], device)
affordance = 'support'
gt_labels = get_affordance_label(val_set[0], affordance, device)
visualize_affordance_pointcloud(pointcloud.cpu().numpy(), gt_labels.detach().cpu().numpy(), point_size=8.0)