## 1. Imports and Model Loading

In [None]:
import os

os.chdir("notebook")
os.getcwd()

In [None]:
import uuid
import imageio
import numpy as np
from IPython.display import Image as ImageDisplay

PATH = os.getcwd()
TAG = "hf"

## 2. Load input image to lift to 3D (multiple objects)

In [None]:
import gc
import torch

# gc.collect()
# torch.cuda.empty_cache()

In [None]:
import torch

print("CUDA available:", torch.cuda.is_available())
print("CUDA devices:", torch.cuda.device_count())

In [None]:
os.environ["LIDRA_SKIP_INIT"] = "true"

from pathlib import Path
from sam3d_objects.pipeline.inference_with_embeddings import InferenceWithEmbeddings


config_path = f"{PATH}/../checkpoints/{TAG}/pipeline.yaml"

data_dir = Path("../../../../../projects/FRI/jn16867/3d-counting/Stacks-3D-Real/scenes")

pipeline = InferenceWithEmbeddings(config_path, compile=False)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from importlib import reload
import sam3d_objects.data.precompute_embeddings as pe

reload(pe)

In [None]:
for category_dir in sorted(data_dir.iterdir()):
    if (category_dir / "geco2_mask" / "image.png").exists():
        print(f"Processing sample for category: {category_dir}")
        pe.preprocess_sample(category_dir, pipeline)
    else:
        print(f"No GeCo2 mask found for category: {category_dir}")

In [None]:
from pathlib import Path

train_dir_1 = Path("../../../../../projects/FRI/jn16867/3d-counting/scenes_part1")
train_dir_2 = Path("../../../../../projects/FRI/jn16867/3d-counting/scenes_part2")

In [None]:
import numpy as np

data = np.load(train_dir_1 / "00003219_1" / "00003219_1.npz")
print(data.files)


In [None]:
from matplotlib import pyplot as plt
from PIL import Image

image = np.array(Image.open(train_dir_1 / "00003531_1" / "images" / "RGB0010.jpg").convert("RGB"))
box = np.array(Image.open(train_dir_1 / "00003531_1" / "box_seg" / "Box_Mask0010.png").convert("RGB"))
obj = np.array(Image.open(train_dir_1 / "00003531_1" / "obj_seg" / "Objects_Mask0010.png").convert("RGB"))
floor = np.array(Image.open(train_dir_1 / "00003531_1" / "floor_seg" / "Ground_Mask0010.png").convert("RGB"))

plt.imshow(image)
plt.show()

In [None]:
plt.imshow(obj)
plt.show()

In [None]:
os.environ["LIDRA_SKIP_INIT"] = "true"

from sam3d_objects.pipeline.inference_with_embeddings import InferenceWithEmbeddings

config_path = f"{PATH}/../checkpoints/{TAG}/pipeline.yaml"
pipeline = InferenceWithEmbeddings(config_path, compile=False)

In [None]:
from PIL import Image, ImageOps
import json


frame = "0010"
data_dir = Path("../../../../../projects/FRI/jn16867/3d-counting/scenes_part1")

for scene in sorted(data_dir.iterdir()):
    image = np.array(Image.open(scene / "images" / ("RGB" + frame + ".jpg")).convert("RGB"))
    box_mask = np.array(Image.open(scene / "box_seg" / ("Box_Mask" + frame + ".png")).convert("L"))
    obj_mask = np.array(Image.open(scene / "obj_seg" / ("Objects_Mask" + frame + ".png")).convert("L"))
    
    floor_mask = Image.open(scene / "floor_seg" / ("Ground_Mask" + frame + ".png")).convert("L")
    container_mask = np.array(ImageOps.invert(floor_mask))
    
    obj_mask = (obj_mask > 0).astype(np.uint8) * 255
    container_mask = (container_mask > 0).astype(np.uint8) * 255

    print(f"Computing embeddings for container")
    container_out = pipeline.run_with_embeddings(image, container_mask, seed=42)
    print(f"Computing embeddings for object")
    object_out = pipeline.run_with_embeddings(image, obj_mask, seed=42)
    
    with open(scene / "gt_count.json") as f:
        gt_count = json.load(f)
        
    save_dict = {
        "container": container_out,
        "object": object_out,
        "true_count": gt_count
    }
    
    torch.save(save_dict, data_dir / "embeddings.pt")
    print(f"Saved embeddings for {data_dir.name}")
    
    break


In [None]:
import json

with open(scene / "gt_count.json") as f:
    gt_count = json.load(f)
    print(gt_count)

In [None]:
plt.imshow(floor)
plt.show()

In [None]:
image = np.array(Image.open(train_dir_1 / "00003851_0" / "image.png").convert("RGB"))
first_frame = "frame_00001.png"

