![Neural NMR Hero](assets/nmr_hero.png)

# üß† Multi-Modal AI Research: The Structural-Magnetic Bridge ‚öõÔ∏è

**Objective**: Learn how to generate synchronized datasets of 3D Structural Tensors and Experimental observables (NMR Chemical Shifts) for Multi-Modal AI training.

### üåü The Vision: "AlphaFold-NMR"
In modern structural biology, 3D coordinates are only half the story. Real experimental verification often comes from **Nuclear Magnetic Resonance (NMR)**. NMR chemical shifts are incredibly sensitive to the local electronic environment‚Äîmeaning every atom's magnetic frequency is a "fingerprint" of the local geometry.

In this lab, we build an end-to-end pipeline that treats the protein as both a **Geometric Object** and a **Magnetic Observable**. This data is used to train models that can:
1. **Back-Calculate**: Predict NMR shifts from structure.
2. **De-Novo Solve**: Predict structure directly from chemical shifts.

In [None]:
# @title Setup & Installation { display-mode: "form" }
import os
import sys
from pathlib import Path

try:
    current_path = Path(".").resolve()
    repo_root = current_path.parent.parent 
    if (repo_root / "synth_pdb").exists():
        if str(repo_root) not in sys.path:
            sys.path.insert(0, str(repo_root))
            print(f"üìå Added local library to path: {repo_root}")
except Exception:
    pass

if 'google.colab' in str(get_ipython()):
    if not os.path.exists("installed.marker"):
        print("Running on Google Colab. Installing dependencies...")
        get_ipython().run_line_magic('pip', 'install synth-pdb torch numpy matplotlib py3Dmol biotite')
        
        with open("installed.marker", "w") as f:
            f.write("done")
        
        print("üîÑ Installation complete. KERNEL RESTARTING AUTOMATICALLY...")
        os.kill(os.getpid(), 9)
    else:
        print("‚úÖ Dependencies Ready.")
else:
    import synth_pdb
    print(f"‚úÖ Running locally. Using synth-pdb version: {synth_pdb.__version__}")

In [None]:
import numpy as np
import time
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import py3Dmol
import biotite.structure as struc
from synth_pdb.batch_generator import BatchedGenerator
from synth_pdb.chemical_shifts import predict_chemical_shifts, calculate_csi
from synth_pdb.generator import generate_pdb_content

print("Magnetic Resonance Engine: ONLINE ‚ö°")

## 1. Synchronized Generation: The Coords-Shift Tensor

We will generate a batch of structures with significant structural drift, then compute the resulting chemical shifts for every atom. This creates a paired dataset: `(X, Y) = (Coordinates, NMR Shifts)`.

In [None]:
# FIX: Use explicit hyphenation to avoid 'ASPTRP' merging errors
sequence = "-".join(["TRP-PHE-TYR-HIS-LYS-GLU-ASP"] * 3) # 21 residues, rich in Aromatics
n_samples = 100

print(f"üöÄ Generating {n_samples} synchronized multi-modal structural samples...")

generator = BatchedGenerator(sequence, n_batch=n_samples, full_atom=True)
batch = generator.generate_batch(drift=2.0) 

print("‚úÖ Structural Tensors Generated.")
print("‚ö° Predicting Chemical Shifts (SPARTA-Lite + Ring Currents)...")

all_shifts = []
for i in range(5): # We'll analyze the first 5 in detail for the demo
    # Convert batch member to biotite structure for the NMR engine
    pdb_str = batch.to_pdb(i)
    from io import StringIO
    import biotite.structure.io.pdb as pdb_io
    struct = pdb_io.PDBFile.read(StringIO(pdb_str)).get_structure(model=1)
    
    shifts = predict_chemical_shifts(struct)
    all_shifts.append(shifts)

print(f"‚úÖ Paired Data Ready. Sample 0 Chain A Res 1 chemical shifts: {all_shifts[0]['A'][1]}")

## 2. Fold Recognition: The CSI Plot

The **Chemical Shift Index (CSI)** is the deviation of an atom's frequency from its "Random Coil" baseline. 

- **Alpha Helices**: Move C-alpha shifts **Upfield** (+ ppm).
- **Beta Sheets**: Move C-alpha shifts **Downfield** (- ppm).

Let's visualize this footprint for our first sample.

