## Setup and Imports

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import py3Dmol
import sys
from pathlib import Path

ROOT = Path.cwd().parent
sys.path.append(str(ROOT))

from src.contact_steering import (
    get_baseline_structure as z_get_baseline,
    run_with_intervention,
    create_targeted_steering_intervention,
    define_hairpin_topology,
    compute_comprehensive_metrics,
    compute_backbone_no_distances,
    load_directions as z_load_directions,
    HBOND_DISTANCE,
)

import warnings
from tqdm import TqdmWarning

warnings.filterwarnings("ignore", category=TqdmWarning)
warnings.filterwarnings("ignore", message=".*mmCIF.*")
warnings.filterwarnings("ignore", category=UserWarning, module="Bio.PDB.DSSP")

In [None]:
from demo_utils import contact_show_protein as show_protein, contact_show_side_by_side as show_side_by_side


def get_hbond_pairs(outputs, topology, threshold=HBOND_DISTANCE):
    """Get list of (res_i, res_j, min_no_dist) for pairs with H-bonds."""
    positions = outputs.positions[-1, 0].cpu()
    no_dists = compute_backbone_no_distances(positions)
    hbond_pairs = []
    for i, j in topology.cross_strand_pairs:
        d_ij = no_dists[i, j].item()
        d_ji = no_dists[j, i].item()
        min_d = min(d_ij, d_ji)
        if min_d < threshold:
            hbond_pairs.append((i, j, min_d))
    return hbond_pairs

## Load Distance Probe Directions

In [None]:
import pandas as pd

directions = z_load_directions(ROOT / "models" / "gradient_directions.pt")
print(f"Directions available for {len(directions.directions)} blocks")

# Load probe evaluation metrics
eval_df = pd.read_csv(ROOT / "models" / "probe_evaluation.csv")

blocks = sorted(directions.directions.keys())
stds  = [directions.stds[b] for b in blocks]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 3.5))
ax1.plot(eval_df['block'], eval_df['r2'], 'o-', markersize=3, color='steelblue')
ax1.set_xlabel('Block'); ax1.set_ylabel('R²')
ax1.set_title('Distance Probe Accuracy (Test R²)'); ax1.grid(alpha=0.3)
ax2.plot(blocks, stds, 'o-', markersize=3, color='coral')
ax2.set_xlabel('Block'); ax2.set_ylabel('Std dev of projections')
ax2.set_title('Projection Spread by Block'); ax2.grid(alpha=0.3)
plt.tight_layout(); plt.show()

## Load Model

In [None]:
from transformers import EsmForProteinFolding, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").to(device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
print(f"Model loaded on {device}")

Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded on cuda


## Contact Induction on Minimal Alpha Helix

Use the same helix-turn-helix peptide from the charge steering notebook.
Steer cross-strand residue pairs toward close contact (target 5.5 \u00c5)
using the pairwise distance probe intervention on z.

In [None]:
# Same helix-turn-helix as the charge steering notebook
helix1 = "AEAAAKEAAAKEAAAK"   # 16 residues
turn   = "GPG"                 # 3 residues
helix2 = "KAAAEKAAAAEKAAAE"   # 16 residues

ind_seq = helix1 + turn + helix2  # 35 residues

# Define hairpin topology over full sequence
ind_topology = define_hairpin_topology(region_start=0, region_end=35, turn_length=3)

# Intervention parameters
ind_win_start = 22
ind_win_end   = 31
ind_mag       = 10       # magnitude in σ
ind_target    = 5.5      # target Cβ-Cβ distance (Å)

print(f"Sequence:  {ind_seq}")
print(f"Length:    {len(ind_seq)}")
print(f"Strand 1: [{ind_topology.strand1_start}:{ind_topology.strand1_end})")
print(f"Turn:     [{ind_topology.turn_start}:{ind_topology.turn_end})")
print(f"Strand 2: [{ind_topology.strand2_start}:{ind_topology.strand2_end})")
print(f"Cross-strand pairs: {ind_topology.cross_strand_pairs}")
print(f"Window: blocks {ind_win_start}-{ind_win_end}, magnitude={ind_mag}σ, target={ind_target}Å")

Sequence:  AEAAAKEAAAKEAAAKGPGKAAAEKAAAAEKAAAE
Length:    35
Strand 1: [0:16)
Turn:     [16:19)
Strand 2: [19:35)
Cross-strand pairs: [(0, 34), (1, 33), (2, 32), (3, 31), (4, 30), (5, 29), (6, 28), (7, 27), (8, 26), (9, 25), (10, 24), (11, 23), (12, 22), (13, 21), (14, 20), (15, 19)]
Window: blocks 22-31, magnitude=10σ, target=5.5Å


