In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import sys
sys.path.insert(0, "/home/noamatia/repos/control_point_e")
sys.path.insert(0, "/home/noamatia/repos/control_point_e/changeit3d")
import tqdm
import torch
import logging
import argparse
import numpy as np
import pandas as pd
import os.path as osp
from functools import partial
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from control_point_e import ControlPointE
from control_shapenet import ControlShapeNet
from point_e.util.plotting import plot_point_cloud
from changeit3d.utils.basics import parallel_apply
from changeit3d.language.vocabulary import Vocabulary
from changeit3d.in_out.pointcloud import pc_loader_from_npz
from changeit3d.evaluation.all_metrics import run_all_metrics

Dask dataframe query planning is disabled because dask-expr is not installed.

You can install it with `pip install dask[dataframe]` or `conda install dask`.
This will raise in a future version.



Jitting Chamfer 3D
Loaded JIT 3D CUDA chamfer distance


In [2]:
obj = "chair"
num_samples = 500
n_sample_points = 2048
ckpt = "epoch=99-step=52000.ckpt"
base_dir = "/scratch/noam/control_point_e"
top_pc_dir = "/scratch/noam/shapetalk/point_clouds/scaled_to_align_rendering"
# run_name = "07_31_2024_17_20_58_utterance_train_chair_val_chair_switch_0_chamfer_0_5"
run_name = "07_31_2024_20_53_33_utterance_train_chair_val_chair_switch_0_25_chamfer_0_5"
# run_name = "07_31_2024_20_53_04_utterance_train_chair_val_chair_switch_0_5_chamfer_0_5"

In [3]:
output_dir = os.path.join(base_dir, "eval_changeit3d", obj, f"{num_samples}_random_samples")
os.makedirs(output_dir, exist_ok=True)
dataset_dir = os.path.join(base_dir, "datasets", obj, "val.csv")
samples_path = os.path.join(output_dir, "samples.csv")
if not os.path.exists(samples_path):
    df = pd.read_csv(dataset_dir)
    df = df.sample(num_samples)
    df.to_csv(samples_path, index=False)
else:
    df = pd.read_csv(samples_path)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = ControlShapeNet(
    df=df,
    batch_size=6,
    device=device,
    num_points=1024,
    prompt_key="utterance",
)
data_loader = DataLoader(dataset=dataset, batch_size=6, shuffle=False)
model = ControlPointE.load_from_checkpoint(
    os.path.join(f"/scratch/noam/control_point_e/executions", run_name, "checkpoints", ckpt),
    lr=7e-5 * 0.4,
    dev=device,
    batch_size=6,
    timesteps=1024,
    num_points=1024,
    switch_prob=0.5,
)

Creating ControlShapeNet dataset: 100%|██████████| 500/500 [00:01<00:00, 281.25it/s]


In [5]:
output = None
output_path = os.path.join(output_dir, f"{run_name}.pt")
if not os.path.exists(output_path):
    for batch in tqdm.tqdm(data_loader):
        prompts, source_latents = (batch["prompts"], batch["source_latents"].to(device))
        indices = torch.randperm(3072)[:n_sample_points]
        curr_output = model.sampler.sample_batch(
            guidances=[source_latents, None],
            model_kwargs={"texts": prompts},
            batch_size=6,
        )[:, :3, indices]
        if output is None:
            output = curr_output.detach().cpu()
        else:
            output = torch.cat((output, curr_output.detach().cpu()), dim=0)
    torch.save(output[:num_samples], output_path)

