## Parsing AlaScan Output Files

### What This Does

Parses FoldX AlaScan results and matches ΔΔG values to specific residues in the PDB structure using **line-by-line correspondence**.

### Key Insight: Line-by-Line Matching

FoldX AlaScan outputs ΔΔG values in the **same order** as residues appear in the PDB file. This function exploits this ordering to correctly assign ΔΔG values to residues without relying on residue numbers (which can be ambiguous).

### How It Works

**Step 1: Parse AlaScan File** (`.fxout`)
```
Example lines:
LEU 479 to ALA energy change is 1.19627
VAL 480 to ALA energy change is 0.85432
ARG 481 to ALA energy change is 2.34567
...
```
- Extract ΔΔG values in order
- Store in list: `[1.19627, 0.85432, 2.34567, ...]`
- Ignore comment lines starting with `#`

**Step 2: Parse PDB Structure**
- Iterate through structure: model → chain → residue
- Only standard amino acids (residue.id[0] == ' ')
- Extract: residue name, chain ID, residue number
- Build list in **PDB order**

**Step 3: Match by Index**
- AlaScan line 1 → PDB residue 1
- AlaScan line 2 → PDB residue 2
- ...and so on
- Assumes perfect 1:1 correspondence

### Input

- **alascan_file**: FoldX AlaScan output file (`*_AS.fxout`)
  - Format: `RESNAME RESNUM to ALA energy change is DDG`
- **pdb_file**: Corresponding PDB structure
  - Must be the **exact** structure used for AlaScan calculation

### Output

**DataFrame with columns:**
- `residue_name`: 3-letter amino acid code (e.g., LEU, VAL, ARG)
- `chain`: Chain identifier (A, B, C, ...)
- `residue_number`: Residue number from PDB
- `ddG`: ΔΔG value (kcal/mol)
  - Positive = mutation to ALA destabilizes (residue is important)
  - Negative = mutation to ALA stabilizes (residue is destabilizing)
  - Near zero = minimal effect

### Example Output

| residue_name | chain | residue_number | ddG    |
|--------------|-------|----------------|--------|
| LEU          | A     | 479            | 1.196  |
| VAL          | A     | 480            | 0.854  |
| ARG          | A     | 481            | 2.346  |
| ...          | ...   | ...            | ...    |
| LEU          | B     | 479            | 1.203  |

### Important Assumptions

1. **Order matters**: AlaScan output and PDB residue order **must match exactly**
2. **Same structure**: PDB file must be the one used for AlaScan (not a different conformation)
3. **Standard residues only**: Skips heteroatoms, waters, ligands
4. **One model**: Uses first model if PDB contains multiple models

### Validation

The function prints diagnostic information:
- Number of AlaScan lines parsed
- Number of residues matched
- Number of chains found
- ΔΔG value range (sanity check)

**Verify these make sense before proceeding!**

In [None]:
import nglview as nv
import pandas as pd
import numpy as np
import re
from pathlib import Path
from Bio.PDB import PDBParser

def parse_alascan_file(alascan_file, pdb_file):
    """
    Parse FoldX AlaScan output and match line-by-line with PDB structure
    
    Args:
        alascan_file: Path to AlaScan results
        pdb_file: Path to PDB structure file
    
    Returns:
        pd.DataFrame with columns: residue_name, chain, residue_number, ddG
    """
    from Bio.PDB import PDBParser
    
    # Parse AlaScan results line by line
    alascan_ddgs = []
    with open(alascan_file, 'r') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            
            # Parse: RESNAME RESNUM to ALA energy change is DDG
            match = re.match(r'(\w+)\s+(\d+)\s+to ALA energy change is\s+([+-]?\d+\.?\d*)', line)
            
            if match:
                ddg = float(match.group(3))
                alascan_ddgs.append(ddg)
    
    print(f"AlaScan file contains {len(alascan_ddgs)} lines")
    
    # Parse PDB structure line by line
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_file)
    
    data = []
    idx = 0
    
    for model in structure:
        for chain in model:
            chain_id = chain.id
            for residue in chain:
                if residue.id[0] == ' ':  # Standard amino acid residue
                    res_num = residue.id[1]
                    res_name = residue.resname
                    
                    # Match by index
                    ddg = alascan_ddgs[idx] if idx < len(alascan_ddgs) else 0.0
                    
                    data.append({
                        'residue_name': res_name,
                        'chain': chain_id,
                        'residue_number': res_num,
                        'ddG': ddg
                    })
                    
                    idx += 1
    
    df = pd.DataFrame(data)
    print(f"Matched {idx} residues across {len(df['chain'].unique())} chains")
    print(f"Chains: {sorted(df['chain'].unique())}")
    print(f"ΔΔG range: {df['ddG'].min():.2f} to {df['ddG'].max():.2f} kcal/mol")
    
    return df

## Frame Averaging of AlaScan Results

### What This Calculates

For each residue in each chain, we calculate statistics **across MD trajectory frames**:

1. **Mean ΔΔG**: Average ΔΔG value across all frames
   - Groups by `(chain, residue_number, residue_name)`
   - Each frame represents a different conformational snapshot

2. **Standard Deviation (std)**:
```
   std = √[Σ(ΔΔG_i - mean)² / (n-1)]
```
   - **Measures**: Conformational variability over time
   - **Interpretation**: How much does the mutation effect fluctuate across protein dynamics?
   - Sample standard deviation with Bessel's correction (ddof=1)

3. **Standard Error of the Mean (sem)**:
```
   sem = std / √n_frames
```
   - **Measures**: Uncertainty in the mean estimate
   - **Interpretation**: How confident are we in this average ΔΔG value?
   - Smaller SEM = more confident (more frames = better sampling)

4. **Additional Statistics**:
   - `count`: Number of frames analyzed
   - `min`, `max`: Range of ΔΔG values across frames

### How It Works

**Step 1**: Parse AlaScan output files (`*_AS.fxout`) and matching PDB structures

**Step 2**: Combine data from all frames into a single dataframe

**Step 3**: Group by `(chain, residue_number, residue_name)` and calculate:
- `mean`: Average ΔΔG across frames
- `std`: Conformational variability
- `sem`: Uncertainty in the mean
- `count`: Number of frames (should be consistent)

**Step 4**: Map mean ΔΔG to B-factors in PDB for visualization

### Example

**Chain A, residue LEU-327, 10 MD frames:**

| Frame | ΔΔG (kcal/mol) |
|-------|----------------|
| 1     | 2.3            |
| 2     | 2.7            |
| 3     | 2.1            |
| ...   | ...            |
| 10    | 2.6            |

**Output:**
- `ddG_mean = 2.5` (average across 10 frames)
- `ddG_std = 0.8` (conformational fluctuation)
- `ddG_sem = 0.8/√10 = 0.25` (uncertainty in mean)
- `n_frames = 10`

### Interpretation

- **High std**: Residue effect is conformation-dependent (dynamic)
- **Low std**: Residue effect is stable across conformations (static)
- **Low sem**: Well-sampled, confident in the mean value

### Output

- **CSV per assembly**: `{assembly}_frame_averaged.csv` with columns:
  - `chain`, `residue_number`, `residue_name`
  - `ddG_mean`, `ddG_std`, `ddG_sem`
  - `n_frames`, `ddG_min`, `ddG_max`
- **PDB per assembly**: `{assembly}_frame_averaged.pdb` with ΔΔG in B-factors
- **Plots per assembly**: Distribution, sequence profile, variance analysis

### Next Step

This frame-averaged data serves as **input** for symmetry averaging, where we further average across symmetry-related chains (e.g., 5-fold symmetry units).

In [None]:
import pandas as pd
import numpy as np
import glob
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from Bio.PDB import PDBParser, PDBIO

# ========================================
# CONFIGURATION
# ========================================
ALASCAN_BASE = "/home/engiray/work/YVF/2-trjconv/17d/viral_assembly_pipeline_results/alascan"

OUTPUT_DIR = f"{ALASCAN_BASE}/alascan_averaged_results"
Path(OUTPUT_DIR).mkdir(exist_ok=True)

# Assembly types to process
ASSEMBLY_TYPES = ['3-fold', '5-fold', 'dimer', 'dimer_para', 'dimer_perp']

