In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
from pathlib import Path
import sys

import cv2
import dill as pickle
import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d
from PIL import Image
import plotly.graph_objects as go
from tqdm import tqdm
from scipy.spatial import KDTree
import torch
import visu3d as v3d

sys.path.append(os.path.join(os.getcwd(), "dust3r"))
from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.image import load_images
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from lang_sam import LangSAM

from barrelnet.dust3r_utils import save_dust3r_outs, read_dust3r, resize_to_dust3r
from barrelnet.langsam_utils import display_image_with_masks, display_image_with_boxes
from barrelnet.utils import segment_pc_from_mask, get_bbox_mask, get_local_plane_mask

In [None]:
image_dir = Path("data/barrelddt1")

# H, W = (875, 1920)
H, W = (224, 512)
reconstr_dir = Path(f"results/{image_dir.name}-reconstr")
mask_dir = reconstr_dir / "masks"
mask_dir.mkdir(parents=True, exist_ok=True)
maskcomp_dir = reconstr_dir / "image_with_masks"
maskcomp_dir.mkdir(parents=True, exist_ok=True)
ply_dir = reconstr_dir / "pc_plys"
ply_dir.mkdir(parents=True, exist_ok=True)
resizeimg_dir = reconstr_dir / "resized"
resizeimg_dir.mkdir(parents=True, exist_ok=True)
text_prompt = "underwater barrel"
imgpaths = sorted(image_dir.glob("*.jpg"))
dust3rout_path = reconstr_dir / "dust3r_out.pth"

In [None]:
for imgpath in imgpaths:
    img = Image.open(imgpath)
    img = resize_to_dust3r(img, 512)
    img.save(resizeimg_dir / imgpath.name)

# Reconstruction with dust3r

In [None]:
device = "cuda"
batch_size = 1
schedule = "cosine"
lr = 0.01
niter = 300
model_name = "checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device)
images = load_images(str(image_dir), size=512)
pairs = make_pairs(images, scene_graph="complete", prefilter=None, symmetrize=True)
output = inference(pairs, model, device, batch_size=batch_size)

view1, pred1 = output["view1"], output["pred1"]
view2, pred2 = output["view2"], output["pred2"]

scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)

imgs = scene.imgs
focals = scene.get_focals()
poses = scene.get_im_poses()
pts3d = scene.get_pts3d()
save_dust3r_outs(focals, poses, pts3d, savepath=dust3rout_path)
confidence_masks = scene.get_masks()

In [None]:
pts_final, pts_each, v3dcams = read_dust3r(dust3rout_path, W, H)
pc = v3d.Point3d(p=pts_each[0])

In [None]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pts_final)
o3d.io.write_point_cloud(str(ply_dir / f"pts_agg.ply"), pcd)
for i, imgpts in enumerate(pts_each):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(imgpts)
    o3d.io.write_point_cloud(str(ply_dir / f"{imgpaths[i].stem}_pts.ply"), pcd)

# segmentation with SAM

In [None]:
model = LangSAM()

In [None]:
bboxes = []
for i, imgpath in enumerate(tqdm(imgpaths)):
    imgpath = resizeimg_dir / imgpath.name
    image_pil = Image.open(imgpath).convert("RGB")

    masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)

    if len(masks) == 0:
        print(f"No objects of the '{text_prompt}' prompt detected in the image.")
    else:
        # Convert masks to numpy arrays
        masks_np = [mask.squeeze().cpu().numpy() for mask in masks]

        bbox_mask_path = maskcomp_dir / f"{imgpath.stem}_img_with_mask.png"
        bbox_mask_path.parent.mkdir(parents=True, exist_ok=True)
        display_image_with_masks(image_pil, masks_np, boxes, logits, figwidth=13, savefig=bbox_mask_path, all_masks=True, show=False)

        # Save the masks
        for i, mask_np in enumerate(masks_np):
            # each box is x_min, y_min, x_max, y_max
            bbox = boxes[i]
            mask_path = mask_dir / f"{imgpath.stem}_mask_{i+1}.png"
            mask_image = Image.fromarray((mask_np * 255).astype(np.uint8))
            mask_image.save(mask_path)
            if i == 0:
                bboxes.append(bbox)

bboxes = np.array(bboxes, dtype=int)
with open(reconstr_dir / "bboxes.pickle", "wb") as f:
    pickle.dump(bboxes, f)

In [None]:
with open(reconstr_dir / "bboxes.pickle", "rb") as f:
    bboxes = pickle.load(f)

In [None]:
imgid2mask = {}
for idx, img in enumerate(imgpaths):
    imgname = Path(img.name).stem
    maskpath = mask_dir / f"{imgname}_mask_1.png"
    if maskpath.exists():
        # imgid2mask[idx] = cv2.imread(str(maskpath), cv2.IMREAD_GRAYSCALE)
        maskpil = Image.open(maskpath).convert("1").resize((W, H), Image.Resampling.NEAREST)
        imgid2mask[idx] = np.asarray(maskpil)
npts = pc.shape[0]
idxs = np.arange(npts)
barrelscores = np.zeros(npts)
for i, mask in imgid2mask.items():
    v3dcam = v3dcams[i]
    barrelidxs = segment_pc_from_mask(pc, mask, v3dcam)
    barrelscores[barrelidxs] += 1
barrelyes = barrelscores > len(imgid2mask) / 3
segcols = np.zeros_like(pc.p, dtype=np.uint8)
segcols[barrelyes] = [50, 222, 100]
segcols[~barrelyes] = [255, 0, 0]
segpc = v3d.Point3d(p=pc.p, rgb=segcols)
v3d.make_fig(segpc)

In [None]:
bbox = bboxes[0]
diffmask = get_local_plane_mask(bbox, 1.1, 1.6, W, H)
plt.imshow(diffmask)

In [None]:
pc = v3d.Point3d(p=pts_each[0])
v3dcam = v3dcams[0]
localflooridxs = segment_pc_from_mask(pc, diffmask, v3dcam)
segcols = np.zeros_like(pc.p, dtype=np.uint8)
segcols[:, 0] = 255
segcols[localflooridxs] = [50, 222, 100]
segpc = v3d.Point3d(p=pc.p, rgb=segcols)
v3d.make_fig(segpc)

In [None]:
import pyransac3d as pyrsc

points = pc.p[localflooridxs] # Load your point cloud as a numpy array (N, 3)

plane1 = pyrsc.Plane()
best_eq, best_inliers = plane1.fit(points, 0.01)

In [None]:
print(best_eq)

In [None]:
a, b, c, d = best_eq
normal = np.array([a, b, c])
xx, yy = np.meshgrid(np.linspace(-0.2, 0.2, 10), np.linspace(-0.2, 0.2, 10))
zz = (-a * xx - b * yy - d) / c
raycent = np.mean(segpc[localflooridxs].p, axis=0)
fig = v3d.make_fig([segpc[localflooridxs], v3d.Ray(pos=raycent, dir=-normal / 3)])
plane = go.Surface(x=xx, y=yy, z=zz, opacity=0.2)
fig.add_trace(plane)
fig.show()

In [None]:
asdf = rotate_to_zax_np(segpc.p, normal, [0, 0, 1])
v3d.Point3d(p=asdf, rgb=segcols).fig

# Pointnet stuff