# Prediction Alignment: From Sequences to Aligned Predictions

Complete workflow from generating variant sequences to running model predictions and aligning them for comparison. Core functionality for analyzing how genetic variants affect model predictions.

## What You'll Learn

- Generate reference and alternate sequences around variants
- Run mock genomic models (1D and 2D predictions)
- Align predictions to account for coordinate changes from variants
- Visualize prediction differences between reference and alternate alleles

## The Complete Workflow

```
Reference Genome + VCF
        ↓
get_alt_ref_sequences()
        ↓
Reference & Alternate Sequences
        ↓
TestModel / TestModel2D
        ↓
Reference & Alternate Predictions
        ↓
align_predictions_by_coordinate()
        ↓
Aligned Predictions (ready for comparison!)
```

## Setup

In [10]:
import supremo_lite as sl
from supremo_lite.mock_models import TestModel, TestModel2D, TORCH_AVAILABLE
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pyfaidx import Fasta
import os

# Set plotting style
sns.set_style("whitegrid")

# Check PyTorch availability
if not TORCH_AVAILABLE:
    raise ImportError("PyTorch is required for this notebook. Install with: pip install torch")

print(f"supremo_lite version: {sl.__version__}")
print(f"PyTorch available: {TORCH_AVAILABLE}")

# Load test data
test_data_dir = "../../tests/data"
reference = Fasta(os.path.join(test_data_dir, "test_genome.fa"))
variants = sl.read_vcf(os.path.join(test_data_dir, "multi", "multi.vcf"))

print(f"\nLoaded {len(variants)} variants from test data")

supremo_lite version: 0.5.4
PyTorch available: True

Loaded 8 variants from test data


## Step 1: Generate Reference and Alternate Sequences

First, we create sequence windows around each variant:

In [11]:
# Generate sequences around first variant
# We'll use a 200bp window (sufficient for our mock models)
seq_len = 200

# Note: get_alt_ref_sequences is a generator that yields chunks
results = list(sl.get_alt_ref_sequences(
    reference_fn=reference,
    variants_fn=variants, 
    seq_len=seq_len,
    encode=True  # Get encoded tensors for models
))

# Unpack from the first chunk
alt_seqs, ref_seqs, metadata = results[0]

print(f"Generated sequences:")
print(f"  Reference sequences shape: {ref_seqs.shape}")
print(f"  Alternate sequences shape: {alt_seqs.shape}")
print(f"  Number of variants: {len(metadata)}")

print("\nFirst variant metadata:")
print(metadata.iloc[0].to_dict())

Generated sequences:
  Reference sequences shape: torch.Size([8, 200, 4])
  Alternate sequences shape: torch.Size([8, 200, 4])
  Number of variants: 8

First variant metadata:
{'chrom': 'chr1', 'window_start': 0, 'window_end': 200, 'variant_pos0': 9, 'variant_pos1': 10, 'ref': 'A', 'alt': 'T', 'variant_type': 'SNV'}


## Understanding Mock Models

Before running predictions, let's understand the mock model architectures:

### TestModel (1D Predictions)
- **Input**: Sequences of shape `(batch, seq_len, 4)`
- **Output**: Predictions of shape `(batch, n_targets, n_bins)`
- **Features**:
  - `bin_size`: Predictions at lower resolution (e.g., 1 prediction per 8bp)
  - `crop_length`: Edge bases removed before prediction
  - Multiple targets (e.g., different histone marks)

### TestModel2D (2D Contact Map Predictions)
- **Input**: Sequences of shape `(batch, seq_len, 4)`
- **Output**: Predictions of shape `(batch, n_targets, n_flattened_ut_bins)`
- **Features**:
  - Contact predictions between genomic positions
  - `diag_offset=2`: Diagonal bins masked during training
  - Flattened upper triangle format
  - `bin_size` and `crop_length` like 1D model

## Step 2: Run 1D Predictions with TestModel

In [5]:
# Initialize 1D model
model_1d = TestModel(
    seq_length=200,
    n_targets=2,      # Predict 2 different signals
    bin_length=8,       # 1 prediction per 8bp
    crop_length=10    # Remove 10bp from each edge
)

print("TestModel (1D) configuration:")
print(f"  Targets: {model_1d.n_targets}")
print(f"  Bin size: {model_1d.bin_size}")
print(f"  Crop length: {model_1d.crop_length}")
print(f"  Input sequence length: {seq_len}")
print(f"  Effective sequence length: {seq_len - 2*model_1d.crop_length}")
print(f"  Number of bins: {(seq_len - 2*model_1d.crop_length) // model_1d.bin_size}")

# Run predictions
ref_preds_1d = model_1d(ref_seqs)
alt_preds_1d = model_1d(alt_seqs)

print(f"\nPrediction shapes:")
print(f"  Reference: {ref_preds_1d.shape}")
print(f"  Alternate: {alt_preds_1d.shape}")
print(f"  Format: (batch={ref_preds_1d.shape[0]}, targets={ref_preds_1d.shape[1]}, bins={ref_preds_1d.shape[2]})")

TypeError: __init__() missing 1 required positional argument: 'seq_length'

