In [1]:
import torch

# Show details
print(f"PyTorch version: {torch.__version__}, CUDA version: {torch.version.cuda}, GPU available: {torch.cuda.is_available()}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PyTorch version: 2.6.0+cu124, CUDA version: 12.4, GPU available: True


In [2]:
import clip
import copy
import json
import kaolin as kal
import kaolin.ops.mesh as mesh
import kaolin.ops.conversions as conversions
import trimesh
import numpy as np
import os
import random
import torch.nn as nn
import torchvision
import open3d as o3d

from itertools import permutations, product
from kaolin.ops.mesh import face_normals
from Normalization import MeshNormalizer
from mesh import Mesh
from pathlib import Path
from render import Renderer
from tqdm import tqdm
from torch.autograd import grad
from torchvision import transforms
from utils import color_mesh
import pickle

Warp 1.7.1 initialized:
   CUDA Toolkit 12.8, Driver 12.2
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "NVIDIA GeForce RTX 3070 Ti" (8 GiB, sm_86, mempool enabled)
   Kernel cache:
     /home/ezmiron/.cache/warp/1.7.1
Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
# ================== AffordanceNet Dataset HELPER FUNCTIONS =============================

def load_dataset(path):
    """
      Load the affordance dataset from a pickle file.

      Returns a list of dicts with keys:
      - 'shape_id'
      - 'semantic class'
      - 'affordance'
      - 'data_info'
    """
    dataset = []
    with open(path, 'rb') as f:
        train_data = pickle.load(f)
        print("Loaded train_data")
        for index,info in enumerate(train_data):
            temp_info = {}
            temp_info["shape_id"] = info["shape_id"]
            temp_info["semantic class"] = info["semantic class"]
            temp_info["affordance"] = info["affordance"]
            temp_info["data_info"] = info["full_shape"]
            dataset.append(temp_info)
    return dataset

def get_coordinates(sample, device='cpu'):
    """
    Returns the point cloud coordinates from a sample as a torch.Tensor on the specified device.

    Args:
        sample: one entry from the dataset
        device: 'cpu' or 'cuda'

    Returns:
        coords: torch.Tensor of shape [N, 3]
    """
    return torch.tensor(sample["data_info"]["coordinate"], dtype=torch.float32).to(device)

def get_affordance_classes(sample):
    """
    Returns the list of affordance class names available for the given sample.
    """
    return sample["affordance"]

def is_affordance_present(sample, affordance_class):
    """
    Given a sample and an affordance class string (e.g., 'grasp'),
    returns True if there is at least one positive label for that affordance,
    otherwise False.
    """
    label = np.array(sample["data_info"]["label"][affordance_class])
    return np.any(label > 0)

def get_affordance_label(sample, affordance_class, device='cpu'):
    """
    Returns the binary label mask for a specific affordance class as a torch tensor.

    Args:
        sample: dataset sample
        affordance_class: string key of the affordance label
        device: 'cpu' or 'cuda'

    Returns:
        labels: torch.Tensor of shape [N]
    """
    label = sample["data_info"]["label"][affordance_class]
    return torch.tensor(label, dtype=torch.float32).squeeze().to(device)

In [4]:
#Takes a screenshot of the point cloud and saves it to the specified path
def render_point_cloud_to_image(pcd, image_path="output/tmp_screen.png"):
    vis = o3d.visualization.Visualizer()
    vis.create_window(visible=False)
    vis.add_geometry(pcd)
    vis.poll_events()
    vis.update_renderer()
    vis.capture_screen_image(image_path)
    vis.destroy_window()


#Shows the point cloud in a window or saves it to the specified path
def show_point_cloud(point_cloud,render_to_image=False,save_path="output/tmp_screen.png"):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(point_cloud)
    if(render_to_image):
        render_point_cloud_to_image(pcd,save_path)
    else:
        o3d.visualization.draw_geometries([pcd])

#Shows the point cloud in a window or saves it to the specified path, coloring the points based on the probabilities
def show_point_cloud_tresholded(point_cloud,probs,treshold,render_to_image=False,save_path="output/tmp_screen.png"):

    #if probs has 2 columns, delete the second one (refer to the "no object" class)
    if probs.shape[1]==2:
        probs = probs[:,0]

    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(point_cloud)
    colors = np.zeros((len(point_cloud),3))
    for i in range(len(point_cloud)):
        if probs[i]>treshold:
            colors[i] = [1,0,0]
        else:
            colors[i] = [0.5,0.5,0.5]
    pcd.colors = o3d.utility.Vector3dVector(colors)
    if render_to_image:
        render_point_cloud_to_image(pcd,save_path)
    else:
        o3d.visualization.draw([pcd])


def create_point_cloud_from_mesh(mesh_path,name):
    mesh = o3d.io.read_triangle_mesh(mesh_path)
    pcd = mesh.sample_points_uniformly(number_of_points=10000) #Tune if needed.
    o3d.io.write_point_cloud(f"output/{name}.ply", pcd)
    return pcd

In [5]:
dataset = load_dataset('data/full-shape/full_shape_train_data.pkl')

Loaded train_data


In [6]:
sample = dataset[660]

In [21]:
print(sample.keys())
print(sample["semantic class"])
print(get_affordance_classes(sample))

dict_keys(['shape_id', 'semantic class', 'affordance', 'data_info'])
Chair
['grasp', 'contain', 'lift', 'openable', 'layable', 'sittable', 'support', 'wrap_grasp', 'pourable', 'move', 'displaY', 'pushable', 'pull', 'listen', 'wear', 'press', 'cut', 'stab']


In [22]:
pointcloud = get_coordinates(sample).cpu().numpy()

In [23]:
# Render and save
show_point_cloud(pointcloud)

In [10]:
class NeuralHighlighter(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=256, output_dim=2, num_layers=6):
        """
        Args:
            input_dim: usually 3 (x, y, z)
            hidden_dim: size of hidden layers
            output_dim: 2 for [highlight, gray]
            num_layers: total number of linear layers
        """
        super(NeuralHighlighter, self).__init__()

        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.LayerNorm(hidden_dim)]

        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(hidden_dim))

        layers.append(nn.Linear(hidden_dim, output_dim))
        layers.append(nn.Softmax(dim=1))

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

