In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import sys
sys.path.insert(0, "/home/noamatia/repos/control_point_e")
sys.path.insert(0, "/home/noamatia/repos/control_point_e/changeit3d")
import tqdm
import torch
import logging
import argparse
import numpy as np
import pandas as pd
import os.path as osp
from functools import partial
from shapetalk import ShapeTalk
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from control_point_e import ControlPointE
from point_e.util.point_cloud import PointCloud
from changeit3d.utils.basics import parallel_apply
from point_e.util.plotting import render_point_cloud
from changeit3d.language.vocabulary import Vocabulary
from point_e.diffusion.sampler import PointCloudSampler
from changeit3d.in_out.pointcloud import pc_loader_from_npz
from changeit3d.evaluation.all_metrics import run_all_metrics

Dask dataframe query planning is disabled because dask-expr is not installed.

You can install it with `pip install dask[dataframe]` or `conda install dask`.
This will raise in a future version.



Jitting Chamfer 3D
Loaded JIT 3D CUDA chamfer distance


In [3]:
obj = "chair"
num_samples = 100
n_sample_points = 2048
ckpt = "epoch=99-step=52000.ckpt"
base_dir = "/scratch/noam/control_point_e"
top_pc_dir = "/scratch/noam/shapetalk/point_clouds/scaled_to_align_rendering"
run_name = "08_06_2024_16_10_17_train_chair_chamfer_0_5_val_chair_prompt_key_utterance_cond_drop_0_5_copy_0_1_copy_prompt_COPY"

In [4]:
output_dir = os.path.join(base_dir, "eval", obj, f"{num_samples}_random_samples")
os.makedirs(output_dir, exist_ok=True)
dataset_dir = os.path.join(base_dir, "datasets", obj, "val.csv")
samples_path = os.path.join(output_dir, "samples.csv")
if not os.path.exists(samples_path):
    df = pd.read_csv(dataset_dir)
    df = df.sample(num_samples)
    df.to_csv(samples_path, index=False)
else:
    df = pd.read_csv(samples_path)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = ShapeTalk(
    df=df,
    batch_size=6,
    device=device,
    num_points=1024,
    prompt_key="utterance",
)
data_loader = DataLoader(dataset=dataset, batch_size=6, shuffle=False)
model = ControlPointE.load_from_checkpoint(
    os.path.join(f"/scratch/noam/control_point_e/executions", run_name, "checkpoints", ckpt),
    lr=7e-5 * 0.4,
    dev=device,
    batch_size=6,
    timesteps=1024,
    num_points=1024,
    copy_prob=0.1,
    copy_prompt="COPY",
    cond_drop_prob=0.5,
)

Creating ShapeTalk dataset:   0%|          | 0/100 [00:00<?, ?it/s]

Creating ShapeTalk dataset: 100%|██████████| 100/100 [00:00<00:00, 518.78it/s]


In [5]:
model.sampler = PointCloudSampler(
            s_churn=[3],
            sigma_max=[120],
            device=model.dev,
            sigma_min=[1e-3],
            num_points=[1024],
            use_karras=[True],
            karras_steps=[64],
            models=[model.model],
            guidance_scale=[3.0],
            aux_channels=["R", "G", "B"],
            diffusions=[model.diffusion],
            model_kwargs_key_filter=["texts"],
        )

In [30]:
i, j = 0, 0
output = None
output_path = os.path.join(output_dir, f"{run_name}.pt")
results_source_dir = os.path.join(output_dir, "results", "source")
os.makedirs(results_source_dir, exist_ok=True)
results_target_dir = os.path.join(output_dir, "results", "target")
os.makedirs(results_target_dir, exist_ok=True)
if not os.path.exists(output_path):
    for batch in tqdm.tqdm(data_loader):
        source_pcs = model.sampler.output_to_point_clouds(batch["source_latents"])
        target_pcs = model.sampler.output_to_point_clouds(batch["target_latents"])
        for pc in source_pcs:
            render_point_cloud(
                pc,
                theta=np.pi,
                output_path=os.path.join(results_source_dir, f"{i}.png"),
            )
            i += 1
        for pc in target_pcs:
            render_point_cloud(
                pc,
                theta=np.pi,
                output_path=os.path.join(results_target_dir, f"{j}.png"),
            )
            j += 1

  0%|          | 0/17 [00:00<?, ?it/s]