In [None]:
# Extract CA secondary shifts for Sample 0
sample_idx = 0
res_ids = sorted(all_shifts[sample_idx]['A'].keys())
ca_deltas = [all_shifts[sample_idx]['A'][r].get('CA', 0) - 52.5 for r in res_ids] # Relative to generic Ala baseline for visual

plt.figure(figsize=(12, 4))
plt.bar(res_ids, ca_deltas, color='#9b59b6', alpha=0.7, label="Delta-CA (Secondary Shift)")
plt.axhline(0.7, color='red', linestyle='--', alpha=0.3, label="Helix Threshold")
plt.axhline(-0.7, color='blue', linestyle='--', alpha=0.3, label="Sheet Threshold")
plt.title("The Magnetic Footprint of Protein Folding")
plt.xlabel("Residue Number")
plt.ylabel("CSI Deviation (ppm)")
plt.legend()
plt.grid(alpha=0.2)
plt.show()

print("Educational Insight: Note how consistent positive deviations signal a stable secondary structure.")

## 3. Visualizing Ring Current Effects (Tertiary Proximity)

Aromatic rings (Phe, Tyr, Trp) act like tiny electromagnets. Atoms that get too close to the "face" of the ring are shielded and shift toward lower frequencies. This is how NMR "sees" tertiary packing.

In [None]:
view = py3Dmol.view(width=800, height=400)
view.setBackgroundColor("#fdfdfd")

pdb_str = batch.to_pdb(0)

# 1. Highlight the Aromatic Rings
view.addModel(pdb_str, 'pdb')
view.setStyle({'model': 0}, {'cartoon': {'color': '#667eea', 'opacity': 0.6}})
view.setStyle({'resn': ['PHE', 'TYR', 'TRP']}, {'stick': {'radius': 0.25, 'color': '#ffcc00'}})

# 2. Show the "Magnetic Cloud"
# We'll put a surface around aromatics to visualize the 'Influence Zone'
view.addSurface(py3Dmol.MS, {'opacity': 0.2, 'color': '#ffcc00'}, {'resn': ['PHE', 'TYR', 'TRP']})

view.zoomTo()
view.center()
view.show()

print("Yellow regions indicate Aromatic hubs that distort the local magnetic field of nearby nuclei.")

## 4. Multi-Modal PyTorch Pipeline

Finally, we combine both signals into a single high-performance `DataLoader`. Every sample is a tuple of `(Geometry, NMR)`.

In [None]:
class MultiModalProteinDataset(Dataset):
    def __init__(self, coords, shifts_list):
        self.coords = torch.from_numpy(coords).float()
        # Tensorize the 'CA' and 'HA' shifts as features
        n_samples = len(shifts_list)
        n_res = coords.shape[1] // 4 # Approximate for backbone clusters
        
        self.nmr_features = torch.zeros((n_samples, n_res, 2)) # [CA_shift, HA_shift]
        
        for i in range(n_samples):
            # Only use chain A for the demo
            if 'A' not in shifts_list[i]: continue
            s = shifts_list[i]['A']
            sorted_keys = sorted(s.keys())
            for r_idx, r_id in enumerate(sorted_keys):
                if r_idx < n_res:
                    self.nmr_features[i, r_idx, 0] = s[r_id].get('CA', 0.0)
                    self.nmr_features[i, r_idx, 1] = s[r_id].get('HA', 0.0)
                    
    def __len__(self):
        return len(self.coords)
        
    def __getitem__(self, idx):
        return self.coords[idx], self.nmr_features[idx]

# Create the synchronized dataset
ds = MultiModalProteinDataset(batch.coords[:5], all_shifts)
loader = DataLoader(ds, batch_size=2, shuffle=True)

batch_coords, batch_nmr = next(iter(loader))
print(f"‚úÖ Multi-Modal Batch Data Ready.")
print(f"Geometry Shape: {batch_coords.shape}")
print(f"NMR Tensor Shape: {batch_nmr.shape} (Input for Transformer Encoded Shifts)")

### üèÜ Next Steps
1. **Predicting Reality**: Try generating structures with `--conformation beta` and see how the **CSI Plot** flips! üìâ
2. **Transformer Training**: Feed the `batch_nmr` tensor into a 1D Transformer to see if it can recover secondary structure labels.

You are now generating the same type of data used to train the next generation of experimental AI solvers. **The lab is yours.** üß¨ü§ñ