def get_clip_model(clipmodel='ViT-L/14', jit=False):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load(clipmodel, device=device, jit=jit)
    print(f"Loaded CLIP model: {clipmodel} on {device} (jit={jit})")
    return model, preprocess


# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=1,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=background)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name='final_render.jpg')


def clip_loss(rendered_images, prompt, clip_model, aug_transform, n_augs, device, tokenizer):
    """
    Computes the CLIP loss as negative cosine similarity between
    rendered image embeddings and the text prompt embedding.

    Args:
        rendered_images (torch.Tensor): shape (B, 3, H, W)
        text_prompt (str): e.g., "a gray chair with highlighted seat"
        clip_model (torch.nn.Module): preloaded CLIP model
        device (str): "cuda" or "cpu"

    Returns:
        loss (torch.Tensor): scalar CLIP loss
    """

    # Encode text
    text_encoded = tokenizer([prompt]).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text_encoded)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)

    loss = 0.0

    for _ in range(n_augs):
      aug_image = aug_transform(rendered_images)
      image_encoded = clip_model.encode_image(aug_image)
      loss -= torch.mean(torch.cosine_similarity(image_encoded, text_features))

    return loss / n_augs

def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))


In [11]:
# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)

# ==== Set Seed for Determinism ====
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [12]:
def pointcloud_to_voxel_mesh(points, resolution=64, threshold=0.5, export_path=None):
  min_coords, _ = points.min(dim=0)
  max_coords, _ = points.max(dim=0)
  scale = max_coords - min_coords
  points_norm = (points - min_coords) / scale

  voxel_grid = conversions.pointclouds_to_voxelgrids(points_norm.unsqueeze(0), resolution=resolution).to(device)
  verts_faces = conversions.voxelgrids_to_trianglemeshes(voxel_grid, iso_value=threshold)

  verts = verts_faces[0][0].cpu() / resolution
  faces = verts_faces[1][0].cpu()

  # Denormalize
  scale = scale.cpu()
  min_coords = min_coords.cpu()
  verts = verts * scale + min_coords

  if verts.numel() == 0 or faces.numel() == 0:
      raise ValueError("Empty mesh generated from voxel grid.")

  # Create mesh
  mesh = trimesh.Trimesh(vertices=verts.numpy(), faces=faces.numpy())

  # Smoothing and export
  mesh = trimesh.smoothing.filter_laplacian(
      mesh, lamb=0.2, iterations=8,
      implicit_time_integration=False,
      volume_constraint=True
  )

  if export_path:
    mesh.export(export_path)

  return mesh