In [5]:
i = 0
output = None
output_path = os.path.join(output_dir, f"{run_name}.pt")
results_dir = os.path.join(output_dir, "results", run_name, "outputs")
os.makedirs(results_dir, exist_ok=True)
if not os.path.exists(output_path):
    for batch in tqdm.tqdm(data_loader):
        prompts, source_latents = (batch["prompts"], batch["source_latents"].to(device))
        curr_output = model.sampler.sample_batch(
            guidances=[source_latents, None],
            model_kwargs={"texts": prompts},
            batch_size=6,
        )
        pcs = model.sampler.output_to_point_clouds(curr_output)
        for pc in pcs:
            render_point_cloud(
                pc,
                theta=np.pi,
                output_path=os.path.join(results_dir, f"{i}.png"),
            )
            i += 1
        if output is None:
            output = curr_output.detach().cpu()
        else:
            output = torch.cat((output, curr_output.detach().cpu()), dim=0)
    torch.save(output, output_path)

100%|██████████| 17/17 [40:44<00:00, 143.81s/it]


In [8]:
i = 0
model.sampler.injection_percentile = 0.5
model.sampler.injection_t = 30
output = None
suffix_str = f"injection_t_{model.sampler.injection_t}_p_{str(model.sampler.injection_percentile).replace('.', '_')}"
output_path = os.path.join(output_dir, f"{run_name}_{suffix_str}.pt")
results_dir = os.path.join(output_dir, "results", run_name, f"outputs_{suffix_str}")
os.makedirs(results_dir, exist_ok=True)
copy_dir = os.path.join(output_dir, "results", run_name, f"copy_{suffix_str}")
os.makedirs(copy_dir, exist_ok=True)
if not os.path.exists(output_path):
    for batch in tqdm.tqdm(data_loader):
        prompts, source_latents = batch["prompts"], batch["source_latents"]
        for prompt, source_latent in zip(prompts, source_latents):
            injection_seed_dir = os.path.join(output_dir, "seeds", f"{i}")
            os.makedirs(injection_seed_dir, exist_ok=True)
            model.sampler.injection_seed_dir = injection_seed_dir
            curr_output = model.sampler.sample_batch(
                guidances=[torch.stack([source_latent, source_latent]).to(device)],
                model_kwargs={"texts": ["COPY", prompt]},
                batch_size=2,
            )
            pcs = model.sampler.output_to_point_clouds(curr_output)
            render_point_cloud(
                pcs[0],
                theta=np.pi,
                output_path=os.path.join(copy_dir, f"{i}.png"),
            )
            render_point_cloud(
                pcs[1],
                theta=np.pi,
                output_path=os.path.join(results_dir, f"{i}.png"),
            )
            i += 1
            if output is None:
                output = curr_output.detach().cpu()[1:]
            else:
                output = torch.cat((output, curr_output.detach().cpu()[1:]), dim=0)
    torch.save(output, output_path)
model.sampler.precentile = None
model.sampler.injection_t = None
model.sampler.injection_seed_dir = None

  0%|          | 0/17 [00:00<?, ?it/s]

100%|██████████| 17/17 [34:50<00:00, 122.99s/it]


In [15]:
html = "<style>\n"
html += "td { font-size: 64px; } /* Increase the font size */\n"
html += "img { width: 1000px; height: 1000px; object-fit: cover; } /* Set square dimensions */\n"
html += "</style>\n"
html += "<table>\n"
for i, row in df.iterrows():
    html += "<tr><td><b>prompt</b></td><td><b>shapetalk</b></td><td><b>spice_p_1_0</b></td><td><b>spice_p_0_9</b></td><td><b>spice_p_0_75</b></td><td><b>spice_p_0_5</b></td></tr>\n"
    html += f"<tr><td>{row.utterance}</td><td><img src='source/{i}.png'></td><td><img src='{run_name}/copy_injection_t_30_p_1_0/{i}.png'></td><td><img src='{run_name}/copy_injection_t_30_p_0_9/{i}.png'></td><td><img src='{run_name}/copy_injection_t_30_p_0_75/{i}.png'></td><td><img src='{run_name}/copy_injection_t_30_p_0_5/{i}.png'></td></tr>\n"
    html += f"<tr><td></td><td><img src='target/{i}.png'></td><td><img src='{run_name}/outputs_injection_t_30_p_1_0/{i}.png'></td><td><img src='{run_name}/outputs_injection_t_30_p_0_9/{i}.png'></td><td><img src='{run_name}/outputs_injection_t_30_p_0_75/{i}.png'></td><td><img src='{run_name}/outputs_injection_t_30_p_0_5/{i}.png'></td></tr>\n"