# ========================================
# PROCESS EACH ASSEMBLY TYPE
# ========================================
for ASSEMBLY_TYPE in ASSEMBLY_TYPES:
    print("\n" + "="*70)
    print(f"PROCESSING: {ASSEMBLY_TYPE}")
    print("="*70)
    
    ALASCAN_DIR = f"{ALASCAN_BASE}/{ASSEMBLY_TYPE}"
    print(f"AlaScan directory: {ALASCAN_DIR}")
    
    # Find files
    alascan_files = sorted(glob.glob(f"{ALASCAN_DIR}/*/*_AS.fxout"))
    print(alascan_files)
    
    if not alascan_files:
        print(f"⚠️  No AlaScan files found for {ASSEMBLY_TYPE}, skipping...")
        continue
    
    print(f"Found {len(alascan_files)} AlaScan files")
    
    # Find corresponding PDB files
    pdb_files = []
    for alascan_file in alascan_files:
        base_path = Path(alascan_file).parent / Path(alascan_file).stem.replace('_AS', '')
        pdb_file = str(base_path) + ".pdb"
        if Path(pdb_file).exists():
            pdb_files.append(pdb_file)
        else:
            print(f"⚠️  PDB not found: {pdb_file}")
    
    print(f"Found {len(pdb_files)} matching PDB files")
    
    if not pdb_files:
        print(f"⚠️  No PDB files found for {ASSEMBLY_TYPE}, skipping...")
        continue
    
    # ========================================
    # PARSE ALL FILES
    # ========================================
    all_frames = []
    
    for i, (alascan_file, pdb_file) in enumerate(zip(alascan_files, pdb_files)):
        frame_name = Path(alascan_file).parent.name
        print(f"  Parsing frame {i+1}/{len(pdb_files)}: {frame_name}", end='\r')
        
        df = parse_alascan_file(alascan_file, pdb_file)
        df['frame'] = i
        df['frame_name'] = frame_name
        all_frames.append(df)
    
    print()  # New line after progress
    
    # Combine all frames
    combined_df = pd.concat(all_frames, ignore_index=True)
    print(f"Total data points: {len(combined_df)}")
    print(f"Chains: {sorted(combined_df['chain'].unique())}")
    
    # ========================================
    # CALCULATE STATISTICS (AVERAGE OVER FRAMES ONLY)
    # ========================================
    stats_df = combined_df.groupby(['chain', 'residue_number', 'residue_name']).agg({
        'ddG': ['mean', 'std', 'sem', 'count', 'min', 'max']
    }).reset_index()
    
    stats_df.columns = ['chain', 'residue_number', 'residue_name', 
                        'ddG_mean', 'ddG_std', 'ddG_sem', 'n_frames', 'ddG_min', 'ddG_max']
    
    print(f"\nStatistics calculated:")
    print(f"  Total residues: {len(stats_df)}")
    print(f"  Frames per residue: {stats_df['n_frames'].mean():.1f}")
    print(f"\nΔΔG statistics (mean over frames):")
    print(stats_df['ddG_mean'].describe())
    
    # ========================================
    # SAVE STATISTICS
    # ========================================
    csv_file = f"{OUTPUT_DIR}/{ASSEMBLY_TYPE}_frame_averaged.csv"
    stats_df.to_csv(csv_file, index=False)
    print(f"\n✓ Saved statistics: {csv_file}")
    
    # ========================================
    # CREATE PLOTS
    # ========================================
    fig, axes = plt.subplots(2, 2, figsize=(7, 6))
    fig.suptitle(f'{ASSEMBLY_TYPE} - Frame-Averaged AlaScan Results', 
                 fontsize=16, fontweight='bold')
    
    # Plot 1: Mean ΔΔG distribution
    axes[0, 0].hist(stats_df['ddG_mean'], bins=50, edgecolor='black', alpha=0.7)
    axes[0, 0].axvline(0, color='red', linestyle='--', linewidth=2)
    axes[0, 0].set_xlabel('Mean ΔΔG (kcal/mol)', fontsize=12)
    axes[0, 0].set_ylabel('Count', fontsize=12)
    axes[0, 0].set_title('Distribution of Mean ΔΔG', fontsize=13, fontweight='bold')
    axes[0, 0].grid(alpha=0.3)
    
    # Plot 2: Standard deviation
    axes[0, 1].hist(stats_df['ddG_std'], bins=50, edgecolor='black', alpha=0.7, color='orange')
    axes[0, 1].set_xlabel('Standard Deviation (kcal/mol)', fontsize=12)
    axes[0, 1].set_ylabel('Count', fontsize=12)
    axes[0, 1].set_title('Distribution of ΔΔG Std Dev', fontsize=13, fontweight='bold')
    axes[0, 1].grid(alpha=0.3)
    
    # Plot 3: Mean ΔΔG along sequence (first chain)
    first_chain = sorted(stats_df['chain'].unique())[0]
    chain_data = stats_df[stats_df['chain'] == first_chain].sort_values('residue_number')
    axes[1, 0].errorbar(chain_data['residue_number'], chain_data['ddG_mean'], 
                       yerr=chain_data['ddG_std'], fmt='o', markersize=4, 
                       alpha=0.6, capsize=2, color='steelblue')
    axes[1, 0].axhline(0, color='black', linestyle='--', linewidth=1)
    axes[1, 0].set_xlabel('Residue Number', fontsize=12)
    axes[1, 0].set_ylabel('Mean ΔΔG (kcal/mol)', fontsize=12)
    axes[1, 0].set_title(f'ΔΔG Along Sequence (Chain {first_chain}, n={len(pdb_files)} frames)', 
                         fontsize=13, fontweight='bold')
    axes[1, 0].grid(alpha=0.3)
    
    # Plot 4: Std vs Mean
    axes[1, 1].scatter(stats_df['ddG_mean'], stats_df['ddG_std'], alpha=0.3, s=10)
    axes[1, 1].set_xlabel('Mean ΔΔG (kcal/mol)', fontsize=12)
    axes[1, 1].set_ylabel('Standard Deviation (kcal/mol)', fontsize=12)
    axes[1, 1].set_title('Variance vs Mean ΔΔG', fontsize=13, fontweight='bold')
    axes[1, 1].grid(alpha=0.3)
    
    plt.tight_layout()
    plot_file = f"{OUTPUT_DIR}/{ASSEMBLY_TYPE}_plots.png"
    plt.savefig(plot_file, dpi=150, bbox_inches='tight')
    print(f"✓ Saved plots: {plot_file}")
    #plt.close()
    
    # ========================================
    # CREATE AVERAGED PDB STRUCTURE
    # ========================================
    template_pdb = pdb_files[0]
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', template_pdb)
    
    # Map averaged ΔΔG to B-factors
    for model in structure:
        for chain in model:
            chain_id = chain.id
            for residue in chain:
                if residue.id[0] == ' ':
                    res_num = residue.id[1]
                    
                    # Find averaged ΔΔG
                    match = stats_df[(stats_df['chain'] == chain_id) & 
                                    (stats_df['residue_number'] == res_num)]
                    
                    if len(match) > 0:
                        ddg_mean = match['ddG_mean'].iloc[0]
                        for atom in residue:
                            atom.set_bfactor(ddg_mean)
    
    # Save structure
    pdb_file = f"{OUTPUT_DIR}/{ASSEMBLY_TYPE}_frame_averaged.pdb"
    io = PDBIO()
    io.set_structure(structure)
    io.save(pdb_file)
    print(f"✓ Saved structure: {pdb_file}")
    
    # ========================================
    # SUMMARY STATISTICS
    # ========================================
    print("\n" + "-"*70)
    print("TOP 10 MOST IMPORTANT RESIDUES (highest mean ΔΔG)")
    print("-"*70)
    top_important = stats_df.nlargest(10, 'ddG_mean')[['residue_name', 'chain', 'residue_number', 'ddG_mean', 'ddG_std']]
    print(top_important.to_string(index=False))
    
    print("\n" + "-"*70)
    print("TOP 10 MOST DESTABILIZING RESIDUES (lowest mean ΔΔG)")
    print("-"*70)
    top_destab = stats_df.nsmallest(10, 'ddG_mean')[['residue_name', 'chain', 'residue_number', 'ddG_mean', 'ddG_std']]
    print(top_destab.to_string(index=False))

print("\n" + "="*70)
print("ALL ASSEMBLIES PROCESSED")
print("="*70)
print(f"\nResults saved in: {OUTPUT_DIR}/")
print("\nFiles created for each assembly:")
print("  - {assembly}_frame_averaged.csv  (statistics)")
print("  - {assembly}_frame_averaged.pdb  (structure with ΔΔG in B-factor)")
print("  - {assembly}_plots.png           (visualization plots)")


## Symmetry Averaging with Proper Error Propagation

### What This Calculates

For each unique residue position in the asymmetric unit, we calculate:

1. **Mean ΔΔG**: Average across all symmetry-related chains and frames
   - Groups by `(sym_position, residue_number, residue_name)`
   - Takes mean of frame-averaged ΔΔG values

2. **Frame Standard Deviation (σ_frame)**:
   - Average of the std values from frame averaging step
   - **Measures**: Conformational variability across MD trajectory
   - **Interpretation**: How much does this residue's ΔΔG fluctuate over time?

3. **Symmetry Standard Deviation (σ_symmetry)**:
   - Std of ΔΔG across symmetry-related chains
   - **Measures**: Symmetry-breaking between supposedly equivalent units
   - **Interpretation**: Do all 5-fold (or 3-fold) related copies behave the same?

4. **Total Standard Deviation**:
```
   σ_total = √(σ²_frame + σ²_symmetry)
```
   - Combines both sources of variance
   - Describes total spread of all ΔΔG measurements

5. **Nested Design Standard Error**:
```
   SEM = √(σ²_symmetry/n_fold + σ²_frame/(n_fold × n_frames))
```
   - Accounts for non-independence of symmetry units within frames
   - Only frames are truly independent replicates
   - Gives correct uncertainty estimate for the mean ΔΔG

### How It Works

**Step 1**: Assign each chain to its symmetry position (0, 1, 2, 3 in asymmetric unit)

**Step 2**: Group by symmetry position and calculate:
- Mean of frame-averaged ΔΔG values → `ddG_mean`
- Std across symmetry units → `ddG_symmetry_std`
- Mean of frame stds → `ddG_frame_std`

**Step 3**: Combine variance components:
- Total variance = frame variance + symmetry variance
- `ddG_total_std = √(σ²_frame + σ²_symmetry)`

**Step 4**: Calculate nested SEM (accounts for correlation):
- `ddG_total_sem = √(σ²_symmetry/n_fold + σ²_frame/(n_fold × n_frames))`

### Example

**5-fold assembly, 10 frames, residue LEU-327 at position 0:**
- 5 symmetry-related chains (A, E, I, M, Q) each averaged over 10 frames
- Mean of those 5 values → `ddG_mean`
- Std of those 5 values → `ddG_symmetry_std` (between-chain variation)
- Mean of their individual stds → `ddG_frame_std` (within-chain variation)
- Combined: `σ_total = √(σ²_frame + σ²_symmetry)`
- SEM: `√(σ²_symmetry/5 + σ²_frame/50)`

