In [1]:
!git clone https://github.com/ezmi234/Affordance_Highlighting_Project_2024.git

Cloning into 'Affordance_Highlighting_Project_2024'...
remote: Enumerating objects: 65, done.[K
remote: Counting objects: 100% (65/65), done.[K
remote: Compressing objects: 100% (51/51), done.[K
remote: Total 65 (delta 25), reused 52 (delta 12), pack-reused 0 (from 0)[K
Receiving objects: 100% (65/65), 1.84 MiB | 20.67 MiB/s, done.
Resolving deltas: 100% (25/25), done.


In [2]:
%cd Affordance_Highlighting_Project_2024

/content/Affordance_Highlighting_Project_2024


In [1]:
!pip install git+https://github.com/openai/CLIP.git
!pip install kaolin==0.17.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.5.1_cu121.html

# Downgrade numpy to a compatible version
!pip install numpy==1.23.5 --force-reinstall

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-mnq412s1
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-mnq412s1
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting 

In [35]:
import torch

# Show details
print(f"PyTorch version: {torch.__version__}, CUDA version: {torch.version.cuda}, GPU available: {torch.cuda.is_available()}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PyTorch version: 2.6.0+cu124, CUDA version: 12.4, GPU available: True


In [36]:
import clip
import copy
import json
import kaolin as kal
import kaolin.ops.mesh
import numpy as np
import os
import random
import torch.nn as nn
import torchvision
from datetime import datetime
from google.colab import drive
drive.mount('/content/drive')

from itertools import permutations, product
from Normalization import MeshNormalizer
from mesh import Mesh
from pathlib import Path
from render import Renderer
from tqdm import tqdm
from torch.autograd import grad
from torchvision import transforms
from utils import color_mesh
import time

In [37]:
class NeuralHighlighter(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=256, output_dim=2, num_layers=6):
        """
        Args:
            input_dim: usually 3 (x, y, z)
            hidden_dim: size of hidden layers
            output_dim: 2 for [highlight, gray]
            num_layers: total number of linear layers
        """
        super(NeuralHighlighter, self).__init__()

        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.LayerNorm(hidden_dim)]

        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(hidden_dim))

        layers.append(nn.Linear(hidden_dim, output_dim))
        layers.append(nn.Softmax(dim=1))  # 2-class output

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


def get_clip_model(clipmodel='ViT-L/14', jit=False):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load(clipmodel, device=device, jit=jit)
    print(f"Loaded CLIP model: {clipmodel} on {device} (jit={jit})")
    return model, preprocess


# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=1,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=background)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name='final_render.jpg')

def clip_loss(rendered_images, text_prompt, clip_transform, clip_model, tokenizer, device, aug_transform=None, n_augs=0):
    """
    """

    loss = 0.0

    # Encode text
    text_tokens = tokenizer([text_prompt]).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text_tokens).float()
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)  # L2 norm

    if n_augs == 0:
        clip_image = clip_transform(rendered_images)
        image_features = clip_model.encode_image(clip_image).float()
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # Cosine similarity
        loss = -torch.mean(torch.cosine_similarity(image_features, text_features))

    else:
        for _ in range(n_augs):
          aug_image = aug_transform(rendered_images)
          image_encoded = clip_model.encode_image(aug_image)
          loss -= torch.mean(torch.cosine_similarity(image_encoded, text_features))

        loss =  loss / n_augs

    return loss


def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))


In [38]:
# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)

# ==== Set Seed for Determinism ====
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [42]:
# ==== Hyperparameters and Settings ====
render_res = 224
learning_rate = 0.00005
n_iter = 2200
obj_path = 'data/dog.obj'
n_augs = 1
output_dir = './output/'
clip_model_name = 'ViT-B/16'
prompt = 'A gray dog with highlighted belt'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [43]:
# ==== Load Mesh ====
objbase, extension = os.path.splitext(os.path.basename(obj_path))
render = Renderer(dim=(render_res, render_res))
mesh = Mesh(obj_path)
MeshNormalizer(mesh)()

