In [1]:
import torch

from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget
from shap_e.util.image_util import load_image

import os

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
render_mode = 'nerf' # you can change this to 'stf' for mesh rendering
# size = 64 # this is the size of the renders; higher values take longer to render.
size = 32 # this is the size of the renders; higher values take longer to render. 16 causes an assertion error.
cameras = create_pan_cameras(size, device)

In [4]:
input_folder = "../../../content/batch_web"
output_folder = "../../../content/batch_output"

image_files = [file for file in os.listdir(input_folder) if file.lower().endswith(('.png', '.jpg', '.jpeg'))]
print(image_files)
image_path = os.path.join(input_folder, image_files[0])
print(image_path)

['muffin_blueberry.png', 'chair_blue_rocking_fabric.png', 'chair_blue_rocking.png', 'chair_blue_rocking_outdoor.png']
../../../content/batch_web/muffin_blueberry.png


In [5]:
xm = load_model('transmitter', device=device) # rendering latents
model = load_model('image300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))

In [6]:
def render_latents(latents):
# function for rendering interpolated or extrapolated latents
    for i, latent in enumerate(latents):
        images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)
    return images
    # images = []
    # for latent in latents:
    #     images.append(decode_latent_images(xm, latent, cameras, rendering_mode=render_mode))
    #     # images[0].save('../../../content/{}_{}.gif'.format(name, i), save_all=True, append_images=images[1:], duration=100, loop=0)
    # return images

def render_transformation(latents, name):
# function for rendering the interpolation between two latents as a single gif
    # render the first frame from the first latent, second frame from the second latent, and so on
    images = []
    for i, latent in enumerate(latents):
        images.append(decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)[0])
        
    # write images to a gif
    images[0].save('../../../content/{}.gif'.format(name), save_all=True, append_images=images[1:], duration=100, loop=0)

def process_images(input_folder, output_folder):
    # Get a list of image files in the input folder
    image_files = [file for file in os.listdir(input_folder) if file.lower().endswith(('.png', '.jpg', '.jpeg'))]

    for image_file in image_files:
        # Load the image
        image_path = os.path.join(input_folder, image_file)
        image = load_image(image_path)

        batch_size = 1
        guidance_scale = 3.0

        # computing latent
        latent_vector = sample_latents(
            batch_size=batch_size,
            model=model,
            diffusion=diffusion,
            guidance_scale=guidance_scale,
            model_kwargs=dict(images=[image] * batch_size),
            progress=True,
            clip_denoised=True,
            use_fp16=True,
            use_karras=True,
            karras_steps=64,
            sigma_min=1e-3,
            sigma_max=160,
            s_churn=0,
        )

        gif = render_latents(latent_vector)

        index = 0

        # Save latent vector
        latent_vector_path = os.path.join(output_folder, f"{os.path.splitext(image_file)[0]}_{index}.npy")

        # Check if a file with same name and index exists and increment index and file name if it does
        while os.path.exists(latent_vector_path):
            index += 1
            print(f"File {latent_vector_path} already exists, incrementing index to {index}")
            latent_vector_path = os.path.join(output_folder, f"{os.path.splitext(image_file)[0]}_{index}.npy")
        torch.save(latent_vector, latent_vector_path)
        print(f"Saved latent vector to {latent_vector_path}")
        del(latent_vector)

        # Save GIF
        gif_path = os.path.join(output_folder, f"{os.path.splitext(image_file)[0]}_{index}.gif")
        gif[0].save(gif_path, save_all=True, append_images=gif[1:], duration=100, loop=0)
        print(f"Saved gif to {gif_path}")
        del(gif)


In [None]:
for i in range(2):
    process_images(input_folder, output_folder)