### Output

- **CSV**: `ddG_mean`, `ddG_total_std`, `ddG_total_sem`, `ddG_frame_std`, `ddG_symmetry_std`, `n_frames`
- **PDB files**: ΔΔG values mapped to B-factors for visualization
- **Plots**: Variance decomposition, distributions, sequence profiles

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from Bio.PDB import PDBParser, PDBIO, Select

# ========================================
# CONFIGURATION
# ========================================
INPUT_DIR = f"{ALASCAN_BASE}/alascan_averaged_results"
OUTPUT_DIR = f"{ALASCAN_BASE}/alascan_symmetry_averaged_results"
Path(OUTPUT_DIR).mkdir(exist_ok=True)

# Symmetry configuration for each assembly type
SYMMETRY_CONFIG = {
    '3-fold': {'n_fold': 3, 'chains_per_unit': 4},
    '5-fold': {'n_fold': 5, 'chains_per_unit': 4},
    'dimer': {'n_fold': 1, 'chains_per_unit': 4},
    'dimer_para': {'n_fold': 2, 'chains_per_unit': 4},
    'dimer_perp': {'n_fold': 2, 'chains_per_unit': 4}
}

def assign_symmetry_position(chain_id, chains_per_unit):
    """Assign symmetry position based on chain ID"""
    chain_idx = ord(chain_id) - ord('A')
    position = chain_idx % chains_per_unit
    unit = chain_idx // chains_per_unit
    return position, unit

# Define Select class
class AsymmetricUnitSelect(Select):
    def __init__(self, chains):
        self.chains = chains
    
    def accept_chain(self, chain):
        return chain.id in self.chains

# ========================================
# PROCESS EACH ASSEMBLY TYPE
# ========================================
for ASSEMBLY_TYPE, symmetry in SYMMETRY_CONFIG.items():
    print("\n" + "="*70)
    print(f"PROCESSING: {ASSEMBLY_TYPE}")
    print("="*70)
    
    csv_file = f"{INPUT_DIR}/{ASSEMBLY_TYPE}_frame_averaged.csv"
    pdb_file = f"{INPUT_DIR}/{ASSEMBLY_TYPE}_frame_averaged.pdb"
    
    if not Path(csv_file).exists():
        print(f"⚠️  File not found: {csv_file}, skipping...")
        continue
    
    stats_df = pd.read_csv(csv_file)
    print(f"Loaded {len(stats_df)} residues")
    print(f"Chains: {sorted(stats_df['chain'].unique())}")
    print(f"\nSymmetry: {symmetry['n_fold']}-fold, {symmetry['chains_per_unit']} chains per unit")
    
    # ========================================
    # ADD SYMMETRY POSITIONS
    # ========================================
    stats_df['sym_position'] = stats_df['chain'].apply(
        lambda c: assign_symmetry_position(c, symmetry['chains_per_unit'])[0]
    )
    stats_df['sym_unit'] = stats_df['chain'].apply(
        lambda c: assign_symmetry_position(c, symmetry['chains_per_unit'])[1]
    )
    
    print("\nChain to symmetry mapping:")
    unique_chains = sorted(stats_df['chain'].unique())
    for chain in unique_chains:
        pos, unit = assign_symmetry_position(chain, symmetry['chains_per_unit'])
        print(f"  Chain {chain} -> Unit {unit}, Position {pos}")
    
    # ========================================
    # AVERAGE OVER SYMMETRY WITH PROPER ERROR PROPAGATION
    # ========================================
    if symmetry['n_fold'] > 1:
        # Multiple units - calculate all variance components
        symmetry_avg = stats_df.groupby(['sym_position', 'residue_number', 'residue_name']).agg({
            'ddG_mean': 'mean',           # Grand mean across symmetry units
            'ddG_std': 'mean',            # Mean of frame stds (within-unit variability)
            'n_frames': 'first'
        }).reset_index()
        
        # Calculate std across symmetry units (between-unit variability)
        symmetry_std = stats_df.groupby(['sym_position', 'residue_number', 'residue_name'])['ddG_mean'].std().reset_index()
        symmetry_std.columns = ['sym_position', 'residue_number', 'residue_name', 'ddG_symmetry_std']
        
        # Merge
        symmetry_avg = symmetry_avg.merge(symmetry_std, on=['sym_position', 'residue_number', 'residue_name'])
        
        # Rename for clarity
        symmetry_avg.rename(columns={'ddG_std': 'ddG_frame_std'}, inplace=True)
        
        # Calculate total standard deviation (combined variance sources)
        symmetry_avg['ddG_total_std'] = np.sqrt(
            symmetry_avg['ddG_frame_std']**2 + 
            symmetry_avg['ddG_symmetry_std']**2
        )
        
        # Calculate nested design SEM
        # SEM = sqrt(σ²_symmetry/n_fold + σ²_frame/(n_fold × n_frames))
        symmetry_avg['ddG_total_sem'] = np.sqrt(
            symmetry_avg['ddG_symmetry_std']**2 / symmetry['n_fold'] +
            symmetry_avg['ddG_frame_std']**2 / (symmetry['n_fold'] * symmetry_avg['n_frames'])
        )
        
        # Also calculate individual SEM components for reference
        symmetry_avg['ddG_frame_sem'] = symmetry_avg['ddG_frame_std'] / np.sqrt(
            symmetry['n_fold'] * symmetry_avg['n_frames']
        )
        symmetry_avg['ddG_symmetry_sem'] = symmetry_avg['ddG_symmetry_std'] / np.sqrt(symmetry['n_fold'])
        
        has_symmetry_std = True
        
        print(f"\nSymmetry-averaged data: {len(symmetry_avg)} residues")
        print(f"  (averaged over {symmetry['n_fold']} symmetry-related chains)")
        
    else:
        # Only one unit - no symmetry averaging possible
        symmetry_avg = stats_df.groupby(['sym_position', 'residue_number', 'residue_name']).agg({
            'ddG_mean': 'mean',
            'ddG_std': 'mean',
            'n_frames': 'first'
        }).reset_index()
        
        symmetry_avg.rename(columns={'ddG_std': 'ddG_frame_std'}, inplace=True)
        
        # No symmetry contribution
        symmetry_avg['ddG_symmetry_std'] = np.nan
        symmetry_avg['ddG_total_std'] = symmetry_avg['ddG_frame_std']
        symmetry_avg['ddG_total_sem'] = symmetry_avg['ddG_frame_std'] / np.sqrt(symmetry_avg['n_frames'])
        symmetry_avg['ddG_frame_sem'] = symmetry_avg['ddG_total_sem']
        symmetry_avg['ddG_symmetry_sem'] = np.nan
        
        has_symmetry_std = False
        print("\n⚠️  Note: Only 1 symmetry unit - no symmetry std calculated")
    
    print(f"\nΔΔG statistics (symmetry-averaged):")
    print(f"  Mean ΔΔG: {symmetry_avg['ddG_mean'].mean():.3f} ± {symmetry_avg['ddG_mean'].std():.3f} kcal/mol")
    if has_symmetry_std:
        print(f"  Mean frame std: {symmetry_avg['ddG_frame_std'].mean():.3f} kcal/mol")
        print(f"  Mean symmetry std: {symmetry_avg['ddG_symmetry_std'].mean():.3f} kcal/mol")
    print(f"  Mean total std: {symmetry_avg['ddG_total_std'].mean():.3f} kcal/mol")
    print(f"  Mean total SEM: {symmetry_avg['ddG_total_sem'].mean():.3f} kcal/mol")
    
    # ========================================
    # SAVE STATISTICS
    # ========================================
    # Reorder columns for clarity
    col_order = ['sym_position', 'residue_number', 'residue_name', 'ddG_mean',
                 'ddG_total_std', 'ddG_total_sem', 
                 'ddG_frame_std', 'ddG_frame_sem',
                 'ddG_symmetry_std', 'ddG_symmetry_sem',
                 'n_frames']
    
    csv_out = f"{OUTPUT_DIR}/{ASSEMBLY_TYPE}_symmetry_averaged.csv"
    symmetry_avg[col_order].to_csv(csv_out, index=False)
    print(f"\n✓ Saved statistics: {csv_out}")
    
    # ========================================
    # CREATE PLOTS
    # ========================================
    fig, axes = plt.subplots(2, 3, figsize=(7, 6))
    fig.suptitle(f'{ASSEMBLY_TYPE} - Symmetry-Averaged AlaScan Results', 
                 fontsize=16, fontweight='bold')
    
    # Plot 1: Mean ΔΔG distribution
    axes[0, 0].hist(symmetry_avg['ddG_mean'], bins=50, edgecolor='black', alpha=0.7)
    axes[0, 0].axvline(0, color='red', linestyle='--', linewidth=2)
    axes[0, 0].set_xlabel('Mean ΔΔG (kcal/mol)', fontsize=12)
    axes[0, 0].set_ylabel('Count', fontsize=12)
    axes[0, 0].set_title('Distribution of Mean ΔΔG', fontsize=13, fontweight='bold')
    axes[0, 0].grid(alpha=0.3)
    
    # Plot 2: Total standard deviation
    axes[0, 1].hist(symmetry_avg['ddG_total_std'], bins=50, 
                    edgecolor='black', alpha=0.7, color='green')
    axes[0, 1].set_xlabel('Total Std Dev (kcal/mol)', fontsize=12)
    axes[0, 1].set_ylabel('Count', fontsize=12)
    axes[0, 1].set_title('Distribution of Total ΔΔG Std', 
                        fontsize=13, fontweight='bold')
    axes[0, 1].grid(alpha=0.3)
    
    # Plot 3: Frame vs Symmetry contributions
    if has_symmetry_std:
        axes[0, 2].scatter(symmetry_avg['ddG_frame_std'], 
                          symmetry_avg['ddG_symmetry_std'], 
                          alpha=0.3, s=10)
        axes[0, 2].plot([0, symmetry_avg['ddG_frame_std'].max()], 
                       [0, symmetry_avg['ddG_frame_std'].max()], 
                       'r--', label='Equal contribution')
        axes[0, 2].set_xlabel('Frame Std Dev (kcal/mol)', fontsize=12)
        axes[0, 2].set_ylabel('Symmetry Std Dev (kcal/mol)', fontsize=12)
        axes[0, 2].set_title('Frame vs Symmetry Contributions', 
                            fontsize=13, fontweight='bold')
        axes[0, 2].legend()
    else:
        axes[0, 2].hist(symmetry_avg['ddG_frame_std'], bins=50, 
                        edgecolor='black', alpha=0.7, color='orange')
        axes[0, 2].set_xlabel('Frame Std Dev (kcal/mol)', fontsize=12)
        axes[0, 2].set_ylabel('Count', fontsize=12)
        axes[0, 2].set_title('Distribution of Frame Std', 
                            fontsize=13, fontweight='bold')
    axes[0, 2].grid(alpha=0.3)
    
    # Plot 4: Mean ΔΔG along sequence with total error bars
    colors = plt.cm.tab10(np.linspace(0, 1, symmetry['chains_per_unit']))
    for pos in range(symmetry['chains_per_unit']):
        pos_data = symmetry_avg[symmetry_avg['sym_position'] == pos].sort_values('residue_number')
        
        axes[1, 0].errorbar(pos_data['residue_number'], pos_data['ddG_mean'],
                           yerr=pos_data['ddG_total_std'],  # Use total std
                           fmt='o-', markersize=3,
                           alpha=0.7, capsize=2, label=f'Position {pos}', color=colors[pos])
    
    axes[1, 0].axhline(0, color='black', linestyle='--', linewidth=1)
    axes[1, 0].set_xlabel('Residue Number', fontsize=12)
    axes[1, 0].set_ylabel('Mean ΔΔG (kcal/mol)', fontsize=12)
    axes[1, 0].set_title(f'ΔΔG Along Sequence (±Total Std)', fontsize=13, fontweight='bold')
    if symmetry['chains_per_unit'] > 1:
        axes[1, 0].legend(fontsize=8)
    axes[1, 0].grid(alpha=0.3)
    
    # Plot 5: Total std vs Mean ΔΔG
    axes[1, 1].scatter(symmetry_avg['ddG_mean'], symmetry_avg['ddG_total_std'], 
                      alpha=0.3, s=10, color='green')
    axes[1, 1].set_xlabel('Mean ΔΔG (kcal/mol)', fontsize=12)
    axes[1, 1].set_ylabel('Total Std Dev (kcal/mol)', fontsize=12)
    axes[1, 1].set_title('Total Variance vs Mean ΔΔG', fontsize=13, fontweight='bold')
    axes[1, 1].grid(alpha=0.3)
    
    # Plot 6: Total SEM distribution
    axes[1, 2].hist(symmetry_avg['ddG_total_sem'], bins=50, 
                    edgecolor='black', alpha=0.7, color='purple')
    axes[1, 2].set_xlabel('Total SEM (kcal/mol)', fontsize=12)
    axes[1, 2].set_ylabel('Count', fontsize=12)
    axes[1, 2].set_title('Distribution of Total SEM (Nested Design)', 
                        fontsize=13, fontweight='bold')
    axes[1, 2].grid(alpha=0.3)
    
    plt.tight_layout()
    plot_out = f"{OUTPUT_DIR}/{ASSEMBLY_TYPE}_symmetry_plots.png"
    plt.savefig(plot_out, dpi=150, bbox_inches='tight')
    print(f"✓ Saved plots: {plot_out}")
    plt.close()
    
    # ========================================
    # CREATE PDB STRUCTURES
    # ========================================
    if not Path(pdb_file).exists():
        print(f"⚠️  PDB file not found: {pdb_file}, skipping structure creation...")
        continue
    
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_file)
    
    asymmetric_chains = [chr(ord('A') + i) for i in range(symmetry['chains_per_unit'])]
    
    # Map symmetry-averaged ΔΔG to B-factors
    for model in structure:
        for chain in model:
            chain_id = chain.id
            sym_pos, _ = assign_symmetry_position(chain_id, symmetry['chains_per_unit'])
            
            for residue in chain:
                if residue.id[0] == ' ':
                    res_num = residue.id[1]
                    
                    match = symmetry_avg[(symmetry_avg['sym_position'] == sym_pos) &
                                        (symmetry_avg['residue_number'] == res_num)]
                    
                    if len(match) > 0:
                        ddg_mean = match['ddG_mean'].iloc[0]
                        for atom in residue:
                            atom.set_bfactor(ddg_mean)
    
    # Save full structure
    pdb_out_full = f"{OUTPUT_DIR}/{ASSEMBLY_TYPE}_symmetry_averaged_full.pdb"
    io = PDBIO()
    io.set_structure(structure)
    io.save(pdb_out_full)
    print(f"✓ Saved full structure: {pdb_out_full}")
    
    # Save asymmetric unit
    pdb_out_asym = f"{OUTPUT_DIR}/{ASSEMBLY_TYPE}_symmetry_averaged_asymmetric.pdb"
    io = PDBIO()
    io.set_structure(structure)
    io.save(pdb_out_asym, AsymmetricUnitSelect(asymmetric_chains))
    print(f"✓ Saved asymmetric unit: {pdb_out_asym}")
    
    # ========================================
    # SUMMARY STATISTICS
    # ========================================
    print("\n" + "-"*70)
    print("TOP 10 MOST IMPORTANT RESIDUES (highest mean ΔΔG)")
    print("-"*70)
    cols = ['residue_name', 'sym_position', 'residue_number', 'ddG_mean', 
            'ddG_total_std', 'ddG_total_sem']
    if has_symmetry_std:
        cols.extend(['ddG_frame_std', 'ddG_symmetry_std'])
    top_important = symmetry_avg.nlargest(10, 'ddG_mean')[cols]
    print(top_important.to_string(index=False))
    
    print("\n" + "-"*70)
    print("TOP 10 MOST DESTABILIZING RESIDUES (lowest mean ΔΔG)")
    print("-"*70)
    top_destab = symmetry_avg.nsmallest(10, 'ddG_mean')[cols]
    print(top_destab.to_string(index=False))
    
    if has_symmetry_std:
        print("\n" + "-"*70)
        print("MOST VARIABLE ACROSS SYMMETRY (highest symmetry std)")
        print("-"*70)
        top_variable = symmetry_avg.nlargest(10, 'ddG_symmetry_std')[cols]
        print(top_variable.to_string(index=False))
        
        print("\n" + "-"*70)
        print("VARIANCE DECOMPOSITION")
        print("-"*70)
        total_var = symmetry_avg['ddG_total_std'].mean()**2
        frame_var = symmetry_avg['ddG_frame_std'].mean()**2
        sym_var = symmetry_avg['ddG_symmetry_std'].mean()**2
        print(f"  Total variance: {total_var:.4f}")
        print(f"  Frame variance: {frame_var:.4f} ({100*frame_var/total_var:.1f}%)")
        print(f"  Symmetry variance: {sym_var:.4f} ({100*sym_var/total_var:.1f}%)")

