In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
from pathlib import Path
import sys

import cv2
import dill as pickle
import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d
from PIL import Image
import plotly.graph_objects as go
import pyransac3d as pyrsc
from tqdm import tqdm
from scipy.spatial import KDTree
import torch
import transforms3d as t3d
import visu3d as v3d

sys.path.append(os.path.join(os.getcwd(), "dust3r"))
from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.image import load_images
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from lang_sam import LangSAM

from barrelnet.pointnet.pointnet_utils import PointNetEncoder, feature_transform_reguliarzer
from barrelnet.pointnet.barrelnet import BarrelNet
from barrelnet.pointnet.data import pts2inference_format
from barrelnet.dust3r_utils import save_dust3r_outs, read_dust3r, resize_to_dust3r
from barrelnet.langsam_utils import display_image_with_masks, display_image_with_boxes
from barrelnet.utils import segment_pc_from_mask, get_local_plane_mask, rotate_pts_to_ax, get_surface_line_traces, get_ray_trace
from barrelnet.synthbarrel import get_cyl_endpoints, get_cylinder_surf

In [None]:
# image_dir = Path("data/dive8-barrel-10-45-less")
# image_dir = Path("data/barrel1-5sec-contrast")
image_dir = Path("data/barrelddt1")
# image_dir = Path("data/barrel2-5sec-contrast")

if not image_dir.exists():
    raise FileNotFoundError(f"Image directory {image_dir} not found.")

# H, W = (875, 1920)
H, W = (224, 512)
reconstr_dir = Path(f"results/{image_dir.name}-reconstr")
mask_dir = reconstr_dir / "masks"
mask_dir.mkdir(parents=True, exist_ok=True)
maskcomp_dir = reconstr_dir / "image_with_masks"
maskcomp_dir.mkdir(parents=True, exist_ok=True)
ply_dir = reconstr_dir / "pc_plys"
ply_dir.mkdir(parents=True, exist_ok=True)
resizeimg_dir = reconstr_dir / "resized"
resizeimg_dir.mkdir(parents=True, exist_ok=True)
text_prompt = "underwater barrel"
imgpaths = sorted(image_dir.glob("*.jpg"))
dust3rout_path = reconstr_dir / "dust3r_out.pth"

In [None]:
for imgpath in imgpaths:
    img = Image.open(imgpath)
    img = resize_to_dust3r(img, 512)
    img.save(resizeimg_dir / imgpath.name)

# Reconstruction with dust3r

In [None]:
device = "cuda"
model_name = "checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device)

In [None]:
batch_size = 1
schedule = "cosine"
lr = 0.01
niter = 300
images = load_images(str(image_dir), size=512)
pairs = make_pairs(images, scene_graph="complete", prefilter=None, symmetrize=True)
output = inference(pairs, model, device, batch_size=batch_size)

scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)

outdict = save_dust3r_outs(scene, dust3rout_path)

In [None]:
pc_final, pcs_each, v3dcams = read_dust3r(dust3rout_path)
pc_idx = -1
pc = pcs_each[pc_idx]
v3dcam = v3dcams[pc_idx]
# pc = pc_final
v3d.make_fig([pc, v3dcams])

In [None]:
plt.imshow(v3dcams[0].render(pc))

In [None]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pc_final.p)
pcd.colors = o3d.utility.Vector3dVector(pc_final.rgb.astype(np.float64) / 255)
o3d.io.write_point_cloud(str(ply_dir / f"pts_agg.ply"), pcd)
for i, imgpc in enumerate(pcs_each):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(imgpc.p)
    pcd.colors = o3d.utility.Vector3dVector(imgpc.rgb.astype(np.float64) / 255)
    o3d.io.write_point_cloud(str(ply_dir / f"{imgpaths[i].stem}_pts.ply"), pcd)

# segmentation with SAM

## running language-SAM

In [None]:
model = LangSAM()

In [None]:
bboxes = []
for i, imgpath in enumerate(tqdm(imgpaths)):
    imgpath = resizeimg_dir / imgpath.name
    image_pil = Image.open(imgpath).convert("RGB")

    masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)

    if len(masks) == 0:
        print(f"No objects of the '{text_prompt}' prompt detected in the image.")
    else:
        masks_np = [mask.squeeze().cpu().numpy() for mask in masks]

        bbox_mask_path = maskcomp_dir / f"{imgpath.stem}_img_with_mask.png"
        bbox_mask_path.parent.mkdir(parents=True, exist_ok=True)
        display_image_with_masks(image_pil, masks_np, boxes, logits, figwidth=13, savefig=bbox_mask_path, all_masks=True, show=False)

        # save masks
        for i, mask_np in enumerate(masks_np):
            # each box is x_min, y_min, x_max, y_max
            bbox = boxes[i]
            mask_path = mask_dir / f"{imgpath.stem}_mask_{i+1}.png"
            mask_image = Image.fromarray((mask_np * 255).astype(np.uint8))
            mask_image.save(mask_path)
            if i == 0:
                bboxes.append(bbox)