## Step 3: Align 1D Predictions

Now we align the predictions to account for coordinate changes caused by the variant:

In [None]:
# Align predictions for the first variant
var_idx = 0

ref_aligned_1d, alt_aligned_1d = sl.align_predictions_by_coordinate(
    ref_pred=ref_preds_1d[var_idx],
    alt_pred=alt_preds_1d[var_idx],
    metadata=metadata[var_idx],
    prediction_type="1D",
    bin_size=model_1d.bin_size,
    crop_length=model_1d.crop_length
)

print(f"Aligned prediction shapes:")
print(f"  Reference aligned: {ref_aligned_1d.shape}")
print(f"  Alternate aligned: {alt_aligned_1d.shape}")

print(f"\nVariant info:")
print(f"  Type: {metadata[var_idx]['variant_type']}")
print(f"  {metadata[var_idx]['ref']} → {metadata[var_idx]['alt']}")
print(f"  Position: {metadata[var_idx]['chrom']}:{metadata[var_idx]['variant_pos1']}")

## Visualizing 1D Aligned Predictions

Let's visualize how the predictions compare:

In [None]:
print("\nInterpretation:")
print("• Blue circles: Reference allele predictions")
print("• Orange squares: Alternate allele predictions")
print("• Red dashed line: Variant position")
print("• NaN values (gaps) show regions affected by indels")

## Step 4: Run 2D Predictions with TestModel2D

Now let's predict contact maps (2D predictions):

In [None]:
# Initialize 2D model
model_2d = TestModel2D(
    n_targets=1,       # Single contact map
    bin_size=8,        # Match 1D model
    crop_length=10,    # Match 1D model
    diag_offset=2      # Mask 2 diagonal bins
)

print("TestModel2D (Contact Map) configuration:")
print(f"  Targets: {model_2d.n_targets}")
print(f"  Bin size: {model_2d.bin_size}")
print(f"  Crop length: {model_2d.crop_length}")
print(f"  Diagonal offset: {model_2d.diag_offset}")
print(f"  Input sequence length: {seq_len}")
print(f"  Number of bins: {(seq_len - 2*model_2d.crop_length) // model_2d.bin_size}")

n_bins = (seq_len - 2*model_2d.crop_length) // model_2d.bin_size
n_ut_bins = (n_bins * (n_bins - 1)) // 2 - (n_bins - model_2d.diag_offset) * model_2d.diag_offset
print(f"  Upper triangle bins (flattened): {n_ut_bins}")

# Run predictions
ref_preds_2d = model_2d(ref_seqs)
alt_preds_2d = model_2d(alt_seqs)

print(f"\n2D Prediction shapes (flattened):")
print(f"  Reference: {ref_preds_2d.shape}")
print(f"  Alternate: {alt_preds_2d.shape}")
print(f"  Format: (batch={ref_preds_2d.shape[0]}, targets={ref_preds_2d.shape[1]}, flattened_bins={ref_preds_2d.shape[2]})")

## Step 5: Align 2D Predictions

Aligning 2D contact maps is more complex because variants affect both dimensions:

In [None]:
# Align 2D predictions for first variant
ref_aligned_2d, alt_aligned_2d = sl.align_predictions_by_coordinate(
    ref_pred=ref_preds_2d[var_idx, 0],  # First target
    alt_pred=alt_preds_2d[var_idx, 0],
    metadata=metadata[var_idx],
    prediction_type="2D",
    bin_size=model_2d.bin_size,
    crop_length=model_2d.crop_length,
    diag_offset=model_2d.diag_offset,
    matrix_size=n_bins  # Required for 2D
)

print(f"Aligned 2D prediction shapes:")
print(f"  Reference aligned: {ref_aligned_2d.shape}")
print(f"  Alternate aligned: {alt_aligned_2d.shape}")
print(f"  Format: (n_bins, n_bins) - square matrix")

## Visualizing 2D Aligned Predictions (Contact Maps)

Contact maps show pairwise interactions between genomic positions:

In [None]:
# Convert to numpy
ref_2d_np = ref_aligned_2d.cpu().numpy() if hasattr(ref_aligned_2d, 'cpu') else ref_aligned_2d
alt_2d_np = alt_aligned_2d.cpu().numpy() if hasattr(alt_aligned_2d, 'cpu') else alt_aligned_2d