In [None]:
with torch.no_grad():
    ind_bl_out, ind_bl_collector = z_get_baseline(model, tokenizer, device, ind_seq)
ind_bl_pdb = model.output_to_pdb(ind_bl_out)[0]

ind_bl_metrics = compute_comprehensive_metrics(ind_bl_out, ind_topology, ind_target)
ind_bl_hbonds = get_hbond_pairs(ind_bl_out, ind_topology)

print(f"Baseline: {ind_bl_metrics['n_potential_hbonds']} H-bonds, "
      f"mean Cβ dist {ind_bl_metrics['mean_cb_dist']:.1f}Å, "
      f"contacts(6Å)={ind_bl_metrics['n_contacts_6A']}/{len(ind_topology.cross_strand_pairs)}")

show_protein(ind_bl_pdb, ind_topology)

Baseline: 0 H-bonds, mean Cβ dist 10.7Å, contacts(6Å)=1/16


<py3Dmol.view at 0x7e1531ae17b0>

In [None]:
ind_blocks = set(range(ind_win_start, ind_win_end + 1))

ind_z_interventions = create_targeted_steering_intervention(
    topology=ind_topology,
    seq_len=len(ind_seq),
    directions=directions,
    blocks=list(ind_blocks),
    baseline_z_list=ind_bl_collector.z_blocks,
    target_distance=ind_target,
    magnitude=ind_mag,
    device=device,
)

with torch.no_grad():
    ind_st_out = run_with_intervention(
        model, tokenizer, device, ind_seq,
        ind_blocks, ind_z_interventions,
    )
ind_st_pdb = model.output_to_pdb(ind_st_out)[0]

ind_st_metrics = compute_comprehensive_metrics(ind_st_out, ind_topology, ind_target)
ind_st_hbonds = get_hbond_pairs(ind_st_out, ind_topology)

print(f"Steered: {ind_st_metrics['n_potential_hbonds']} H-bonds, "
      f"mean Cβ dist {ind_st_metrics['mean_cb_dist']:.1f}Å, "
      f"contacts(6Å)={ind_st_metrics['n_contacts_6A']}/{len(ind_topology.cross_strand_pairs)}")
for i, j, d in ind_st_hbonds:
    print(f"  H-bond: residues {i}-{j}: {d:.2f} Å")

show_protein(ind_st_pdb, ind_topology, ind_st_hbonds)

Steered: 13 H-bonds, mean Cβ dist 5.2Å, contacts(6Å)=12/16
  H-bond: residues 0-34: 2.02 Å
  H-bond: residues 1-33: 3.37 Å
  H-bond: residues 3-31: 3.34 Å
  H-bond: residues 4-30: 3.11 Å
  H-bond: residues 5-29: 2.15 Å
  H-bond: residues 6-28: 3.02 Å
  H-bond: residues 7-27: 2.17 Å
  H-bond: residues 8-26: 1.48 Å
  H-bond: residues 9-25: 2.53 Å
  H-bond: residues 10-24: 1.90 Å
  H-bond: residues 11-23: 1.31 Å
  H-bond: residues 12-22: 2.88 Å
  H-bond: residues 13-21: 2.69 Å


<py3Dmol.view at 0x7e15320c38e0>

In [None]:
show_side_by_side(
    ind_bl_pdb, ind_st_pdb, ind_topology,
    hb_left=ind_bl_hbonds, hb_right=ind_st_hbonds,
    label_left="Baseline (helix)",
    label_right=f"Steered (blocks {ind_win_start}-{ind_win_end}, {ind_mag}σ, target={ind_target}Å)",
)

Left: Baseline (helix)  |  Right: Steered (blocks 22-31, 10σ, target=5.5Å)


In [None]:
bm, sm = ind_bl_metrics, ind_st_metrics
delta = "\u0394"

