In [None]:
import torch

from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
xm = load_model('transmitter', device=device)
model = load_model('text300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))

In [None]:
batch_size = 4
guidance_scale = 15.0
prompt = "taylor swift in a red dress"

latents = sample_latents(
    batch_size=batch_size,
    model=model,
    diffusion=diffusion,
    guidance_scale=guidance_scale,
    model_kwargs=dict(texts=[prompt] * batch_size),
    progress=True,
    clip_denoised=True,
    use_fp16=True,
    use_karras=True,
    karras_steps=64,
    sigma_min=1e-3,
    sigma_max=160,
    s_churn=0,
)

In [None]:
render_mode = 'nerf' # you can change this to 'stf'
size = 64 # this is the size of the renders; higher values take longer to render.

cameras = create_pan_cameras(size, device)
for i, latent in enumerate(latents):
    images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)
    display(gif_widget(images))

In [None]:
render_mode = 'stf' # you can change this to 'stf'
size = 64 # this is the size of the renders; higher values take longer to render.

cameras = create_pan_cameras(size, device)
for i, latent in enumerate(latents):
    images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)
    display(gif_widget(images))

In [None]:
rendering_mode='stf'
size=64
cameras = create_pan_cameras(size, device)
latent=latents[0]
from shap_e.models.transmitter.base import Transmitter, VectorDecoder
from shap_e.util.collections import AttrDict
decoded = xm.renderer.render_views(
        AttrDict(cameras=cameras),
        params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
            latent[None]
        ),
        options=AttrDict(rendering_mode=rendering_mode, render_with_direction=False),
    )

In [None]:
import base64
import io
from typing import Union

import ipywidgets as widgets
import numpy as np
import torch
from PIL import Image

from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera
from shap_e.models.transmitter.base import Transmitter, VectorDecoder
from shap_e.util.collections import AttrDict

@torch.no_grad()
def decode_latent_images_foo(
    xm: Union[Transmitter, VectorDecoder],
    latent: torch.Tensor,
    cameras: DifferentiableCameraBatch,
    rendering_mode: str = "stf",
):
    decoded = xm.renderer.render_views(
        AttrDict(cameras=cameras),
        params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
            latent[None]
        ),
        options=AttrDict(rendering_mode=rendering_mode, render_with_direction=False),
    )
    return decoded
    arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    return [Image.fromarray(x) for x in arr]

In [None]:
x=decode_latent_images_foo(xm, latents[0], cameras, rendering_mode=render_mode)
#x['meshes']

In [None]:
mesh=x['meshes'][0]

In [None]:
rm=x['raw_meshes'][0]

In [None]:
rm.vertex_channels["R"]=mesh.vertex_colors[:,0]
rm.vertex_channels["G"]=mesh.vertex_colors[:,1]
rm.vertex_channels["B"]=mesh.vertex_colors[:,2]

In [None]:
tm=rm.tri_mesh()

In [None]:
with open("yoda.ply",'wb') as f:
    tm.write_ply(f)

In [None]:
import trimesh

def convert_ply_to_gltf(ply_file, gltf_file):
    # Load the .ply file
    mesh = trimesh.load_mesh(ply_file)

    # Export the mesh to .gltf format
    gltf_data = mesh.export(file_type='glb')

    # Write the .gltf file
    with open(gltf_file, 'wb') as f:
        f.write(gltf_data)

# Replace these with your input and output file paths
input_ply_file = "yoda.ply"
output_gltf_file = "yoda.glb"

# Convert the .ply file to .gltf
convert_ply_to_gltf(input_ply_file, output_gltf_file)

In [None]:
import pytorch3d