In [6]:
parser = argparse.ArgumentParser()
parser.add_argument('-shape_talk_file', type=str, default="/scratch/noam/shapetalk/language/shapetalk_preprocessed_public_version_0.csv")
parser.add_argument('-vocab_file', type=str, default="/scratch/noam/shapetalk/language/vocabulary.pkl")
parser.add_argument('-latent_codes_file', type=str, default="/scratch/noam/changeit3d/pretrained/shape_latents/pcae_latent_codes.pkl")    
parser.add_argument('-pretrained_changeit3d', type=str, default="/scratch/noam/changeit3d/pretrained/changers/pcae_based/all_shapetalk_classes/decoupling_mag_direction/idpen_0.01_sc_False/best_model.pt")    
parser.add_argument('-top_pc_dir', type=str, default="/scratch/noam/shapetalk/point_clouds/scaled_to_align_rendering")
parser.add_argument('--restrict_shape_class', type=str, nargs='*', default=['chair'])        
parser.add_argument('--pretrained_shape_classifier', type=str,default="/scratch/noam/changeit3d/pretrained/pc_classifiers/rs_2022/all_shapetalk_classes/best_model.pkl")    
parser.add_argument('--compute_fpd', type=bool, default=True)
parser.add_argument('--shape_part_classifiers_top_dir', type=str, default="/scratch/noam/changeit3d/pretrained/part_predictors/shapenet_core_based")
parser.add_argument('--pretrained_oracle_listener', type=str, default="/scratch/noam/changeit3d/pretrained/listeners/oracle_listener/all_shapetalk_classes/rs_2023/listener_dgcnn_based/ablation1/best_model.pkl")
parser.add_argument('--shape_generator_type', type=str, default="pcae", choices=["pcae", "sgf", "imnet"])
parser.add_argument('--pretrained_shape_generator', type=str, required=False, default="/scratch/noam/changeit3d/pretrained/pc_autoencoders/pointnet/rs_2022/points_4096/all_classes/scaled_to_align_rendering/08-07-2022-22-23-42/best_model.pt")
parser.add_argument('--n_sample_points', type=int, default=2048)
parser.add_argument('--sub_sample_dataset', type=int)
parser.add_argument('--gpu_id', type=int, default=0)
parser.add_argument('--save_reconstructions', default=False, type=bool)
parser.add_argument('--use_timestamp', default=False, type=bool)
parser.add_argument('--experiment_tag', type=str)
parser.add_argument('--random_seed', type=int, default=2022)
parser.add_argument('--log_dir', type=str, default='./logs')
parser.add_argument('--clean_train_val_data', type=bool, default=False)
parser.add_argument('-pretrained_listener_file', type=str, default="/scratch/noam/changeit3d/pretrained/listeners/oracle_listener/all_shapetalk_classes/rs_2023/listener_dgcnn_based/ablation1/best_model.pkl")
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--num_workers', type=int, default=10)
parser.add_argument('--evaluate_retrieval_version', type=bool, default=False)
parser.add_argument('--f', type=str)
args = parser.parse_args()

In [7]:
logger = logging.getLogger('stdout_logger')
logger.setLevel(logging.INFO)

In [8]:
transformed_shapes = torch.load(output_path).transpose(1, 2).numpy()
theta = np.pi * 1 / 2   
rotation = np.array(
    [
        [np.cos(theta), -np.sin(theta), 0.0],
        [np.sin(theta), np.cos(theta), 0.0],
        [0.0, 0.0, 1.0],
    ],
    dtype=np.float32,
)
transformed_shapes = transformed_shapes @ rotation
theta = np.pi / 2
rotation = np.array(
    [
        [1.0, 0.0, 0.0],
        [0.0, np.cos(theta), -np.sin(theta)],
        [0.0, np.sin(theta), np.cos(theta)],
    ],
    dtype=np.float32,
)
transformed_shapes = transformed_shapes @ rotation
gt_pc_files = df.source_uid.apply(lambda x: osp.join(top_pc_dir, x + ".npz")).tolist()

In [9]:
pc_loader =  partial(pc_loader_from_npz, n_samples=n_sample_points, random_seed=2022)
gt_pcs = parallel_apply(gt_pc_files, pc_loader, n_processes=20)
gt_pcs = np.array(gt_pcs)
vocab = Vocabulary.load(args.vocab_file)