html += "</table>"
with open(os.path.join(output_dir, "results", "index.html"), "w") as f:
    f.write(html)


In [193]:
parser = argparse.ArgumentParser()
parser.add_argument('-shape_talk_file', type=str, default="/scratch/noam/shapetalk/language/shapetalk_preprocessed_public_version_0.csv")
parser.add_argument('-vocab_file', type=str, default="/scratch/noam/shapetalk/language/vocabulary.pkl")
parser.add_argument('-latent_codes_file', type=str, default="/scratch/noam/changeit3d/pretrained/shape_latents/pcae_latent_codes.pkl")    
parser.add_argument('-pretrained_changeit3d', type=str, default="/scratch/noam/changeit3d/pretrained/changers/pcae_based/all_shapetalk_classes/decoupling_mag_direction/idpen_0.01_sc_False/best_model.pt")    
parser.add_argument('-top_pc_dir', type=str, default="/scratch/noam/shapetalk/point_clouds/scaled_to_align_rendering")
parser.add_argument('--restrict_shape_class', type=str, nargs='*', default=['chair'])        
parser.add_argument('--pretrained_shape_classifier', type=str,default="/scratch/noam/changeit3d/pretrained/pc_classifiers/rs_2022/all_shapetalk_classes/best_model.pkl")    
parser.add_argument('--compute_fpd', type=bool, default=True)
parser.add_argument('--shape_part_classifiers_top_dir', type=str, default="/scratch/noam/changeit3d/pretrained/part_predictors/shapenet_core_based")
parser.add_argument('--pretrained_oracle_listener', type=str, default="/scratch/noam/changeit3d/pretrained/listeners/oracle_listener/all_shapetalk_classes/rs_2023/listener_dgcnn_based/ablation1/best_model.pkl")
parser.add_argument('--shape_generator_type', type=str, default="pcae", choices=["pcae", "sgf", "imnet"])
parser.add_argument('--pretrained_shape_generator', type=str, required=False, default="/scratch/noam/changeit3d/pretrained/pc_autoencoders/pointnet/rs_2022/points_4096/all_classes/scaled_to_align_rendering/08-07-2022-22-23-42/best_model.pt")
parser.add_argument('--n_sample_points', type=int, default=2048)
parser.add_argument('--sub_sample_dataset', type=int)
parser.add_argument('--gpu_id', type=int, default=0)
parser.add_argument('--save_reconstructions', default=False, type=bool)
parser.add_argument('--use_timestamp', default=False, type=bool)
parser.add_argument('--experiment_tag', type=str)
parser.add_argument('--random_seed', type=int, default=2022)
parser.add_argument('--log_dir', type=str, default='./logs')
parser.add_argument('--clean_train_val_data', type=bool, default=False)
parser.add_argument('-pretrained_listener_file', type=str, default="/scratch/noam/changeit3d/pretrained/listeners/oracle_listener/all_shapetalk_classes/rs_2023/listener_dgcnn_based/ablation1/best_model.pkl")
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--num_workers', type=int, default=10)
parser.add_argument('--evaluate_retrieval_version', type=bool, default=False)
parser.add_argument('--f', type=str)
args = parser.parse_args()

In [194]:
logger = logging.getLogger('stdout_logger')

In [195]:
transformed_shapes = torch.load(output_path).transpose(1, 2).cpu().numpy()
theta = np.pi * 1 / 2   
rotation = np.array(
    [
        [np.cos(theta), -np.sin(theta), 0.0],
        [np.sin(theta), np.cos(theta), 0.0],
        [0.0, 0.0, 1.0],
    ],
    dtype=np.float32,
)
transformed_shapes = transformed_shapes @ rotation
theta = np.pi / 2
rotation = np.array(
    [
        [1.0, 0.0, 0.0],
        [0.0, np.cos(theta), -np.sin(theta)],
        [0.0, np.sin(theta), np.cos(theta)],
    ],
    dtype=np.float32,
)
transformed_shapes = transformed_shapes @ rotation
gt_pc_files = df.source_uid.apply(lambda x: osp.join(top_pc_dir, x + ".npz")).tolist()

In [196]:
pc_loader =  partial(pc_loader_from_npz, n_samples=n_sample_points, random_seed=2022)
gt_pcs = parallel_apply(gt_pc_files, pc_loader, n_processes=1)
gt_pcs = np.array(gt_pcs)
vocab = Vocabulary.load(args.vocab_file)

