# Enhance static quality

In [1]:
%load_ext autoreload
%autoreload 2

import imageio
from matplotlib import pyplot as plt
import torch, numpy as np
import torch
import os, os.path as osp
import logging
from lib_4d.solver_gs import Solver
from lib_4d.gs_static_model import StaticGaussian
from lib_4d.gs_dyn_model import DenseDynGaussian
from lib_4d.camera import SimpleFovCameras
from lib_4d.cfg_helpers import OptimCFG, GSControlCFG
from lib_prior.diffusion.sd_sds import StableDiffusionSDS
from tqdm import tqdm

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [10]:
src = "./data/DAVIS/train"
static_saved_dir = "data/DAVIS/train/log/20240117_143638"
dyn_saved_dir = "data/DAVIS/train/log/20240117_144846"
device = torch.device("cuda:0")

solver = Solver(src, device, depth_mode="zoe")

model_fn = osp.join(static_saved_dir, "joint_model.pth")
cam_fn = osp.join(static_saved_dir, "joint_cams.pth")
static_model = StaticGaussian(load_fn=model_fn).to(solver.device)
cams: SimpleFovCameras = SimpleFovCameras(solver.T, 40.0)
cams.load_state_dict(torch.load(cam_fn), strict=True)
cams.to(solver.device)
dyn_model_fn = osp.join(dyn_saved_dir, "joint_dyn_model_1.0res.pth")
dyn_model = DenseDynGaussian(load_fn=dyn_model_fn).to(solver.device)

| endurance | INFO | Jan-18-17:18:01 | Loading motion masks...   [solver_utils.py:117]
100%|██████████| 80/80 [00:00<00:00, 2404.77it/s]
| endurance | INFO | Jan-18-17:18:01 | Loading motion masks...   [solver_utils.py:117]
100%|██████████| 80/80 [00:00<00:00, 3217.02it/s]


| endurance | INFO | Jan-18-17:18:01 | Loading flows...   [solver_utils.py:58]
100%|██████████| 79/79 [00:00<00:00, 295.23it/s]
| endurance | INFO | Jan-18-17:18:01 | Loading rgbs...   [solver_utils.py:42]
100%|██████████| 80/80 [00:00<00:00, 303.12it/s]
| endurance | INFO | Jan-18-17:18:02 | Loading depths from ./data/DAVIS/train/zoe_depth ...   [solver_utils.py:102]
100%|██████████| 80/80 [00:00<00:00, 137.36it/s]
| endurance | INFO | Jan-18-17:18:02 | rgbs: (80, 480, 854, 3), depths: (80, 480, 854), flows: [158,(480, 854, 2)], motion_masks: (80, 480, 854)   [solver_utils.py:188]
| endurance | INFO | Jan-18-17:18:02 | Filtering depth maps...   [solver_utils.py:229]
100%|██████████| 80/80 [00:00<00:00, 458.56it/s]
| endurance | INFO | Jan-18-17:18:03 | rounding the flow ...   [prior2d.py:114]
100%|██████████| 158/158 [00:00<00:00, 723.85it/s]
| endurance | INFO | Jan-18-17:18:03 | Loading static model from data/DAVIS/train/log/20240117_143638/joint_model.pth   [gs_static_model.py:68]

In [16]:
# get all cams rendering for training frame
rendered_list = []
with torch.no_grad():
    for tid in tqdm(range(cams.T)):
        _dict = solver.render_frame(cams.rel_focal, cams, 40, static_model, dyn_model, render_view_id=tid)
        rendered_list.append(_dict)
invalid_masks = (~solver.prior2d.sky_masks).clone() # used to mask out sky

  0%|          | 0/80 [00:00<?, ?it/s]

100%|██████████| 80/80 [00:01<00:00, 72.73it/s]


In [17]:
from lib_4d.gs_geo_helpers import tsdf_meshing

tsdf_meshing(
    rendered_dict_list=rendered_list,
    cams=cams,
    invalid_masks=invalid_masks,
)

  0%|          | 0/80 [00:00<?, ?it/s]

100%|██████████| 80/80 [00:01<00:00, 50.86it/s]