print("\n" + "="*70)
print("ALL ASSEMBLIES PROCESSED WITH PROPER ERROR PROPAGATION")
print("="*70)
print("\nOutput columns explained:")
print("  ddG_mean:         Grand mean across frames and symmetry units")
print("  ddG_total_std:    √(σ²_frame + σ²_symmetry) - combined variability")
print("  ddG_total_sem:    √(σ²_sym/n_fold + σ²_frame/(n_fold×n_frames)) - nested design")
print("  ddG_frame_std:    Mean std across frames (conformational variability)")
print("  ddG_symmetry_std: Std across symmetry units (symmetry-breaking)")


In [None]:
"""
================================================================================
INTERFACE CONTRIBUTION ANALYSIS (Symmetry-Averaged Data)
================================================================================

CRITICAL INSIGHT:
-----------------
After symmetry averaging, we have PER-COPY effects for each assembly type.
Therefore, we do NOT multiply when subtracting the dimer baseline.

CALCULATION:
------------
For each residue position in the asymmetric unit:

- 3-fold interface = 3-fold_asymmetric - dimer_asymmetric
- 5-fold interface = 5-fold_asymmetric - dimer_asymmetric
- dimer_para interface = dimer_para_asymmetric - dimer_asymmetric
- dimer_perp interface = dimer_perp_asymmetric - dimer_asymmetric

All values are per-copy (already averaged over symmetry units).

WHAT THIS MEANS:
----------------
The difference reflects the change in residue importance when going from:
- Isolated dimer context → Specific vertex/interface context

Example:
- If LEU-327 has ΔΔG = 2.5 in dimer and 3.2 in 5-fold:
  → Interface contribution = 3.2 - 2.5 = 0.7 kcal/mol
  → This residue becomes MORE important at the 5-fold vertex

PROPER ERROR PROPAGATION:
-------------------------
For subtraction (A - B) of independent measurements:

1. Standard Deviation:
   σ_interface = √(σ²_assembly + σ²_dimer)
   
2. Standard Error (Nested):
   SEM_interface = √(SEM²_assembly + SEM²_dimer)

No multiplication because we're comparing per-copy values.
================================================================================
"""

