In [1]:
import torch

from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget

import imageio

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
xm = load_model('decoder', device=device)
model = load_model('text300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))

In [10]:
# batch_size = 1
guidance_scale = 15.0
# prompt = "a simple cube"
prompts = ["a cube", "a tall cube"]
batch_size = len(prompts)

latents = sample_latents(
    batch_size=batch_size,
    model=model,
    diffusion=diffusion,
    guidance_scale=guidance_scale,
    # model_kwargs=dict(texts=[prompt] * batch_size),
    model_kwargs=dict(texts=prompts[:batch_size]),
    progress=True,
    clip_denoised=True,
    use_fp16=True,
    use_karras=True,
    karras_steps=64, # 64 for 256x256, 128 for 512x512
    sigma_min=1e-3,
    sigma_max=160,
    s_churn=0,
)

  0%|          | 0/64 [00:00<?, ?it/s]

# Experiment notes:
The cell below renders .gif images of the given latents.
## October 20 2023
- $\texttt{size = 128}$ takes over 16 minutes to render two latents.
- should I keep using the notebook or create a different design where I start the model up as a service and send client requests to it in a similar way to how we did the Construct() demo for Design X pitch day?
- Where and how do I store relevant vectors for useful latents?

In [11]:
render_mode = 'nerf' # you can change this to 'stf'
size = 32 # this is the size of the renders; higher values take longer to render. 128 is a good value for presentation.

cameras = create_pan_cameras(size, device)
for i, latent in enumerate(latents):
    images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)
    image_path = f'../../images/experiment-20231020/test{i}.gif'
    imageio.mimsave(image_path, images, 'GIF', loop=0)
    display(gif_widget(images))

HTML(value='<img src="data:image/gif;base64,R0lGODlhIAAgAIcAALTHuea4AOS4ANa4AJbNqH3SiI++UV/X0FjTdmHAsmDJeW/CcE…

HTML(value='<img src="data:image/gif;base64,R0lGODlhIAAgAIcAAIJ+i4F+jYF+i4F+ioB+jYB+ioB9jIF9iIF9h4B9iYB9iIB9h3…

In [12]:
print(type(latents[0]))

<class 'torch.Tensor'>


In [15]:
latents.shape
# save torch tensor to file
torch.save(latents[0], '../../latents/latent_cube.pt')

In [17]:
# read latent from file
latent_l = torch.load('../../latents/latent_cube.pt')
print(latent_l.shape)

torch.Size([1048576])


In [None]:
# steps of latent interpolation
n = 8

weights = torch.linspace(0, 1, n).unsqueeze(1).to(device)

interpolated_latents = torch.lerp(latents[0], latents[1], weights)

# Reshape the tensor to have shape (n, 1000)
interpolated_latents = interpolated_latents.view(n, -1)

size = 32

cameras = create_pan_cameras(size, device)
for i, latent in enumerate(interpolated_latents):
    images_inp = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)
    image_path_inp = f'/home/demircantas/shap-e/images/test_interpolated{i}.gif'
    imageio.mimsave(image_path_inp, images_inp, 'GIF', loop=0)
    display(gif_widget(images_inp))

In [11]:
# Example of saving the latents as meshes.
from shap_e.util.notebooks import decode_latent_mesh

for i, latent in enumerate(latents):
    t = decode_latent_mesh(xm, latent).tri_mesh()
    # with open(f'/home/demircantas/shap-e/meshes/example_mesh_{i}.ply', 'wb') as f:
    #     t.write_ply(f)
    with open(f'/home/demircantas/shap-e/meshes/example_mesh_{i}.obj', 'w') as f:
        t.write_obj(f)