# ==== CLIP ====
clip_model, preprocess = get_clip_model(clip_model_name)
tokenizer = clip.tokenize

# ==== Normalization and Augmentation ====
clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

clip_transform = transforms.Compose([
    transforms.Resize((render_res, render_res)),
    clip_normalizer
])

augment_transform = transforms.Compose([
    transforms.RandomResizedCrop(render_res, scale=(1, 1)),
    transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
    clip_normalizer
])

# ==== Neural Highlighter ====
mlp = NeuralHighlighter().to(device)
optim = torch.optim.Adam(mlp.parameters(), learning_rate)

# ==== Colors and Other Constants ====
colors = torch.tensor([[204/255, 1., 0.], [180/255, 180/255, 180/255]]).to(device)
background = torch.tensor((1., 1., 1.)).to(device)
vertices = copy.deepcopy(mesh.vertices).to(device)
n_views = 7
losses = []


Loaded CLIP model: ViT-B/16 on cuda (jit=False)


In [45]:
# ==== Setup Output Directory ====
timestamp = datetime.now().strftime("%Y-%m-%d__%H:%M:%S")

base_path = '/content/drive/MyDrive/affordance_outputs'
export_path = os.path.join(base_path, f"run_{timestamp}")
Path(os.path.join(export_path, 'renders')).mkdir(parents=True, exist_ok=True)

start_time = time.time()
# ==== Training Loop ====
for i in tqdm(range(n_iter)):
    optim.zero_grad()

    # predict highlight probabilities
    pred_class = mlp(vertices)

    # color and render mesh
    sampled_mesh = mesh
    color_mesh(pred_class, sampled_mesh, colors)
    rendered_images, elev, azim = render.render_views(sampled_mesh,
                                                      num_views=n_views,
                                                      show=False,
                                                      center_azim=0,
                                                      center_elev=0,
                                                      std=1,
                                                      return_views=True,
                                                      lighting=True,
                                                      background=background)

    # compute CLIP loss
    loss = clip_loss(rendered_images, prompt, clip_transform, clip_model, tokenizer, device, augment_transform, n_augs)
    loss.backward(retain_graph=True)
    optim.step()

    with torch.no_grad():
        losses.append(loss.item())

    # report
    if i % 100 == 0:
        print(f"Iter {i} | Last 100 CLIP score: {np.mean(losses[-100:])} | Current loss: {loss.item()}")
        save_renders(export_path, i, rendered_images)
        with open(os.path.join(export_path, "training_info.txt"), "a") as f:
            f.write(f"Iter {i} | Prompt: {prompt}, Last 100 avg CLIP score: {np.mean(losses[-100:]):.4f}, Current loss: {loss.item():.4f}\n")

# ==== Save Final Results ====
save_final_results(export_path, objbase, mesh, mlp, vertices, colors, render, background)

total_time = time.time() - start_time
minutes, seconds = divmod(total_time, 60)

# ==== Save Prompt ====
with open(os.path.join(export_path, "summary.txt"), "w") as f:
    f.write(f"Prompt: {prompt}\n")
    f.write(f"CLIP model: {clip_model_name}\n")
    f.write(f"Learning rate: {learning_rate}\n")
    f.write(f"Number of iterations: {n_iter}\n")
    f.write(f"Number of views: {n_views}\n")
    f.write(f"Number of augmentations: {n_augs}\n")
    f.write(f"Final CLIP score: {np.mean(losses[-100:]):.4f}\n")
    f.write(f"Final loss: {loss.item():.4f}\n")
    f.write(f"Total time: {int(minutes)}m {int(seconds)}s\n")

  0%|          | 3/2200 [00:00<03:36, 10.17it/s]

Iter 0 | Last 100 CLIP score: -0.29602759808301926 | Current loss: -0.3008759617805481


  5%|▍         | 103/2200 [00:09<03:05, 11.31it/s]

Iter 100 | Last 100 CLIP score: -0.2942989248037338 | Current loss: -0.2898792326450348


  9%|▉         | 202/2200 [00:18<03:49,  8.71it/s]

