# CryoLens Reconstruction from NDJSON Coordinates

This notebook loads particle coordinates from an NDJSON file and generates reconstructions using CryoLens.

In [1]:
import json
import numpy as np
import torch
from pathlib import Path
import copick

from cryolens.utils.checkpoint_loading import load_vae_model
from cryolens.inference.pipeline import InferencePipeline
from cryolens.data.copick import extract_particles_from_tomogram

## Configuration

In [2]:
# Paths
NDJSON_PATH = Path("/mnt/czi-sci-ai/imaging-models/kyle/git/czi-ai/cryolens/notebooks/try_cryolens_ribosome_results.ndjson")
CHECKPOINT_PATH = Path("/mnt/czi-sci-ai/imaging-models/cryolens/mlflow/outputs/alternating_curriculum/cryolens-sim-015/checkpoints/model_epoch_2600_train_loss_6.436.pt")
COPICK_CONFIG = Path("/mnt/czi-sci-ai/imaging-models/data/cryolens/mlc/copick_czcdp/ml_challenge_experimental_only.json")
STRUCTURES_DIR = Path("/mnt/czi-sci-ai/imaging-models/data/cryolens/mlc/structures/mrcs")
OUTPUT_DIR = Path("./reconstructions")

# Parameters
VOXEL_SIZE = 10.0  # Angstroms
BOX_SIZE = 48  # Voxels

## Load NDJSON Coordinates

In [3]:
# Load coordinates from NDJSON
coordinates = []
with open(NDJSON_PATH) as f:
    for line in f:
        data = json.loads(line)
        coordinates.append(data)

print(f"Loaded {len(coordinates)} coordinates")
print(f"Run name: {coordinates[0]['run_name']}")
print(f"Structure: {coordinates[0]['prediction']}")

Loaded 10 coordinates
Run name: TS_100_3
Structure: ribosome


## Load Copick Project and Find Matching Run

In [7]:
root.runs[0].meta

CopickRunMetaCDP(name='16463', portal_run_id=16463, portal_run_name='TS_5_4')

In [11]:
# Load Copick project
root = copick.from_file(str(COPICK_CONFIG))

# Extract run name from NDJSON (e.g., "TS_100_3")
target_run_name = coordinates[0]['run_name']
print(f"Looking for run: {target_run_name}")

# Find matching run - the run name in Copick might have a different format
# Try to match based on the tomogram identifier in the run's metadata
matching_run = None
for run in root.runs:
    # Check if run name directly matches
    if run.meta.portal_run_name == target_run_name:
        matching_run = run
        break

if matching_run is None:
    print(f"Available runs: {[r.name for r in root.runs][:10]}")
    raise ValueError(f"Could not find run matching '{target_run_name}'")

print(f"Found matching run: {matching_run.name}")

Looking for run: TS_100_3
Found matching run: 17682


## Load Tomogram

In [12]:
import zarr

# Find voxel spacing closest to target
best_vs = None
best_diff = float('inf')
for vs in matching_run.voxel_spacings:
    diff = abs(vs.voxel_size - VOXEL_SIZE)
    if diff < best_diff:
        best_diff = diff
        best_vs = vs

print(f"Using voxel spacing: {best_vs.voxel_size}Å")

# Get tomogram (prefer denoised)
tomograms = list(best_vs.tomograms)
tomogram = None
for tomo in tomograms:
    if hasattr(tomo, 'tomo_type') and tomo.tomo_type == 'denoised':
        tomogram = tomo
        break
if tomogram is None:
    tomogram = tomograms[0]

# Load tomogram data
tomo_zarr = zarr.open(tomogram.zarr(), mode='r')
for key in ['0', 's0', 'data']:
    if key in tomo_zarr:
        tomogram_data = np.array(tomo_zarr[key])
        break
else:
    tomogram_data = np.array(tomo_zarr[list(tomo_zarr.keys())[0]])

print(f"Tomogram shape: {tomogram_data.shape}")

Using voxel spacing: 10.012Å
Tomogram shape: (184, 630, 630)


## Extract Particles at Coordinates

In [13]:
# Extract positions from NDJSON
positions = []
for coord in coordinates:
    loc = coord['location']
    positions.append((loc['x'], loc['y'], loc['z']))

print(f"Extracting {len(positions)} particles...")

# Extract particles
particles = extract_particles_from_tomogram(
    tomogram_data=tomogram_data,
    positions=positions,
    voxel_spacing=best_vs.voxel_size,
    box_size=BOX_SIZE,
    normalize=True
)

