In [1]:
#sv3d_p import
import math
import os
import sys
from glob import glob
from pathlib import Path
from typing import List, Optional

import cv2
import imageio
import numpy as np
import torch
from einops import rearrange, repeat
from fire import Fire
from omegaconf import OmegaConf
from PIL import Image
from rembg import remove
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.helpers import embed_watermark
from sgm.util import default, instantiate_from_config
from torchvision.transforms import ToTensor

In [2]:
def get_unique_embedder_keys_from_conditioner(conditioner):
    return list(set([x.input_key for x in conditioner.embedders]))


def get_batch(keys, value_dict, N, T, device):
    batch = {}
    batch_uc = {}

    for key in keys:
        if key == "fps_id":
            batch[key] = (
                torch.tensor([value_dict["fps_id"]])
                .to(device)
                .repeat(int(math.prod(N)))
            )
        elif key == "motion_bucket_id":
            batch[key] = (
                torch.tensor([value_dict["motion_bucket_id"]])
                .to(device)
                .repeat(int(math.prod(N)))
            )
        elif key == "cond_aug":
            batch[key] = repeat(
                torch.tensor([value_dict["cond_aug"]]).to(device),
                "1 -> b",
                b=math.prod(N),
            )
        elif key == "cond_frames" or key == "cond_frames_without_noise":
            batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0])
        elif key == "polars_rad" or key == "azimuths_rad":
            batch[key] = torch.tensor(value_dict[key]).to(device).repeat(N[0])
        else:
            batch[key] = value_dict[key]

    if T is not None:
        batch["num_video_frames"] = T

    for key in batch.keys():
        if key not in batch_uc and isinstance(batch[key], torch.Tensor):
            batch_uc[key] = torch.clone(batch[key])
    return batch, batch_uc


def load_model(
    config: str,
    device: str,
    num_frames: int,
    num_steps: int,
    verbose: bool = False,
):
    config = OmegaConf.load(config)
    if device == "cuda":
        config.model.params.conditioner_config.params.emb_models[
            0
        ].params.open_clip_embedding_config.params.init_device = device

    config.model.params.sampler_config.params.verbose = verbose
    config.model.params.sampler_config.params.num_steps = num_steps
    config.model.params.sampler_config.params.guider_config.params.num_frames = (
        num_frames
    )
    if device == "cuda":
        with torch.device(device):
            model = instantiate_from_config(config.model).to(device).eval()
    else:
        model = instantiate_from_config(config.model).to(device).eval()

    filter = DeepFloydDataFiltering(verbose=False, device=device)
    return model, filter

In [3]:
root_dir='/home/luoziqian/Works/Baseline'
input_path=os.path.join(root_dir,"test_model_images/239.png")
num_frames=21
model_config=os.path.join(root_dir,'configs/sv3d_p.yaml')
num_steps=50
elevations_deg=[10]*num_frames
polars_rad=[np.deg2rad(90-e) for e in elevations_deg]
azimuths_deg=np.linspace(0,360,num_frames+1)[1:]%360
azimuths_rad=[np.deg2rad((a-azimuths_deg[-1])%360) for a in azimuths_deg]
image_frame_ratio=None
motion_bucket_id=127
fps_id=6
cond_aug=0.02
decoding_t=6
device='cuda'
seed=0

In [4]:
model,filter=load_model(
    model_config,
    device,
    num_frames,
    num_steps,
    verbose=False,
)
torch.manual_seed(seed)

VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
VideoTransformerBlock is using checkpointing
Initialized embedder #0: FrozenOpenCLIPImagePredictionEmbedder with 683800065 params. Trainable: False
Initialized embedder #1: VideoPredictionEmbedderWithEncoder with 83653863 params. Trainable: False
Initialized embedder #2: ConcatTimestepEmbedderND with 0 params. Trainable: Fa

<torch._C.Generator at 0x79020e6df550>

In [5]:
assert input_path is not None
assert os.path.exists(os.path.join(root_dir, input_path)) or os.path.exists(input_path)

In [6]:
if os.path.exists(os.path.join(root_dir, input_path)):
    image=Image.open(input_path)
elif os.path.exists(input_path):
    image=Image.open(os.path.join(root_dir, input_path))
else:
    raise FileNotFoundError(f"Could not find file {input_path}")

if image.mode == "RGBA":
    pass
else:
    # remove bg
    image.thumbnail([768, 768], Image.Resampling.LANCZOS)
    image = remove(image.convert("RGBA"), alpha_matting=True)

