In [1]:
import torch

from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget
from shap_e.util.image_util import load_image

import os

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load models

In [4]:
# xm is for rendering latents
xm = load_model('transmitter', device=device)
model = load_model('image300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))

# compute latents from image
- [-] As of March 25 2024, this cell is running very slow (19 minutes and still going for a single batch). I remember this step taking far less. Is this a problem with hyperparameters or a temporary issue with the kernel?
    - [-] The hyperparameters are at default values.
    - [-] Rendering interpolation takes a lot longer than rendering for a single object
## computing time
| batch size | time |
| --- | --- |
| 1 | 3m, 26s |


In [42]:
create_latent = True
batch_size = 1
guidance_scale = 3.0

# To get the best result, you should remove the background and show only the object of interest to the model.
# image = load_image("example_data/corgi.png")
# image = load_image("example_data/cube.png")
image = load_image("../../../content/cube_tall_bbg.png")
image = load_image("../../../content/cube_bbg.png")

if create_latent:
    latents = sample_latents(
        batch_size=batch_size,
        model=model,
        diffusion=diffusion,
        guidance_scale=guidance_scale,
        model_kwargs=dict(images=[image] * batch_size),
        progress=True,
        clip_denoised=True,
        use_fp16=True,
        use_karras=True,
        karras_steps=64,
        sigma_min=1e-3,
        sigma_max=160,
        s_churn=0,
    )

  0%|          | 0/64 [00:00<?, ?it/s]

# latent loading, manipulation, and visualization

In [9]:

# load latents from file
if 0:
    latents_cube = torch.load('../../latents/cube_latents.pt')
    latents_cube_tall = torch.load('../../latents/cube_tall_latents.pt')
    # print minimum and maximum values of latents
    print(f'latents_cube min: {latents_cube.min()}')
    print(f'latents_cube max: {latents_cube.max()}')
    print(f'latents_cube_tall min: {latents_cube_tall.min()}')
    print(f'latents_cube_tall max: {latents_cube_tall.max()}')

In [11]:

# function for plotting a histogram of latents
def plot_hist(latents):
    import matplotlib.pyplot as plt
    plt.hist(latents.flatten().cpu().numpy(), bins=500)
    plt.show()

# function for interpolating among two latent vectors
def interpolate_latents(lat_A, lat_B, intp_steps):
    latents = []
    for i in range(intp_steps):
        latents.append(lat_A + (lat_B - lat_A) * i / intp_steps)
    return latents

# function for extrapolating from a latent vector
def extrapolate_latents(lat_A, lat_B, extp_steps):
    latents = []
    for i in range(extp_steps):
        latents.append(lat_B + (lat_B - lat_A) * i / extp_steps)
    return latents

# function for extracting transformation from two latent vectors
def extract_transformation(lat_A, lat_B):
    return lat_B - lat_A


In [None]:
latent_transform = extract_transformation(latents_cube[0], latents_cube_tall[0])
# print the indices of five largest values in the latent_transform vector
print(torch.topk(latent_transform, 5))

plot_hist(latents_cube)
plot_hist(latent_transform)

In [None]:
# interpolate between two latents
intp_latents = interpolate_latents(latents_cube[3], latents_cube_tall[0], 20)
extp_latents = extrapolate_latents(latents_cube_tall[0], latents_cube[0], 10)

# render latents

## *render_transformation* function below took:

| size | samples | time |
| --- | --- | --- |
| 32 | 5 | 15m21.0s |
| 32 | 10 | 38m25.5s |
| 32 | 10 | 54m50.5s |
|___|___|___|
| 64 | 20 | 14m5.7s |
| 64 | 20 | 19m40.4s |
| 64 | 20 | 31m.52.9s |

In [43]:
render_mode = 'nerf' # you can change this to 'stf' for mesh rendering
# size = 64 # this is the size of the renders; higher values take longer to render.
size = 32 # this is the size of the renders; higher values take longer to render. 16 causes an assertion error.
cameras = create_pan_cameras(size, device)

In [None]:
# function for rendering interpolated or extrapolated latents
def render_latents(latents):
    for i, latent in enumerate(latents):
        images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)
    return images
    # images = []
    # for latent in latents:
    #     images.append(decode_latent_images(xm, latent, cameras, rendering_mode=render_mode))
    #     # images[0].save('../../../content/{}_{}.gif'.format(name, i), save_all=True, append_images=images[1:], duration=100, loop=0)
    # return images

