## Setup and Imports

In [23]:
import torch
import pandas as pd
from transformers import EsmForProteinFolding, AutoTokenizer
from transformers.models.esm.modeling_esmfold import EsmFoldingTrunk
import sys
import py3Dmol
from pathlib import Path

ROOT = Path.cwd().parent
sys.path.append(str(ROOT))

from src.module_patching import (
    run_and_collect,
    patch_trunk_all_blocks,
    create_pairwise_mask,
)
from demo_utils import show_protein

import warnings
warnings.filterwarnings("ignore", message=".*mmCIF.*")
warnings.filterwarnings("ignore", category=UserWarning, module="Bio.PDB.DSSP")

## Load Model

In [24]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").to(device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
print(f"Model loaded on {device}")

Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded on cuda


## Example Case

In [25]:
case = pd.read_csv(ROOT / "demo_notebooks" / "demo_full_patching.csv").iloc[0]

target_seq = case["target_sequence"]
donor_seq = case["donor_sequence"]
target_start = int(case["target_start"])
target_end = int(case["target_end"])
donor_start = int(case["donor_start"])
donor_end = int(case["donor_end"])
mask_mode = case["patch_mask_mode"]

print(f"Target: {case['target_name']}, loop {case['loop_idx']}")
print(f"Donor: {case['donor_pdb']}")
print(f"Target patch region: [{target_start}:{target_end})")
print(f"Donor hairpin region: [{donor_start}:{donor_end})")
print(f"Mask mode: {mask_mode}")

Target: 1P68, loop 0
Donor: 1jof
Target patch region: [13:34)
Donor hairpin region: [119:140)
Mask mode: intra


## Example Target Alpha Helical Protein


In [26]:
##1. Original Alpha Target
with torch.no_grad():
    target_inputs = tokenizer(target_seq, return_tensors='pt', add_special_tokens=False).to(device)
    
    target_outputs = model(**target_inputs, num_recycles=0)

target_pdb = model.output_to_pdb(target_outputs)[0]

In [27]:
show_protein(target_pdb, target_start, target_end)

<py3Dmol.view at 0x7dc28262e860>

## Extract Donor Hairpin Representations

In [28]:
# 3. Collect donor representations for patching
donor_outputs, donor_collected = run_and_collect(
    model, tokenizer, device, donor_seq,
    collect_esm=True, collect_trunk=True,
    collect_s=True, collect_z=True, collect_ipa=True,
)

donor_s_region = {k: v[:, donor_start:donor_end, :] for k, v in donor_collected.s_blocks.items()}
donor_z_full = donor_collected.z_blocks

donor_pdb = model.output_to_pdb(donor_outputs)[0]

In [29]:
show_protein(donor_pdb, donor_start, donor_end)

<py3Dmol.view at 0x7dc22fea3460>

## Run forward donor pass (patch sequence and pairwise representations at all blocks)

In [30]:
pairwise_mask = create_pairwise_mask(
    donor_start, donor_end, len(donor_seq),
    target_start, target_end, len(target_seq),
    mode=mask_mode,
)

# 4. Patched run
with patch_trunk_all_blocks(
    model, donor_s_region, donor_z_full,
    target_start, target_end, donor_start,
    pairwise_mask, "both"
):
    with torch.no_grad():
        inputs = tokenizer(target_seq, return_tensors='pt', add_special_tokens=False).to(device)
        patched_outputs = model(**inputs, num_recycles=0)

patched_pdb = model.output_to_pdb(patched_outputs)[0]


In [31]:
show_protein(patched_pdb, target_start, target_end)

<py3Dmol.view at 0x7dc22fea0550>

### Show side by side

In [32]:
view = py3Dmol.view(width=800, height=400, viewergrid=(1, 3))

configs = [
    (target_pdb, target_start, target_end, "Original Target"),
    (donor_pdb, donor_start, donor_end, "Donor Hairpin"),
    (patched_pdb, target_start, target_end, "Patched Result"),
]