# resize object in frame
image_arr = np.array(image)
in_w, in_h = image_arr.shape[:2]
ret, mask = cv2.threshold(
    np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY
)
x, y, w, h = cv2.boundingRect(mask)
max_size = max(w, h)
side_len = (
    int(max_size / image_frame_ratio)
    if image_frame_ratio is not None
    else in_w
)
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
center = side_len // 2
padded_image[
    center - h // 2 : center - h // 2 + h,
    center - w // 2 : center - w // 2 + w,
] = image_arr[y : y + h, x : x + w]
# resize frame to 576x576
rgba = Image.fromarray(padded_image).resize((576, 576), Image.LANCZOS)
# white bg
rgba_arr = np.array(rgba) / 255.0
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
input_image = Image.fromarray((rgb * 255).astype(np.uint8))


In [7]:
image = ToTensor()(input_image)
image = image * 2.0 - 1.0

image = image.unsqueeze(0).to(device)
H, W = image.shape[2:]
assert image.shape[1] == 3
F = 8
C = 4
shape = (num_frames, C, H // F, W // F)

In [8]:
value_dict = {}
value_dict["cond_frames_without_noise"] = image
value_dict["motion_bucket_id"] = motion_bucket_id
value_dict["fps_id"] = fps_id
value_dict["cond_aug"] = cond_aug
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
value_dict["polars_rad"] = polars_rad
value_dict["azimuths_rad"] = azimuths_rad

In [9]:
with torch.no_grad():
    with torch.autocast(device):
        batch, batch_uc = get_batch(
            get_unique_embedder_keys_from_conditioner(model.conditioner),
            value_dict,
            [1, num_frames],
            T=num_frames,
            device=device,
        )
        c, uc = model.conditioner.get_unconditional_conditioning(
            batch,
            batch_uc=batch_uc,
            force_uc_zero_embeddings=[
                "cond_frames",
                "cond_frames_without_noise",
            ],
        )

        for k in ["crossattn", "concat"]:
            uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
            uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
            c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
            c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)

        randn = torch.randn(shape, device=device)

        additional_model_inputs = {}
        additional_model_inputs["image_only_indicator"] = torch.zeros(
            2, num_frames
        ).to(device)
        additional_model_inputs["num_video_frames"] = batch["num_video_frames"]

        def denoiser(input, sigma, c):
            return model.denoiser(
                model.model, input, sigma, c, **additional_model_inputs
            )

        samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
        model.en_and_decode_n_samples_a_time = decoding_t
        samples_x = model.decode_first_stage(samples_z)

        samples_x[-1:] = value_dict["cond_frames_without_noise"]
        samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)

        samples = embed_watermark(samples)
        samples = filter(samples) # t c h w [21,3,576,576]



In [10]:
model= model.cpu()
torch.cuda.empty_cache()

In [11]:
#LGM import
import os
import tyro
import imageio
import numpy as np
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from safetensors.torch import load_file
import rembg
import gradio as gr

import kiui
from kiui.op import recenter
from kiui.cam import orbit_camera

from core.options import AllConfigs, Options,config_defaults
from core.models import LGM
from mvdream.pipeline_mvdream import MVDreamPipeline
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)



In [12]:
output_ply_path='pointclouds.ply'
output_video_path='video.mp4'

In [13]:
opt = config_defaults['big']
opt.resume='pretrained/model_fp16.safetensors'
opt.num_frames=21

# model
model = LGM(opt)
# resume pretrained checkpoint
if opt.resume is not None:
    if opt.resume.endswith('safetensors'):
        ckpt = load_file(opt.resume, device='cpu')
    else:
        ckpt = torch.load(opt.resume, map_location='cpu')
    model.load_state_dict(ckpt, strict=False)
    print(f'[INFO] Loaded checkpoint from {opt.resume}')
else:
    print(f'[WARN] model randomly initialized, are you sure?')

[INFO] Loaded checkpoint from pretrained/model_fp16.safetensors


In [14]:
model.unet.mid_block.attns[0].num_frames

21

In [15]:
model = model.half().to(device)
model.eval()