print(f"{'Metric':<30} {'Baseline':>12} {'Steered':>12} {delta:>8}")
print("-" * 64)
print(f"{'Potential H-bonds':<30} {bm['n_potential_hbonds']:>12} {sm['n_potential_hbonds']:>12} {sm['n_potential_hbonds'] - bm['n_potential_hbonds']:>+8}")
print(f"{'H-bond fraction':<30} {bm['hbond_fraction']:>12.2f} {sm['hbond_fraction']:>12.2f} {sm['hbond_fraction'] - bm['hbond_fraction']:>+8.2f}")
print(f"{'Mean N-O dist (Å)':<30} {bm['mean_no_dist']:>12.1f} {sm['mean_no_dist']:>12.1f} {sm['mean_no_dist'] - bm['mean_no_dist']:>+8.1f}")
print(f"{'Mean Cβ dist (Å)':<30} {bm['mean_cb_dist']:>12.1f} {sm['mean_cb_dist']:>12.1f} {sm['mean_cb_dist'] - bm['mean_cb_dist']:>+8.1f}")
print(f"{'Contacts (<6Å)':<30} {bm['n_contacts_6A']:>12} {sm['n_contacts_6A']:>12} {sm['n_contacts_6A'] - bm['n_contacts_6A']:>+8}")
print(f"{'Contacts (<8Å)':<30} {bm['n_contacts_8A']:>12} {sm['n_contacts_8A']:>12} {sm['n_contacts_8A'] - bm['n_contacts_8A']:>+8}")
print(f"{'Region pLDDT':<30} {bm['region_plddt']:>12.2f} {sm['region_plddt']:>12.2f} {sm['region_plddt'] - bm['region_plddt']:>+8.2f}")
print(f"{'Overall pLDDT':<30} {bm['overall_plddt']:>12.2f} {sm['overall_plddt']:>12.2f} {sm['overall_plddt'] - bm['overall_plddt']:>+8.2f}")

Metric                             Baseline      Steered        Δ
----------------------------------------------------------------
Potential H-bonds                         0           13      +13
H-bond fraction                        0.00         0.81    +0.81
Mean N-O dist (Å)                       9.2          2.8     -6.4
Mean Cβ dist (Å)                       10.7          5.2     -5.5
Contacts (<6Å)                            1           12      +11
Contacts (<8Å)                            3           16      +13
Region pLDDT                           0.69         0.50    -0.19
Overall pLDDT                          0.68         0.50    -0.18


## Contact Disruption on Minimal Beta Hairpin

Use the same GB1 hairpin peptide from the charge steering notebook.
Push cross-strand pairs apart by steering toward a large target distance (25 \u00c5)
using the pairwise distance probe intervention on z.

In [None]:
# Same minimal beta-hairpin (GB1 hairpin) as the charge steering notebook
rep_seq = "GEWTYDDATKTFTVTE"  # 16 residues

# Define hairpin topology (same as charge notebook)
rep_topology = define_hairpin_topology(region_start=1, region_end=15, turn_length=4)

# Disruption parameters — large target distance pushes strands apart
rep_win_start = 22
rep_win_end   = 31
rep_mag       = 10        # magnitude in σ
rep_target    = 25.0      # target Cβ-Cβ distance (Å) — far apart

print(f"Sequence:  {rep_seq}")
print(f"Length:    {len(rep_seq)}")
print(f"Strand 1: [{rep_topology.strand1_start}:{rep_topology.strand1_end})")
print(f"Turn:     [{rep_topology.turn_start}:{rep_topology.turn_end})")
print(f"Strand 2: [{rep_topology.strand2_start}:{rep_topology.strand2_end})")
print(f"Cross-strand pairs: {rep_topology.cross_strand_pairs}")
print(f"Window: blocks {rep_win_start}-{rep_win_end}, magnitude={rep_mag}σ, target={rep_target}Å")

Sequence:  GEWTYDDATKTFTVTE
Length:    16
Strand 1: [1:6)
Turn:     [6:10)
Strand 2: [10:15)
Cross-strand pairs: [(1, 14), (2, 13), (3, 12), (4, 11), (5, 10)]
Window: blocks 22-31, magnitude=10σ, target=25.0Å


In [None]:
with torch.no_grad():
    rep_bl_out, rep_bl_collector = z_get_baseline(model, tokenizer, device, rep_seq)
rep_bl_pdb = model.output_to_pdb(rep_bl_out)[0]

rep_bl_metrics = compute_comprehensive_metrics(rep_bl_out, rep_topology, rep_target)
rep_bl_hbonds = get_hbond_pairs(rep_bl_out, rep_topology)

print(f"Baseline: {rep_bl_metrics['n_potential_hbonds']} H-bonds, "
      f"mean Cβ dist {rep_bl_metrics['mean_cb_dist']:.1f}Å, "
      f"contacts(6Å)={rep_bl_metrics['n_contacts_6A']}/{len(rep_topology.cross_strand_pairs)}")

show_protein(rep_bl_pdb, rep_topology)

Baseline: 3 H-bonds, mean Cβ dist 5.7Å, contacts(6Å)=3/5


<py3Dmol.view at 0x7e1531ae2080>

In [None]:
rep_blocks = set(range(rep_win_start, rep_win_end + 1))