if (data_dir / "box_seg").exists():
    container_mask = np.array(Image.open(data_dir / "box_seg" / first_frame).convert("L"))
elif (data_dir / "obj_seg").exists():
    container_mask = np.array(Image.open(data_dir / "obj_seg" / first_frame).convert("L"))
else:
    print(f"No box or object segmentation found for data_dir: {data_dir}. Skipping category.")
container_mask = (container_mask > 0).astype(np.uint8) * 255

object_mask = np.array(Image.open(data_dir / "geco2_mask" / "mask.png"))
object_mask = (object_mask > 0).astype(np.uint8) * 255


In [None]:
import torch
from torch.utils.data import DataLoader
from sam3d_objects.data.count_dataset import CountDataset


data_dir = Path("../../../../../projects/FRI/jn16867/3d-counting/Stacks-3D-Real/scenes")

dataset = CountDataset(data_dir)
print(f"Dataset length: {len(dataset)}")

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    collate_fn=lambda x: x[0]
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
    collate_fn=lambda x: x[0]
)

In [None]:
from sam3d_objects.pipeline.sam3d_count_predictor import SAM3DCountPredictor
import torch.nn as nn

device = "cuda"
num_epochs = 100
lr = 1e-3
output_dir = Path("../model_checkpoints/stacks-3d/")

use_hybrid = False
model = SAM3DCountPredictor(use_hybrid=use_hybrid).to(device)
    
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

mse_loss = nn.MSELoss()
l1_loss = nn.L1Loss()

best_val_mae = float('inf')

In [None]:
# Load best current model and continue training
checkpoint = torch.load(Path("../model_checkpoints/stacks-3d/checkpoint_epoch_100.pth"), map_location=device)

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

start_epoch = checkpoint["epoch"] + 1
best_val_mae = checkpoint["best_val_mae"]

In [None]:
from tqdm import tqdm
from sam3d_objects.pipeline.inference_with_embeddings import extract_geometric_features, compute_geometric_count_estimate


for epoch in range(start_epoch, start_epoch + num_epochs):
    model.train()
    train_loss = 0
    train_mae = 0
    
    print(f"\nEpoch {epoch + 1}/{num_epochs} - Training...")
    
    for sample in tqdm(train_loader):
        container_out = sample['container_outputs']
        object_out = sample['object_outputs']
        true_count = torch.tensor([sample['true_count']], dtype=torch.float32).to(device)
        
        shape_latent_container = container_out['shape_latent'].to(device)
        shape_latent_object = object_out['shape_latent'].to(device)
            
        slat_features_container = container_out['slat_features'].to(device)
        slat_features_object = object_out['slat_features'].to(device)
        
        if slat_features_container.dim() == 2:
            slat_features_container = slat_features_container.unsqueeze(0)
        if slat_features_object.dim() == 2:
            slat_features_object = slat_features_object.unsqueeze(0)
        
        geom_feat = extract_geometric_features(container_out, object_out).unsqueeze(0).to(device)
        
        geometric_estimate = torch.tensor(
            [compute_geometric_count_estimate(container_out, object_out, 1.0)],
            dtype=torch.float32
        ).to(device)
        
        if use_hybrid:
            pred_count, correction = model(
                shape_latent_container,
                shape_latent_object,
                slat_features_container,
                slat_features_object,
                geom_feat,
                geometric_estimate
            )
        else:
            pred_count, _ = model(
                shape_latent_container,
                shape_latent_object,
                slat_features_container,
                slat_features_object,
                geom_feat,
                None
            )
            
        loss = 0.5 * mse_loss(pred_count, true_count) + 0.5 * l1_loss(pred_count, true_count)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        train_loss += loss.item()
        train_mae += torch.abs(pred_count - true_count).item()
        
    print(f"Train loss at epoch {epoch + 1}: {train_loss}")
    print(f"Train mae at epoch {epoch + 1}: {train_mae}")
    
    scheduler.step()
    
    avg_train_loss = train_loss / len(train_dataset)
    avg_train_mae = train_mae / len(train_dataset)

    # Validation
    if epoch % 5 == 0 or epoch == num_epochs - 1:
        model.eval()
        val_loss = 0
        val_mae = 0
        print("Validating")
        with torch.no_grad():
            for sample in tqdm(val_loader):
                container_out = sample['container_outputs']
                object_out = sample['object_outputs']
                true_count = torch.tensor([sample['true_count']], dtype=torch.float32).to(device)
                
                shape_latent_container = container_out['shape_latent'].to(device)
                shape_latent_object = object_out['shape_latent'].to(device)
                    
                slat_features_container = container_out['slat_features'].to(device)
                slat_features_object = object_out['slat_features'].to(device)
                
                if slat_features_container.dim() == 2:
                    slat_features_container = slat_features_container.unsqueeze(0)
                if slat_features_object.dim() == 2:
                    slat_features_object = slat_features_object.unsqueeze(0)
                
                geom_feat = extract_geometric_features(container_out, object_out).unsqueeze(0).to(device)
                
                geometric_estimate = torch.tensor(
                    [compute_geometric_count_estimate(container_out, object_out, 1.0)],
                    dtype=torch.float32
                ).to(device)
                
                pred_count, _ = model(
                    shape_latent_container,
                    shape_latent_object,
                    slat_features_container,
                    slat_features_object,
                    geom_feat,
                    None
                )
                
                loss = 0.5 * mse_loss(pred_count, true_count) + 0.5 * l1_loss(pred_count, true_count)
                
                val_loss += loss
                val_mae += torch.abs(pred_count - true_count).item()
            
            avg_val_loss = val_loss / len(val_dataset)
            avg_val_mae = val_mae / len(val_dataset)
            
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"  Train Loss: {avg_train_loss:.4f}, Train MAE: {avg_train_mae:.2f}")
            print(f"  Val Loss: {avg_val_loss:.4f}, Val MAE: {avg_val_mae:.2f}")
            
            if avg_val_mae < best_val_mae:
                best_val_mae = avg_val_mae
                
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_val_mae': best_val_mae 
                }, output_dir / 'best_model.pth')
                print(f"  Saved new best model with VAL MAE: {avg_val_mae:.2f}")
            
            if (epoch + 1) % 10 == 0:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, output_dir / f"checkpoint_epoch_{epoch+1}.pth")
        
