In [1]:
import sys
sys.path.insert(0, '/home/noamatia/repos/control_point_e/')

In [2]:
import os
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
from point_e.util.plotting import plot_point_cloud
from point_e.util.point_cloud import PointCloud
from point_e.models.download import load_checkpoint
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_name = 'base40M-textvec'
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])
base_model.load_state_dict(load_checkpoint(base_name, device))
base_model.create_control_layers()
base_model.load_state_dict(torch.load(f'/scratch/noam/cntrl_pointe/07_25_2024_23_28_17_chair_train_chair_val_utterance/epoch_59.pth'))
base_model.eval()

CLIPImagePointDiffusionTransformer(
  (time_embed): MLP(
    (c_fc): Linear(in_features=512, out_features=2048, bias=True)
    (c_proj): Linear(in_features=2048, out_features=512, bias=True)
    (gelu): GELU(approximate='none')
  )
  (ln_pre): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (backbone): Transformer(
    (resblocks): ModuleList(
      (0-11): 12 x ResidualAttentionBlock(
        (attn): MultiheadAttention(
          (c_proj): Linear(in_features=512, out_features=512, bias=True)
          (attention): QKVMultiheadAttention()
          (c_qkv): Linear(in_features=512, out_features=1536, bias=True)
        )
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=512, out_features=2048, bias=True)
          (c_proj): Linear(in_features=2048, out_features=512, bias=True)
          (gelu): GELU(approximate='none')
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
     

In [4]:
sampler = PointCloudSampler(
    device=device,
    s_churn = [3],
    sigma_max = [120],
    num_points=[1024],
    sigma_min = [1e-3],
    models=[base_model],
    use_karras = [True],
    karras_steps = [64],
    guidance_scale=[3.0],
    diffusions=[base_diffusion],
    aux_channels=['R', 'G', 'B'],
    model_kwargs_key_filter=['texts']
)

In [27]:
def build_experiment_dir(prompt2, t, i):
    d = f'experiment2/{prompt2}_t_{t}_i_{i}'
    os.makedirs(d, exist_ok=True)
    os.makedirs(d.replace("experiment2", "experiment2_objs"), exist_ok=True)
    return d

In [36]:
experimental1_t = 0
sampler.experiment2_t = experimental1_t
percentiles = [0.1, 0.25, 0.5, 0.75, 0.9]
df = pd.read_csv('/home/noamatia/repos/control_point_e/data/chair_armrests/train.csv').sample(10)
PCS_DIR = "/scratch/noam/shapetalk/point_clouds/scaled_to_align_rendering"

In [37]:
html = "<table>\n"
html += '<tr><td><h1 style="font-size: 24px;">prompt</h1></td>'
html += '<td><h1 style="font-size: 24px;">shapetalk</h1></td>'
for i in percentiles:
    html += f'<td><h1 style="font-size: 24px;">{i}</h1></td>'
html += '<td><h1 style="font-size: 24px;">1024</h1></td></tr>\n'
prompt1 = ""
for j, row in tqdm(df.iterrows(), total=len(df)):
    if j == 100:
        break
    prompt2 = row['utterance'].replace(" ", "_")
    source_pc = PointCloud.load_shapenet(os.path.join(PCS_DIR, row["source_uid"] + ".npz")).random_sample(1024)
    source_latent = source_pc.encode()
    source_latents = torch.stack([source_latent, source_latent]).to(device)
    target_pc = PointCloud.load_shapenet(os.path.join(PCS_DIR, row["target_uid"] + ".npz")).random_sample(1024)
    target_latent = target_pc.encode()
    for i in range(1):
        sampler.experiment2_indices = None
        sampler.precentile = None
        os.environ['EXPERIMENT2_DIR'] = build_experiment_dir(prompt2, experimental1_t, i)
        fig = plot_point_cloud(source_pc)
        fig.savefig(os.path.join(os.getenv("EXPERIMENT2_DIR"), "source.png"))
        plt.close()
        fig = plot_point_cloud(target_pc)
        fig.savefig(os.path.join(os.getenv("EXPERIMENT2_DIR"), "target.png"))
        plt.close()
        samples = None
        for x in tqdm(sampler.sample_batch_progressive(batch_size=2, model_kwargs=dict(texts=[prompt1.replace("_", " "), prompt2.replace("_", " ")]), guidances=[source_latents])):
            samples = x
        reversed_sorted_indices = np.load(os.path.join(os.getenv("EXPERIMENT2_DIR"), "sorted_indices.npy"))
        for percentile in percentiles:
            sampler.precentile = f'{percentile}'.replace(".", "_")
            experiment2_indices = reversed_sorted_indices[:int(len(reversed_sorted_indices) * percentile)]
            sampler.experiment2_indices = experiment2_indices
            samples = None
            for x in tqdm(sampler.sample_batch_progressive(batch_size=2, model_kwargs=dict(texts=[prompt1.replace("_", " "), prompt2.replace("_", " ")]), guidances=[source_latents])):
                samples = x
        html += f'<tr><td><h1 style="font-size: 24px;">{prompt1}</h1></td>'
        html += f"<td><img src=\"{os.path.join(os.getenv('EXPERIMENT2_DIR'), f'source.png')}\"></td>"
        for percentile in percentiles:
            pp = f'{percentile}'.replace(".", "_")
            html += f"<td><img src=\"{os.path.join(os.getenv('EXPERIMENT2_DIR'), f'0_{pp}.png')}\"></td>"
        html += f"<td><img src=\"{os.path.join(os.getenv('EXPERIMENT2_DIR'), f'0.png')}\"></td></tr>\n"
        html += f'<tr><td><h1 style="font-size: 24px;">{prompt2}</h1></td>'
        html += f"<td><img src=\"{os.path.join(os.getenv('EXPERIMENT2_DIR'), f'target.png')}\"></td>"
        for percentile in percentiles:
            pp = f'{percentile}'.replace(".", "_")
            html += f"<td><img src=\"{os.path.join(os.getenv('EXPERIMENT2_DIR'), f'1_{pp}.png')}\"></td>"
        html += f"<td><img src=\"{os.path.join(os.getenv('EXPERIMENT2_DIR'), f'1.png')}\"></td></tr>\n"
html += "</table>"
with open("index.html", "w") as f:
    f.write(html)


  0%|          | 0/10 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

KeyboardInterrupt: 