In [1]:
import sys
sys.path.insert(0, '/home/noamatia/repos/control_point_e/')

In [2]:
import os
import torch
import numpy as np
import open3d as o3d
from tqdm.auto import tqdm
from point_e.models.download import load_checkpoint
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_name = 'base40M-textvec'
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])
base_model.load_state_dict(load_checkpoint(base_name, device))

<All keys matched successfully>

In [4]:
sampler = PointCloudSampler(
    device=device,
    s_churn = [3],
    sigma_max = [120],
    num_points=[1024],
    sigma_min = [1e-3],
    models=[base_model],
    use_karras = [True],
    karras_steps = [64],
    guidance_scale=[3.0],
    diffusions=[base_diffusion],
    aux_channels=['R', 'G', 'B'],
    model_kwargs_key_filter=['texts']
)

In [5]:
def build_experiment_dir(prompt1, prompt2, t, i):
    d = f'experiment2/p1_{prompt1}_p2_{prompt2}_t_{t}_i_{i}'
    os.makedirs(d, exist_ok=True)
    return d

In [6]:
experimental1_t = 30
sampler.experiment2_t = experimental1_t
prompt1, prompt2 = 'a_chair', 'a_chair_with_armrests'

In [7]:
for i in tqdm(range(25), total=25):
    os.environ['EXPERIMENT2_DIR'] = build_experiment_dir(prompt1, prompt2, experimental1_t, i)
    samples = None
    for x in tqdm(sampler.sample_batch_progressive(batch_size=2, model_kwargs=dict(texts=[prompt1.replace("_", " "), prompt2.replace("_", " ")]))):
        samples = x

  0%|          | 0/25 [00:00<?, ?it/s]

65it [00:07,  9.08it/s]
65it [00:06,  9.71it/s]00:07<02:51,  7.16s/it]
65it [00:06,  9.60it/s]00:13<02:38,  6.89s/it]
65it [00:06,  9.60it/s]00:20<02:30,  6.83s/it]
65it [00:06,  9.60it/s]00:27<02:23,  6.81s/it]
65it [00:06,  9.54it/s]00:34<02:15,  6.80s/it]
65it [00:06,  9.53it/s]00:40<02:09,  6.80s/it]
65it [00:06,  9.55it/s]00:47<02:02,  6.81s/it]
65it [00:06,  9.60it/s]00:54<01:55,  6.81s/it]
65it [00:06,  9.40it/s]01:01<01:48,  6.80s/it]
65it [00:06,  9.58it/s][01:08<01:42,  6.83s/it]
65it [00:06,  9.58it/s][01:15<01:35,  6.82s/it]
65it [00:06,  9.55it/s][01:21<01:28,  6.81s/it]
65it [00:06,  9.56it/s][01:28<01:21,  6.81s/it]
65it [00:06,  9.55it/s][01:35<01:14,  6.81s/it]
65it [00:06,  9.52it/s][01:42<01:08,  6.81s/it]
65it [00:06,  9.53it/s][01:49<01:01,  6.81s/it]
65it [00:06,  9.54it/s][01:55<00:54,  6.82s/it]
65it [00:06,  9.57it/s][02:02<00:47,  6.82s/it]
65it [00:06,  9.33it/s][02:09<00:40,  6.81s/it]
65it [00:06,  9.56it/s][02:16<00:34,  6.86s/it]
65it [00:06,  9.52it/s][0

In [7]:
from matplotlib import pyplot as plt
from point_e.util.plotting import plot_point_cloud
from point_e.util.point_cloud import PointCloud


for i in range(25):
    os.environ["EXPERIMENT2_DIR"] = build_experiment_dir(
        prompt1, prompt2, experimental1_t, i
    )
    selected_indices_path = os.path.join(os.environ["EXPERIMENT2_DIR"], "selected_indices.txt")
    if not os.path.exists(selected_indices_path):
        continue
    experiment2_indices = np.loadtxt(selected_indices_path, dtype=int)
    if not len(experiment2_indices):
        continue
    ply_path = os.path.join(os.environ["EXPERIMENT2_DIR"], "1.ply")
    pc = PointCloud.from_ply(ply_path)
    pc.set_color_by_indices(experiment2_indices)
    fig = plot_point_cloud(pc)
    path = os.path.join(os.getenv("EXPERIMENT2_DIR"), f"1_selected.png")
    fig.savefig(path)
    plt.close()
    sampler.experiment2_indices = experiment2_indices
    samples = None
    for x in tqdm(sampler.sample_batch_progressive(batch_size=2, model_kwargs=dict(texts=[prompt1.replace("_", " "), prompt2.replace("_", " ")]))):
        samples = x
    sampler.experiment2_indices = None

  experiment2_indices = np.loadtxt(selected_indices_path, dtype=int)
  experiment2_indices = np.loadtxt(selected_indices_path, dtype=int)


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  experiment2_indices = np.loadtxt(selected_indices_path, dtype=int)
  experiment2_indices = np.loadtxt(selected_indices_path, dtype=int)
  experiment2_indices = np.loadtxt(selected_indices_path, dtype=int)
  experiment2_indices = np.loadtxt(selected_indices_path, dtype=int)


0it [00:00, ?it/s]

  experiment2_indices = np.loadtxt(selected_indices_path, dtype=int)
  experiment2_indices = np.loadtxt(selected_indices_path, dtype=int)


0it [00:00, ?it/s]