In [None]:
# =============================================================================
#  IMPORTANT: 
# =============================================================================
# This notebook is a modified replica of the custom_data notebook provided in
# the gs2mesh repository. In order to make variable image sizes work, we have 
# implemented patches to the intermediate steps and outputs of gs2mesh. 

# DO NOT change the order of the cells. The functionality is dependent on it.
# DO NOT edit in places that you are not prompted to edit.
# PLEASE KEEP args.skip_video_extraction and args.skip_colmap set to True, and 
# follow the guidelines provided for data population and COLMAP reconstruction 
# in our repository.

# Please note:
# If the cleaned tsdf mesh saving throws an error saying PLY could not be saved
# because it has 0 vertices, please adjust the args.TSDF_cleaning_threshold 
# parameter. For smaller garments such as shorts, a 10x reduction from default
# value is sufficient.
# =============================================================================
#  Imports
# =============================================================================
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import shutil
import types
from sam2.sam2_video_predictor import SAM2VideoPredictor

import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.io as pio
pio.renderers.default = 'notebook' # changed from iframe

import open3d as o3d
import sys
gs2mesh_path = os.path.join(os.getcwd(), "gs2mesh")

# Changes the working directory to gs2mesh
os.chdir(gs2mesh_path)

# Adds gs2mesh to sys.path so Python can find modules inside it
sys.path.insert(0, gs2mesh_path)
print("Changed working directory to:", os.getcwd())

from gs2mesh_utils.argument_utils import ArgParser
from gs2mesh_utils.colmap_utils import extract_frames, create_downsampled_colmap_dir, run_colmap, visualize_colmap_poses
from gs2mesh_utils.eval_utils import create_strings
from gs2mesh_utils.renderer_utils import Renderer
from gs2mesh_utils.stereo_utils import Stereo
from gs2mesh_utils.tsdf_utils import TSDF
from gs2mesh_utils.masker_utils import init_predictor, Masker

%load_ext autoreload
%autoreload 2

device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_dir = os.path.abspath(os.getcwd())

**Parameters:** (edit only here)

In [None]:
# =============================================================================
#  Initialize argument parser - DO NOT EDIT!
# =============================================================================
# Create argument parser with default arguments
args = ArgParser('custom')

# =============================================================================
#  Parameters - EDIT ONLY HERE!
# =============================================================================
# please specify the garment to be processed, must be one of [upper, lower, dress]
# where upper corresponds to tops, sweaters, jackets, etc.
# lower corresponds to pants, shorts, etc., and dress is self-explanatory.
garment_type = 'lower'

# the absolute path to the directory where you require the final (cleaned) garment 
# mesh .obj is to be stored. 
mesh_output_path = '/home/hramzan/Desktop/semester-project/Gaussian-Garments/data/outputs/actor02_seq1'

# General params
args.dataset_name = 'custom' # Name of the dataset
args.colmap_name = 'actor02_seq1' # Name of the directory with the COLMAP sparse model
args.experiment_folder_name = None # Name of the experiment folder

# Preprocessing params
args.downsample = 1 # Downsampling factor

# Gaussian Splatting parameters
args.GS_iterations = 30000  # Number of Gaussian Splatting iterations
args.GS_save_test_iterations = [7000, 30000]  # Gaussian Splatting test iterations to save
args.GS_white_background = False  # Use white background in Gaussian Splatting

# Renderer parameters
args.renderer_baseline_absolute = None  # Absolute value of the renderer baseline (None uses 7 percent of scene radius)
args.renderer_baseline_percentage = 7.0  # Percentage value of the renderer baseline
args.renderer_scene_360 = True # Scene is a 360 scene
args.renderer_folder_name = None  # Name of the renderer folder (None uses the colmap name)
args.renderer_save_json = True  # Save renderer data to JSON
args.renderer_sort_cameras = False  # Sort cameras in the renderer (True if using unordered set of views)

