# Network Training Tutorial

This notebook walks through training the PCN point-cloud completion network on a small synthetic dataset generated with the DNA origami simulator. The goal is to demonstrate the end-to-end workflow: environment setup, data preparation, model training, and qualitative evaluation.

## 1. Environment setup

Run the following cell if you are executing this notebook in a fresh environment (for example, Google Colab). It checks whether the repository is available locally, clones it when needed, and installs the required Python dependencies.

In [None]:
import os
import sys
import pathlib

IN_COLAB = 'google.colab' in sys.modules
START_DIR = pathlib.Path.cwd()

if IN_COLAB:
    if not (START_DIR / 'smlm').exists():
        !git clone https://github.com/dianamindroc/smlm.git
    START_DIR = START_DIR / 'smlm'

for candidate in [START_DIR, *START_DIR.parents]:
    if (candidate / 'setup.py').exists():
        PROJECT_ROOT = candidate
        break
else:
    raise RuntimeError('Could not locate the repository root. Please ensure you run this notebook from within the smlm project.')

os.chdir(PROJECT_ROOT)
print(f'Working directory set to: {PROJECT_ROOT}')

if IN_COLAB:
    %pip install -q .


## 2. Imports and helper functions

We load the simulator dataset, data transforms, the PCN model, and Chamfer distance loss. A helper plotting function is included for later visualisation.

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt

from dataset.SMLMSimulator import DNAOrigamiSimulator
from model_architectures import pcn, losses
from model_architectures.transforms import Padding, ToTensor
from helpers.misc import set_seed

set_seed(42)

def plot_point_cloud_triplet(gt, partial, reconstructed, title_suffix=''):
    fig = plt.figure(figsize=(15, 4))
    entries = [
        (gt, 'Ground truth'),
        (partial, 'Input (partial)'),
        (reconstructed, 'Reconstruction')
    ]
    for idx, (cloud, title) in enumerate(entries):
        ax = fig.add_subplot(1, 3, idx + 1, projection='3d')
        ax.scatter(cloud[:, 0], cloud[:, 1], cloud[:, 2], s=4)
        ax.set_title(f"{title} {title_suffix}")
        ax.set_axis_off()
    plt.tight_layout()


## 3. Build synthetic training and validation datasets

We use the DNA Origami simulator to generate paired complete and partial point clouds. Padding guarantees a fixed number of points per sample so the model can process mini-batches.

In [None]:
# Define dye properties for the simulator
dye_properties = {
    'density_range': (10, 40),
    'blinking_times_range': (10, 30),
    'intensity_range': (500, 2500),
    'precision_range': (0.5, 2.0)
}

# Number of points after padding (kept modest for a fast tutorial run)
MAX_POINTS = 384
transform_pipeline = transforms.Compose([Padding(MAX_POINTS), ToTensor()])

full_dataset = DNAOrigamiSimulator(
    num_samples=240,
    structure_type='box',
    dye_properties=dye_properties,
    augment=True,
    remove_corners=True,
    transform=transform_pipeline
)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

BATCH_SIZE = 8
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Training batches: {len(train_loader)}, validation batches: {len(val_loader)}")


## 4. Initialise the PCN model and optimizer

We keep the architecture compact (reduced latent size and density) to ensure the tutorial runs quickly on CPU or a small GPU.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

NUM_DENSE = MAX_POINTS
LATENT_DIM = 512
GRID_SIZE = 2

model = pcn.PCN(num_dense=NUM_DENSE, latent_dim=LATENT_DIM, grid_size=GRID_SIZE, classifier=False, channels=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

param_count = sum(p.numel() for p in model.parameters())
print(f'Model parameters: {param_count / 1_000_000:.2f} M')


## 5. Training and validation loops

Chamfer distance is minimised on both coarse and fine outputs. We track the average loss per epoch for monitoring.

In [None]:
alpha = 0.1

def chamfer_l1(pred, target):
    try:
        return losses.cd_loss_l1(pred, target)
    except RuntimeError as exc:
        if 'chamfer' not in str(exc).lower():
            raise
        # Fallback implementation using torch.cdist for CPU-only environments
        dist = torch.cdist(pred, target)  # [B, N, M]
        min_pred = dist.min(dim=-1)[0]
        min_target = dist.min(dim=-2)[0]
        return (min_pred.mean() + min_target.mean()) / 2

def run_epoch(loader, train=True):
    running_loss = 0.0
    num_samples = 0
    if train:
        model.train()
    else:
        model.eval()

    for batch in loader:
        target_full = batch['pc'].to(device)
        partial_input = batch['partial_pc'].to(device)

        partial_input = partial_input.permute(0, 2, 1)

        if train:
            optimizer.zero_grad()

        with torch.set_grad_enabled(train):
            coarse, fine, _, _ = model(partial_input)
            loss_fine = chamfer_l1(fine, target_full)
            loss_coarse = chamfer_l1(coarse, target_full)
            loss = loss_fine + alpha * loss_coarse

            if train:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

        batch_size = partial_input.size(0)
        running_loss += loss.item() * batch_size
        num_samples += batch_size

    return running_loss / max(1, num_samples)

NUM_EPOCHS = 5
train_history, val_history = [], []

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss = run_epoch(train_loader, train=True)
    val_loss = run_epoch(val_loader, train=False)
    train_history.append(train_loss)
    val_history.append(val_loss)
    print(f"Epoch {epoch:02d} | train loss: {train_loss:.4f} | val loss: {val_loss:.4f}")


## 6. Loss curves

In [None]:
plt.figure(figsize=(6, 4))
plt.plot(train_history, label='train')
plt.plot(val_history, label='validation')
plt.xlabel('Epoch')
plt.ylabel('Chamfer loss')
plt.title('Training progress')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.4)
plt.show()


## 7. Qualitative inspection

We visualise a validation sample, comparing the ground truth, the partial input, and the reconstructed point cloud.

In [None]:
model.eval()
example_batch = next(iter(val_loader))
partial_example = example_batch['partial_pc'].to(device)
full_example = example_batch['pc'].to(device)

with torch.no_grad():
    _, fine_pred, _, _ = model(partial_example.permute(0, 2, 1))

reconstructed = fine_pred[0].cpu().numpy()
partial_np = partial_example[0].cpu().numpy()
full_np = full_example[0].cpu().numpy()

plot_point_cloud_triplet(full_np, partial_np, reconstructed)


## 8. Save the trained weights (optional)

In [None]:
OUTPUT_DIR = pathlib.Path('artifacts')
OUTPUT_DIR.mkdir(exist_ok=True)
model_path = OUTPUT_DIR / 'pcn_demo_weights.pth'
torch.save(model.state_dict(), model_path)
print('Model checkpoint saved to', model_path.resolve())


## Next steps

- Replace the simulator with experimental datasets by pointing the configuration to your data directories.
- Integrate logging (e.g. Weights & Biases) for richer monitoring.
- Extend the model with classification heads or alternative loss functions, as done in the full training scripts under `scripts/`.