import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

# ========================================
# CONFIGURATION
# ========================================
INPUT_DIR = f"{ALASCAN_BASE}/alascan_symmetry_averaged_results"
OUTPUT_DIR = f"{ALASCAN_BASE}/alascan_interface_contributions"
Path(OUTPUT_DIR).mkdir(exist_ok=True)

# Interface definitions (no multipliers needed for symmetry-averaged data)
INTERFACE_CONFIG = {
    '3-fold': {'label': '3-fold vertex (vs dimer)'},
    '5-fold': {'label': '5-fold vertex (vs dimer)'},
    'dimer_para': {'label': 'Para dimer (vs dimer)'},
    'dimer_perp': {'label': 'Perp dimer (vs dimer)'}
}

# ========================================
# LOAD DATA
# ========================================
print("Loading symmetry-averaged data...")

# Load dimer (baseline)
dimer_file = f"{INPUT_DIR}/dimer_symmetry_averaged.csv"
dimer_df = pd.read_csv(dimer_file)
print(f"✓ Loaded dimer: {len(dimer_df)} residues")

# Load all other assemblies
assemblies = {}
for assembly_name in INTERFACE_CONFIG.keys():
    csv_file = f"{INPUT_DIR}/{assembly_name}_symmetry_averaged.csv"
    if Path(csv_file).exists():
        assemblies[assembly_name] = pd.read_csv(csv_file)
        print(f"✓ Loaded {assembly_name}: {len(assemblies[assembly_name])} residues")
    else:
        print(f"⚠️  File not found: {csv_file}")

# ========================================
# CALCULATE INTERFACE CONTRIBUTIONS
# ========================================
interface_results = {}

for assembly_name, assembly_df in assemblies.items():
    print(f"\n{'='*70}")
    print(f"CALCULATING: {assembly_name} interface")
    print(f"{'='*70}")
    
    # Merge assembly and dimer data (same positions)
    merged = assembly_df.merge(
        dimer_df,
        on=['sym_position', 'residue_number', 'residue_name'],
        suffixes=('_assembly', '_dimer')
    )
    
    print(f"Merged {len(merged)} residues")
    
    # Calculate interface contribution (no multiplication - both are per-copy)
    merged['ddG_interface'] = merged['ddG_mean_assembly'] - merged['ddG_mean_dimer']
    
    # Error propagation for total std
    # σ²_interface = σ²_assembly + σ²_dimer (simple addition for uncorrelated measurements)
    merged['ddG_interface_total_std'] = np.sqrt(
        merged['ddG_total_std_assembly']**2 + 
        merged['ddG_total_std_dimer']**2
    )
    
    # Error propagation for nested SEM
    # SEM²_interface = SEM²_assembly + SEM²_dimer
    merged['ddG_interface_total_sem'] = np.sqrt(
        merged['ddG_total_sem_assembly']**2 + 
        merged['ddG_total_sem_dimer']**2
    )
    
    # Propagate frame and symmetry components separately for interpretation
    merged['ddG_interface_frame_std'] = np.sqrt(
        merged['ddG_frame_std_assembly']**2 + 
        merged['ddG_frame_std_dimer']**2
    )
    
    # Check if symmetry std exists (will be NaN for dimer)
    if 'ddG_symmetry_std_assembly' in merged.columns:
        # Dimer has NaN for symmetry_std, so fill with 0
        merged['ddG_interface_symmetry_std'] = np.sqrt(
            merged['ddG_symmetry_std_assembly']**2 + 
            merged['ddG_symmetry_std_dimer'].fillna(0)**2
        )
    else:
        merged['ddG_interface_symmetry_std'] = np.nan
    
    # Select output columns
    output_cols = [
        'sym_position', 'residue_number', 'residue_name',
        'ddG_interface', 'ddG_interface_total_std', 'ddG_interface_total_sem',
        'ddG_interface_frame_std', 'ddG_interface_symmetry_std',
        'ddG_mean_assembly', 'ddG_total_std_assembly', 'ddG_total_sem_assembly',
        'ddG_mean_dimer', 'ddG_total_std_dimer', 'ddG_total_sem_dimer',
        'n_frames_assembly'
    ]
    
    interface_df = merged[output_cols].copy()
    interface_df.rename(columns={'n_frames_assembly': 'n_frames'}, inplace=True)
    
    interface_results[assembly_name] = interface_df
    
    # Save results
    csv_out = f"{OUTPUT_DIR}/{assembly_name}_interface.csv"
    interface_df.to_csv(csv_out, index=False)
    print(f"✓ Saved: {csv_out}")
    
    # Summary statistics
    print(f"\nInterface ΔΔG statistics (per-copy contribution):")
    print(f"  Mean: {interface_df['ddG_interface'].mean():.3f} ± {interface_df['ddG_interface'].std():.3f} kcal/mol")
    print(f"  Range: [{interface_df['ddG_interface'].min():.3f}, {interface_df['ddG_interface'].max():.3f}]")
    print(f"  Mean total std: {interface_df['ddG_interface_total_std'].mean():.3f} kcal/mol")
    print(f"  Mean total SEM: {interface_df['ddG_interface_total_sem'].mean():.3f} kcal/mol")
    
    # Count significant changes (|ΔΔG_interface| > 2×SEM)
    significant = interface_df[np.abs(interface_df['ddG_interface']) > 2 * interface_df['ddG_interface_total_sem']]
    print(f"  Significant changes (|ΔΔG| > 2×SEM): {len(significant)} residues ({100*len(significant)/len(interface_df):.1f}%)")

# ========================================
# CREATE COMPARISON PLOTS
# ========================================
print(f"\n{'='*70}")
print("CREATING COMPARISON PLOTS")
print(f"{'='*70}")

# Plot 1: Distribution comparison
fig, axes = plt.subplots(2, 2, figsize=(7.5, 6.5))
fig.suptitle('Interface ΔΔG Distributions (vs Dimer Baseline)', 
             fontsize=12, fontweight='bold')

for idx, (assembly_name, interface_df) in enumerate(interface_results.items()):
    ax = axes.flat[idx]
    
    label = INTERFACE_CONFIG[assembly_name]['label']
    
    ax.hist(interface_df['ddG_interface'], bins=50, edgecolor='black', 
            alpha=0.7, label=label)
    ax.axvline(0, color='red', linestyle='--', linewidth=2, label='No change')
    ax.set_xlabel('Interface ΔΔG (kcal/mol)', fontsize=10)
    ax.set_ylabel('Count', fontsize=10)
    ax.set_title(f'{assembly_name.replace("-", " ").title()}', fontsize=11, fontweight='bold')
    ax.legend(fontsize=8)
    ax.grid(alpha=0.3)
    ax.tick_params(labelsize=9)

plt.tight_layout()
plot_file = f"{OUTPUT_DIR}/interface_distributions.png"
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
plt.show()
print(f"✓ Saved: {plot_file}")
plt.close()

# Plot 2: Sequence profiles
fig, axes = plt.subplots(2, 2, figsize=(7.5, 6.5))
fig.suptitle('Interface ΔΔG Along Sequence (Position 0)', 
             fontsize=12, fontweight='bold')

for idx, (assembly_name, interface_df) in enumerate(interface_results.items()):
    ax = axes.flat[idx]
    
    # Plot position 0 only
    pos0 = interface_df[interface_df['sym_position'] == 0].sort_values('residue_number')
    
    ax.errorbar(pos0['residue_number'], pos0['ddG_interface'],
               yerr=pos0['ddG_interface_total_std'],
               fmt='o-', markersize=2, alpha=0.7, capsize=1, linewidth=0.5)
    ax.axhline(0, color='black', linestyle='--', linewidth=1)
    ax.set_xlabel('Residue Number', fontsize=10)
    ax.set_ylabel('Interface ΔΔG (kcal/mol)', fontsize=10)
    ax.set_title(f'{assembly_name.replace("-", " ").title()}', fontsize=11, fontweight='bold')
    ax.grid(alpha=0.3)
    ax.tick_params(labelsize=9)

plt.tight_layout()
plot_file = f"{OUTPUT_DIR}/interface_sequence_profiles.png"
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
plt.show()
print(f"✓ Saved: {plot_file}")
plt.close()

# Plot 3: Assembly vs Dimer comparison
fig, axes = plt.subplots(2, 2, figsize=(7.5, 6.5))
fig.suptitle('Assembly ΔΔG vs Dimer ΔΔG', 
             fontsize=12, fontweight='bold')

for idx, (assembly_name, interface_df) in enumerate(interface_results.items()):
    ax = axes.flat[idx]
    
    ax.scatter(interface_df['ddG_mean_dimer'], interface_df['ddG_mean_assembly'],
              alpha=0.3, s=5)
    
    # Identity line
    min_val = min(interface_df['ddG_mean_dimer'].min(), interface_df['ddG_mean_assembly'].min())
    max_val = max(interface_df['ddG_mean_dimer'].max(), interface_df['ddG_mean_assembly'].max())
    ax.plot([min_val, max_val], [min_val, max_val], 'r--', label='No change', linewidth=2)
    
    ax.set_xlabel('Dimer ΔΔG (kcal/mol)', fontsize=10)
    ax.set_ylabel(f'{assembly_name} ΔΔG (kcal/mol)', fontsize=10)
    ax.set_title(f'{assembly_name.replace("-", " ").title()}', fontsize=11, fontweight='bold')
    ax.legend(fontsize=8)
    ax.grid(alpha=0.3)
    ax.tick_params(labelsize=9)

plt.tight_layout()
plot_file = f"{OUTPUT_DIR}/assembly_vs_dimer.png"
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
plt.show()
print(f"✓ Saved: {plot_file}")
plt.close()