for col, (pdb_str, hl_start, hl_end, title) in enumerate(configs):
    view.addModel(pdb_str, "pdb", viewer=(0, col))
    view.setStyle({"cartoon": {"color": "lightgray", "opacity": 0.8, "arrows" : True}}, viewer=(0, col))
    view.addStyle(
        {"resi": list(range(hl_start + 1, hl_end + 1))},
        {"cartoon": {"color": "purple", "opacity": 1.0}},
        viewer=(0, col),
    )
    view.zoomTo(viewer=(0, col))

view.show()

In [33]:
from src.module_patching import evaluate_hairpin



target_eval = evaluate_hairpin(target_outputs, model, target_start, target_end)
patched_eval = evaluate_hairpin(patched_outputs, model, target_start, target_end)

print(f"{'Metric':<30} {'Original':>12} {'Patched':>12}")
print("-" * 56)
print(f"{'Hairpin found':<30} {str(target_eval['hairpin_found']):>12} {str(patched_eval['hairpin_found']):>12}")
print(f"{'Mean pLDDT':<30} {target_eval['mean_plddt']:>12.2f} {patched_eval['mean_plddt']:>12.2f}")
print(f"{'Patch region pLDDT':<30} {target_eval['patch_region_plddt']:>12.2f} {patched_eval['patch_region_plddt']:>12.2f}")
if target_eval['ptm'] is not None:
    print(f"{'pTM':<30} {target_eval['ptm']:>12.3f} {patched_eval['ptm']:>12.3f}")

Metric                             Original      Patched
--------------------------------------------------------
Hairpin found                         False         True
Mean pLDDT                             0.68         0.54
Patch region pLDDT                     0.70         0.58
pTM                                   0.726        0.407


## Patching a single block

## Example Cases

In [34]:
sb_results = pd.read_csv(ROOT / "demo_notebooks" / "demo_block_patching.csv")

seq_row = sb_results[sb_results["patch_mode"] == "sequence"].iloc[0]
pair_row = sb_results[sb_results["patch_mode"] == "pairwise"].iloc[0]

seq_block = int(seq_row["block_idx"])
pair_block = int(pair_row["block_idx"])

print(f"Sequence patch: block {seq_block}, donor {seq_row['donor_pdb']}")
print(f"Pairwise/intra patch: block {pair_block}, donor {pair_row['donor_pdb']}")

Sequence patch: block 0, donor 9gtx
Pairwise/intra patch: block 30, donor 7jzs


## Extract donor representations

In [35]:
from src.block_patching import (
    run_and_collect as sb_run_and_collect,
    patch_trunk_single_block,
    create_pairwise_mask as sb_create_pairwise_mask,
)

# Use the sequence-patching case (can reuse donor from all-block if same donor)
sb_donor_seq = seq_row["donor_sequence"]
sb_target_seq = seq_row["target_sequence"]
sb_target_start = int(seq_row["target_start"])
sb_target_end = int(seq_row["target_end"])
sb_donor_start = int(seq_row["donor_hairpin_start"])
sb_donor_end = int(seq_row["donor_hairpin_end"])

# Collect donor trunk representations
sb_donor_outputs, sb_donor_collected = sb_run_and_collect(model, tokenizer, device, sb_donor_seq)

sb_donor_s_region = {
    k: v[:, sb_donor_start:sb_donor_end, :]
    for k, v in sb_donor_collected.s_blocks.items()
}
sb_donor_z_full = sb_donor_collected.z_blocks

sb_pairwise_mask_intra = sb_create_pairwise_mask(
    sb_donor_start, sb_donor_end, len(sb_donor_seq),
    sb_target_start, sb_target_end, len(sb_target_seq),
    mode="intra",
)

print("Donor representations collected")
print(f"  Blocks available: {len(sb_donor_s_region)}")

sb_donor_pdb = model.output_to_pdb(sb_donor_outputs)[0]

Donor representations collected
  Blocks available: 48


In [36]:
show_protein(sb_donor_pdb, sb_donor_start, sb_donor_end)

<py3Dmol.view at 0x7dc22fef87f0>

## Show target protein

In [37]:
##1. Original Alpha Target
with torch.no_grad():
    target_inputs = tokenizer(sb_target_seq, return_tensors='pt', add_special_tokens=False).to(device)
    
    target_outputs = model(**target_inputs, num_recycles=0)