print("\nTraining completed")
print(f"Best Validation MAE: {best_val_mae:.2f}")

In [5]:
import os
os.environ["LIDRA_SKIP_INIT"] = "true"

from sam3d_objects.pipeline.inference_with_embeddings import InferenceWithEmbeddings
from pathlib import Path
from PIL import Image, ImageOps
import numpy as np
import torch
import json
import time

  import pynvml  # type: ignore[import]
[32m2026-02-20 15:50:06.407[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36mset_attention_backend[0m:[36m17[0m - [1mGPU name is Tesla V100S-PCIE-32GB[0m


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


[32m2026-02-20 15:52:39.265[0m | [1mINFO    [0m | [36msam3d_objects.model.backbone.tdfy_dit.modules.sparse[0m:[36m__from_env[0m:[36m39[0m - [1m[SPARSE] Backend: spconv, Attention: sdpa[0m
[32m2026-02-20 15:52:57.466[0m | [1mINFO    [0m | [36msam3d_objects.model.backbone.tdfy_dit.modules.attention[0m:[36m__from_env[0m:[36m30[0m - [1m[ATTENTION] Using backend: sdpa[0m


[SPARSE][CONV] spconv algo: auto
Warp 1.11.0 initialized:
   CUDA Toolkit 12.9, Driver 13.0
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "Tesla V100S-PCIE-32GB" (32 GiB, sm_70, mempool enabled)
     "cuda:1"   : "Tesla V100S-PCIE-32GB" (32 GiB, sm_70, mempool enabled)
   CUDA peer access:
     Supported fully (all-directional)
   Kernel cache:
     /d/hpc/home/jn16867/.cache/warp/1.11.0




In [4]:
import os

os.chdir("cso/sam-3d-objects/notebook")
os.getcwd()

'/d/hpc/home/jn16867/cso/sam-3d-objects/notebook'

In [7]:
PATH = os.getcwd()
TAG = "hf"

config_path = f"{PATH}/../checkpoints/{TAG}/pipeline.yaml"
pipeline = InferenceWithEmbeddings(config_path, compile=False)

[32m2026-02-20 18:34:08.099[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36m__init__[0m:[36m100[0m - [1mself.device: cuda[0m
[32m2026-02-20 18:34:08.099[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36m__init__[0m:[36m101[0m - [1mCUDA_VISIBLE_DEVICES: 0,1[0m
[32m2026-02-20 18:34:08.100[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36m__init__[0m:[36m102[0m - [1mActually using GPU: 0[0m
[32m2026-02-20 18:34:08.100[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36minit_pose_decoder[0m:[36m297[0m - [1mUsing pose decoder: ScaleShiftInvariant[0m
[32m2026-02-20 18:34:08.101[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36m__init__[0m:[36m133[0m - [1mLoading model weights...[0m
[32m2026-02-20 18:34:11.601[0m | [1mINFO    [0m | [36msam3d_objects.model.io[0m:[36mload_model_from_checkpoint[0m:[36m158

In [8]:
def prepare_data(data_dir: set, pipeline: InferenceWithEmbeddings):
    for scene in sorted(data_dir.iterdir()):
        if not os.path.exists(scene / "geco2_data"):
            print(f"Skipping scene {scene}, no geco2_data found")
            continue
        
        image = np.array(Image.open(scene / "geco2_data" / "image.png").convert("RGB"))
        obj_mask = np.array(Image.open(scene / "geco2_data" / "obj_mask.png").convert("L"))
        box_mask = np.array(Image.open(scene / "geco2_data" / "box_mask.png").convert("L"))
        
        obj_mask = (obj_mask > 0).astype(np.uint8)
        box_mask = (box_mask > 0).astype(np.uint8)
        
        object_out = pipeline.run_with_embeddings(image, obj_mask, seed=42)
        container_out = pipeline.run_with_embeddings(image, box_mask, seed=42)
        
        with open(scene / "gt_count.json") as f:
            gt_count = json.load(f)
        
        save_dict = {
            "container": container_out,
            "object": object_out,
            "true_count": gt_count
        }
        
        torch.save(save_dict, scene / "embeddings.pt")
        print(f"Saved embeddings for {scene.name}")
        
    return

    

In [11]:
data_dir = Path("../../../../../projects/FRI/jn16867/3d-counting/scenes_part1")

print(f"Preparing data")
start_time = time.time()

prepare_data(data_dir=data_dir, pipeline=pipeline)

end_time = time.time() - start_time
print(f"Finished preparing data, time taken: {end_time} seconds")

Preparing data


[32m2026-02-20 19:00:57.451[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36mmerge_image_and_mask[0m:[36m584[0m - [1mReplacing alpha channel with the provided mask[0m


RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.

In [16]:
for scene in sorted(data_dir.iterdir()):
    if not os.path.exists(scene / "geco2_data"):
        print(f"Skipping scene {scene}, no geco2_data found")
        continue
    
    image = np.array(Image.open(scene / "geco2_data" / "image.png").convert("RGBA"))
    obj_mask = np.array(Image.open(scene / "geco2_data" / "obj_mask.png").convert("L"))
    box_mask = np.array(Image.open(scene / "geco2_data" / "box_mask.png").convert("L"))
    
    obj_mask = (obj_mask > 0).astype(np.uint8) * 255
    box_mask = (box_mask > 0).astype(np.uint8) * 255
    
    object_out = pipeline.run_with_embeddings(image, obj_mask, seed=42)
    container_out = pipeline.run_with_embeddings(image, box_mask, seed=42)
    
    with open(scene / "gt_count.json") as f:
        gt_count = json.load(f)
    
    save_dict = {
        "container": container_out,
        "object": object_out,
        "true_count": gt_count
    }
    
    torch.save(save_dict, scene / "embeddings.pt")
    print(f"Saved embeddings for {scene.name}")


[32m2026-02-20 19:13:30.381[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36mmerge_image_and_mask[0m:[36m584[0m - [1mReplacing alpha channel with the provided mask[0m
[32m2026-02-20 19:13:33.242[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36mget_condition_input[0m:[36m633[0m - [1mRunning condition embedder ...[0m
[32m2026-02-20 19:13:33.991[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36mget_condition_input[0m:[36m637[0m - [1mCondition embedder finishes![0m
[32m2026-02-20 19:13:47.658[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36mget_condition_input[0m:[36m633[0m - [1mRunning condition embedder ...[0m
[32m2026-02-20 19:13:47.816[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36mget_condition_input[0m:[36m637[0m - [1mCondition embedder finishes![0m
[32m2026-02-20 19:14:05.460[0m | [1mINFO    

Saved embeddings for 00000111_0


[32m2026-02-20 19:14:35.064[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36mmerge_image_and_mask[0m:[36m584[0m - [1mReplacing alpha channel with the provided mask[0m
[32m2026-02-20 19:14:35.233[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36mget_condition_input[0m:[36m633[0m - [1mRunning condition embedder ...[0m
[32m2026-02-20 19:14:35.363[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36mget_condition_input[0m:[36m637[0m - [1mCondition embedder finishes![0m
[32m2026-02-20 19:14:48.042[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36mget_condition_input[0m:[36m633[0m - [1mRunning condition embedder ...[0m
[32m2026-02-20 19:14:48.134[0m | [1mINFO    [0m | [36msam3d_objects.pipeline.inference_pipeline[0m:[36mget_condition_input[0m:[36m637[0m - [1mCondition embedder finishes![0m
[32m2026-02-20 19:14:59.543[0m | [1mINFO    

KeyboardInterrupt: 

In [15]:
print(obj_mask.min(), obj_mask.max(), obj_mask.sum())

0 1 1325
