In [None]:
import sys
sys.path.append('../mast3r/')
sys.path.append('../gaussian-splatting')
sys.path.append('../src')

In [None]:
import os
import torch
from pathlib import Path
import numpy as np

from mast3r.utils.misc import hash_md5
from mast3r.model import load_model
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment

from dust3r.image_pairs import make_pairs
from dust3r.utils.image import load_images
from dust3r.utils.device import to_numpy

try:
    import lovely_tensors as lt
except:
    ! pip install --upgrade lovely-tensors
    import lovely_tensors as lt
    
lt.monkey_patch()

In [None]:
device='cuda'
model_path = "../mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
outdir = "./out_mast3r/"
image_size=512
optim_level='refine+depth'
lr1=0.07
niter1=500
lr2=0.014
niter2=200
min_conf_thr=1.5
matching_conf_thr=5.0
mask_sky=False
clean_depth=True
transparent_cams=False
cam_size=0.2
scenegraph_type='complete'
winsize=1
win_cyclic=False
refid=0
TSDF_thresh=0.0
shared_intrinsics=False
norm_scene=True
mask_images = False

scene_graph_params = [scenegraph_type]
if scenegraph_type in ["swin", "logwin"]:
    scene_graph_params.append(str(winsize))
elif scenegraph_type == "oneref":
    scene_graph_params.append(str(refid))
if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
    scene_graph_params.append('noncyclic')
scene_graph = '-'.join(scene_graph_params)

chkpt_tag = hash_md5(model_path)
outdir = os.path.join(outdir, chkpt_tag)
os.makedirs(outdir, exist_ok=True)

cache_dir = os.path.join(outdir, 'cache')
os.makedirs(cache_dir, exist_ok=True)

In [None]:
from pathlib import Path
Path.ls = lambda x: list(x.iterdir())

image_dir = Path('../data/images/turtle_imgs/')

image_files = [str(x) for x in image_dir.ls() if x.suffix in ['.png', '.jpg']]
image_files = sorted(image_files, key=lambda x: int(x.split('/')[-1].split('.')[0]))

In [None]:
model = load_model(model_path, device)
images = load_images(image_files, size=512)
pairs = make_pairs(images, scene_graph=scene_graph, prefilter=None, symmetrize=True)

In [None]:
scene = sparse_global_alignment(image_files, pairs, cache_dir,
                                model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
                                opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
                                matching_conf_thr=matching_conf_thr)

In [None]:
from colmap_dataset_utils import inv

cam2world = scene.get_im_poses().detach().cpu().numpy()
world2cam = inv(cam2world) #
principal_points = scene.get_principal_points().detach().cpu().numpy()
focals = scene.get_focals().detach().cpu().numpy()[..., None]
imgs = np.array(scene.imgs)

pts3d, _, confs = scene.get_dense_pts3d(clean_depth=clean_depth)
pts3d = [i.detach().reshape(imgs[0].shape) for i in pts3d] #

masks = to_numpy([c > min_conf_thr for c in to_numpy(confs)])

In [None]:
from colmap_dataset_utils import normalize_scene
from copy import deepcopy

pts_norm, c2w_norm = normalize_scene(deepcopy(pts3d), deepcopy(masks), deepcopy(cam2world))
if norm_scene:
    pts3d = pts_norm
    world2cam = inv(np.array(c2w_norm))

In [None]:
from visualisation import visualize_pcd, visualize_cameras
fig = None

num_to_show = 10_000
num_of_valid = sum([m.sum() for m in masks])
skip = num_of_valid // num_to_show

for p, i, m, c2w in zip(pts_norm, imgs, masks, c2w_norm):
    fig = visualize_pcd(p[m].cpu().numpy(), i[m], skip=skip, show=False, size=2, fig=fig)
    R, T = np.transpose(c2w[None, :3, :3], (0, 2, 1)), c2w[None, :, 3]
    fig = visualize_cameras(R, T, fig=fig, show=False, radius=2, size=0.2)
    
fig

# Construct colmap dataset

After convertion such data sctructure should appear

```
│   │   │   ├── images
│   │   │   ├── masks
│   │   │   ├── sparse/0
|   |   |   |    |------cameras.bin
|   |   |   |    |------images.bin
|   |   |   |    |------points3D.bin
|   |   |   |    |------points3D.ply
```

In [None]:
save_dir = Path('../data/scenes/turtle_mast3r')
save_dir.mkdir(exist_ok=True, parents=True)

In [None]:
from colmap_dataset_utils import (
    init_filestructure,
    save_images_masks,
    save_cameras,
    save_imagestxt,
    save_pointcloud,
    save_pointcloud_with_normals
)

save_path, images_path, masks_path, sparse_path = init_filestructure(save_dir)
save_images_masks(imgs, masks, images_path, masks_path, mask_images)
save_cameras(focals, principal_points, sparse_path, imgs_shape=imgs.shape)
save_imagestxt(world2cam, sparse_path)
# save_pointcloud(imgs, pts3d, masks, sparse_path)
save_pointcloud_with_normals(imgs, pts3d, masks, sparse_path)