In [None]:
!pip install git+https://github.com/openai/CLIP.git
!pip install kaolin==0.17.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.5.1_cu121.html

# Downgrade numpy to a compatible version
!pip install numpy==1.23.5 --force-reinstall
!pip install nltk

In [None]:
!git clone https://github.com/ezmi234/Affordance_Highlighting_Project_2024.git

In [None]:
%cd Affordance_Highlighting_Project_2024
!git checkout extensions-experiments

In [None]:
import torch

# Show details
print(f"PyTorch version: {torch.__version__}, CUDA version: {torch.version.cuda}, GPU available: {torch.cuda.is_available()}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import clip
import copy
import json
import kaolin as kal
import kaolin.ops.mesh
import numpy as np
import os
import random
import torch.nn as nn
import torchvision
from PIL import Image
from datetime import datetime
from google.colab import drive
drive.mount('/content/drive')

from Normalization import MeshNormalizer
from mesh import Mesh
from pathlib import Path
from render import Renderer
from tqdm import tqdm
from torchvision import transforms
from utils import color_mesh
import time
from utilities.prompt_enricher import generate_prompts
from utilities.positional_encoding import PositionalEncoding, FourierFeatureTransform, LocalPositionalEncoding, NGPHashEncoding

In [None]:
class NeuralHighlighter(nn.Module):
    def __init__(self, input_dim=3,
                 hidden_dim=256,
                 output_dim=2,
                 num_layers=6,
                 positional_encoding=False,
                 sigma=5.0,
                 encoding_type='none',
                 num_frequencies=10,
                 grid_resolution=16,
                 hash_levels=16,
                 hash_features_per_level=2):
        """
        Args:
            input_dim: usually 3 (x, y, z)
            hidden_dim: size of hidden layers
            output_dim: 2 for [highlight, gray]
            num_layers: total number of linear layers
        """
        super(NeuralHighlighter, self).__init__()

        layers = []

        # Select the appropriate encoding
        if encoding_type == 'none':
            layers.append(nn.Linear(input_dim, hidden_dim))

        elif encoding_type == 'positional':
            layers.append(PositionalEncoding(input_dim, num_frequencies))
            layers.append(nn.Linear(input_dim + 2 * input_dim * num_frequencies, hidden_dim))

        elif encoding_type == 'fourier':
            layers.append(FourierFeatureTransform(input_dim, hidden_dim, sigma))
            layers.append(nn.Linear(hidden_dim * 2 + input_dim, hidden_dim))

        elif encoding_type == 'local':
            layers.append(LocalPositionalEncoding(grid_resolution, num_frequencies))
            layers.append(nn.Linear(2 * num_frequencies * 3, hidden_dim))

        elif encoding_type == 'hash':
            layers.append(NGPHashEncoding(
                input_dim=input_dim,
                n_levels=hash_levels,
                n_features_per_level=hash_features_per_level,
                log2_hashmap_size=19
            ))
            total_features = hash_levels * hash_features_per_level
            layers.append(nn.Linear(total_features, hidden_dim))

        layers.append(nn.ReLU())
        layers.append(nn.LayerNorm([hidden_dim]))

        # Append hidden layers
        for _ in range(num_layers):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm([hidden_dim]))

        # Output layer
        layers.append(nn.Linear(hidden_dim, output_dim))
        layers.append(nn.Softmax(dim=1))

        self.mlp = nn.ModuleList(layers)

    def forward(self, x):
        for layer in self.mlp:
            x = layer(x)
        return x

In [None]:
def clip_loss(rendered_images, text_prompts, clip_transform, clip_model, tokenizer, device, aug_transform=None, n_augs=0, weights=None):
    loss = 0.0

    # Encode text
    text_tokens = tokenizer(text_prompts).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text_tokens).float()
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)  # L2 norm

    num_prompts = len(text_prompts)
    if weights is None:
        weights = [1.0 / num_prompts] * num_prompts
    else:
        total = sum(weights)
        weights = [w / total for w in weights]
    weights = torch.tensor(weights, dtype=torch.float32, device=device)

    if n_augs == 0:
        clip_image = clip_transform(rendered_images)
        image_features = clip_model.encode_image(clip_image).float()
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        weighted_sum = 0.0
        for i in range(num_prompts):
            similarity = torch.cosine_similarity(image_features, text_features[i].unsqueeze(0)).mean()
            weighted_sum += weights[i] * similarity

        loss -= weighted_sum
    else:
        for _ in range(n_augs):
          aug_image = aug_transform(rendered_images)
          image_encoded = clip_model.encode_image(aug_image)

          weighted_sum = 0.0
          for i in range(num_prompts):
              similarity = torch.cosine_similarity(image_encoded, text_features[i].unsqueeze(0)).mean()
              weighted_sum += weights[i] * similarity

        loss -=  weighted_sum / n_augs

    return loss

In [None]:
# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=1,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=background)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name='final_render.jpg')

def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))

def get_clip_model(clipmodel='ViT-L/14', jit=False):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load(clipmodel, device=device, jit=jit)
    print(f"Loaded CLIP model: {clipmodel} on {device} (jit={jit})")
    return model, preprocess

In [None]:
# ==== Hyperparameters and Settings ====
render_res = 224
learning_rate = 0.00005
n_iter = 2200
obj_path = 'data/Auto.obj'
n_augs = 3
output_dir = './output/'
clip_model_name = 'ViT-B/32'
prompt = 'A gray car with highlighted wheels'
affordance = 'wheels'
semantic_class = 'car'
prompt_template = 'A gray {} with highlighted {}'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==== Normalization and Augmentation ====
clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

clip_transform = transforms.Compose([
    transforms.Resize((render_res, render_res)),
    clip_normalizer
])