# function for rendering the interpolation between two latents as a single gif
def render_transformation(latents, name):
    # render the first frame from the first latent, second frame from the second latent, and so on
    images = []
    for i, latent in enumerate(latents):
        images.append(decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)[0])
        
    # write images to a gif
    images[0].save('../../../content/{}.gif'.format(name), save_all=True, append_images=images[1:], duration=100, loop=0)

In [None]:
if 0:
    render_transformation(intp_latents[0:20], 'transformation_render_test_20_64_3')

if 0:
    # render_latents(extp_latents, size, 'cube_intp')
    images = render_latents(latents)

In [46]:
# images[0].save('../../../content/{}_{}.gif'.format('cube_debug', '0'), save_all=True, append_images=images[1:], duration=100, loop=0)

- 17 minutes for size = 128
- 12 minutes for size = 64
- 50 seconds for size = 32

# save latents to file

In [47]:
# save latents to file
if 0:
    torch.save(latents, '../../latents/cube_latents_bbg.pt')

# Automation for folders

In [55]:
input_folder = "../../../content/batch_sourceimages"
output_folder = "../../../content/batch_output"

image_files = [file for file in os.listdir(input_folder) if file.lower().endswith(('.png', '.jpg', '.jpeg'))]
print(image_files)
image_path = os.path.join(input_folder, image_files[0])
print(image_path)

['cube_tall_bbg.png', 'cube_bbg.png']
../../../content/batch_sourceimages/cube_tall_bbg.png


In [None]:
def process_images(input_folder, output_folder):
    # Get a list of image files in the input folder
    image_files = [file for file in os.listdir(input_folder) if file.lower().endswith(('.png', '.jpg', '.jpeg'))]

    for image_file in image_files:
        # Load the image
        image_path = os.path.join(input_folder, image_file)
        image = load_image(image_path)

        batch_size = 1
        guidance_scale = 3.0

        # computing latent
        latent_vector = sample_latents(
            batch_size=batch_size,
            model=model,
            diffusion=diffusion,
            guidance_scale=guidance_scale,
            model_kwargs=dict(images=[image] * batch_size),
            progress=True,
            clip_denoised=True,
            use_fp16=True,
            use_karras=True,
            karras_steps=64,
            sigma_min=1e-3,
            sigma_max=160,
            s_churn=0,
        )

        gif = render_latents(latent_vector)

        index = 0

        # Save latent vector
        latent_vector_path = os.path.join(output_folder, f"{os.path.splitext(image_file)[0]}_{index}.npy")

        # Check if a file with same name and index exists and increment index and file name if it does
        while os.path.exists(latent_vector_path):
            index += 1
            latent_vector_path = os.path.join(output_folder, f"{os.path.splitext(image_file)[0]}_{index}.npy")
        torch.save(latent_vector, latent_vector_path)
        del(latent_vector)

        # Save GIF
        gif_path = os.path.join(output_folder, f"{os.path.splitext(image_file)[0]}_{index}.gif")
        gif[0].save(gif_path, save_all=True, append_images=images[1:], duration=100, loop=0)
        del(gif)


In [57]:
# Example usage
input_folder = "../../../content/batch_sourceimages"
output_folder = "../../../content/batch_output"
process_images(input_folder, output_folder)

  0%|          | 0/64 [00:00<?, ?it/s]

: 

: 