In [197]:
gt_pcs_decoded = None
gt_path = os.path.join(output_dir, "gt.pt")
if not os.path.exists(gt_path):
    for batch in tqdm.tqdm(data_loader):
        source_latents = batch["source_latents"]
        samples = model.sampler.sample_batch(
                                batch_size=6,
                                model_kwargs={},
                                prev_samples=source_latents,
                            )
        pcs = model.sampler.output_to_point_clouds(samples)
        for pc in pcs:
            pc = pc.random_sample(2048)
            if gt_pcs_decoded is None:
                gt_pcs_decoded = pc.encode().unsqueeze(0)
            else:
                gt_pcs_decoded = torch.cat((gt_pcs_decoded, pc.encode().unsqueeze(0)), dim=0)
    gt_pcs_decoded = gt_pcs[:, :3, :]
    torch.save(gt_pcs_decoded, gt_path)

In [198]:
gt_pcs_decoded = torch.load(gt_path)[:num_samples].transpose(1, 2).cpu().numpy()
theta = np.pi * 1 / 2   
rotation = np.array(
    [
        [np.cos(theta), -np.sin(theta), 0.0],
        [np.sin(theta), np.cos(theta), 0.0],
        [0.0, 0.0, 1.0],
    ],
    dtype=np.float32,
)
gt_pcs_decoded = gt_pcs_decoded @ rotation
theta = np.pi / 2
rotation = np.array(
    [
        [1.0, 0.0, 0.0],
        [0.0, np.cos(theta), -np.sin(theta)],
        [0.0, np.sin(theta), np.cos(theta)],
    ],
    dtype=np.float32,
)
gt_pcs_decoded = gt_pcs_decoded @ rotation

In [199]:
sentences = df.utterance_spelled.values
gt_classes = df.source_object_class.values
changeit3d_outputs = torch.load(os.path.join(output_dir, "changeit3d.pt"))

In [204]:
i = 2
results_on_metrics = run_all_metrics(changeit3d_outputs[i:i+1], gt_pcs[i:i+1], gt_classes[i:i+1], sentences[i:i+1], vocab, args, logger)
# results_on_metrics = run_all_metrics(transformed_shapes[i:i+1], gt_pcs[i:i+1], gt_classes[i:i+1], sentences[i:i+1], vocab, args, logger)

Chamfer Distance (all pairs), Average: 2.865
Chamfer Distance (all pairs), Average: 2.865
Chamfer Distance (all pairs), Average: 2.865
Chamfer Distance (all pairs), Average: 2.865
Chamfer Distance (all pairs), Average: 2.865
Chamfer Distance (all pairs), Average: 2.865
Chamfer Distance (all pairs), Average, per class:
Chamfer Distance (all pairs), Average, per class:
Chamfer Distance (all pairs), Average, per class:
Chamfer Distance (all pairs), Average, per class:
Chamfer Distance (all pairs), Average, per class:
Chamfer Distance (all pairs), Average, per class:
  shape_class  holistic-chamfer
0       chair             2.865
  shape_class  holistic-chamfer
0       chair             2.865
  shape_class  holistic-chamfer
0       chair             2.865
  shape_class  holistic-chamfer
0       chair             2.865
  shape_class  holistic-chamfer
0       chair             2.865
  shape_class  holistic-chamfer
0       chair             2.865
LAB Average:0.9895337224006653
LAB Average:0.9

In [None]:
curr = 0
n_examples = 100
results_dir = os.path.join(output_dir, "results", run_name)
os.makedirs(results_dir, exist_ok=True)
for batch in tqdm.tqdm(data_loader):
    if curr >= n_examples:
        break
    prompts, source_latents, target_latents = (
        batch["prompts"],
        batch["source_latents"].to(device),
        batch["target_latents"].to(device),
    )
    samples = model.sampler.sample_batch(
        batch_size=6,
        model_kwargs={"texts": prompts},
        guidances=[source_latents, None],
    )
    pcs = model.sampler.output_to_point_clouds(samples)
    for i, (prompt, source_latent, target_latent) in enumerate(zip(prompts, source_latents, target_latents)):
        if curr >= n_examples:
            break
        output_path = os.path.join(results_dir, f"output_{curr}.png")
        fig = plot_point_cloud(pcs[i], theta=model.theta)
        fig.savefig(output_path)
        plt.close()
        source_path = os.path.join(output_dir, "results", "source", f"source_{curr}.png")
        if not os.path.exists(target_path):
            samples = model.sampler.sample_batch(
                            batch_size=1,
                            model_kwargs={},
                            prev_samples=source_latent.unsqueeze(0),
                        )
            pc = model.sampler.output_to_point_clouds(samples)[0]
            fig = plot_point_cloud(pc, theta=model.theta)
            fig.savefig(source_path)
            plt.close()
        target_path = os.path.join(output_dir, "results", "target", f"target_{curr}.png")
        if not os.path.exists(target_path):
            samples = model.sampler.sample_batch(
                            batch_size=1,
                            model_kwargs={},
                            prev_samples=target_latent.unsqueeze(0),
                        )       
            pc = model.sampler.output_to_point_clouds(samples)[0]
            fig = plot_point_cloud(pc, theta=model.theta)
            fig.savefig(target_path)
            plt.close()
        curr += 1