# ========================================
# TOP INTERFACE RESIDUES
# ========================================
print(f"\n{'='*70}")
print("TOP INTERFACE RESIDUES")
print(f"{'='*70}")

for assembly_name, interface_df in interface_results.items():
    label = INTERFACE_CONFIG[assembly_name]['label']
    
    print(f"\n{label}:")
    print("-" * 70)
    print("TOP 10 MORE IMPORTANT AT INTERFACE (most positive ΔΔG_interface):")
    top_stab = interface_df.nlargest(10, 'ddG_interface')[
        ['residue_name', 'sym_position', 'residue_number', 
         'ddG_interface', 'ddG_interface_total_sem', 'ddG_mean_assembly', 'ddG_mean_dimer']
    ]
    print(top_stab.to_string(index=False))
    
    print("\nTOP 10 LESS IMPORTANT AT INTERFACE (most negative ΔΔG_interface):")
    top_destab = interface_df.nsmallest(10, 'ddG_interface')[
        ['residue_name', 'sym_position', 'residue_number', 
         'ddG_interface', 'ddG_interface_total_sem', 'ddG_mean_assembly', 'ddG_mean_dimer']
    ]
    print(top_destab.to_string(index=False))

print(f"\n{'='*70}")
print("INTERFACE ANALYSIS COMPLETE")
print(f"{'='*70}")
print(f"\nResults saved in: {OUTPUT_DIR}/")
print("\nInterpretation:")
print("  - Positive ΔΔG_interface: Residue MORE important at this interface (vs isolated dimer)")
print("  - Negative ΔΔG_interface: Residue LESS important at this interface (vs isolated dimer)")
print("  - Near zero: Similar importance in both contexts")
print("\nError propagation (symmetry-averaged data):")
print("  σ²_interface = σ²_assembly + σ²_dimer")
print("  (No multiplication - comparing per-copy effects)")

In [None]:
"""
================================================================================
VIRAL CAPSID CONTRIBUTION ANALYSIS (With Correct Error Propagation)
================================================================================

GOAL:
-----
Calculate total viral capsid contribution per residue using icosahedral symmetry,
with CORRECT error propagation accounting for correlated dimer measurements.

ALGEBRAIC EXPANSION:
--------------------
Starting from interface-specific contributions:
- vertex_5fold = 5fold_raw - 5×dimer
- vertex_3fold = 3fold_raw - 3×dimer  
- inter_dimer_para = dimer_para_raw - 2×dimer
- inter_dimer_perp = dimer_perp_raw - 2×dimer

Total viral envelope:
capsid = 12×vertex_5fold + 20×vertex_3fold + 60×inter_dimer_para + 60×inter_dimer_perp + 90×dimer

Expanding:
capsid = 12×(5fold - 5×dimer) + 20×(3fold - 3×dimer) + 60×(para - 2×dimer) + 60×(perp - 2×dimer) + 90×dimer
       = 12×5fold + 20×3fold + 60×para + 60×perp + dimer×(-60 - 60 - 120 - 120 + 90)
       = 12×5fold + 20×3fold + 60×para + 60×perp - 270×dimer

CORRECT ERROR PROPAGATION:
--------------------------
Because the SAME dimer measurement is used in all subtractions, we must use the
NET coefficient on dimer:

Variance:
σ²_capsid = 12²×σ²_5fold + 20²×σ²_3fold + 60²×σ²_para + 60²×σ²_perp + 270²×σ²_dimer
          = 144×σ²_5fold + 400×σ²_3fold + 3600×σ²_para + 3600×σ²_perp + 72900×σ²_dimer

Standard Error (nested design):
SEM²_capsid = 144×SEM²_5fold + 400×SEM²_3fold + 3600×SEM²_para + 3600×SEM²_perp + 72900×SEM²_dimer

KEY INSIGHT:
------------
We calculate interface contributions for INTERPRETATION but use RAW assembly values
for VARIANCE calculation to avoid double-counting correlated dimer uncertainty.
================================================================================
"""

import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

# ========================================
# CONFIGURATION
# ========================================
INPUT_DIR = f"{ALASCAN_BASE}/alascan_symmetry_averaged_results"
INTERFACE_DIR = f"{ALASCAN_BASE}/alascan_interface_contributions"
OUTPUT_DIR = f"{ALASCAN_BASE}/alascan_viral_capsid"
Path(OUTPUT_DIR).mkdir(exist_ok=True)

# Icosahedral symmetry coefficients
ICOSAHEDRAL_COEFFICIENTS = {
    '5-fold': 12,      # 12 five-fold vertices
    '3-fold': 20,      # 20 three-fold vertices
    'dimer_para': 60,  # 60 parallel dimer-dimer interfaces
    'dimer_perp': 60,  # 60 perpendicular dimer-dimer interfaces
    'dimer': -270      # Net dimer contribution (calculated above)
}

# ========================================
# LOAD SYMMETRY-AVERAGED DATA
# ========================================
print("Loading symmetry-averaged data...")

assemblies_data = {}
for assembly_name in ['3-fold', '5-fold', 'dimer_para', 'dimer_perp', 'dimer']:
    csv_file = f"{INPUT_DIR}/{assembly_name}_symmetry_averaged.csv"
    if Path(csv_file).exists():
        assemblies_data[assembly_name] = pd.read_csv(csv_file)
        print(f"✓ Loaded {assembly_name}: {len(assemblies_data[assembly_name])} residues")
    else:
        print(f"❌ Missing: {csv_file}")
        raise FileNotFoundError(f"Required file not found: {csv_file}")

# ========================================
# MERGE ALL DATA
# ========================================
print("\nMerging all assembly data...")

# Start with dimer as base
merged = assemblies_data['dimer'][['sym_position', 'residue_number', 'residue_name']].copy()

# Add data from each assembly
for assembly_name, df in assemblies_data.items():
    suffix = assembly_name.replace('-', '_')
    
    # Select columns we need
    cols_to_add = ['sym_position', 'residue_number', 'residue_name',
                   'ddG_mean', 'ddG_total_std', 'ddG_total_sem',
                   'ddG_frame_std', 'ddG_symmetry_std']
    
    df_subset = df[cols_to_add].copy()
    df_subset.columns = ['sym_position', 'residue_number', 'residue_name',
                         f'ddG_{suffix}', f'std_{suffix}', f'sem_{suffix}',
                         f'frame_std_{suffix}', f'sym_std_{suffix}']
    
    merged = merged.merge(df_subset, on=['sym_position', 'residue_number', 'residue_name'])

print(f"✓ Merged data: {len(merged)} residues")

# ========================================
# CALCULATE VIRAL CAPSID CONTRIBUTIONS
# ========================================
print("\nCalculating viral capsid contributions...")

# Total ΔΔG using raw assembly values
merged['ddG_capsid'] = (
    12 * merged['ddG_5_fold'] +
    20 * merged['ddG_3_fold'] +
    60 * merged['ddG_dimer_para'] +
    60 * merged['ddG_dimer_perp'] +
    (-270) * merged['ddG_dimer']  # Net coefficient
)

# Correct variance propagation using raw assembly values
# σ² = Σ(coeff² × σ²_assembly)
merged['ddG_capsid_total_std'] = np.sqrt(
    144 * merged['std_5_fold']**2 +
    400 * merged['std_3_fold']**2 +
    3600 * merged['std_dimer_para']**2 +
    3600 * merged['std_dimer_perp']**2 +
    72900 * merged['std_dimer']**2
)

# Correct SEM propagation
merged['ddG_capsid_total_sem'] = np.sqrt(
    144 * merged['sem_5_fold']**2 +
    400 * merged['sem_3_fold']**2 +
    3600 * merged['sem_dimer_para']**2 +
    3600 * merged['sem_dimer_perp']**2 +
    72900 * merged['sem_dimer']**2
)

# Decompose into frame and symmetry components
merged['ddG_capsid_frame_std'] = np.sqrt(
    144 * merged['frame_std_5_fold']**2 +
    400 * merged['frame_std_3_fold']**2 +
    3600 * merged['frame_std_dimer_para']**2 +
    3600 * merged['frame_std_dimer_perp']**2 +
    72900 * merged['frame_std_dimer']**2
)

merged['ddG_capsid_sym_std'] = np.sqrt(
    144 * merged['sym_std_5_fold'].fillna(0)**2 +
    400 * merged['sym_std_3_fold'].fillna(0)**2 +
    3600 * merged['sym_std_dimer_para'].fillna(0)**2 +
    3600 * merged['sym_std_dimer_perp'].fillna(0)**2 +
    72900 * merged['sym_std_dimer'].fillna(0)**2
)

# Also calculate interface contributions for interpretation (but don't use for final variance)
merged['vertex_5fold'] = merged['ddG_5_fold'] - 5 * merged['ddG_dimer']
merged['vertex_3fold'] = merged['ddG_3_fold'] - 3 * merged['ddG_dimer']
merged['inter_dimer_para'] = merged['ddG_dimer_para'] - 2 * merged['ddG_dimer']
merged['inter_dimer_perp'] = merged['ddG_dimer_perp'] - 2 * merged['ddG_dimer']

# Select output columns
output_cols = [
    'sym_position', 'residue_number', 'residue_name',
    'ddG_capsid', 'ddG_capsid_total_std', 'ddG_capsid_total_sem',
    'ddG_capsid_frame_std', 'ddG_capsid_sym_std',
    # Interface contributions for interpretation
    'vertex_5fold', 'vertex_3fold', 'inter_dimer_para', 'inter_dimer_perp',
    # Raw assembly values
    'ddG_5_fold', 'ddG_3_fold', 'ddG_dimer_para', 'ddG_dimer_perp', 'ddG_dimer'
]