bboxes = np.array(bboxes, dtype=int)
with open(reconstr_dir / "bboxes.pickle", "wb") as f:
    pickle.dump(bboxes, f)

In [None]:
with open(reconstr_dir / "bboxes.pickle", "rb") as f:
    bboxes = pickle.load(f)

## segmenting point cloud

In [None]:
imgid2mask = {}
for idx, img in enumerate(imgpaths):
    imgname = Path(img.name).stem
    maskpath = mask_dir / f"{imgname}_mask_1.png"
    if maskpath.exists():
        # imgid2mask[idx] = cv2.imread(str(maskpath), cv2.IMREAD_GRAYSCALE)
        maskpil = Image.open(maskpath).convert("1").resize((W, H), Image.Resampling.NEAREST)
        masknp = np.asarray(maskpil)
        masknperoded = cv2.erode((masknp * 255).astype(np.uint8), np.ones((5, 5), np.uint8), iterations=2)
        imgid2mask[idx] = masknperoded
npts = pc.shape[0]
idxs = np.arange(npts)
barrelscores = np.zeros(npts)
for i, mask in imgid2mask.items():
    v3dcam = v3dcams[i]
    barrelidxs = segment_pc_from_mask(pc, mask, v3dcam)
    barrelscores[barrelidxs] += 1
# arbitrary 1/3 of images see the point threshold
barrelyes = barrelscores > len(imgid2mask) / 2
barrelcols = np.zeros_like(pc.p, dtype=np.uint8)
barrelcols[barrelyes] = [50, 222, 100]
barrelcols[~barrelyes] = [255, 0, 0]
barrelsegpc = v3d.Point3d(p=pc.p, rgb=barrelcols)
v3d.make_fig(barrelsegpc)

## fit plane to local seafloor around barrel, then rotate scene

In [None]:
bbox = bboxes[pc_idx]
diffmask = get_local_plane_mask(bbox, 1.1, 1.6, W, H)
plt.imshow(diffmask)

In [None]:
localflooridxs = segment_pc_from_mask(pc, diffmask, v3dcam)
floorcols = np.zeros_like(pc.p, dtype=np.uint8)
floorcols[:, 0] = 255
floorcols[localflooridxs] = [50, 222, 100]
floorsegpc = v3d.Point3d(p=pc.p, rgb=floorcols)

localfloorpts = pc.p[localflooridxs]

plane1 = pyrsc.Plane()
best_eq, best_inliers = plane1.fit(localfloorpts, thresh=0.005)
a, b, c, d = best_eq
normal = np.array([a, b, c])
xx, yy = np.meshgrid(np.linspace(-0.2, 0.2, 10), np.linspace(-0.2, 0.2, 10))
zz = (-a * xx - b * yy - d) / c
raycent = np.mean(floorsegpc[localflooridxs].p, axis=0)
fig = v3d.make_fig([floorsegpc, v3d.Ray(pos=raycent, dir=normal / 5)])
plane = go.Surface(x=xx, y=yy, z=zz, opacity=0.2)
fig.add_trace(plane)
fig.show()

In [None]:
scene_all_cols = np.zeros_like(barrelcols)
scene_all_cols[:] = [200, 0, 0]
scene_all_cols[barrelyes] = [70, 242, 22]
scene_all_cols[localflooridxs] = [54, 218, 255]
# rotate scene, then flip upside down if it rotates upside down
rotatedpts, R = rotate_pts_to_ax(floorsegpc.p, normal, [0, 0, 1.0], ret_R=True)
T = np.eye(4)
T[:3, :3] = R
rotatedpts = np.hstack([rotatedpts, np.ones((rotatedpts.shape[0], 1))])
unrotatedpts = np.hstack([floorsegpc.p, np.ones((rotatedpts.shape[0], 1))])
dtrans = np.eye(4)
dtrans[2, 3] = d
rotatedpts = (dtrans @ rotatedpts.T).T
if np.mean(rotatedpts[barrelyes, 2]) < 0:
    rot180 = np.eye(4)
    rot180[:3, :3] = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
    rotatedpts = (rot180 @ rotatedpts.T).T
    T = rot180 @ dtrans @ T
else:
    T = dtrans @ T
rotatedpts = rotatedpts[:, :3]
rotatedpc = v3d.Point3d(p=rotatedpts, rgb=pc.rgb)