# Stereo parameters
args.stereo_model = 'DLNR_Middlebury'  # Stereo model to use
args.stereo_occlusion_threshold = 3  # Occlusion threshold for stereo model (Lower value masks out more areas)
args.stereo_shading_eps = 1e-4 # Small value used for visualization of the depth gradient. Adjusted according to the scale of the scene.
args.stereo_warm = False  # Use the previous disparity as initial disparity for current view (False if views are not sorted)

args.masker_automask = True # Use GroundingDINO for automatic object detection for masking with SAM2
args.masker_prompt = 'beige shorts' # Prompt for GroundingDINO
args.masker_SAM2_local = False # Use local SAM2 weights

# TSDF parameters
args.TSDF_scale = 1.0  # Fix depth scale
args.TSDF_dilate = 1  # Take every n-th image (1 to take all images)
args.TSDF_valid = None  # Choose valid images as a list of indices (None to ignore)
args.TSDF_skip = None  # Choose non-valid images as a list of indices (None to ignore)
args.TSDF_use_occlusion_mask = True  # Ignore occluded regions in stereo pairs for better geometric consistency
args.TSDF_use_mask = False  # Use object masks (optional)
args.TSDF_invert_mask = False  # Invert the background mask for TSDF. Only if TSDF_use_mask is True
args.TSDF_erode_mask = True  # Erode masks in TSDF. Only if TSDF_use_mask is True
args.TSDF_erosion_kernel_size = 10  # Erosion kernel size in TSDF.  Only if TSDF_use_mask is True
args.TSDF_closing_kernel_size = 10  # Closing kernel size in TSDF.  Only if TSDF_use_mask is True.
args.TSDF_voxel = 2  # Voxel size (voxel length is TSDF_voxel/512)
args.TSDF_sdf_trunc = 0.04  # SDF truncation in TSDF
args.TSDF_min_depth_baselines = 4  # Minimum depth baselines in TSDF
args.TSDF_max_depth_baselines = 20  # Maximum depth baselines in TSDF
args.TSDF_cleaning_threshold = 100000  # Minimal cluster size for clean mesh

# Running parameters
args.video_extension = 'mp4'  # Video file extension
args.video_interval = 10  # Extract every n-th frame - aim for 3fps
args.GS_port = 8090  # GS port number (relevant if running several instances at the same time)
args.skip_video_extraction = True  # Skip the video extraction stage
args.skip_colmap = True  # Skip the COLMAP stage
args.skip_GS = True  # Skip the GS stage
args.skip_rendering = True  # Skip the rendering stage
args.skip_masking = True  # Skip the masking stage
args.skip_TSDF = False  # Skip the TSDF stage

# =============================================================================
#  DO NOT EDIT THESE LINES:
# =============================================================================
colmap_dir = os.path.abspath(os.path.join(base_dir,'data', args.dataset_name, args.colmap_name))
strings = create_strings(args)

def set_references(garment_type):
    if garment_type.lower() == "upper":
        hor_ref, vert_ref = 1, 0
    elif garment_type.lower() in ["lower", "dress"]:
        hor_ref, vert_ref = 0, 0
    else:
        raise ValueError(f"Invalid garment type: {garment_type}")
    return hor_ref, vert_ref

h_ref, v_ref = set_references(garment_type)

