In [None]:
!pip install git+https://github.com/openai/CLIP.git
!pip install kaolin==0.17.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.5.1_cu121.html
!pip install trimesh
!pip install open3d

# Downgrade numpy to a compatible version
!pip install numpy==1.23.5 --force-reinstall

In [None]:
!git clone https://github.com/ezmi234/Affordance_Highlighting_Project_2024.git

In [None]:
%cd Affordance_Highlighting_Project_2024
!git checkout part2-pointcloud-adaptation

In [None]:
import torch

# Show details
print(f"PyTorch version: {torch.__version__}, CUDA version: {torch.version.cuda}, GPU available: {torch.cuda.is_available()}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import clip
import kaolin as kal
import trimesh
import numpy as np
import os
import random
import torch.nn as nn
import torchvision

from datetime import datetime
from kaolin.ops.mesh import face_normals
from Normalization import MeshNormalizer
from mesh import Mesh
from pathlib import Path
from render import Renderer
from tqdm import tqdm
from torchvision import transforms
from utils import color_mesh
import pickle
from utilities.point_cloud import pointcloud_to_voxel_mesh

from google.colab import drive
drive.mount('/content/drive')
import time

In [None]:
class NeuralHighlighter(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=256, output_dim=2, num_layers=6):
        """
        Args:
            input_dim: usually 3 (x, y, z)
            hidden_dim: size of hidden layers
            output_dim: 2 for [highlight, gray]
            num_layers: total number of linear layers
        """
        super(NeuralHighlighter, self).__init__()

        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.LayerNorm(hidden_dim)]

        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(hidden_dim))

        layers.append(nn.Linear(hidden_dim, output_dim))
        layers.append(nn.Softmax(dim=1))  # 2-class output

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


def get_clip_model(clipmodel='ViT-L/14', jit=False):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load(clipmodel, device=device, jit=jit)
    print(f"Loaded CLIP model: {clipmodel} on {device} (jit={jit})")
    return model, preprocess


# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=1,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=background)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name='final_render.jpg')

def clip_loss(rendered_images, text_prompt, clip_transform, clip_model, tokenizer, device, aug_transform=None, n_augs=0):
    """
    Computes the CLIP loss as negative cosine similarity between
    rendered image embeddings and the text prompt embedding.

    Args:
        rendered_images (torch.Tensor): shape (B, 3, H, W)
        text_prompt (str): e.g., "a gray chair with highlighted seat"
        clip_transform (torchvision.transforms): preprocessing for CLIP
        clip_model (torch.nn.Module): preloaded CLIP model
        tokenizer (callable): CLIP tokenizer
        device (str): "cuda" or "cpu"
        aug_transform (torchvision.transforms): augmentation for CLIP
        n_augs (int): number of augmentations to apply
    Returns:
        loss (torch.Tensor): scalar CLIP loss
    """

    loss = 0.0

    # Encode text
    text_tokens = tokenizer([text_prompt]).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text_tokens).float()
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)  # L2 norm

    if n_augs == 0:
        clip_image = clip_transform(rendered_images)
        image_features = clip_model.encode_image(clip_image).float()
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # Cosine similarity
        loss = -torch.mean(torch.cosine_similarity(image_features, text_features))

    else:
        for _ in range(n_augs):
          aug_image = aug_transform(rendered_images)
          image_encoded = clip_model.encode_image(aug_image)
          loss -= torch.mean(torch.cosine_similarity(image_encoded, text_features))

        loss =  loss / n_augs

    return loss


def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))

In [None]:
def get_sample(object_class="Knife", seed=42):
  path = "data/full-shape/full_shape_train_data.pkl"
  if not os.path.exists(path):
    print("Local dataset not found. Downloading...")
    !gdown 1siZtGusB1LfQVapTvNOiYi8aeKKAgcDF --output full-shape.zip
    !unzip -q full-shape.zip -d data/full-shape
  else:
      print("Local dataset found. Skipping download.")

  # Load and split
  dataset = []
  with open(path, 'rb') as f:
      train_data = pickle.load(f)
      print("Loaded sample")
      for index,info in enumerate(train_data):
          temp_info = {}
          temp_info["shape_id"] = info["shape_id"]
          temp_info["semantic class"] = info["semantic class"]
          temp_info["affordance"] = info["affordance"]
          temp_info["data_info"] = info["full_shape"]
          dataset.append(temp_info)
  filtered_dataset = [item for item in dataset if item['semantic class']]

  random.seed(seed)
  random.shuffle(filtered_dataset)

  sample = filtered_dataset[0]

  return sample

