# PocaFoldAS inference (demo test data)
Run the trained PocaFoldAS on the bundled demo test split (no simulation). Configure the checkpoint and output folder, then run inference.


## 1. Project root
Ensure the notebook runs from the repository root.


In [None]:
import os
import pathlib

PROJECT_ROOT = None
for candidate in [pathlib.Path.cwd(), *pathlib.Path.cwd().parents]:
    if (candidate / 'setup.py').exists():
        PROJECT_ROOT = candidate
        break
if PROJECT_ROOT is None:
    raise RuntimeError('Could not locate the repository root. Please run this notebook from within the smlm project.')

os.chdir(PROJECT_ROOT)
print(f'Working directory set to: {PROJECT_ROOT}')



## 2. Paths, checkpoint, and output location
Default to demo dataset, tetrahedron checkpoint, output directory 


In [None]:
import os
import pathlib
import yaml
import torch

CONFIG_PATH = pathlib.Path('configs/config_demo_data.yaml')  # model hyperparams
CKPT_PATH = pathlib.Path(os.getenv('POCAFOLDAS_CKPT', 'weights/tetra.pth'))  # set to your trained weights
DATA_ROOT = pathlib.Path(os.getenv('DEMO_TEST_ROOT', 'demo_data/tetrahedron_seed1234_test'))
CLASSES = ['tetra']
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def choose_output_dir():
    env_path = os.getenv('POCAFOLDAS_INFER_OUT')
    if env_path:
        return pathlib.Path(env_path)
    container_default = pathlib.Path('/workspace/smlm_inference')
    if container_default.parent.exists() and os.access(container_default.parent, os.W_OK):
        return container_default
    return pathlib.Path('output/notebook_inference')

OUTPUT_DIR = choose_output_dir()

if not CONFIG_PATH.exists():
    raise FileNotFoundError(f'Config file not found: {CONFIG_PATH}')
if not DATA_ROOT.exists():
    raise FileNotFoundError(f'Data root not found: {DATA_ROOT}')
if not CKPT_PATH.exists():
    raise FileNotFoundError(f'Checkpoint not found: {CKPT_PATH}. Set POCAFOLDAS_CKPT to your trained weights.')

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print('Device:', DEVICE)
print('Data root:', DATA_ROOT.resolve())
print('Output dir:', OUTPUT_DIR.resolve())
print('Checkpoint:', CKPT_PATH.resolve())


### Container note
- Default output dir: `/workspace/smlm_inference`.
- Override with `PocaFoldAS_INFER_OUT` or bind-mount a host folder, e.g. `-v /host/logs:/workspace/smlm_inference`.
- Set `PocaFoldAS_CKPT` to a checkpoint inside the container or to a mounted path.


## 3. Load config hyperparameters


In [None]:
cfg = yaml.safe_load(CONFIG_PATH.read_text())
train_cfg = cfg.get('train', {})
dset_cfg = cfg.get('dataset', {})

num_dense = train_cfg.get('num_dense', 2048)
latent_dim = train_cfg.get('latent_dim', 1024)
grid_size = train_cfg.get('grid_size', 2)
channels = train_cfg.get('channels', 3)
classifier = bool(train_cfg.get('classifier', False))
suffix = dset_cfg.get('suffix', '.csv')
remove_part_prob = dset_cfg.get('remove_part_prob', 0.0)
remove_outliers = dset_cfg.get('remove_outliers', False)
remove_corners = dset_cfg.get('remove_corners', False)
number_corners_remove = dset_cfg.get('number_corners_remove', [0, 1, 2])

print(f"Model: num_dense={num_dense}, latent_dim={latent_dim}, grid_size={grid_size}, channels={channels}, classifier={classifier}")
print(f"Data cfg: suffix={suffix}, remove_part_prob={remove_part_prob}, remove_corners={remove_corners}, remove_outliers={remove_outliers}")



## 4. Dataset and dataloader
Uses the demo test split (paired iso/aniso) and pads to the max point count found in the dataset.


In [None]:
from torchvision import transforms
from model_architectures.transforms import Padding, ToTensor
from helpers.data import get_highest_shape
from dataset.Dataset import PairedAnisoIsoDataset

highest_shape = get_highest_shape(str(DATA_ROOT), CLASSES, subfolders=['iso', 'aniso'], suffix=suffix)
print('Highest shape:', highest_shape)

pc_transforms = transforms.Compose([Padding(highest_shape), ToTensor()])

dataset = PairedAnisoIsoDataset(
    root_folder=str(DATA_ROOT),
    suffix=suffix,
    transform=pc_transforms,
    classes_to_use=CLASSES,
    remove_part_prob=remove_part_prob,
    remove_outliers=remove_outliers,
    remove_corners=remove_corners,
    number_corners_to_remove=number_corners_remove,
)

print('Samples in dataset:', len(dataset))



## 5. Load model and checkpoint


In [None]:
from model_architectures.pocafoldas import PocaFoldAS

model = PocaFoldAS(num_dense=num_dense, latent_dim=latent_dim, grid_size=grid_size, classifier=classifier, channels=channels)
state = torch.load(CKPT_PATH, map_location=DEVICE)
current = model.state_dict()
filtered = {k: v for k, v in state.items() if k in current and v.size() == current[k].size()}
current.update(filtered)
model.load_state_dict(current, strict=False)
model.to(DEVICE)
model.eval()
print('Loaded checkpoint with', len(filtered), 'matched keys (strict=False for flexibility).')



## 6. Run inference and export a few predictions
Saves the first few outputs as PLY files; collects Chamfer L1 for a quick sanity check.


In [None]:
import numpy as np
import open3d as o3d
from torch.utils.data import DataLoader
from model_architectures.losses import l1_cd

loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
ply_dir = OUTPUT_DIR / 'ply'
ply_dir.mkdir(parents=True, exist_ok=True)

n_export = 5
metrics = []

with torch.no_grad():
    for idx, batch in enumerate(loader):
        partial = batch['partial_pc'].to(DEVICE).permute(0, 2, 1)
        gt = batch['pc'].to(DEVICE)
        filenames = batch['filename']

        coarse, fine, *_ = model(partial)
        pred = fine  # shape [B, N, 3]
        l1 = l1_cd(pred, gt).item()
        metrics.append(l1)

        pred_np = pred.cpu().numpy()[0]
        if idx < n_export:
            out_path = ply_dir / f"{idx}_{pathlib.Path(filenames[0]).stem}.ply"
            pc = o3d.geometry.PointCloud()
            pc.points = o3d.utility.Vector3dVector(pred_np)
            o3d.io.write_point_cloud(str(out_path), pc, write_ascii=True)

print(f"Exported {min(n_export, len(loader))} predictions to {ply_dir.resolve()}")
if metrics:
    print(f"Mean Chamfer L1 (demo subset): {np.mean(metrics):.4f}")