**Extract frames if needed and Run COLMAP:** (only run if you don't have a COLMAP dataset. If you do, copy the colmap dataset to the "data" folder in the main root and update "colmap_output_dir")

In [None]:
# =============================================================================
#  Visualize the sparse COLMAP output and the COLMAP poses.
# =============================================================================
GT_path = None # OPTIONAL: compare to a GT point cloud if it is aligned with the COLMAP sparse point cloud
# if you don't see the cameras, adjust the depth scale. If you don't see the points, adjust the subsample
visualize_colmap_poses(colmap_dir, depth_scale=10.0, subsample=100, visualize_points=True, GT_path=GT_path) 

**Run Gaussian Splatting:**

In [None]:
# =============================================================================
#  Run Gaussian Splatting
# =============================================================================
if not args.skip_GS:
    try:
        os.chdir(os.path.join(base_dir, 'third_party', 'gaussian-splatting'))
        iterations_str = ' '.join([str(iteration) for iteration in args.GS_save_test_iterations])
        os.system(f"python train.py -s {colmap_dir} --port {args.GS_port} --model_path {os.path.join(base_dir, 'splatting_output', strings['splatting'], args.colmap_name)} --iterations {args.GS_iterations} --test_iterations {iterations_str} --save_iterations {iterations_str}{' --white_background' if args.GS_white_background else ''}")
        os.chdir(base_dir)
    except:
        os.chdir(base_dir)
        print("ERROR")

**Prepare GS renderer for rendering stereo views:**

In [None]:
# =============================================================================
#  Initialize renderer
# =============================================================================
renderer = Renderer(base_dir, 
                    colmap_dir,
                    strings['output_dir_root'],
                    args,
                    dataset = strings['dataset'], 
                    splatting = strings['splatting'],
                    experiment_name = strings['experiment_name'],
                    device=device)

In [None]:
# =============================================================================
#  Visualize GS point cloud with COLMAP poses
# =============================================================================
# Green points are inside the FOV of at least one camera, given the min/max depth truncation at the TSDF stage.
# Make sure that the object you want to reconstruct is Green. If not, adjust TSDF_max_depth_baselines to include the object.
# If too much background is also green, reduce TSDF_max_depth_baselines to discard it.
renderer.visualize_poses(depth_scale=10, subsample=100)

In [None]:
# =============================================================================
#  Prepare renderer
# =============================================================================
# ONLY NEED TO RUN ONCE PER SCENE!! Initializes renderer, takes some time
if not args.skip_rendering:
    renderer.prepare_renderer()

**Run Rendering + Stereo Model:**

In [None]:
# =============================================================================
#  Initialize stereo
# =============================================================================
stereo = Stereo(base_dir, renderer, args, device=device)

In [None]:
# =============================================================================
#  Run stereo
# =============================================================================
%matplotlib inline
if not args.skip_rendering:
    stereo.run(start=0, visualize=False)

**Run SAM2 Masker (NECESSARY FOR GARMENT SEGMENTATION):**

In [None]:
# =============================================================================
#  Initialize SAM2 predictor + GroundingDINO model
# =============================================================================
# ONLY NEED TO RUN ONCE PER SCENE!! Initializes SAM2 predictor and GroundingDino model, takes some time
if not args.skip_masking:
    GD_model, predictor, inference_state, images_dir = init_predictor(base_dir, renderer, args, device=device) 

In [None]:
# =============================================================================
#  Organises rendered images by sizes, since actorsHQ has hortizontal and vertical frames
# =============================================================================
if not args.skip_masking:
    subdir_bins = []
    for img in os.listdir(images_dir):
        if os.path.splitext(img)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]:
            src_im = os.path.join(images_dir, img)
            image = plt.imread(src_im)
            height, width = image.shape[:2]
            subdir_bin = str(height) + 'x' + str(width)
            if subdir_bin not in subdir_bins:
                subdir_bins.append(subdir_bin)
            os.makedirs(os.path.join(images_dir, subdir_bin), exist_ok=True)
            shutil.copy(src_im, os.path.join(images_dir, subdir_bin, img))
    
    garment_dict = {
            "vertical": {"subdir": next(dim for dim in subdir_bins if dim.startswith("1022")), "ref": v_ref},
            "horizontal": {"subdir": next(dim for dim in subdir_bins if dim.startswith("747")), "ref": h_ref},
    }

In [None]:
# =============================================================================
#  Patch on Masker class to handle variable size images data logic
# =============================================================================
def segment_custom(self, images_dir):
    image_filenames = [
            int(p.split('.')[0]) for p in os.listdir(images_dir)
            if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
    ]
    image_filenames.sort()
    
    for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(self.inference_state):
        mask = (out_mask_logits[0] > 0.0).cpu().numpy().squeeze(0)
        idx = image_filenames[out_frame_idx]
        output_dir = self.renderer.render_folder_name(idx)
        np.save(os.path.join(output_dir, 'left_mask.npy'), mask)
        plt.imsave(os.path.join(output_dir, 'left_mask.png'), mask)
    plt.close('all')