In [10]:
sentences = df.utterance_spelled.values
gt_classes = df.source_object_class.values
results_on_metrics = run_all_metrics(transformed_shapes, gt_pcs, gt_classes, sentences, vocab, args, logger)

OutOfMemoryError: CUDA out of memory. Tried to allocate 4.00 GiB (GPU 0; 23.68 GiB total capacity; 10.46 GiB already allocated; 3.64 GiB free; 18.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [16]:
curr = 0
n_examples = 100
results_dir = os.path.join(output_dir, "results", run_name)
os.makedirs(results_dir, exist_ok=True)
for batch in tqdm.tqdm(data_loader):
    if curr >= n_examples:
        break
    prompts, source_latents, target_latents = (
        batch["prompts"],
        batch["source_latents"].to(device),
        batch["target_latents"].to(device),
    )
    samples = model.sampler.sample_batch(
        batch_size=6,
        model_kwargs={"texts": prompts},
        guidances=[source_latents, None],
    )
    pcs = model.sampler.output_to_point_clouds(samples)
    for i, (prompt, source_latent, target_latent) in enumerate(zip(prompts, source_latents, target_latents)):
        if curr >= n_examples:
            break
        output_path = os.path.join(results_dir, f"output_{curr}.png")
        fig = plot_point_cloud(pcs[i], theta=model.theta)
        fig.savefig(output_path)
        plt.close()
        source_path = os.path.join(output_dir, "results", "source", f"source_{curr}.png")
        if not os.path.exists(source_path):
            samples = model.sampler.sample_batch(
                            batch_size=1,
                            model_kwargs={},
                            prev_samples=source_latent.unsqueeze(0),
                        )
            pc = model.sampler.output_to_point_clouds(samples)[0]
            fig = plot_point_cloud(pc, theta=model.theta)
            fig.savefig(source_path)
            plt.close()
        target_path = os.path.join(output_dir, "results", "target", f"target_{curr}.png")
        if not os.path.exists(target_path):
            samples = model.sampler.sample_batch(
                            batch_size=1,
                            model_kwargs={},
                            prev_samples=target_latent.unsqueeze(0),
                        )       
            pc = model.sampler.output_to_point_clouds(samples)[0]
            fig = plot_point_cloud(pc, theta=model.theta)
            fig.savefig(target_path)
            plt.close()
        curr += 1

 20%|██        | 17/84 [1:17:18<5:04:40, 272.84s/it]


In [20]:
run_names = [
    "07_31_2024_17_20_58_utterance_train_chair_val_chair_switch_0_chamfer_0_5",
    "07_31_2024_20_53_33_utterance_train_chair_val_chair_switch_0_25_chamfer_0_5",
    "07_31_2024_20_53_04_utterance_train_chair_val_chair_switch_0_5_chamfer_0_5",
]
html = """
<style>
    table {
        font-size: 20px;
    }
</style>
<table>
"""
html += "<tr><td>prompt</td><td>source</td><td>target</td>"
for p in ["0", "0_25", "0_5"]:
    html += f"<td>{p}</td>"
html += "</tr>\n"

curr = 0
n_examples = 100
for batch in tqdm.tqdm(data_loader):
    if curr >= n_examples:
        break
    prompts = batch["prompts"]
    for i, prompt in enumerate(prompts):
        if curr >= n_examples:
            break
        source_path = os.path.join("source", f"source_{curr}.png")
        target_path = os.path.join("target", f"target_{curr}.png")
        html += f"<tr><td>{prompt}</td><td><img src='{source_path}'></td><td><img src='{target_path}'></td>"
        for run_name in run_names:
            output_path = os.path.join(run_name, f"output_{curr}.png")
            html += f"<td><img src='{output_path}'></td>"
        html += "</tr>\n"
        curr += 1
html += "</table>"
with open(os.path.join(output_dir, "results", "results.html"), "w") as f:
    f.write(html)


 20%|██        | 17/84 [00:00<00:00, 914.44it/s]