In [13]:
def load_vertices(data):
  if type(data) == str:
    mesh = o3d.io.read_triangle_mesh(data)
    mesh.compute_vertex_normals()
    vertices = mesh.sample_points_uniformly(number_of_points=4096)
    return torch.tensor(np.asarray(vertices.points), dtype=torch.float32).to(device)
  else:
    return torch.tensor(data, dtype=torch.float32).to(device)

In [14]:
# ==== Hyperparameters and Settings ====
render_res = 224
learning_rate = 0.00005
n_iter = 2200
obj_path = 'data/dog.obj'
n_augs = 1
output_dir = './output/'
clip_model_name = 'ViT-B/16'
prompt = 'A gray chair with highlighted shoes'

In [15]:
# ==== Setup Output Directory ====
Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)
log_dir = output_dir

In [16]:
# ==== Load Mesh ====
objbase, extension = os.path.splitext(os.path.basename(obj_path))
render = Renderer(dim=(render_res, render_res))

# ==== CLIP ====
clip_model, preprocess = get_clip_model(clip_model_name)
tokenizer = clip.tokenize

clip_normalizer = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],std=[0.26862954, 0.26130258, 0.27577711]) #from https://github.com/openai/CLIP/issues/20

aug_transform = transforms.Compose([
        transforms.RandomResizedCrop(render_res, scale=(1, 1)),
        transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
        clip_normalizer
    ])

# ==== Neural Highlighter ====
mlp = NeuralHighlighter().to(device)
optim = torch.optim.Adam(mlp.parameters(), learning_rate)

# ==== Colors and Other Constants ====
colors = torch.tensor([[204/255, 1., 0.], [180/255, 180/255, 180/255]]).to(device)
background = torch.tensor((1., 1., 1.)).to(device)
n_views = 7
losses = []

Loaded CLIP model: ViT-B/16 on cuda (jit=False)


In [None]:
# Load dataset
dataset = load_dataset("data/full-shape/full_shape_train_data.pkl")

Loaded train_data


In [17]:
sample = dataset[3915]

In [18]:
vertices = get_coordinates(sample, device)

In [None]:
affordances = get_affordance_classes(sample)

for affordance in affordances:
    print(affordance, is_affordance_present(sample, affordance))

grasp False
contain False
lift False
openable False
layable False
sittable True
support True
wrap_grasp False
pourable False
move True
displaY False
pushable False
pull False
listen False
wear False
press False
cut False
stab False


In [None]:
affordance_label = get_affordance_label(sample, affordance_class='move', device=device)
print(affordance_label.shape)

torch.Size([2048])


In [None]:
mesh = pointcloud_to_voxel_mesh(
    vertices,  # sampled point cloud from Open3D
    resolution=16,
    threshold=0.5,
    export_path="data/chair_voxel.obj"
)

In [None]:
sampled_mesh = Mesh("data/chair_voxel.obj")
MeshNormalizer(sampled_mesh)()
vertices = torch.tensor(sampled_mesh.vertices, dtype=torch.float32, device=device)

  vertices = torch.tensor(sampled_mesh.vertices, dtype=torch.float32, device=device)


In [None]:
from scipy.spatial import cKDTree

# Step 1: Original coordinates and labels (on point cloud)
orig_coords = get_coordinates(sample, device='cpu')  # [N, 3]
orig_labels = get_affordance_label(sample, 'move', device='cpu')  # [N]

