In [None]:
import os
import sys
import cv2
import time
import torch
import faulthandler
import numpy as np

from tqdm import tqdm

faulthandler.enable()

In [None]:
os.environ["PYTHONUTF8"] = "1"
os.environ['TORCHDYNAMO_VERBOSE'] = "1"
os.environ["PYTHONIOENCODING"] = "utf-8"

torch._dynamo.config.verbose = True
torch._dynamo.config.suppress_errors = False
torch.set_float32_matmul_precision('high')

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
print(torch.__version__)
print(torch.version.cuda)

In [None]:
sys.path.insert(0, os.path.abspath(os.path.join(os.path.curdir, '..')))

In [None]:
from model.arvane import (
    ArvaneModel
)
from model.utils import(
    load_config
)
from model.loaders.loaders import (
    get_scans,
    InferenceDataset
)
from source.analyzer.arvane_analyzer import (
    ArvaneAnalyzer
)

In [None]:
config = load_config('../config/config.yml')

In [None]:
_, _, test_scans = get_scans(
    config.dataset_dir,
    config.tsdf_dir,
    config.depth_guidance.pred_depth_dir,
)
dataset = InferenceDataset(
    scans=test_scans,
    load_depth=True,
)
predict_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    num_workers=config.workers_predict,
    persistent_workers=True,
    shuffle=False,
)

In [None]:
model = ArvaneModel(config).to(device)
model.load_state_dict(torch.load("../checkpoint/arvane.pt"))

In [None]:
model.eval()
model = torch.compile(model)

In [None]:
loader_iter = iter(predict_loader)
first_batch = next(loader_iter)

In [None]:
first_batch['color'].shape

In [None]:
gt_origin = first_batch['gt_origin']
gt_maxbound = first_batch['gt_maxbound']
model.start(
    gt_origin=gt_origin,
    gt_maxbound=gt_maxbound
)

In [None]:
first_batch.keys()

In [None]:
# colors = [
#     batch['color'][0, 0, ...]
#     for batch in tqdm(predict_loader, total=len(predict_loader))
# ][:10]

# colors = torch.stack(colors, dim=0)
# colors.shape

In [None]:
# in_colors = colors[:, None, ...].to(device=device, dtype=torch.float32)
# in_colors.shape

In [None]:
# with torch.no_grad():
#     depth, f_px = model.create_depth_before_update(
#         in_colors,
#         save_to_file=True,
#         save_path=f"../output/scannet_v2/scene0708_00",
#         exist_ok=True
#     )

In [None]:
with torch.no_grad():
    for idx, batch in enumerate(
        tqdm(predict_loader, total=len(predict_loader))
    ):            
        torch.compiler.cudagraph_mark_step_begin()
        model.update(
            batch['color'  ].to(device=torch.device(config.device), dtype=torch.float32),
            batch['depth'  ].to(device=torch.device(config.device), dtype=torch.float32),
            batch['K_color'].to(device=torch.device(config.device), dtype=torch.float32),
            batch['K_depth'].to(device=torch.device(config.device), dtype=torch.float32),
            batch['poses'  ].to(device=torch.device(config.device), dtype=torch.float32),
        )
        
        model.update_view()

In [None]:
torch.stack(model.container.k_color, dim=0)[None, :, 0, ...].shape

In [None]:
torch.stack([pose[0, ...] for pose in model.container.poses], dim=0)[None].shape

In [None]:
with torch.no_grad():
    result = model.final_update(
        "../output/final-result.ply"
    )