capsid_df = merged[output_cols].copy()

# ========================================
# SAVE RESULTS
# ========================================
csv_out = f"{OUTPUT_DIR}/viral_capsid_contributions.csv"
capsid_df.to_csv(csv_out, index=False)
print(f"\n✓ Saved: {csv_out}")

# Summary statistics
print(f"\n{'='*70}")
print("VIRAL CAPSID CONTRIBUTION STATISTICS")
print(f"{'='*70}")
print(f"Total residues: {len(capsid_df)}")
print(f"\nΔΔG_capsid (per residue):")
print(f"  Mean: {capsid_df['ddG_capsid'].mean():.3f} ± {capsid_df['ddG_capsid'].std():.3f} kcal/mol")
print(f"  Range: [{capsid_df['ddG_capsid'].min():.3f}, {capsid_df['ddG_capsid'].max():.3f}]")
print(f"  Mean total std: {capsid_df['ddG_capsid_total_std'].mean():.3f} kcal/mol")
print(f"  Mean total SEM: {capsid_df['ddG_capsid_total_sem'].mean():.3f} kcal/mol")

# ========================================
# VARIANCE DECOMPOSITION
# ========================================
print(f"\n{'='*70}")
print("VARIANCE DECOMPOSITION (which assembly contributes most?)")
print(f"{'='*70}")

# Calculate contribution to total variance from each assembly
var_5fold = (144 * merged['std_5_fold']**2).mean()
var_3fold = (400 * merged['std_3_fold']**2).mean()
var_para = (3600 * merged['std_dimer_para']**2).mean()
var_perp = (3600 * merged['std_dimer_perp']**2).mean()
var_dimer = (72900 * merged['std_dimer']**2).mean()
total_var = var_5fold + var_3fold + var_para + var_perp + var_dimer

print(f"Mean variance contributions:")
print(f"  5-fold (12×):     {var_5fold:.2f} ({100*var_5fold/total_var:.1f}%)")
print(f"  3-fold (20×):     {var_3fold:.2f} ({100*var_3fold/total_var:.1f}%)")
print(f"  Para (60×):       {var_para:.2f} ({100*var_para/total_var:.1f}%)")
print(f"  Perp (60×):       {var_perp:.2f} ({100*var_perp/total_var:.1f}%)")
print(f"  Dimer (-270×):    {var_dimer:.2f} ({100*var_dimer/total_var:.1f}%)")
print(f"  Total:            {total_var:.2f}")

# ========================================
# CREATE PLOTS
# ========================================
print(f"\n{'='*70}")
print("CREATING PLOTS")
print(f"{'='*70}")

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle('Viral Capsid Contribution Analysis', fontsize=14, fontweight='bold')

# Plot 1: ΔΔG_capsid distribution
axes[0, 0].hist(capsid_df['ddG_capsid'], bins=50, edgecolor='black', alpha=0.7)
axes[0, 0].axvline(0, color='red', linestyle='--', linewidth=2)
axes[0, 0].set_xlabel('Capsid ΔΔG (kcal/mol)', fontsize=10)
axes[0, 0].set_ylabel('Count', fontsize=10)
axes[0, 0].set_title('Distribution of Capsid ΔΔG', fontsize=11, fontweight='bold')
axes[0, 0].grid(alpha=0.3)

# Plot 2: Total std distribution
axes[0, 1].hist(capsid_df['ddG_capsid_total_std'], bins=50, edgecolor='black', 
                alpha=0.7, color='orange')
axes[0, 1].set_xlabel('Total Std (kcal/mol)', fontsize=10)
axes[0, 1].set_ylabel('Count', fontsize=10)
axes[0, 1].set_title('Distribution of Uncertainty', fontsize=11, fontweight='bold')
axes[0, 1].grid(alpha=0.3)

# Plot 3: Variance decomposition pie chart
var_contributions = [var_5fold, var_3fold, var_para, var_perp, var_dimer]
var_labels = ['5-fold\n(12×)', '3-fold\n(20×)', 'Para\n(60×)', 'Perp\n(60×)', 'Dimer\n(-270×)']
axes[0, 2].pie(var_contributions, labels=var_labels, autopct='%1.1f%%', startangle=90)
axes[0, 2].set_title('Variance Contributions', fontsize=11, fontweight='bold')

# Plot 4: Sequence profile (position 0)
pos0 = capsid_df[capsid_df['sym_position'] == 0].sort_values('residue_number')
axes[1, 0].errorbar(pos0['residue_number'], pos0['ddG_capsid'],
                   yerr=pos0['ddG_capsid_total_std'],
                   fmt='o-', markersize=2, alpha=0.7, capsize=1, linewidth=0.5)
axes[1, 0].axhline(0, color='black', linestyle='--', linewidth=1)
axes[1, 0].set_xlabel('Residue Number', fontsize=10)
axes[1, 0].set_ylabel('Capsid ΔΔG (kcal/mol)', fontsize=10)
axes[1, 0].set_title('Sequence Profile (Asymmetric Unit)', fontsize=11, fontweight='bold')
axes[1, 0].grid(alpha=0.3)

# Plot 5: Std vs Mean
axes[1, 1].scatter(capsid_df['ddG_capsid'], capsid_df['ddG_capsid_total_std'],
                  alpha=0.3, s=10)
axes[1, 1].set_xlabel('Mean Capsid ΔΔG (kcal/mol)', fontsize=10)
axes[1, 1].set_ylabel('Total Std (kcal/mol)', fontsize=10)
axes[1, 1].set_title('Uncertainty vs Mean', fontsize=11, fontweight='bold')
axes[1, 1].grid(alpha=0.3)

# Plot 6: Interface contributions breakdown
pos0 = capsid_df[capsid_df['sym_position'] == 0].sort_values('residue_number')
axes[1, 2].plot(pos0['residue_number'], 12*pos0['vertex_5fold'], label='12×5-fold', alpha=0.7, linewidth=1)
axes[1, 2].plot(pos0['residue_number'], 20*pos0['vertex_3fold'], label='20×3-fold', alpha=0.7, linewidth=1)
axes[1, 2].plot(pos0['residue_number'], 60*pos0['inter_dimer_para'], label='60×Para', alpha=0.7, linewidth=1)
axes[1, 2].plot(pos0['residue_number'], 60*pos0['inter_dimer_perp'], label='60×Perp', alpha=0.7, linewidth=1)
axes[1, 2].plot(pos0['residue_number'], -270*pos0['ddG_dimer'], label='-270×Dimer', alpha=0.7, linewidth=1)
axes[1, 2].axhline(0, color='black', linestyle='--', linewidth=1)
axes[1, 2].set_xlabel('Residue Number', fontsize=10)
axes[1, 2].set_ylabel('Contribution (kcal/mol)', fontsize=10)
axes[1, 2].set_title('Interface Contributions', fontsize=11, fontweight='bold')
axes[1, 2].legend(fontsize=7)
axes[1, 2].grid(alpha=0.3)

plt.tight_layout()
plot_file = f"{OUTPUT_DIR}/viral_capsid_analysis.png"
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
plt.show()
print(f"✓ Saved: {plot_file}")
plt.close()

# ========================================
# TOP RESIDUES
# ========================================
print(f"\n{'='*70}")
print("TOP RESIDUES FOR VIRAL CAPSID STABILITY")
print(f"{'='*70}")

print("\nTOP 20 MOST IMPORTANT (highest ΔΔG_capsid):")
top_stab = capsid_df.nlargest(20, 'ddG_capsid')[
    ['residue_name', 'sym_position', 'residue_number', 
     'ddG_capsid', 'ddG_capsid_total_sem']
]
print(top_stab.to_string(index=False))

print("\nTOP 20 MOST DESTABILIZING (lowest ΔΔG_capsid):")
top_destab = capsid_df.nsmallest(20, 'ddG_capsid')[
    ['residue_name', 'sym_position', 'residue_number', 
     'ddG_capsid', 'ddG_capsid_total_sem']
]
print(top_destab.to_string(index=False))

print(f"\n{'='*70}")
print("ANALYSIS COMPLETE")
print(f"{'='*70}")
print("\nKey insight: Proper error propagation accounts for correlated dimer measurements")
print("Net dimer coefficient: -270 (not +90)")

In [None]:
ALASCAN_BASE

In [None]:
# ========================================
# INTERACTIVE ALASCAN VISUALIZATION
# ========================================
import ipywidgets as widgets
from IPython.display import display
import pandas as pd
import nglview as nv
from Bio.PDB import PDBParser, PDBIO

