# Structural Variant Handling and Prediction Alignment

This notebook demonstrates how supremo_lite handles complex **structural variants** (SVs) including inversions (INV), duplications (DUP), and breakends (BND). These variants require special coordinate transformation logic for prediction alignment.

## Learning Objectives

- 🔄 **Inversions (INV)**: Understand reverse complement operations and cross-pattern masking
- 📋 **Duplications (DUP)**: Learn tandem duplication handling and SVLEN usage
- 🔗 **Breakends (BND)**: Explore chimeric reference assembly for translocations
- 📊 **2D Prediction Alignment**: Visualize how SVs affect contact map predictions
- 🧬 **Coordinate Transformations**: Master complex genomic coordinate handling

## Structural Variant Types

| Type | Description | Sequence Change | Prediction Impact |
|------|-------------|-----------------|-------------------|
| **INV** | Inversion | Reverse complement of region | Cross-pattern masking (2D) |
| **DUP** | Duplication | Tandem repeat of region | Length increase |
| **BND** | Breakend/Translocation | Join two distant loci | Chimeric reference |

## Setup

In [None]:
import supremo_lite as sl
from supremo_lite.mock_models import TestModel, TestModel2D, TORCH_AVAILABLE
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pyfaidx import Fasta
import os

sns.set_style("whitegrid")

if not TORCH_AVAILABLE:
    raise ImportError("PyTorch required. Install with: pip install torch")

print(f"supremo_lite version: {sl.__version__}")

# Load test data
test_data_dir = "../../tests/data"
reference = Fasta(os.path.join(test_data_dir, "test_genome.fa"))

## Part 1: Inversions (INV)

Inversions reverse complement a genomic region. This creates unique challenges for prediction alignment.

In [None]:
# Load inversion variants
inv_vcf = os.path.join(test_data_dir, "inv", "inv.vcf")
inv_variants = sl.read_vcf(inv_vcf)

print("Inversion variants:")
print(inv_variants[['chrom', 'pos', 'ref', 'alt', 'info']])

# Generate sequences for first inversion
ref_seqs, alt_seqs, metadata = sl.get_alt_ref_sequences(
    reference_fn=reference,
    variants_fn=inv_variants.iloc[:1],
    seq_len=200,
    encode=True
)

print(f"\nGenerated sequences for inversion:")
print(f"  Variant type: {metadata[0]['variant_type']}")
print(f"  Position: {metadata[0]['chrom']}:{metadata[0]['variant_pos1']}")
print(f"  Reference allele: {metadata[0]['ref']}")
print(f"  Alternate allele: {metadata[0]['alt']}")

### Visualizing Inversion Sequence Changes

In [None]:
# Get raw sequences to see the inversion
ref_raw, alt_raw, _ = sl.get_alt_ref_sequences(
    reference_fn=reference,
    variants_fn=inv_variants.iloc[:1],
    seq_len=80,  # Smaller window for visualization
    encode=False
)

print("Inversion example (80bp window):")
print("="*80)
print(f"Reference: {ref_raw[0]}")
print(f"Alternate: {alt_raw[0]}")
print("="*80)

# Find the inverted region
# For <INV> variants, look for the region that's different
print("\nLook for the region that's reverse complemented in the alternate sequence!")

### 2D Predictions for Inversions: Cross-Pattern Masking

When a region is inverted, the contact map shows a unique cross-pattern. Both the rows AND columns at the inverted position must be masked.

In [None]:
# Initialize 2D model
model_2d = TestModel2D(n_targets=1, bin_size=8, crop_length=10, diag_offset=2)

# Run predictions
ref_preds_2d = model_2d(ref_seqs)
alt_preds_2d = model_2d(alt_seqs)

# Align predictions
n_bins = (200 - 2*model_2d.crop_length) // model_2d.bin_size
ref_aligned_2d, alt_aligned_2d = sl.align_predictions_by_coordinate(
    ref_pred=ref_preds_2d[0, 0],
    alt_pred=alt_preds_2d[0, 0],
    metadata=metadata[0],
    prediction_type="2D",
    bin_size=model_2d.bin_size,
    crop_length=model_2d.crop_length,
    diag_offset=model_2d.diag_offset,
    matrix_size=n_bins
)

print(f"2D predictions aligned for inversion")
print(f"  Matrix shape: {ref_aligned_2d.shape}")

In [None]:
# Visualize inversion cross-pattern
ref_2d_np = ref_aligned_2d.cpu().numpy() if hasattr(ref_aligned_2d, 'cpu') else ref_aligned_2d
alt_2d_np = alt_aligned_2d.cpu().numpy() if hasattr(alt_aligned_2d, 'cpu') else alt_aligned_2d

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Reference
im1 = axes[0].imshow(ref_2d_np, cmap='Reds', vmin=0, vmax=1, origin='lower')
axes[0].set_title('Reference Contact Map', fontsize=12, fontweight='bold')
plt.colorbar(im1, ax=axes[0], label='Contact')

