# SPIDS Quickstart Tutorial

Welcome to SPIDS (Sparse Phase Imaging by Diffraction Spectroscopy)!

This tutorial will walk you through:
1. Setting up a basic reconstruction
2. Understanding the algorithm
3. Visualizing results
4. Interpreting metrics

**Estimated time**: 15-20 minutes

## 1. Setup and Imports

In [None]:
# Standard imports
# SPIDS imports
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch


sys.path.insert(0, "../..")  # Add parent directory to path

from prism.config.objects import get_object_params
from prism.models.networks import GenCropSpidsNet
from prism.utils.sampling import fermat_spiral_points


# Notebook settings
%matplotlib inline
plt.rcParams["figure.figsize"] = (12, 8)
plt.rcParams["figure.dpi"] = 100

print("✓ Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Choose Your Target Object

SPIDS supports several astronomical objects with realistic physical parameters:

In [None]:
# Available objects
objects = ["europa", "titan", "betelgeuse", "neptune"]

# Select one
obj_name = "europa"  # <-- Change this to try different objects

# Get object parameters
obj_params = get_object_params(obj_name)

print(f"Selected: {obj_name.capitalize()}")
print("\nPhysical parameters:")
for key, value in obj_params.items():
    print(f"  {key}: {value}")

## 3. Configure the Experiment

Set the key parameters:

In [None]:
# Experiment configuration
config = {
    "image_size": 256,  # Resolution (smaller = faster)
    "n_samples": 30,  # Number of measurements (more = better quality)
    "max_epochs": 5,  # Training epochs per sample (more = slower)
    "aperture_diameter": 32,  # Telescope aperture size in pixels
}

# Derived parameters
obj_size = obj_params["size"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")
print(f"\nUsing device: {device}")

## 4. Generate Sampling Pattern

Visualize where the telescope will take measurements:

In [None]:
# Generate Fermat spiral sampling points
sample_points = fermat_spiral_points(n=config["n_samples"], radius=config["image_size"] // 4)

# Visualize
plt.figure(figsize=(8, 8))
plt.scatter(sample_points[:, 0], sample_points[:, 1], c=range(len(sample_points)), cmap="viridis")
plt.colorbar(label="Sample order")
plt.title(f"Sampling Pattern: Fermat Spiral ({len(sample_points)} points)")
plt.xlabel("X position")
plt.ylabel("Y position")
plt.axis("equal")
plt.grid(True, alpha=0.3)
plt.show()

print(f"Generated {len(sample_points)} sampling points")

## 5. Create the Model

Initialize the generative neural network:

In [None]:
# Create model
model = GenCropSpidsNet(
    obj_size=obj_size,
    image_size=config["image_size"],
    latent_dim=512,
).to(device)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())

print("Model created:")
print(f"  Parameters: {n_params:,}")
print(f"  Input size: {config['image_size']}x{config['image_size']}")
print(f"  Output size: {obj_size}x{obj_size}")

# Test forward pass
with torch.no_grad():
    test_output = model()
print(f"  Output shape: {test_output.shape}")

## 6. Simplified Training Loop

**Note**: This is a simplified version for tutorial purposes. Real experiments use the full trainer from `spids.core.trainers`.

In [None]:
# For tutorial: just show one training iteration
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Dummy training
losses = []
for epoch in range(20):
    optimizer.zero_grad()

    output = model()

    # Dummy loss (real version uses telescope measurements)
    loss = output.mean()

    loss.backward()
    optimizer.step()

    losses.append(loss.item())

# Plot training curve
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Progress (Simplified Demo)")
plt.grid(True, alpha=0.3)
plt.show()

print("Training complete!")

## 7. Visualize Results

In [None]:
# Get final reconstruction
with torch.no_grad():
    reconstruction = model().cpu()

# Visualize (taking magnitude if complex)
if reconstruction.shape[1] == 2:  # Complex (real + imaginary)
    magnitude = torch.sqrt(reconstruction[:, 0] ** 2 + reconstruction[:, 1] ** 2)
    phase = torch.atan2(reconstruction[:, 1], reconstruction[:, 0])
else:
    magnitude = reconstruction[:, 0]
    phase = None

# Plot
fig, axes = plt.subplots(1, 2 if phase is not None else 1, figsize=(12, 5))

if phase is not None:
    # Magnitude
    im1 = axes[0].imshow(magnitude[0], cmap="viridis")
    axes[0].set_title("Magnitude")
    axes[0].axis("off")
    plt.colorbar(im1, ax=axes[0])

    # Phase
    im2 = axes[1].imshow(phase[0], cmap="twilight", vmin=-np.pi, vmax=np.pi)
    axes[1].set_title("Phase")
    axes[1].axis("off")
    plt.colorbar(im2, ax=axes[1])
else:
    im = axes.imshow(magnitude[0], cmap="viridis")
    axes.set_title("Reconstruction")
    axes.axis("off")
    plt.colorbar(im, ax=axes)

plt.tight_layout()
plt.show()

## 8. Next Steps

Try modifying:
- **Object**: Change `obj_name` to 'titan', 'betelgeuse', or 'neptune'
- **Samples**: Increase `n_samples` for better quality
- **Epochs**: Increase `max_epochs` for more training
- **Resolution**: Try larger `image_size` (warning: slower)

For full experiments, use the command-line interface:
```bash
cd ../..
uv run python main.py --obj_name europa --n_samples 100 --fermat --name my_experiment
```

Or use the Python API (see `examples/python_api/`).

## Summary

You've learned:
- ✓ How to set up SPIDS
- ✓ Generate sampling patterns
- ✓ Create and train models
- ✓ Visualize results

**Next tutorials**:
- Tutorial 2: Pattern design and comparison
- Tutorial 3: Result analysis and metrics