In [1]:
import os
import torch
import shutil
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
from point_e.util.plotting import plot_point_cloud
from point_e.models.download import load_checkpoint
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_name = 'base40M-textvec'
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])
base_model.load_state_dict(load_checkpoint(base_name, device))

<All keys matched successfully>

In [3]:
sampler = PointCloudSampler(
    device=device,
    models=[base_model],
    diffusions=[base_diffusion],
    num_points=[1024],
    aux_channels=['R', 'G', 'B'],
    guidance_scale=[3.0],
    model_kwargs_key_filter=['texts'], 
)

In [9]:
experimental_ts = [30, 35, 40, 45]
prompt_pairs = [('a chair', 'an armchair'),
                ('a chair', 'a chair with long legs'),
                ('a chair', 'a chair with a square backrest'),
                ('a chair', 'a chair with a thich seat'),
                ('an office chair', 'an office chair with wheels'),]

In [10]:
def plot_pc(sampler, samples, j, output_dir, experimental_t, prompt):
    pc = sampler.output_to_point_clouds(samples)[j]
    fig = plot_point_cloud(pc, color=False)
    fig.savefig(os.path.join(output_dir, f'{experimental_t}_{prompt.replace(" ", "_")}.png'))
    plt.close()

In [11]:
html = "<table>\n"
for prompt1, prompt2 in prompt_pairs:
    for experimental_t in experimental_ts:
        for i in range(10):
            sampler.experimental_t = experimental_t
            output_dir = os.path.join('experiment1', f'{prompt1.replace(" ", "_")}_{prompt2.replace(" ", "_")}_{i}')
            os.makedirs(output_dir, exist_ok=True)
            samples = None
            for x in tqdm(sampler.sample_batch_progressive(batch_size=2, model_kwargs=dict(texts=[prompt1, prompt2]))):
                samples = x
            for j, prompt in zip(range(2), [prompt1, prompt2]):
                plot_pc(sampler, samples, j, output_dir, experimental_t, prompt)
            for file in os.listdir('experimental_sampler'):
                shutil.move(os.path.join('experimental_sampler', file), 
                            os.path.join(output_dir, file))
            col1 = f'{prompt1} (t={experimental_t})'
            col3 = f'{prompt2} (t={experimental_t})'
            col2 = f'{prompt1} (t={experimental_t})'
            col4 = f'{prompt2} (t={experimental_t})'
            html += f'<tr><td><font size="5">{col1}</font></td><td><font size="5">{col2}</font></td><td><font size="5">{col3}</font></td><td><font size="5">{col4}</font></td></tr>\n'
            col1_src = f'{output_dir}/{experimental_t}_{prompt1.replace(" ", "_")}.png'
            col3_src = f'{output_dir}/{experimental_t}_{prompt2.replace(" ", "_")}.png'
            col2_src = f'{output_dir}/{experimental_t}_movie_0.gif'
            col4_src = f'{output_dir}/{experimental_t}_movie_1.gif'
            html += f"<tr><td><img src='{col1_src}'></td><td><img src='{col2_src}'></td><td><img src='{col3_src}'></td><td><img src='{col4_src}'></td></tr>\n"
html += "</table>"
with open('output.html', 'w') as f:
    f.write(html)

65it [00:14,  4.53it/s]
65it [00:12,  5.13it/s]
65it [00:11,  5.62it/s]
65it [00:10,  6.10it/s]
65it [00:13,  4.70it/s]
65it [00:12,  5.30it/s]
65it [00:11,  5.44it/s]
65it [00:10,  6.30it/s]