# Alternate (with inversion)
im2 = axes[1].imshow(alt_2d_np, cmap='Blues', vmin=0, vmax=1, origin='lower')
axes[1].set_title('Alternate Contact Map (Inverted)', fontsize=12, fontweight='bold')
plt.colorbar(im2, ax=axes[1], label='Contact')

# Show masked regions
mask = np.isnan(alt_2d_np).astype(float)
im3 = axes[2].imshow(mask, cmap='RdYlBu_r', vmin=0, vmax=1, origin='lower')
axes[2].set_title('Masked Regions (Cross-Pattern)', fontsize=12, fontweight='bold')
plt.colorbar(im3, ax=axes[2], label='Masked')

for ax in axes:
    ax.set_xlabel('Bin position')
    ax.set_ylabel('Bin position')

fig.suptitle(f'Inversion Cross-Pattern Masking\n'
             f'{metadata[0]["chrom"]}:{metadata[0]["variant_pos1"]} {metadata[0]["alt"]}',
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nKey observation:")
print("• Right panel shows the cross-pattern (rows AND columns masked)")
print("• This is because inversion affects both dimensions in contact space")
print("• supremo_lite correctly implements this cross-pattern masking!")

## Part 2: Duplications (DUP)

Duplications create tandem repeats of genomic regions. The VCF INFO field contains SVLEN and END information.

In [None]:
# Load duplication variants
dup_vcf = os.path.join(test_data_dir, "dup", "dup.vcf")
dup_variants = sl.read_vcf(dup_vcf)

print("Duplication variants:")
print(dup_variants[['chrom', 'pos', 'ref', 'alt', 'info']])

# Generate sequences
ref_seqs_dup, alt_seqs_dup, metadata_dup = sl.get_alt_ref_sequences(
    reference_fn=reference,
    variants_fn=dup_variants.iloc[:1],
    seq_len=200,
    encode=True
)

print(f"\nDuplication metadata:")
for key, value in metadata_dup[0].items():
    print(f"  {key}: {value}")

### Visualizing Duplication Effects

In [None]:
# Get raw sequences
ref_dup_raw, alt_dup_raw, _ = sl.get_alt_ref_sequences(
    reference_fn=reference,
    variants_fn=dup_variants.iloc[:1],
    seq_len=120,
    encode=False
)

print("Duplication example:")
print("="*80)
print(f"Reference: {ref_dup_raw[0]}")
print(f"Alternate: {alt_dup_raw[0]}")
print("="*80)
print(f"\nReference length: {len(ref_dup_raw[0])}")
print(f"Alternate length: {len(alt_dup_raw[0])}")
print(f"Length difference: +{len(alt_dup_raw[0]) - len(ref_dup_raw[0])} bp (duplicated region)")

### 1D Prediction Alignment for Duplications

In [None]:
# Run 1D predictions
model_1d = TestModel(n_targets=2, bin_size=8, crop_length=10)
ref_preds_1d = model_1d(ref_seqs_dup)
alt_preds_1d = model_1d(alt_seqs_dup)

# Align
ref_aligned_1d, alt_aligned_1d = sl.align_predictions_by_coordinate(
    ref_pred=ref_preds_1d[0],
    alt_pred=alt_preds_1d[0],
    metadata=metadata_dup[0],
    prediction_type="1D",
    bin_size=model_1d.bin_size,
    crop_length=model_1d.crop_length
)

# Plot
ref_np = ref_aligned_1d[0].cpu().numpy() if hasattr(ref_aligned_1d, 'cpu') else ref_aligned_1d[0]
alt_np = alt_aligned_1d[0].cpu().numpy() if hasattr(alt_aligned_1d, 'cpu') else alt_aligned_1d[0]

plt.figure(figsize=(14, 5))
positions = np.arange(len(ref_np))

plt.plot(positions, ref_np, 'o-', label='Reference', color='steelblue', linewidth=2, markersize=5)
alt_masked = np.ma.masked_invalid(alt_np)
plt.plot(positions, alt_masked, 's-', label='Alternate (Duplicated)', color='coral', linewidth=2, markersize=5)

# Mark duplication region
var_pos = metadata_dup[0]['variant_pos0']
window_start = metadata_dup[0]['window_start']
effective_start = window_start + model_1d.crop_length
var_bin = (var_pos - effective_start) // model_1d.bin_size

plt.axvline(x=var_bin, color='red', linestyle='--', alpha=0.7, label='Duplication start')
plt.xlabel('Bin position', fontsize=11)
plt.ylabel('Prediction value', fontsize=11)
plt.title(f'Duplication Effect on 1D Predictions\n'
          f'{metadata_dup[0]["chrom"]}:{metadata_dup[0]["variant_pos1"]} {metadata_dup[0]["alt"]}',
          fontsize=13, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\nDuplication alignment:")
print("• Reference sequence has predictions at all positions")
print("• Alternate has NaN bins where duplication inserted new bases")
print("• This maintains coordinate correspondence for comparison")

## Part 3: Breakends (BND)

Breakends represent translocations where two distant genomic loci are joined. supremo_lite creates a **chimeric reference** sequence from both breakpoints.

In [None]:
# Load breakend variants
bnd_vcf = os.path.join(test_data_dir, "bnd", "bnd.vcf")
bnd_variants = sl.read_vcf(bnd_vcf)

print("Breakend variants:")
print(bnd_variants[['chrom', 'pos', 'ref', 'alt', 'info']])

# BNDs are processed together as pairs
print("\nNote: BND variants come in pairs that define a translocation")

### Generating Chimeric References for BNDs

In [None]:
# Generate sequences for BND
ref_seqs_bnd, alt_seqs_bnd, metadata_bnd = sl.get_alt_ref_sequences(
    reference_fn=reference,
    variants_fn=bnd_variants,
    seq_len=200,
    encode=True
)

print(f"Generated {len(metadata_bnd)} BND sequence(s)")

# Examine BND metadata
for i, meta in enumerate(metadata_bnd):
    print(f"\nBND {i+1} metadata:")
    print(f"  Variant type: {meta['variant_type']}")
    print(f"  Chromosome 1: {meta['chrom']}")
    print(f"  Position 1: {meta['variant_pos1']}")
    if 'mate_chrom' in meta:
        print(f"  Chromosome 2: {meta['mate_chrom']}")
        print(f"  Position 2: {meta['mate_pos']}")
        print(f"  Fusion name: {meta.get('fusion_name', 'N/A')}")

### Visualizing BND Chimeric Sequences

In [None]:
# Get raw BND sequences
ref_bnd_raw, alt_bnd_raw, meta_bnd_raw = sl.get_alt_ref_sequences(
    reference_fn=reference,
    variants_fn=bnd_variants,
    seq_len=80,
    encode=False
)

if len(ref_bnd_raw) > 0:
    print("BND chimeric reference example:")
    print("="*80)
    print(f"Reference (chimeric): {ref_bnd_raw[0]}")
    print(f"Alternate (fusion):   {alt_bnd_raw[0]}")
    print("="*80)
    print("\nThe reference is a chimera joining two genomic loci!")
    print("This allows prediction comparison at the breakpoint.")

### 2D Contact Maps for BND Fusions

In [None]:
if len(ref_seqs_bnd) > 0:
    # Run 2D predictions on BND
    model_2d_bnd = TestModel2D(n_targets=1, bin_size=8, crop_length=10, diag_offset=2)
    ref_preds_bnd_2d = model_2d_bnd(ref_seqs_bnd[:1])
    alt_preds_bnd_2d = model_2d_bnd(alt_seqs_bnd[:1])
    
    # Align
    n_bins = (200 - 2*model_2d_bnd.crop_length) // model_2d_bnd.bin_size
    ref_aligned_bnd, alt_aligned_bnd = sl.align_predictions_by_coordinate(
        ref_pred=ref_preds_bnd_2d[0, 0],
        alt_pred=alt_preds_bnd_2d[0, 0],
        metadata=metadata_bnd[0],
        prediction_type="2D",
        bin_size=model_2d_bnd.bin_size,
        crop_length=model_2d_bnd.crop_length,
        diag_offset=model_2d_bnd.diag_offset,
        matrix_size=n_bins
    )
    
    # Visualize
    ref_bnd_np = ref_aligned_bnd.cpu().numpy() if hasattr(ref_aligned_bnd, 'cpu') else ref_aligned_bnd
    alt_bnd_np = alt_aligned_bnd.cpu().numpy() if hasattr(alt_aligned_bnd, 'cpu') else alt_aligned_bnd
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    im1 = axes[0].imshow(ref_bnd_np, cmap='Reds', vmin=0, vmax=1, origin='lower')
    axes[0].set_title('Chimeric Reference Contact Map', fontsize=12, fontweight='bold')
    axes[0].set_xlabel('Bin position')
    axes[0].set_ylabel('Bin position')
    plt.colorbar(im1, ax=axes[0], label='Contact')
    
    im2 = axes[1].imshow(alt_bnd_np, cmap='Blues', vmin=0, vmax=1, origin='lower')
    axes[1].set_title('Fusion Alternate Contact Map', fontsize=12, fontweight='bold')
    axes[1].set_xlabel('Bin position')
    axes[1].set_ylabel('Bin position')
    plt.colorbar(im2, ax=axes[1], label='Contact')
    
    fusion_name = metadata_bnd[0].get('fusion_name', 'Unknown')
    fig.suptitle(f'BND/Translocation Contact Maps: {fusion_name}', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    print("\nBND contact map interpretation:")
    print("• Left: Chimeric reference joining two loci")
    print("• Right: Fusion alternate sequence")
    print("• Contact patterns show interactions across the breakpoint")
else:
    print("No paired BND variants found for visualization")

## Comparing All SV Types

Let's create a summary comparison of how each SV type affects predictions:

In [None]:
# Summary table
import pandas as pd

sv_summary = pd.DataFrame([
    {
        'SV Type': 'INV',
        'Sequence Change': 'Reverse complement',
        'Length Change': 0,
        '1D Masking': 'Affected bins',
        '2D Masking': 'Cross-pattern (rows + cols)',
        'Key Challenge': 'Coordinate reversal'
    },
    {
        'SV Type': 'DUP',
        'Sequence Change': 'Tandem repeat',
        'Length Change': '+SVLEN',
        '1D Masking': 'Inserted bins',
        '2D Masking': 'Rows + cols at insertion',
        'Key Challenge': 'Length increase'
    },
    {
        'SV Type': 'BND',
        'Sequence Change': 'Chimeric fusion',
        'Length Change': 'Varies',
        '1D Masking': 'Context-dependent',
        '2D Masking': 'Context-dependent',
        'Key Challenge': 'Distant loci joining'
    }
])

print("\nStructural Variant Summary:")
print(sv_summary.to_string(index=False))

print("\n" + "="*80)
print("KEY INSIGHTS:")
print("="*80)
print("\n1. INV: Cross-pattern masking is crucial for 2D predictions")
print("   - Both rows AND columns must be masked at inversion site")
print("   - Implemented correctly in Phase 1.1 of supremo_lite development")

print("\n2. DUP: Length changes require NaN padding for alignment")
print("   - Reference gets NaN bins where duplication adds bases")
print("   - Maintains genomic coordinate correspondence")

print("\n3. BND: Chimeric reference enables breakpoint analysis")
print("   - Reference assembled from two distant loci")
print("   - Allows prediction comparison at fusion junction")

print("\n4. All SVs: align_predictions_by_coordinate() handles them correctly!")
print("   - Automatic detection of SV type from metadata")
print("   - Appropriate masking strategy applied")
print("   - Returns comparable predictions for analysis")

## Best Practices for SV Analysis

### 1. **Always check metadata**
```python
# SV metadata contains critical information
print(metadata[0]['variant_type'])  # SV_INV, SV_DUP, SV_BND, etc.
if 'sym_variant_end' in metadata[0]:
    print(f"SV spans: {metadata[0]['variant_pos1']} - {metadata[0]['sym_variant_end']}")
```

### 2. **Use appropriate window sizes**
```python
# For large SVs, use larger windows
if variant_type.startswith('SV_'):
    seq_len = 500  # Larger window for SVs
else:
    seq_len = 200  # Standard window
```

### 3. **Verify alignment visually**
```python
# Always plot results for SVs
# Check for expected masking patterns
# INV: Cross-pattern in 2D
# DUP/DEL: NaN regions in affected sequence
```

### 4. **Handle NaN values properly**
```python
# Use masked arrays or nanmean for statistics
import numpy as np
valid_diffs = diff[~np.isnan(diff)]
mean_effect = np.nanmean(np.abs(valid_diffs))
```

## Summary

In this notebook, you learned advanced structural variant handling:

1. ✅ **Inversions (INV)** - Reverse complement and cross-pattern masking
2. ✅ **Duplications (DUP)** - Tandem repeats and SVLEN handling
3. ✅ **Breakends (BND)** - Chimeric references for translocations
4. ✅ **2D Alignment** - Complex masking strategies for contact maps
5. ✅ **Coordinate Transforms** - Maintaining genomic correspondence
6. ✅ **Visualization** - Interpreting SV effects on predictions
7. ✅ **Best Practices** - Guidelines for robust SV analysis

supremo_lite handles all these complex cases automatically through `align_predictions_by_coordinate()`!

## Next Steps

- **[05_saturation_mutagenesis.ipynb](05_saturation_mutagenesis.ipynb)** - Systematic mutagenesis and prediction analysis
- **[User Guide: Prediction Alignment](../user_guide/prediction_alignment.md)** - Detailed documentation on alignment strategies