# v3d.make_fig([v3d.Point3d(p=rotatedpts, rgb=barrelcols), v3dcams.apply_transform(v3d.Transform.from_matrix(T))])
xx, yy = np.meshgrid(np.linspace(np.min(rotatedpts[:, 0]), np.max(rotatedpts[:, 0]), 10), np.linspace(np.min(rotatedpts[:, 1]), np.max(rotatedpts[:, 1]), 10))
zz = np.zeros_like(xx)
raycent = np.mean(floorsegpc[localflooridxs].p, axis=0)
plane = go.Surface(x=xx, y=yy, z=zz, opacity=0.5, colorscale="purples")
fig = v3d.make_fig([v3d.Point3d(p=rotatedpts, rgb=scene_all_cols)])
fig.add_traces([
    plane,
    get_ray_trace([np.mean(rotatedpc.p[:, 0]), np.mean(rotatedpc.p[:, 1]), 0], [0, 0, 1], color="#6e0a6c", length=0.1, width=5, markersize=10),
    *get_surface_line_traces(xx, yy, zz)
])

In [None]:
rotpcdict = {
    "p": rotatedpts,
    "isbarrel": barrelyes,
    "rgb": pc.rgb,
    "T": T
}
with open(reconstr_dir / "rotatedpts.pickle", "wb") as f:
    pickle.dump(rotpcdict, f)

# Pointnet stuff

In [None]:
# barrelpc = v3d.Point3d(p=rotatedpts[barrelyes & (rotatedpts[:, 2] > 0.01) & (rotatedpts[:, 1] < 0.2)])
barrelpts = rotatedpts[barrelyes]
barrelxymean = np.mean(barrelpts[:, :2], axis=0)
# barrelpts = barrelpts - np.array([barrelxymean[0], barrelxymean[1], 0])
barrelpc = v3d.Point3d(p=barrelpts)
R = t3d.euler.euler2mat(-np.pi/2, 0, 0)
T = np.eye(4)
T[:3, :3] = R
barrelpc_yup = barrelpc.apply_transform(v3d.Transform.from_matrix(T))
barrelpc_yup.fig

In [None]:
## Load Model 
model_path = "checkpoints/pointnet_iter80_fixed.pth"
pointnet = BarrelNet(k=5, normal_channel=False)
pointnet.load_state_dict(torch.load(model_path))
pointnet.cuda().eval()

In [None]:
pts, scale = pts2inference_format(torch.tensor(barrelpc_yup.p, device="cuda").float(), max_points=1000)
with torch.no_grad():
    radius_pred, yshift_pred, axis_pred = pointnet(pts)
    radius_pred = radius_pred.cpu().numpy()[0]
    yshift_pred = yshift_pred.cpu().numpy()[0]
    axis_pred = axis_pred.cpu().numpy()[0]
axis_pred, yshift_pred, radius_pred

In [None]:
height_ratio = 2.8
axis_pred = axis_pred / np.linalg.norm(axis_pred)
# scale predictions
r = scale * radius_pred
h = r * height_ratio
y = yshift_pred * h
axpred_zup = np.linalg.inv(T)[:3, :3] @ axis_pred
x1, x2 = get_cyl_endpoints(axpred_zup, h, y, axidx=2)
x1[[0, 1]] += np.mean(barrelpc.p, axis=0)[[0, 1]]
x2[[0, 1]] += np.mean(barrelpc.p, axis=0)[[0, 1]]
c = (x1 + x2) / 2
print(h, r, y)
print(x1, x2)

xx, yy, zz = get_cylinder_surf(x1, x2, r)
fig = v3d.make_fig([rotatedpc])
fig.add_trace(get_ray_trace(c, axpred_zup, length=h, width=8, color="#e81b00", markersize=10))
cylsurf = go.Surface(x=xx, y=yy, z=zz, opacity=1.0, surfacecolor=np.zeros_like(xx), colorscale="oranges")
fig.add_traces(get_surface_line_traces(xx, yy, zz, step=5, include_horizontal=False))
xmin, xmax = np.min(rotatedpc.p[:, 0]), np.max(rotatedpc.p[:, 0])
ymin, ymax = np.min(rotatedpc.p[:, 1]), np.max(rotatedpc.p[:, 1])
xx, yy = np.meshgrid(np.linspace(xmin, xmax, 10), np.linspace(ymin, ymax, 10))
zz = np.zeros_like(xx)
planesurf = go.Surface(x=xx, y=yy, z=zz, opacity=0.5, colorscale="purples")
fig.add_traces(get_surface_line_traces(xx, yy, zz))
fig.add_trace(cylsurf)
fig.add_trace(planesurf)
fig.show()