In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import json
import torch
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm
from spice import SPICE
from point_e.util.point_cloud import PointCloud
from point_e.models.download import load_checkpoint
from point_e.util.plotting import render_point_cloud
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config

In [2]:
batch_size = 6
lr = 7e-5 * 0.4
copy_prob = 0.1
num_points = 1024
cond_drop_prob = 0.5
copy_prompt = "COPY"
shapenet_uid_to_partnet_uid_path = "/scratch/noam/partnet/chair.json"
shapetalk_csv = "/scratch/noam/control_point_e/datasets/chair/val.csv"
checkpoint_path = "/scratch/noam/control_point_e/executions/08_28_2024_21_57_22_train_chair_val_chair_prompt_key_utterance_cond_drop_0_5_copy_0_1_copy_prompt_COPY/checkpoints/epoch=74-step=66375.ckpt"

In [3]:
with open(shapenet_uid_to_partnet_uid_path, "r") as f:
    shapenet_uid_to_partnet_uid = json.load(f)
df = pd.read_csv(shapetalk_csv)
df = df[df.source_uid.apply(lambda uid: uid in shapenet_uid_to_partnet_uid)]
df = df[df.llama3_utterance != "Unknown"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SPICE.load_from_checkpoint(
    lr=lr,
    dev=device,
    copy_prob=copy_prob,
    val_dataloader=None,
    num_points=num_points,
    test_dataloader=None,
    batch_size=batch_size,
    copy_prompt=copy_prompt,
    cond_drop_prob=cond_drop_prob,
    checkpoint_path=checkpoint_path,
)

In [4]:
llama3_utterance_to_part = {
    "a chair with long legs": "leg",
    "a chair with thin legs": "leg",
    "a chair with thick legs": "leg",
    "a chair with a thin seat": "seat",
    "a chair with a thick seat": "seat",
    "a chair with thin armrests": "arm",
    "a chair with thick armrests": "arm",
    "a chair with a long backrest": "back",
    "a chair with a thin backrest": "back",
    "a chair with a thick backrest": "back",
    "a chair with a short backrest": "back",
}

In [5]:
images_dir = "inpainting/images"
os.makedirs(images_dir, exist_ok=True)
html = "<table style='font-size:36px;'>\n"
html += "<tr><th>ID</th><th>Prompt</th><th>Condition</th><th>Copy</th><th>SPICE</th><th>SPICE+Inpainting</th><th>Masked</th></tr>\n"
for llama3_utterance, part in tqdm(llama3_utterance_to_part.items(), total=len(llama3_utterance_to_part)):
    curr_df = df[df.llama3_utterance == llama3_utterance]
    curr_df = curr_df.sample(min(5, len(curr_df)))
    for idx, row in tqdm(curr_df.iterrows(), desc=llama3_utterance, total=len(curr_df)):
        condition_pc = PointCloud.load_partnet(shapenet_uid_to_partnet_uid[row.source_uid], row.source_uid).random_sample(4096)
        Image.fromarray(render_point_cloud(condition_pc)).save(f"{images_dir}/{idx}_condition.png")
        guidance = condition_pc.random_sample(1024).encode().unsqueeze(0).to(device)
        injection_dir = os.path.join("/scratch/noam/seeds", str(idx))
        os.makedirs(injection_dir, exist_ok=True)
        copy_samples = model.sampler.sample_batch(
            batch_size=1,
            guidances=[guidance, None],
            injection_dir=injection_dir,
            model_kwargs={"texts": [copy_prompt]},
        )
        copy_pc = model.sampler.output_to_point_clouds(copy_samples)[0].add_labels(condition_pc)
        Image.fromarray(render_point_cloud(copy_pc)).save(f"{images_dir}/{idx}_copy.png")
        spice_samples = model.sampler.sample_batch(
            batch_size=1,
            guidances=[guidance, None],
            injection_dir=injection_dir,
            model_kwargs={"texts": [row.utterance]},
        )
        spice_pc = model.sampler.output_to_point_clouds(spice_samples)[0]
        Image.fromarray(render_point_cloud(spice_pc)).save(f"{images_dir}/{idx}_spice.png")
        output_samples = model.sampler.sample_batch(
            batch_size=1,
            guidances=[guidance, None],
            injection_dir=injection_dir,
            model_kwargs={"texts": [row.utterance]},
            injection_indices=copy_pc.random_sample(num_points).mask(part)
        )
        output_pc = model.sampler.output_to_point_clouds(output_samples)[0]
        Image.fromarray(render_point_cloud(output_pc)).save(f"{images_dir}/{idx}_output.png")
        masked_pc = copy_pc.remove(part)
        Image.fromarray(render_point_cloud(masked_pc)).save(f"{images_dir}/{idx}_masked.png")
        html += f"<tr><td>{idx}</td><td>{row.utterance}</td><td><img src='images/{idx}_condition.png'></td><td><img src='images/{idx}_copy.png'></td><td><img src='images/{idx}_spice.png'></td><td><img src='images/{idx}_output.png'></td><td><img src='images/{idx}_masked.png'></td></tr>\n"
html += "</table>"
with open("inpainting/index.html", "w") as f:
    f.write(html)


  0%|          | 0/11 [00:00<?, ?it/s]

a chair with long legs:   0%|          | 0/5 [00:00<?, ?it/s]