In [None]:
colors = torch.tensor([[204/255, 1., 0.], [180/255, 180/255, 180/255]]).to(device)
background = torch.tensor((1., 1., 1.)).to(device)

# Step 2: Voxelized mesh face centers
voxel_faces = sampled_mesh.faces  # [F, 3]
voxel_vertices = torch.tensor(sampled_mesh.vertices, dtype=torch.float32).clone().detach()
face_centers = voxel_vertices[voxel_faces].mean(dim=1)  # [F, 3]

# Step 3: KDTree from original coords
tree = cKDTree(orig_coords.numpy())
_, indices = tree.query(face_centers.cpu().numpy(), k=1)

# Step 4: Assign labels from nearest points
face_labels = orig_labels[indices]  # shape [F]
face_labels = (face_labels > 0.5).long()  # Binarize

# Step 5: Convert to one-hot for color_mesh
pred_class = torch.nn.functional.one_hot(face_labels, num_classes=2).float().to(device)

num_background = int(pred_class[:, 0].sum())
num_highlighted = int(pred_class[:, 1].sum())
num_highlighted_affordances = int(affordance_label.sum())

print(f"Background points: {num_background}")
print(f"Highlighted points: {num_highlighted}")
print(f"Highlighted affordances: {num_highlighted_affordances}")

Background points: 2583
Highlighted points: 485
Highlighted affordances: 416


  voxel_vertices = torch.tensor(sampled_mesh.vertices, dtype=torch.float32).clone().detach()


In [None]:
color_mesh(pred_class, sampled_mesh, colors)
mlp = NeuralHighlighter().to(device)
render = Renderer(dim=(render_res, render_res))
save_renders(log_dir, 0, render.render_views(
    sampled_mesh,
    num_views=5,
    show=False,
    center_azim=0,
    center_elev=0,
    std=1,
    return_views=True,
    lighting=True,
    background=background
)[0], name='initial_render.jpg')

In [19]:
labels = np.array(sample["data_info"]["label"][sample["affordance"][5]])
# vertices = load_vertices(data)
temp_obj_path = "outputDemo.obj"

mesh = pointcloud_to_voxel_mesh(
    vertices,  # sampled point cloud from Open3D
    resolution=16,
    threshold=0.5,
    export_path=temp_obj_path
)