augment_transform = transforms.Compose([
    transforms.RandomResizedCrop(render_res, scale=(1, 1)),
    transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
    clip_normalizer
])

# ==== Load Mesh ====
objbase, extension = os.path.splitext(os.path.basename(obj_path))
render = Renderer(dim=(render_res, render_res))
mesh = Mesh(obj_path)
MeshNormalizer(mesh)()

# ==== Colors and Other Constants ====
colors = torch.tensor([[204/255, 1., 0.], [180/255, 180/255, 180/255]]).to(device)
background = torch.tensor((1., 1., 1.)).to(device)

# ==== Background Image ====
# background_img_path = 'images/background.jpg'
background_img_path = None
img_bg = None
bg_tensor = None

def load_image(image_path, size=None):
    image = Image.open(image_path).convert("RGB")
    if size is not None:
        image = image.resize(size, Image.BILINEAR)
    return image

if background_img_path is not None:
    img_bg = load_image(background_img_path, size=(render_res, render_res))
    bg_tensor = transforms.ToTensor()(img_bg).unsqueeze(0).to(device)

# ==== Setup Output Directory ====

base_path = '/content/drive/MyDrive/extensions'
mesh = Mesh(obj_path)
MeshNormalizer(mesh)()

synonyms_exts = True

for pe in [
    'none',
    # 'positional',
    # 'fourier', 'local',
    # 'hash'
    ]:
    print("Positional encoding:", pe)
    print("Synonyms:", synonyms_exts)

    # Constrain most sources of randomness
    # (some torch backwards functions within CLIP are non-determinstic)

    # ==== Set Seed for Determinism ====
    seed = 420
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    # ==== Neural Highlighter ====
    mlp = None

    if pe == 'none':
        mlp = NeuralHighlighter(encoding_type='none')
    elif pe == 'positional':
        mlp = NeuralHighlighter(encoding_type='positional', num_frequencies=4)
    elif pe == 'fourier':
        mlp = NeuralHighlighter(encoding_type='fourier', sigma=5.0)
    elif pe == 'local':
        mlp = NeuralHighlighter(encoding_type='local',num_frequencies=8, grid_resolution=16)
    elif pe == 'hash':
        mlp = NeuralHighlighter(encoding_type='hash', hash_levels=24, hash_features_per_level=4)

    mlp.to(device)
    optim = torch.optim.Adam(mlp.parameters(), learning_rate)

    # ==== CLIP ====
    clip_model, preprocess = get_clip_model(clip_model_name)
    tokenizer = clip.tokenize

    vertices = copy.deepcopy(mesh.vertices).to(device)
    n_views = 5
    losses = []

    # ==== Setup Output Directory ====
    timestamp = datetime.now().strftime("%Y-%m-%d--%H:%M:%S")

    export_path = os.path.join(base_path, f"run_PE_{pe}_SYNONYMS_{synonyms_exts}_{timestamp}")
    Path(os.path.join(export_path, 'renders')).mkdir(parents=True, exist_ok=True)

    if synonyms_exts:
      prompts, weights = generate_prompts(prompt_template, semantic_class, affordance, clip_model, tokenizer, device)
      print(prompts)
      print(weights)
    else:
      prompts = [prompt]

    start_time = time.time()
    # ==== Training Loop ====
    for i in tqdm(range(n_iter)):
        optim.zero_grad()

        # predict highlight probabilities
        if pe == 'hash':
          vertices = (vertices - vertices.min(0)[0]) / (vertices.max(0)[0] - vertices.min(0)[0] + 1e-8)
        pred_class = mlp(vertices)

        # color and render mesh
        sampled_mesh = mesh
        color_mesh(pred_class, sampled_mesh, colors)
        rendered_images, elev, azim = render.render_views(sampled_mesh,
                                                          num_views=n_views,
                                                          show=False,
                                                          center_azim=0,
                                                          center_elev=0,
                                                          std=1,
                                                          return_views=True,
                                                          lighting=True,
                                                          background=background)

        # compute CLIP loss
        if (len(prompts)) == 1:
          weights = None

        loss = clip_loss(rendered_images, prompts, clip_transform, clip_model, tokenizer, device, augment_transform, n_augs, weights)
        loss.backward(retain_graph=True)
        optim.step()

        with torch.no_grad():
            losses.append(loss.item())

        # report
        if i % 100 == 0:
            print(f"Iter {i} | Last 100 CLIP score: {np.mean(losses[-100:])} | Current loss: {loss.item()}")
            save_renders(export_path, i, rendered_images)
            with open(os.path.join(export_path, "training_info.txt"), "a") as f:
                f.write(f"Iter {i} | Prompt: {prompts[0]}, Last 100 avg CLIP score: {np.mean(losses[-100:]):.4f}, Current loss: {loss.item():.4f}\n")

    # ==== Save Final Results ====
    save_final_results(export_path, objbase, mesh, mlp, vertices, colors, render, background)

    total_time = time.time() - start_time
    minutes, seconds = divmod(total_time, 60)

    # ==== Save Prompt ====
    with open(os.path.join(export_path, "summary.txt"), "w") as f:
        f.write(f"Prompts:\n")
        for prompt in prompts:
            f.write(f"{prompt}\n")
        f.write(f"CLIP model: {clip_model_name}\n")
        f.write(f"Learning rate: {learning_rate}\n")
        f.write(f"Number of iterations: {n_iter}\n")
        f.write(f"Number of views: {n_views}\n")
        f.write(f"Number of augmentations: {n_augs}\n")
        f.write(f"Final CLIP score: {np.mean(losses[-100:]):.4f}\n")
        f.write(f"Final loss: {loss.item():.4f}\n")
        f.write(f"Total time: {int(minutes)}m {int(seconds)}s\n")