# Create figure with three panels
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Reference contact map
im1 = axes[0].imshow(ref_2d_np, cmap='Reds', vmin=0, vmax=1, origin='lower', interpolation='nearest')
axes[0].set_title('Reference Contact Map', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Bin position')
axes[0].set_ylabel('Bin position')
plt.colorbar(im1, ax=axes[0], label='Contact strength')

# Alternate contact map
im2 = axes[1].imshow(alt_2d_np, cmap='Blues', vmin=0, vmax=1, origin='lower', interpolation='nearest')
axes[1].set_title('Alternate Contact Map', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Bin position')
axes[1].set_ylabel('Bin position')
plt.colorbar(im2, ax=axes[1], label='Contact strength')

# Difference map (alt - ref)
# Only compare where both have valid values
diff = np.where(
    np.isnan(ref_2d_np) | np.isnan(alt_2d_np),
    np.nan,
    alt_2d_np - ref_2d_np
)

im3 = axes[2].imshow(diff, cmap='RdBu_r', vmin=-0.5, vmax=0.5, origin='lower', interpolation='nearest')
axes[2].set_title('Difference (Alt - Ref)', fontsize=12, fontweight='bold')
axes[2].set_xlabel('Bin position')
axes[2].set_ylabel('Bin position')
plt.colorbar(im3, ax=axes[2], label='Δ Contact strength')

# Mark variant bin position on all plots
variant_bin = (variant_pos - effective_start) // model_2d.bin_size
for ax in axes:
    ax.axhline(y=variant_bin, color='lime', linestyle='--', linewidth=1, alpha=0.7)
    ax.axvline(x=variant_bin, color='lime', linestyle='--', linewidth=1, alpha=0.7)

fig.suptitle(f'2D Contact Map Alignment: {metadata[var_idx]["variant_type"]} Variant\n'
             f'{metadata[var_idx]["chrom"]}:{metadata[var_idx]["variant_pos1"]} '
             f'{metadata[var_idx]["ref"]} → {metadata[var_idx]["alt"]}',
             fontsize=14, fontweight='bold', y=1.05)
plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("• Left: Reference allele contact map")
print("• Middle: Alternate allele contact map")
print("• Right: Difference map (blue=decreased, red=increased contacts)")
print("• Green lines: Variant bin position")
print("• White/missing values: Regions affected by indels or masked diagonal")

## Understanding Diagonal Masking

2D models often mask bins near the diagonal because interactions at very close distances are uninformative:

In [None]:
# Visualize diagonal masking effect
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Create a simple example matrix
n = 15
example = np.ones((n, n))

# Show diagonal offset effect
diag_offset = 2
for i in range(n):
    for j in range(n):
        if abs(i - j) <= diag_offset:
            example[i, j] = np.nan

# No masking
axes[0].imshow(np.ones((n, n)), cmap='Greys', vmin=0, vmax=1, origin='lower')
axes[0].set_title('Without Diagonal Masking', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Position')
axes[0].set_ylabel('Position')

# With masking
axes[1].imshow(example, cmap='Greys', vmin=0, vmax=1, origin='lower')
axes[1].set_title(f'With Diagonal Masking (offset={diag_offset})', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Position')
axes[1].set_ylabel('Position')

fig.suptitle('Diagonal Masking in Contact Map Predictions', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"\nDiagonal offset = {diag_offset} means:")
print(f"  Bins within {diag_offset} positions of the diagonal are masked (NaN)")
print(f"  This removes self-interactions and very short-range contacts")
print(f"  TestModel2D uses diag_offset={model_2d.diag_offset}")

## Comparing Multiple Variants

Let's compare predictions across different variant types:

In [None]:
print("\nNotice:")
print("• Each variant type (SNV, INS, DEL) affects predictions differently")
print("• Indels (INS/DEL) create gaps (NaN) in the alternate sequence predictions")
print("• SNVs maintain the same number of predictions for ref and alt")

## Key Concepts Summary

### Model Architecture Components

1. **Binning (`bin_size`)**
   - Reduces resolution: 1 prediction per N base pairs
   - Example: `bin_size=8` means 8bp → 1 prediction
   - More efficient than per-base predictions

2. **Edge Cropping (`crop_length`)**
   - Removes bases from sequence edges before prediction
   - Accounts for edge effects in convolutional models
   - Example: `crop_length=10` removes 10bp from each end

3. **Diagonal Masking (`diag_offset`, 2D only)**
   - Masks bins near the diagonal in contact maps
   - Removes self-interactions and very short-range contacts
   - Example: `diag_offset=2` masks 2 bins from diagonal

### Prediction Alignment

- **Why align?** Indels shift genomic coordinates between ref and alt
- **How?** Insert NaN bins in the shorter sequence to maintain coordinate correspondence
- **1D alignment**: Masks affected bins in the prediction vector
- **2D alignment**: Masks affected rows AND columns (cross-pattern for inversions)

### Output Interpretation

- **Solid lines/values**: Valid predictions at those genomic positions
- **Gaps/NaN**: Regions affected by indels or diagonal masking
- **Differences**: Where variant changes model predictions

## Summary

In this notebook, you learned the complete prediction alignment workflow:

1. Generate sequences - Create ref/alt windows around variants
2. 1D predictions - Run TestModel for genomic signal predictions
3. 2D predictions - Run TestModel2D for contact map predictions
4. Align predictions - Account for coordinate changes from variants
5. Visualize results - Compare ref vs alt predictions
6. Understand models - Binning, cropping, diagonal masking concepts

This workflow is the foundation for analyzing how genetic variants affect genomic predictions.

## Next Steps

- **[04_structural_variants.ipynb](04_structural_variants.ipynb)** - Handle complex structural variants (INV, DUP, BND) with prediction alignment
- **[05_saturation_mutagenesis.ipynb](05_saturation_mutagenesis.ipynb)** - Systematic mutagenesis and prediction analysis