# === Load the voxel mesh from disk ===
sampled_mesh = Mesh(temp_obj_path)
MeshNormalizer(sampled_mesh)()
vertices = torch.tensor(sampled_mesh.vertices, dtype=torch.float32, device=device)

  vg = torch.sparse.FloatTensor(
Attribute "face_normals" has not been set and failed to be computed due to: index -1 is out of bounds for dimension 1 with size 0
Unexpected type passed to requires_grad None
Attribute "vertex_normals" has not been set and failed to be computed due to: 'NoneType' object has no attribute 'unsqueeze'
Attribute "face_normals" has not been set and failed to be computed due to: index -1 is out of bounds for dimension 1 with size 0
  vertices = torch.tensor(sampled_mesh.vertices, dtype=torch.float32, device=device)


In [20]:
# Optimization loop
for i in tqdm(range(n_iter)):
    optim.zero_grad()

    # === Predict highlight probabilities ===
    pred_class = mlp(vertices)

    # === Color mesh ===
    color_mesh(pred_class, sampled_mesh, colors)

    # === Render the mesh ===
    rendered_images, elev, azim = render.render_views(
        sampled_mesh,
        num_views=n_views,
        show=False,
        center_azim=0,
        center_elev=0,
        std=1,
        return_views=True,
        lighting=True,
        background=background
    )

    # === Compute CLIP loss ===
    loss = clip_loss(rendered_images, prompt, clip_model, aug_transform, n_augs, device, tokenizer)
    loss.backward(retain_graph=True)
    optim.step()

    # === Save and log results ===
    with torch.no_grad():
        losses.append(loss.item())

    if i % 100 == 0:
        print(f"Last 100 CLIP score: {np.mean(losses[-100:])}")
        save_renders(log_dir, i, rendered_images)
        with open(os.path.join(log_dir, "training_info.txt"), "a") as f:
            f.write(f"For iteration {i}... Prompt: {prompt}, Last 100 avg CLIP score: {np.mean(losses[-100:])}, CLIP score {losses[-1]}\n")

# Remove generated obj
os.remove(temp_obj_path)

# Final save
save_final_results(log_dir, objbase, sampled_mesh, mlp, vertices, colors, render, background)

# Save prompt
with open(os.path.join(output_dir, prompt), "w") as f:
    f.write('')

  0%|          | 2/2200 [00:00<10:07,  3.62it/s]

Last 100 CLIP score: -0.2020263671875


  5%|▍         | 102/2200 [00:07<02:21, 14.82it/s]

Last 100 CLIP score: -0.239324951171875


  9%|▉         | 202/2200 [00:14<02:14, 14.87it/s]

Last 100 CLIP score: -0.251121826171875


 14%|█▎        | 302/2200 [00:22<03:02, 10.39it/s]

Last 100 CLIP score: -0.251353759765625


 18%|█▊        | 404/2200 [00:30<02:18, 12.99it/s]

Last 100 CLIP score: -0.25096435546875


 23%|██▎       | 502/2200 [00:37<01:54, 14.80it/s]

Last 100 CLIP score: -0.250516357421875


 27%|██▋       | 604/2200 [00:43<01:47, 14.78it/s]

Last 100 CLIP score: -0.250628662109375


 32%|███▏      | 704/2200 [00:50<01:33, 15.92it/s]

Last 100 CLIP score: -0.254281005859375


 37%|███▋      | 804/2200 [00:56<01:28, 15.86it/s]

Last 100 CLIP score: -0.2530712890625


 41%|████      | 904/2200 [01:03<01:22, 15.74it/s]

Last 100 CLIP score: -0.2530908203125


 46%|████▌     | 1004/2200 [01:09<01:15, 15.77it/s]

Last 100 CLIP score: -0.253240966796875


 50%|█████     | 1102/2200 [01:15<01:13, 14.92it/s]

Last 100 CLIP score: -0.253057861328125


 55%|█████▍    | 1204/2200 [01:22<01:05, 15.18it/s]

Last 100 CLIP score: -0.253887939453125


 59%|█████▉    | 1304/2200 [01:29<00:56, 15.84it/s]

Last 100 CLIP score: -0.2531494140625


 64%|██████▍   | 1404/2200 [01:35<00:50, 15.90it/s]

Last 100 CLIP score: -0.254544677734375


 68%|██████▊   | 1504/2200 [01:42<00:47, 14.64it/s]

Last 100 CLIP score: -0.253458251953125


 73%|███████▎  | 1604/2200 [01:49<00:38, 15.61it/s]

Last 100 CLIP score: -0.25267578125


 77%|███████▋  | 1702/2200 [01:55<00:33, 14.81it/s]

Last 100 CLIP score: -0.252388916015625


 82%|████████▏ | 1804/2200 [02:02<00:26, 15.14it/s]

Last 100 CLIP score: -0.2532177734375


 86%|████████▋ | 1902/2200 [02:08<00:19, 14.97it/s]

Last 100 CLIP score: -0.254705810546875


 91%|█████████ | 2004/2200 [02:15<00:12, 15.29it/s]

Last 100 CLIP score: -0.25244873046875


 96%|█████████▌| 2102/2200 [02:21<00:06, 14.76it/s]

Last 100 CLIP score: -0.254456787109375


100%|██████████| 2200/2200 [02:28<00:00, 14.85it/s]


In [None]:
print(dataset[3915]["data_info"].keys())

dict_keys(['coordinate', 'label'])


In [None]:
sample = dataset[3915]
coords = np.array(sample["data_info"]["coordinate"])
labels = np.array(sample["data_info"]["label"][sample["affordance"][5]])
print(labels.shape)

(2048, 1)


In [None]:
%rm -rf output/

In [None]:
%rm -rf output/renders/