rep_z_interventions = create_targeted_steering_intervention(
    topology=rep_topology,
    seq_len=len(rep_seq),
    directions=directions,
    blocks=list(rep_blocks),
    baseline_z_list=rep_bl_collector.z_blocks,
    target_distance=rep_target,
    magnitude=rep_mag,
    device=device,
)

with torch.no_grad():
    rep_st_out = run_with_intervention(
        model, tokenizer, device, rep_seq,
        rep_blocks, rep_z_interventions,
    )
rep_st_pdb = model.output_to_pdb(rep_st_out)[0]

rep_st_metrics = compute_comprehensive_metrics(rep_st_out, rep_topology, rep_target)
rep_st_hbonds = get_hbond_pairs(rep_st_out, rep_topology)

dist_change  = rep_st_metrics['mean_cb_dist'] - rep_bl_metrics['mean_cb_dist']
hbond_change = rep_st_metrics['n_potential_hbonds'] - rep_bl_metrics['n_potential_hbonds']

print(f"After disruption (target={rep_target}Å, mag={rep_mag}σ):")
print(f"  Mean Cβ dist: {rep_st_metrics['mean_cb_dist']:.2f} Å  (Δ = {dist_change:+.2f} Å)")
print(f"  H-bonds:      {rep_st_metrics['n_potential_hbonds']}  (Δ = {hbond_change:+d})")

show_protein(rep_st_pdb, rep_topology)

After disruption (target=25.0Å, mag=10σ):
  Mean Cβ dist: 18.28 Å  (Δ = +12.55 Å)
  H-bonds:      0  (Δ = -3)


<py3Dmol.view at 0x7e153214e5f0>

In [None]:
show_side_by_side(
    rep_bl_pdb, rep_st_pdb, rep_topology,
    hb_left=rep_bl_hbonds, hb_right=rep_st_hbonds,
    label_left="Baseline (hairpin)",
    label_right=f"Disrupted (target={rep_target}Å, mag={rep_mag}σ)",
)

Left: Baseline (hairpin)  |  Right: Disrupted (target=25.0Å, mag=10σ)


In [None]:
bm, sm = rep_bl_metrics, rep_st_metrics
delta = "\u0394"

print(f"{'Metric':<30} {'Baseline':>12} {'Disrupted':>12} {delta:>8}")
print("-" * 64)
print(f"{'Potential H-bonds':<30} {bm['n_potential_hbonds']:>12} {sm['n_potential_hbonds']:>12} {sm['n_potential_hbonds'] - bm['n_potential_hbonds']:>+8}")
print(f"{'H-bond fraction':<30} {bm['hbond_fraction']:>12.2f} {sm['hbond_fraction']:>12.2f} {sm['hbond_fraction'] - bm['hbond_fraction']:>+8.2f}")
print(f"{'Mean N-O dist (Å)':<30} {bm['mean_no_dist']:>12.1f} {sm['mean_no_dist']:>12.1f} {sm['mean_no_dist'] - bm['mean_no_dist']:>+8.1f}")
print(f"{'Mean Cβ dist (Å)':<30} {bm['mean_cb_dist']:>12.1f} {sm['mean_cb_dist']:>12.1f} {sm['mean_cb_dist'] - bm['mean_cb_dist']:>+8.1f}")
print(f"{'Contacts (<6Å)':<30} {bm['n_contacts_6A']:>12} {sm['n_contacts_6A']:>12} {sm['n_contacts_6A'] - bm['n_contacts_6A']:>+8}")
print(f"{'Contacts (<8Å)':<30} {bm['n_contacts_8A']:>12} {sm['n_contacts_8A']:>12} {sm['n_contacts_8A'] - bm['n_contacts_8A']:>+8}")
print(f"{'Region pLDDT':<30} {bm['region_plddt']:>12.2f} {sm['region_plddt']:>12.2f} {sm['region_plddt'] - bm['region_plddt']:>+8.2f}")
print(f"{'Overall pLDDT':<30} {bm['overall_plddt']:>12.2f} {sm['overall_plddt']:>12.2f} {sm['overall_plddt'] - bm['overall_plddt']:>+8.2f}")

Metric                             Baseline    Disrupted        Δ
----------------------------------------------------------------
Potential H-bonds                         3            0       -3
H-bond fraction                        0.60         0.00    -0.60
Mean N-O dist (Å)                       4.6         16.3    +11.7
Mean Cβ dist (Å)                        5.7         18.3    +12.6
Contacts (<6Å)                            3            0       -3
Contacts (<8Å)                            3            0       -3
Region pLDDT                           0.78         0.64    -0.14
Overall pLDDT                          0.76         0.62    -0.14
