# 🧩 Noisy Particle Segmentation with Autoencoder
This notebook demonstrates:
- Generating the noisy dataset
- Training the convolutional autoencoder
- Visualizing reconstructions
- Exploring the latent space with PCA

In [None]:
# Install dependencies (if needed)
!pip install torch torchvision matplotlib scikit-image scikit-learn

In [None]:
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from src.dataset import NoisySegDataset
from src.model import Autoencoder1D
from src.utils import weighted_mse
from src.visualize import visualize_reconstructions, visualize_latent_space

## 1. Create dataset

In [None]:
dataset = NoisySegDataset(n_samples=500, img_size=128)
print("Dataset size:", len(dataset))

# Show a sample
x, y = dataset[0]
plt.imshow(x.squeeze(), cmap='gray')
plt.title("Example Noisy Input")
plt.axis('off')
plt.show()

## 2. Train Autoencoder

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Autoencoder1D(latent_dim=512).to(device)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

epochs = 20  # keep small for demo
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = weighted_mse(out, y, weight=20.0)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs} Loss: {total_loss/len(loader):.4f}")

## 3. Visualize Reconstructions

In [None]:
visualize_reconstructions(model, dataset, device, n=3)

## 4. Latent Space Visualization

In [None]:
visualize_latent_space(model, dataset, device)

## 5. Latent Sampling Playground 🎨

In [None]:
import numpy as np
from sklearn.decomposition import PCA

# 1. Collect latent vectors from training set
loader = DataLoader(dataset, batch_size=16, shuffle=False)
latents = []
model.eval()
with torch.no_grad():
    for x, _ in loader:
        x = x.to(device)
        z = model.encode(x)
        latents.append(z.cpu().numpy())
latents = np.vstack(latents)

# 2. Center of mass
center = latents.mean(axis=0)

# 3. PCA for 2D exploration
pca = PCA(n_components=2)
latents_2d = pca.fit_transform(latents)
center_2d = pca.transform(center.reshape(1, -1))

# 4. Sample random vectors around the center
r = 5.0  # distance in PCA space
num_samples = 4
plt.figure(figsize=(10, 3))

for i in range(num_samples):
    theta = np.random.rand() * 2 * np.pi
    vec_2d = np.array([np.cos(theta), np.sin(theta)]) * r
    vec_full = vec_2d @ pca.components_  # map back to latent_dim
    z_new = center + vec_full

    z_new_tensor = torch.tensor(z_new, dtype=torch.float32).unsqueeze(0).to(device)
    with torch.no_grad():
        generated = model.decode(z_new_tensor).cpu().squeeze().numpy()

    plt.subplot(1, num_samples, i+1)
    plt.imshow(generated, cmap="gray")
    plt.title(f"Sample {i+1}")
    plt.axis("off")

plt.suptitle("Generated Shapes from Latent Sampling", fontsize=14)
plt.show()

🎉 Done! You can also sample new latent vectors manually using PCA directions for generative exploration.