Iter 200 | Last 100 CLIP score: -0.2987414388358593 | Current loss: -0.2868339419364929


 14%|█▍        | 303/2200 [00:29<02:55, 10.81it/s]

Iter 300 | Last 100 CLIP score: -0.292738111615181 | Current loss: -0.30887728929519653


 18%|█▊        | 402/2200 [00:39<02:50, 10.57it/s]

Iter 400 | Last 100 CLIP score: -0.2946107742190361 | Current loss: -0.27164527773857117


 23%|██▎       | 503/2200 [00:50<02:37, 10.76it/s]

Iter 500 | Last 100 CLIP score: -0.2925588798522949 | Current loss: -0.29289358854293823


 27%|██▋       | 603/2200 [00:58<02:22, 11.17it/s]

Iter 600 | Last 100 CLIP score: -0.29168457970023154 | Current loss: -0.2606664299964905


 32%|███▏      | 703/2200 [01:07<02:12, 11.34it/s]

Iter 700 | Last 100 CLIP score: -0.29567648366093635 | Current loss: -0.2775677442550659


 36%|███▋      | 803/2200 [01:16<02:00, 11.55it/s]

Iter 800 | Last 100 CLIP score: -0.29638672202825544 | Current loss: -0.3324246108531952


 41%|████      | 902/2200 [01:25<02:12,  9.80it/s]

Iter 900 | Last 100 CLIP score: -0.29633234307169914 | Current loss: -0.2892357409000397


 46%|████▌     | 1002/2200 [01:34<01:44, 11.48it/s]

Iter 1000 | Last 100 CLIP score: -0.29574556678533553 | Current loss: -0.32409659028053284


 50%|█████     | 1102/2200 [01:43<01:36, 11.37it/s]

Iter 1100 | Last 100 CLIP score: -0.294598953127861 | Current loss: -0.33130884170532227


 55%|█████▍    | 1202/2200 [01:51<01:26, 11.51it/s]

Iter 1200 | Last 100 CLIP score: -0.2954101786017418 | Current loss: -0.3035051226615906


 59%|█████▉    | 1302/2200 [02:00<01:14, 12.02it/s]

Iter 1300 | Last 100 CLIP score: -0.3000993075966835 | Current loss: -0.2960399389266968


 64%|██████▎   | 1402/2200 [02:08<01:07, 11.86it/s]

Iter 1400 | Last 100 CLIP score: -0.2952698065340519 | Current loss: -0.2838353216648102


 68%|██████▊   | 1502/2200 [02:17<01:09, 10.05it/s]

Iter 1500 | Last 100 CLIP score: -0.29110165491700174 | Current loss: -0.3049045205116272


 73%|███████▎  | 1603/2200 [02:26<00:52, 11.34it/s]

Iter 1600 | Last 100 CLIP score: -0.30063384369015694 | Current loss: -0.2950535714626312


 77%|███████▋  | 1703/2200 [02:35<00:40, 12.28it/s]

Iter 1700 | Last 100 CLIP score: -0.29370751708745957 | Current loss: -0.2687332034111023


 82%|████████▏ | 1803/2200 [02:43<00:32, 12.26it/s]

Iter 1800 | Last 100 CLIP score: -0.29359083116054535 | Current loss: -0.2601427733898163


 86%|████████▋ | 1903/2200 [02:51<00:24, 12.14it/s]

Iter 1900 | Last 100 CLIP score: -0.2968218138813972 | Current loss: -0.2814525365829468


 91%|█████████ | 2003/2200 [03:00<00:16, 11.67it/s]

Iter 2000 | Last 100 CLIP score: -0.29478189155459406 | Current loss: -0.2719579339027405


 96%|█████████▌| 2103/2200 [03:09<00:08, 11.74it/s]

Iter 2100 | Last 100 CLIP score: -0.2995603257417679 | Current loss: -0.2538548409938812


100%|██████████| 2200/2200 [03:17<00:00, 11.15it/s]