particles = np.array(particles)
print(f"Extracted particles shape: {particles.shape}")

Extracting 10 particles...
Extracted particles shape: (9, 48, 48, 48)


## Load Model and Generate Reconstructions

In [14]:
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load model
model, config = load_vae_model(
    CHECKPOINT_PATH,
    device=device,
    load_config=True,
    strict_loading=False
)
model.eval()
print("Model loaded")

# Create pipeline
pipeline = InferencePipeline(
    model=model,
    device=device,
    normalization_method=config.get('normalization', 'z-score')
)

Using device: cuda


2025-10-23 10:27:29 - cryolens.utils.checkpoint_loading - INFO - Removing 1 coordinate/renderer buffers for compatibility
2025-10-23 10:27:29 - cryolens.utils.checkpoint_loading - INFO - Inferred from checkpoint: num_splats=768, latent_ratio=0.800
2025-10-23 10:27:29 - cryolens.utils.checkpoint_loading - INFO - Loaded training parameters from: /mnt/czi-sci-ai/imaging-models/cryolens/mlflow/outputs/alternating_curriculum/cryolens-sim-015/training_params.json
2025-10-23 10:27:29 - cryolens.utils.checkpoint_loading - INFO - Applied training parameters from config file
2025-10-23 10:27:30 - cryolens.utils.checkpoint_loading - INFO - Model loaded successfully from /mnt/czi-sci-ai/imaging-models/cryolens/mlflow/outputs/alternating_curriculum/cryolens-sim-015/checkpoints/model_epoch_2600_train_loss_6.436.pt


Model loaded


In [16]:
# %%
# Generate reconstructions using InferencePipeline
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

all_reconstructions = []
all_embeddings = []

print(f"Processing {len(particles)} particles...")

for i, particle in enumerate(particles):
    # Use the pipeline to process each particle
    result = pipeline.process_volume(
        particle,
        return_embeddings=True,
        return_reconstruction=True
    )
    
    all_reconstructions.append(result['reconstruction'])
    all_embeddings.append(result['embeddings'])
    
    if (i + 1) % 10 == 0:
        print(f"  Processed {i + 1}/{len(particles)} particles")

print(f"Generated {len(all_reconstructions)} reconstructions")

# Stack and average reconstructions
reconstructions = np.array(all_reconstructions)
avg_reconstruction = reconstructions.mean(axis=0)

# Save as MRC
import mrcfile
output_path = OUTPUT_DIR / f"{target_run_name}_ribosome_reconstruction.mrc"
with mrcfile.new(str(output_path), overwrite=True) as mrc:
    mrc.set_data(avg_reconstruction.astype(np.float32))
    mrc.voxel_size = best_vs.voxel_size

print(f"Saved reconstruction to: {output_path}")

# Also save embeddings for later analysis
embeddings_array = np.array(all_embeddings)
np.save(OUTPUT_DIR / f"{target_run_name}_embeddings.npy", embeddings_array)
print(f"Saved embeddings: {OUTPUT_DIR / f'{target_run_name}_embeddings.npy'}")

Processing 9 particles...
Generated 9 reconstructions
Saved reconstruction to: reconstructions/TS_100_3_ribosome_reconstruction.mrc
Saved embeddings: reconstructions/TS_100_3_embeddings.npy


## Visualize (Optional)

In [None]:
# %% [markdown]
# ## Visualize (Optional)

# %%
import matplotlib.pyplot as plt

# Show central slices
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Handle different possible shapes
if avg_reconstruction.ndim == 4:
    vol = avg_reconstruction[0]  # Remove batch dimension
elif avg_reconstruction.ndim == 3:
    vol = avg_reconstruction
else:
    raise ValueError(f"Unexpected reconstruction shape: {avg_reconstruction.shape}")

center = vol.shape[0] // 2

axes[0].imshow(vol[center, :, :], cmap='gray')
axes[0].set_title('XY slice')
axes[0].axis('off')

axes[1].imshow(vol[:, center, :], cmap='gray')
axes[1].set_title('XZ slice')
axes[1].axis('off')

axes[2].imshow(vol[:, :, center], cmap='gray')
axes[2].set_title('YZ slice')
axes[2].axis('off')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / f"{target_run_name}_slices.png", dpi=150)
plt.show()

print(f"Saved slices to: {OUTPUT_DIR / f'{target_run_name}_slices.png'}")