target_pdb = model.output_to_pdb(target_outputs)[0]

In [38]:
show_protein(target_pdb, sb_target_start, sb_target_end)

<py3Dmol.view at 0x7dc282b1d750>

## Single block sequence patch

In [39]:
with patch_trunk_single_block(
    model, sb_donor_s_region, sb_donor_z_full,
    sb_target_start, sb_target_end, sb_donor_start,
    sb_pairwise_mask_intra, "sequence", seq_block,
):
    with torch.no_grad():
        inputs = tokenizer(sb_target_seq, return_tensors='pt', add_special_tokens=False).to(device)
        seq_patch_outputs = model(**inputs, num_recycles=0)

seq_patch_pdb = model.output_to_pdb(seq_patch_outputs)[0]
print(f"Sequence patch at block {seq_block} done")

Sequence patch at block 0 done


In [40]:
show_protein(seq_patch_pdb, sb_target_start, sb_target_end)

<py3Dmol.view at 0x7dc22ce32fe0>

In [41]:
with patch_trunk_single_block(
    model, sb_donor_s_region, sb_donor_z_full,
    sb_target_start, sb_target_end, sb_donor_start,
    sb_pairwise_mask_intra, "pairwise", pair_block,
):
    with torch.no_grad():
        inputs = tokenizer(sb_target_seq, return_tensors='pt', add_special_tokens=False).to(device)
        pair_patch_outputs = model(**inputs, num_recycles=0)

pair_patch_pdb = model.output_to_pdb(pair_patch_outputs)[0]
print(f"Pairwise/intra patch at block {pair_block} done")

Pairwise/intra patch at block 30 done


In [42]:
show_protein(pair_patch_pdb, sb_target_start, sb_target_end)

<py3Dmol.view at 0x7dc22fecba30>

## Side by side

In [43]:
view = py3Dmol.view(width=800, height=400, viewergrid=(1, 3))

configs = [
    (target_pdb, sb_target_start, sb_target_end, "Original Target"),
    (seq_patch_pdb, sb_target_start, sb_target_end, f"Seq patch (block {seq_block})"),
    (pair_patch_pdb, sb_target_start, sb_target_end, f"Pair/intra (block {pair_block})"),
]

for col, (pdb_str, hl_start, hl_end, title) in enumerate(configs):
    view.addModel(pdb_str, "pdb", viewer=(0, col))
    view.setStyle({"cartoon": {"color": "lightgray", "opacity": 0.8, 'arrows':True}}, viewer=(0, col))
    view.addStyle(
        {"resi": list(range(hl_start + 1, hl_end + 1))},
        {"cartoon": {"color": "purple", "opacity": 1.0}},
        viewer=(0, col),
    )
    view.zoomTo(viewer=(0, col))

view.show()

## Evaluate hairpins

In [44]:
from src.block_patching import evaluate_hairpin as sb_evaluate_hairpin

seq_eval = sb_evaluate_hairpin(seq_patch_outputs, model, sb_target_start, sb_target_end)
pair_eval = sb_evaluate_hairpin(pair_patch_outputs, model, sb_target_start, sb_target_end)

print(f"{'Metric':<30} {'Original':>12} {'Seq blockk '+str(seq_block):>12} {'Pair block '+str(pair_block):>12}")
print("-" * 68)
print(f"{'Hairpin found':<30} {str(target_eval['hairpin_found']):>12} {str(seq_eval['hairpin_found']):>12} {str(pair_eval['hairpin_found']):>12}")
print(f"{'Mean pLDDT':<30} {target_eval['mean_plddt']:>12.2f} {seq_eval['mean_plddt']:>12.2f} {pair_eval['mean_plddt']:>12.2f}")
print(f"{'Patch region pLDDT':<30} {target_eval['patch_region_plddt']:>12.2f} {seq_eval['patch_region_plddt']:>12.2f} {pair_eval['patch_region_plddt']:>12.2f}")

Metric                             Original Seq blockk 0 Pair block 30
--------------------------------------------------------------------
Hairpin found                         False         True        False
Mean pLDDT                             0.68         0.55         0.54
Patch region pLDDT                     0.70         0.54         0.45