# Analysis mode configurations
MODES = {
    'Symmetry Averaged (Full)': {
        'output_dir': f"{ALASCAN_BASE}/alascan_symmetry_averaged_results",
        'pdb_dir': f"{ALASCAN_BASE}/alascan_symmetry_averaged_results",
        'csv_end': "_symmetry_averaged.csv",
        'pdb_end': "_symmetry_averaged_full.pdb",
        'ddg_col': "ddG_mean",
        'assemblies': ["3-fold", "5-fold", "dimer_para", "dimer_perp", "dimer"],
        'single_csv': False
    },
    'Symmetry Averaged (Asymmetric)': {
        'output_dir': f"{ALASCAN_BASE}/alascan_symmetry_averaged_results",
        'pdb_dir': f"{ALASCAN_BASE}/alascan_symmetry_averaged_results",
        'csv_end': "_symmetry_averaged.csv",
        'pdb_end': "_symmetry_averaged_asymmetric.pdb",
        'ddg_col': "ddG_mean",
        'assemblies': ["3-fold", "5-fold", "dimer_para", "dimer_perp", "dimer"],
        'single_csv': False
    },
    'Frame Averaged': {
        'output_dir': f"{ALASCAN_BASE}/alascan_averaged_results",
        'pdb_dir': f"{ALASCAN_BASE}/alascan_symmetry_averaged_results",
        'csv_end': "_frame_averaged.csv",
        'pdb_end': "_frame_averaged.pdb",
        'ddg_col': "ddG_mean",
        'assemblies': ["3-fold", "5-fold", "dimer_para", "dimer_perp", "dimer"],
        'single_csv': False
    },
    'Interface Contributions (Full)': {
        'output_dir': f"{ALASCAN_BASE}/alascan_interface_contributions",
        'pdb_dir': f"{ALASCAN_BASE}/alascan_symmetry_averaged_results",
        'csv_end': "_interface.csv",
        'pdb_end': "_symmetry_averaged_full.pdb",
        'ddg_col': "ddG_interface",
        'assemblies': ["3-fold", "5-fold", "dimer_para", "dimer_perp"],
        'single_csv': False
    },    
    'Interface Contributions (Asymmetric)': {
        'output_dir': f"{ALASCAN_BASE}/alascan_interface_contributions",
        'pdb_dir': f"{ALASCAN_BASE}/alascan_symmetry_averaged_results",
        'csv_end': "_interface.csv",
        'pdb_end': "_symmetry_averaged_asymmetric.pdb",
        'ddg_col': "ddG_interface",
        'assemblies': ["3-fold", "5-fold", "dimer_para", "dimer_perp"],
        'single_csv': False
    },
    'Viral Capsid': {
        'output_dir': f"{ALASCAN_BASE}/alascan_viral_capsid",
        'pdb_dir': f"{ALASCAN_BASE}/alascan_symmetry_averaged_results",
        'csv_file': "viral_capsid_contributions.csv",
        'pdb_end': "_symmetry_averaged_asymmetric.pdb",
        'ddg_col': "ddG_capsid",
        'assemblies': ["3-fold"],
        'single_csv': True
    }
}

# Widgets
mode_dropdown = widgets.Dropdown(
    options=list(MODES.keys()),
    value='Symmetry Averaged (Full)',
    description='Mode:',
)

assembly_dropdown = widgets.Dropdown(
    options=MODES['Symmetry Averaged (Full)']['assemblies'],
    value=MODES['Symmetry Averaged (Full)']['assemblies'][0],
    description='Assembly:',
)

steepness_slider = widgets.FloatSlider(
    value=1.0,
    min=0.2,
    max=10.0,
    step=0.1,
    description='Steepness:',
    continuous_update=False,
    readout_format='.1f',
)

adaptive_checkbox = widgets.Checkbox(
    value=False,
    description='Adaptive (IQR-based)',
)

view_widget = widgets.Output()

def update_mode(change=None):
    """Update assembly options when mode changes"""
    mode_config = MODES[mode_dropdown.value]
    assembly_dropdown.options = mode_config['assemblies']
    assembly_dropdown.value = mode_config['assemblies'][0]
    assembly_dropdown.disabled = mode_config['single_csv']
    update_view()

def update_view(change=None):
    """Update visualization"""
    mode_config = MODES[mode_dropdown.value]
    assembly = assembly_dropdown.value
    steepness = steepness_slider.value
    adaptive = adaptive_checkbox.value
    
    with view_widget:
        view_widget.clear_output(wait=True)
        
        # Load data
        if mode_config['single_csv']:
            csv_file = f"{mode_config['output_dir']}/{mode_config['csv_file']}"
            print(f"Note: Viewing {mode_dropdown.value} (all assemblies combined)")
        else:
            csv_file = f"{mode_config['output_dir']}/{assembly}{mode_config['csv_end']}"
        
        stats = pd.read_csv(csv_file)
        ddg_col = mode_config['ddg_col']
        
        # Calculate color scale
        has_sym_position = 'sym_position' in stats.columns
        has_chain = 'chain' in stats.columns
        
        if adaptive:
            q1 = stats[ddg_col].quantile(0.25)
            q3 = stats[ddg_col].quantile(0.75)
            iqr = q3 - q1
            median = stats[ddg_col].median()
            vmin = median - (steepness * 0.5 * iqr)
            vmax = median + (steepness * 0.5 * iqr)
        else:
            median = stats[ddg_col].median()
            vmin = median - steepness
            vmax = median + steepness
        
        # Display info
        title = f"{mode_dropdown.value}"
        if not mode_config['single_csv']:
            title += f" - {assembly}"
        title += f" (steepness={steepness:.1f}, adaptive={adaptive})"
        
        print(f"\n{title}:")
        print(f"  ΔΔG range: {stats[ddg_col].min():.2f} to {stats[ddg_col].max():.2f}")
        print(f"  Color scale: {vmin:.2f} to {vmax:.2f} (median: {median:.2f})")
        
        # Distribution
        n_red = len(stats[stats[ddg_col] > vmax])
        n_blue = len(stats[stats[ddg_col] < vmin])
        n_white = len(stats[(stats[ddg_col] >= vmin) & (stats[ddg_col] <= vmax)])
        total = len(stats)
        
        print(f"  Distribution: {n_blue} blue ({100*n_blue/total:.1f}%) | "
              f"{n_white} white ({100*n_white/total:.1f}%) | "
              f"{n_red} red ({100*n_red/total:.1f}%)")
        
        # Load PDB and map ΔΔG to B-factors
        pdb_file = f"{mode_config['pdb_dir']}/{assembly}{mode_config['pdb_end']}"
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure('protein', pdb_file)
        
        for model in structure:
            for chain in model:
                chain_id = chain.id
                for residue in chain:
                    if residue.id[0] == ' ':
                        res_num = residue.id[1]
                        
                        # Find matching data
                        if has_sym_position:
                            chain_idx = ord(chain_id) - ord('A')
                            sym_pos = chain_idx % 4
                            match = stats[(stats['sym_position'] == sym_pos) & 
                                        (stats['residue_number'] == res_num)]
                        elif has_chain:
                            match = stats[(stats['chain'] == chain_id) & 
                                        (stats['residue_number'] == res_num)]
                        else:
                            match = stats[stats['residue_number'] == res_num]
                        
                        ddg_value = match[ddg_col].iloc[0] if len(match) > 0 else median
                        for atom in residue:
                            atom.set_bfactor(ddg_value)
        
        # Save and display
        temp_pdb = f"/tmp/temp_{mode_dropdown.value.replace(' ', '_')}_{assembly}.pdb"
        io = PDBIO()
        io.set_structure(structure)
        io.save(temp_pdb)
        
        view = nv.show_structure_file(temp_pdb)
        view.clear_representations()
        view.add_representation('cartoon', 
                                selection='protein',
                                color_scheme='bfactor',
                                color_scale='rwb',
                                color_domain=[vmin, vmax])
        view.add_representation('spacefill', selection='380', opacity=0.5)
        #                        color_scheme='bfactor',
        #                        color_scale='rwb',
        #                        color_domain=[vmin, vmax])
        
        view.add_representation('spacefill', selection='331', opacity=0.5)
        #                        color_scheme='bfactor',
        #                        color_scale='rwb',
        #                        color_domain=[vmin, vmax])
        
        view.add_representation('spacefill', selection='299', opacity=0.5)
        #                        color_scheme='bfactor',
        #                        color_scale='rwb',
        #                        color_domain=[vmin, vmax])
        
        view.add_representation('spacefill', selection='325', opacity=0.5)
        #                        color_scheme='bfactor',
        #                        color_scale='rwb',
        #                        color_domain=[vmin, vmax])
        
        view.add_representation('spacefill', selection='52', opacity=0.5)
        #                        color_scheme='bfactor',
        #                        color_scale='rwb',
        #                        color_domain=[vmin, vmax])
        
        view.add_representation('spacefill', selection='56', opacity=0.5)
        #                        color_scheme='bfactor',
        #                        color_scale='rwb',
        #                        color_domain=[vmin, vmax])
        
        view.add_representation('spacefill', selection='170', opacity=0.5)
        #                        color_scheme='bfactor',
        #                        color_scale='rwb',
        #                        color_domain=[vmin, vmax])
        
        view.add_representation('spacefill', selection='173', opacity=0.5)
        #                        color_scheme='bfactor',
        #                        color_scale='rwb',
        #                        color_domain=[vmin, vmax])
        
        view.add_representation('spacefill', selection='407', opacity=0.5)
        #                        color_scheme='bfactor',
        #                        color_scale='rwb',
        #                        color_domain=[vmin, vmax])
        
        view.add_representation('spacefill', selection='416', opacity=0.5)
        #                        color_scheme='bfactor',
        #                        color_scale='rwb',
        #                        color_domain=[vmin, vmax])
        
        view.add_representation('spacefill', selection='461', opacity=0.5)
        #                        color_scheme='bfactor',
        #                        color_scale='rwb',
        #                        color_domain=[vmin, vmax])
        
        #view.add_representation('spacefill', selection='81', opacity=0.5)
        #                        color_scheme='bfactor',
        #                        color_scale='rwb',
        #                        color_domain=[vmin, vmax])
        view.center()
        display(view)

# Connect callbacks
mode_dropdown.observe(update_mode, names='value')
assembly_dropdown.observe(update_view, names='value')
steepness_slider.observe(update_view, names='value')
adaptive_checkbox.observe(update_view, names='value')

# Display
controls = widgets.VBox([
    mode_dropdown,
    assembly_dropdown,
    steepness_slider,
    adaptive_checkbox
])

display(controls)
display(view_widget)
update_view()