In [None]:
semantic_class = "Table"
affordance = "support"

In [None]:
sample = get_sample(semantic_class)
print(sample["semantic class"])

In [None]:
# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)

# ==== Set Seed for Determinism ====
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
# ==== Hyperparameters and Settings ====
render_res = 224
learning_rate = 0.00005
n_iter = 2200
temp_obj_path = "data/temp.obj"  # Temporary path for the sampled mesh from point cloud
n_augs = 1
output_dir = './output/'
clip_model_name = 'ViT-B/16'
prompt = f"A gray {str(semantic_class).lower()} with highlighted {affordance}"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# ==== Load Mesh ====
render = Renderer(dim=(render_res, render_res))

points =  torch.tensor(sample["data_info"]["coordinate"], dtype=torch.float32).to(device)
trimesh_mesh = pointcloud_to_voxel_mesh(
    points,
    resolution=16,
    threshold=0.5,
    export_path=temp_obj_path,
)

mesh = Mesh(temp_obj_path)
MeshNormalizer(mesh)()
vertices = mesh.vertices.clone().detach().to(device).float()

# ==== CLIP ====
clip_model, preprocess = get_clip_model(clip_model_name)
tokenizer = clip.tokenize

# ==== Normalization and Augmentation ====
clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

clip_transform = transforms.Compose([
    transforms.Resize((render_res, render_res)),
    clip_normalizer
])

augment_transform = transforms.Compose([
    transforms.RandomResizedCrop(render_res, scale=(1, 1)),
    transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
    clip_normalizer
])

# ==== Neural Highlighter ====
mlp = NeuralHighlighter().to(device)
optim = torch.optim.Adam(mlp.parameters(), learning_rate)

# ==== Colors and Other Constants ====
colors = torch.tensor([[204/255, 1., 0.], [180/255, 180/255, 180/255]]).to(device)
background = torch.tensor((1., 1., 1.)).to(device)
n_views = 7
losses = []

In [None]:
# ==== Setup Output Directory ====
timestamp = datetime.now().strftime("%Y-%m-%d__%H:%M:%S")

base_path = '/content/drive/MyDrive/affordance_outputs_on_pointcloud'
export_path = os.path.join(base_path, f"run_{timestamp}")
Path(os.path.join(export_path, 'renders')).mkdir(parents=True, exist_ok=True)

start_time = time.time()
# ==== Training Loop ====
for i in tqdm(range(n_iter)):
    optim.zero_grad()

    # predict highlight probabilities
    pred_class = mlp(vertices)

    # color and render mesh
    sampled_mesh = mesh
    color_mesh(pred_class, sampled_mesh, colors)
    rendered_images, elev, azim = render.render_views(sampled_mesh,
                                                      num_views=n_views,
                                                      show=False,
                                                      center_azim=0,
                                                      center_elev=0,
                                                      std=1,
                                                      return_views=True,
                                                      lighting=True,
                                                      background=background)

    # compute CLIP loss
    loss = clip_loss(rendered_images, prompt, clip_transform, clip_model, tokenizer, device, augment_transform, n_augs)
    loss.backward(retain_graph=True)
    optim.step()

    with torch.no_grad():
        losses.append(loss.item())

    # report
    if i % 100 == 0:
        print(f"Iter {i} | Last 100 CLIP score: {np.mean(losses[-100:])} | Current loss: {loss.item()}")
        save_renders(export_path, i, rendered_images)
        with open(os.path.join(export_path, "training_info.txt"), "a") as f:
            f.write(f"Iter {i} | Prompt: {prompt}, Last 100 avg CLIP score: {np.mean(losses[-100:]):.4f}, Current loss: {loss.item():.4f}\n")

# ==== Save Final Results ====
save_final_results(export_path, semantic_class, mesh, mlp, vertices, colors, render, background)

total_time = time.time() - start_time
minutes, seconds = divmod(total_time, 60)

# ==== Save Prompt ====
with open(os.path.join(export_path, "summary.txt"), "w") as f:
    f.write(f"Prompt: {prompt}\n")
    f.write(f"CLIP model: {clip_model_name}\n")
    f.write(f"Learning rate: {learning_rate}\n")
    f.write(f"Number of iterations: {n_iter}\n")
    f.write(f"Number of views: {n_views}\n")
    f.write(f"Number of augmentations: {n_augs}\n")
    f.write(f"Final CLIP score: {np.mean(losses[-100:]):.4f}\n")
    f.write(f"Final loss: {loss.item():.4f}\n")
    f.write(f"Total time: {int(minutes)}m {int(seconds)}s\n")