# 🌌 AI Latent Space Explorer
### Visualizing How Protein AI Models "See" Structural Diversity

Modern protein folding models like AlphaFold or trRosetta don't just look at 3D coordinates. They process **2D inter-residue relationships** (6D Orientograms). 

In this tutorial, we will:
1. Use the `BatchedGenerator` to create 500 unique protein conformations in parallel.
2. Transform these 3D structures into **6D trRosetta Orientograms**.
3. Use **Dimensionality Reduction (PCA)** to map the vast "Latent Space" of protein chemistry into a 2D interactive galaxy.

In [None]:
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
from ipywidgets import interact, IntSlider
import os, sys, numpy as np
import py3Dmol
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import plotly.graph_objects as go
import plotly.io as pio

# --- UNIVERSAL SETUP ---
pio.renderers.default = 'vscode'
sys.path.append(os.path.abspath('../../'))
from synth_pdb import PeptideGenerator, EnergyMinimizer, PDBValidator, PeptideResult
from synth_pdb.batch_generator import BatchedGenerator
import biotite.structure as struc

# Robust 3Dmol.js injection
display(HTML('<script src="https://3Dmol.org/build/3Dmol-min.js"></script>'))
print("[✅] Environment Ready.")

## 1. Mass-Generating Proteins
We'll generate a batch of 500 structures for a small peptide sequence. `BatchedGenerator` uses vectorized math to build these in milliseconds.

In [None]:
sequence = "TRP-SER-GLY-ALA-VAL-PRO-ILE"
n_batch = 500

print(f"Generating {n_batch} structures...")
bg = BatchedGenerator(sequence, n_batch=n_batch)
batch = bg.generate_batch() # This is the GPU-accelerated engine

print(f"Shape of coordinates: {batch.coords.shape} (Batch, Atoms, XYZ)")

## 2. Computing 6D Orientograms
For every pair of residues in every protein, we calculate:
- **Distance** (d)
- **Omega** (Torsional orientation)
- **Theta** (Angular orientation)
- **Phi** (Angular orientation)

This is exactly what AI models predict before folding.

In [None]:
print("Computing 6D Orientograms...")
orients = batch.get_6d_orientations()

# orients is a dict of arrays [Batch, N, N]
print(f"Computed orientations for {n_batch} structures.")
print(f"Keys available: {list(orients.keys())}")

## 3. Latent Space Projection (PCA)
We flatten these 2D geometry maps into high-dimensional vectors and project them into 2D.

In [None]:
# Flatten all features
feature_vector = np.concatenate([
    orients['dist'].reshape(n_batch, -1),
    orients['omega'].reshape(n_batch, -1),
    orients['theta'].reshape(n_batch, -1),
    orients['phi'].reshape(n_batch, -1)
], axis=1)

print(f"Feature vector size: {feature_vector.shape[1]}")

# Run PCA
pca = PCA(n_components=2)
latent_points = pca.fit_transform(feature_vector)

print("Projection complete.")

## 4. Interactive Latent Space Explorer
Click and browse the "Galaxy" of proteins.

In [None]:
fig = go.Figure(data=[go.Scatter(
    x=latent_points[:, 0],
    y=latent_points[:, 1],
    mode='markers',
    marker=dict(
        size=8,
        color=np.arange(n_batch), # Color by index to see diversity
        colorscale='Viridis',
        showscale=True
    ),
    text=[f"Protein ID: {i}" for i in range(n_batch)]
)])

fig.update_layout(
    title='Protein Latent Space (PCA of 6D Maps)',
    xaxis_title='Principal Component 1',
    yaxis_title='Principal Component 2',
    width=800, height=600,
    template="plotly_dark"
)

fig

In [None]:
def view_from_latent(index):
    view = py3Dmol.view(width=400, height=400)
    
    # Get coords for this specific batch item
    coords = batch.coords[index]
    
    # Create a simple PDB from the coordinates for this specific demo
    from synth_pdb.generator import PeptideGenerator
    pgen = PeptideGenerator(sequence)
    res = pgen.generate()
    
    # In this demo, we use backbone-only generator by default
    # But if bg was full_atom=True, we'd have fa_coords.
    # For simplicity, we assume the tutorial stays backbone-only or the user
    # can experiment by changing bg = BatchedGenerator(..., full_atom=True).
    if res.structure.array_length() == coords.shape[0]:
        res.structure.coord = coords
    else:
        # If mismatch (e.g. Fa vs Bb), just show the generated structure to avoid crash
        pass
    
    view.addModel(res.pdb, "pdb")
    view.setStyle({'stick': {'color': 'spectrum'}, 'sphere': {'scale': 0.3}})
    view.zoomTo()
    
    # Also show the distance map for this structure
    fig_map, ax = plt.subplots(1, 1, figsize=(3, 3))
    ax.imshow(orients['dist'][index], cmap='magma')
    ax.set_title("Distance Map (AI View)")
    ax.axis('off')
    plt.show()
    
    return view.show()

interact(view_from_latent, index=IntSlider(min=0, max=n_batch-1, step=1, value=0));