LGM(
  (unet): UNet(
    (conv_in): Conv2d(9, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownBlock(
        (nets): ModuleList(
          (0-1): 2 x ResnetBlock(
            (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (shortcut): Identity()
          )
        )
        (attns): ModuleList(
          (0-1): 2 x None
        )
        (downsample): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      )
      (1): DownBlock(
        (nets): ModuleList(
          (0): ResnetBlock(
            (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
            (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 12

In [16]:
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
proj_matrix[0, 0] = 1 / tan_half_fov
proj_matrix[1, 1] = 1 / tan_half_fov
proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
proj_matrix[2, 3] = 1

In [17]:
input_image = F.interpolate(samples, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
rays_embeddings = model.prepare_default_rays(device,num_frames=opt.num_frames,elevation=10)
input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]

In [18]:
with torch.no_grad():
    with torch.autocast(device_type='cuda', dtype=torch.float16):
        # generate gaussians
        gaussians = model.forward_gaussians(input_image)
    
    # save gaussians
    model.gs.save_ply(gaussians, output_ply_path)

In [25]:
num_gaussian_sigle_view:int=int(gaussians.shape[1]/21)
num_gaussian_sigle_view

16384

In [27]:
# render 360 video 
images = []
elevation = 0
if opt.fancy_video:
    azimuth = np.arange(0, 720, 4, dtype=np.int32)
    for azi in tqdm.tqdm(azimuth):
        
        cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)

        cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
        
        # cameras needed by gaussian rasterizer
        cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
        cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
        cam_pos = - cam_poses[:, :3, 3] # [V, 3]

        scale = min(azi / 360, 1)

        image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image']
        images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
else:
    azimuth = np.arange(0, 360, 2, dtype=np.int32)
    for azi in tqdm.tqdm(azimuth):
        
        cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)

        cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
        
        # cameras needed by gaussian rasterizer
        cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
        cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
        cam_pos = - cam_poses[:, :3, 3] # [V, 3]

        image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
        images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))

images = np.concatenate(images, axis=0)
imageio.mimwrite(output_video_path, images, fps=30)

100%|██████████| 180/180 [00:00<00:00, 653.79it/s]


In [28]:
video_views_dir='video_views'
os.makedirs(video_views_dir, exist_ok=True)
for i in range(21):
    # render 360 video 
    output_video_path=os.path.join(video_views_dir,f'video_view_{i}.mp4')
    images = []
    elevation = 0
    if opt.fancy_video:
        azimuth = np.arange(0, 720, 4, dtype=np.int32)
        for azi in tqdm.tqdm(azimuth):
            
            cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)

            cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
            
            # cameras needed by gaussian rasterizer
            cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
            cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
            cam_pos = - cam_poses[:, :3, 3] # [V, 3]

            scale = min(azi / 360, 1)

            image = model.gs.render(gaussians[:,num_gaussian_sigle_view*i:num_gaussian_sigle_view*(i+1)], cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image']
            images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
    else:
        azimuth = np.arange(0, 360, 2, dtype=np.int32)
        for azi in tqdm.tqdm(azimuth):
            
            cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)

            cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
            
            # cameras needed by gaussian rasterizer
            cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
            cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
            cam_pos = - cam_poses[:, :3, 3] # [V, 3]

            image = model.gs.render(gaussians[:,num_gaussian_sigle_view*i:num_gaussian_sigle_view*(i+1)], cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
            images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))

    images = np.concatenate(images, axis=0)
    imageio.mimwrite(output_video_path, images, fps=30)

100%|██████████| 180/180 [00:00<00:00, 700.87it/s]
100%|██████████| 180/180 [00:00<00:00, 722.50it/s]
100%|██████████| 180/180 [00:00<00:00, 725.49it/s]
100%|██████████| 180/180 [00:00<00:00, 714.69it/s]
100%|██████████| 180/180 [00:00<00:00, 722.07it/s]
100%|██████████| 180/180 [00:00<00:00, 737.53it/s]
100%|██████████| 180/180 [00:00<00:00, 733.72it/s]
100%|██████████| 180/180 [00:00<00:00, 730.55it/s]
100%|██████████| 180/180 [00:00<00:00, 708.79it/s]
100%|██████████| 180/180 [00:00<00:00, 759.82it/s]
100%|██████████| 180/180 [00:00<00:00, 724.54it/s]
100%|██████████| 180/180 [00:00<00:00, 734.27it/s]
100%|██████████| 180/180 [00:00<00:00, 719.05it/s]
100%|██████████| 180/180 [00:00<00:00, 730.44it/s]
100%|██████████| 180/180 [00:00<00:00, 737.22it/s]
100%|██████████| 180/180 [00:00<00:00, 719.50it/s]
100%|██████████| 180/180 [00:00<00:00, 754.73it/s]
100%|██████████| 180/180 [00:00<00:00, 748.54it/s]
100%|██████████| 180/180 [00:00<00:00, 763.60it/s]
100%|██████████| 180/180 [00:00