In [30]:
changeit3d_outputs = torch.load(os.path.join(output_dir, "changeit3d.pt"))
theta = np.pi * 2 / 2   
rotation = np.array(
    [
        [np.cos(theta), -np.sin(theta), 0.0],
        [np.sin(theta), np.cos(theta), 0.0],
        [0.0, 0.0, 1.0],
    ],
    dtype=np.float32,
)
changeit3d_outputs = changeit3d_outputs @ rotation
theta = np.pi / 2
rotation = np.array(
    [
        [1.0, 0.0, 0.0],
        [0.0, np.cos(theta), -np.sin(theta)],
        [0.0, np.sin(theta), np.cos(theta)],
    ],
    dtype=np.float32,
)
changeit3d_outputs = changeit3d_outputs @ rotation

n_examples = 100
results_dir = os.path.join(output_dir, "results", "changeit3d")
os.makedirs(results_dir, exist_ok=True)
for curr in tqdm.tqdm(range(n_examples)):
    changeit_path = os.path.join(output_dir, "results", "changeit3d", f"changeit3d_{curr}.png")
    if not os.path.exists(changeit_path):
        coords = changeit3d_outputs[curr]
        channels = {k: np.zeros_like(coords[:, 0], dtype=np.float32) for k in ["R", "G", "B"]}
        pc = PointCloud(coords, channels)
        pc = pc.random_sample(1024)
        samples = model.sampler.sample_batch(
                        batch_size=1,
                        model_kwargs={},
                        prev_samples=pc.encode().unsqueeze(0).to(device),
                    )
        pc = model.sampler.output_to_point_clouds(samples)[0]
        fig = plot_point_cloud(pc)
        fig.savefig(changeit_path)
        plt.close()


100%|██████████| 100/100 [08:39<00:00,  5.19s/it]


In [85]:
run_names = [
    "07_31_2024_17_20_58_utterance_train_chair_val_chair_switch_0_chamfer_0_5",
    "07_31_2024_20_53_33_utterance_train_chair_val_chair_switch_0_25_chamfer_0_5",
    "07_31_2024_20_53_04_utterance_train_chair_val_chair_switch_0_5_chamfer_0_5",
]
html = """
<style>
    table {
        font-size: 60px;
    }
</style>
<table>
"""
html += "<tr><td>prompt</td><td>source</td><td>target</td><td>changeit3d</td>"
for p in ["0", "0_25", "0_5"]:
    html += f"<td>{p}</td>"
html += "</tr>\n"

curr = 0
n_examples = 100
for batch in tqdm.tqdm(data_loader):
    if curr >= n_examples:
        break
    prompts = batch["prompts"]
    for i, prompt in enumerate(prompts):
        if curr >= n_examples:
            break
        source_path = os.path.join("source", f"source_{curr}.png")
        target_path = os.path.join("target", f"target_{curr}.png")
        changeit_path = os.path.join("changeit3d", f"changeit3d_{curr}.png")
        html += f"<tr><td>{prompt}</td><td><img src='{source_path}'></td><td><img src='{target_path}'></td><td><img src='{changeit_path}'></td>"
        for run_name in run_names:
            output_path = os.path.join(run_name, f"output_{curr}.png")
            html += f"<td><img src='{output_path}'></td>"
        html += "</tr>\n"
        curr += 1
html += "</table>"
with open(os.path.join(output_dir, "results", "results.html"), "w") as f:
    f.write(html)


 20%|██        | 17/84 [00:00<00:00, 981.58it/s]