In [None]:
def view_masker_init(orientation):
    subdir_bin = garment_dict[orientation]["subdir"]
    predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large", device=device)
    images_subdir = os.path.join(images_dir, subdir_bin)
    inference_state = predictor.init_state(video_path=images_subdir)
    
    masker = Masker(GD_model, predictor, inference_state, images_subdir, renderer, stereo, args, image_number=garment_dict[orientation]["ref"], visualize=True)
    return masker, images_subdir

In [None]:
%matplotlib widget
if not args.skip_masking:
    masker, images_subdir = view_masker_init("horizontal")

In [None]:
if not args.skip_masking:
    masker.segment_custom = types.MethodType(segment_custom, masker)
    masker.segment_custom(images_subdir)

    # once masks are generated, we need to store the first frame outputs 
    # to handle overwrite issue by moving the contents of 000 folder to 000_save
    copy_src = renderer.render_folder_name(0)
    copy_dest = copy_src + '_save'
    
    os.system(f"mkdir {copy_dest}")
    os.system(f"mv {copy_src + '/*'} {copy_dest}")

In [None]:
%matplotlib widget
if not args.skip_masking:
    masker, images_subdir = view_masker_init("vertical")

In [None]:
if not args.skip_masking:
    masker.segment_custom = types.MethodType(segment_custom, masker)
    masker.segment_custom(images_subdir)

    # now, repopulate the original 000 folder with the stored data
    os.system(f"rm -rf {copy_src + '/*'}")
    os.system(f"mv {copy_dest + '/*'} {copy_src}")
    os.system(f"rm -rf {copy_dest}")

    # if the garment is an upper, then the 000 folder does not have any mask
    # since it is skipped in masking. so, a manual copy paste from one of the 
    # later empty frames is done
    if garment_type.lower() == 'upper':
        src_path = renderer.render_folder_name(8)
        dest_path = renderer.render_folder_name(0)
        shutil.copy(os.path.join(src_path, 'left_mask.png'), os.path.join(dest_path, 'left_mask.png'))
        shutil.copy(os.path.join(src_path, 'left_mask.npy'), os.path.join(dest_path, 'left_mask.npy'))

**View Results:**

In [None]:
# ====================================================================================================
#  View left-right renders, segmentation mask, disparity, occlusion mask and shading (depth gradient)
# ====================================================================================================
%matplotlib inline
stereo.view_results()

**TSDF**

In [None]:
# =============================================================================
#  Initialize TSDF
# =============================================================================
args.TSDF_use_mask = True
tsdf = TSDF(renderer, stereo, args, strings['TSDF'])

In [None]:
# ================================================================================
#  Run TSDF. the TSDF class will have an attribute "mesh" with the resulting mesh
# ================================================================================
%matplotlib inline
if not args.skip_TSDF:
    tsdf.run(visualize=False)

In [None]:
# =============================================================================
#  Save the original mesh before cleaning
# =============================================================================
tsdf.save_mesh()

In [None]:
# =============================================================================
#  Clean the mesh using clustering and save the cleaned mesh.
# =============================================================================
# original mesh is still available under tsdf.mesh (the cleaned is tsdf.clean_mesh)
tsdf.clean_mesh()

# saving the cleaned mesh in desired output directory
o3d.io.write_triangle_mesh(os.path.join(mesh_output_path, 'template.obj'), tsdf.clean_mesh)

In [None]:
# =============================================================================
#  Show clean mesh
# =============================================================================
GT_path = None # OPTIONAL: compare to a GT point cloud if it is aligned with the COLMAP sparse point cloud
tsdf.visualize_mesh(subsample=100, GT_path=GT_path, show_clean=True)