  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Kuhlman-Lab/PIPPack/blob/main/notebooks/PIPPack.ipynb)

# This notebook includes the ability to pack the side chains of a protein given its PDB code or an uploaded PDB file and resample the predicted side chains.

## **Step 1: Check Colab Settings**

PIPPack does not require the use of a GPU, but predictions will be faster with a GPU. To check that GPU use is enabled:
- Use the Colab settings bar above to navigate to `Runtime` -> `Change runtime type`
- Make sure that `Runtime type` is set to `Python 3`
- Make sure that `Hardware accelerator` is set to `GPU`
- Click `Save` to confirm

In [1]:
%%capture
# @title ## **Step 2: Set up PIPPack**
# @markdown Import PIPPack and its dependencies to this session. This may take a minute or two.

# @markdown You only need to do this once *per session*. To re-run PIPPack on a new protein, with a new sequence, or with another model, you can start on Step 3.

# Cleaning out any remaining data
!rm -rf /content/PIPPack

# Making sure Colab can access the GitHub repo
import os
if not os.path.exists("/content/PIPPack"):
  ! git clone https://github.com/Kuhlman-Lab/PIPPack.git
  %cd /content/PIPPack

# Downloading various dependencies
! pip install omegaconf lightning==2.0.1 biopython nglview hydra-core torch_geometric

In [None]:
# %%capture
#@title ## **Step 3: Input Data and Prediction Options**

#@markdown ### **INPUT SETTINGS**

# -------- Collecting Settings for PIPPack run --------- #

from google.colab import files
import os
import sys
from urllib import request
from urllib.error import HTTPError


def download_pdb(pdbcode, datadir, downloadurl="https://files.rcsb.org/download/"):
    """
    Downloads a PDB file from the Internet and saves it in a data directory.
    :param pdbcode: The standard PDB ID e.g. '3ICB' or '3icb'
    :param datadir: The directory where the downloaded file will be saved
    :param downloadurl: The base PDB download URL, cf.
        `https://www.rcsb.org/pages/download/http#structures` for details
    :return: the full path to the downloaded PDB file or None if something went wrong
    """

    pdbfn = pdbcode + ".pdb"
    url = downloadurl + pdbfn
    outfnm = os.path.join(datadir, pdbfn)
    try:
        request.urlretrieve(url, outfnm)
        return outfnm
    except Exception as err:
        print(str(err), file=sys.stderr)
        return None

#@markdown You may either specify a PDB code to fetch or upload a custom PDB file.

#@markdown PDB code (e.g., 1PGA):
pdb = "1PGA" # @param {type: "string"}

#@markdown Upload custom PDB?
custom_pdb = True # @param {type: "boolean"}
#@markdown If enabled, a `Choose files` button will appear once this cell is run.

if custom_pdb:
  print('PDB Upload:')
  uploaded_pdb = files.upload()
  for fn in uploaded_pdb.keys():
    pdb = os.path.basename(fn)
    if not pdb.endswith('.pdb'):
      raise ValueError(f"Uploaded file {pdb} does not end in '.pdb'. Please check and rename file as needed.")
    os.rename(fn, os.path.join("/content/", pdb))
    pdb_file = os.path.join("/content/", pdb)
else:
  try:
    fn = download_pdb(pdb, "/content/")
    if fn is None:
      raise ValueError("Failed to fetch PDB from RSCB. Please double-check PDB code and try again.")
    else:
      pdb_file = fn
  except HTTPError:
    raise HTTPError(f"No protein with code {pdb} exists in RSCB PDB. Please double-check PDB code and try again.")

#@markdown Chain(s) of interest (e.g., A or A,B):
chain = "" # @param {type:"string"}

#@markdown If left blank, PIPPack will pack all chains in the PDB. To pack just a subset of chains, input the chain ids as a comma-separated list (e.g., A,B)

#@markdown Sequence to pack:
sequence = "" # @param {type:"string"}

#@markdown If left blank, the native sequence in the PDB will be repacked. If specifying multiple chains, separate the chain sequences with a '/'. The length of the sequence(s) MUST match the length of the chain(s) in the PDB file.

##custom_fasta = False # @param {type: "boolean"}
###@markdown If enabled, a `Choose files` button will appear once this cell is run. PIPPack will pack all sequences found in the fasta file as long as their length matches the length of the structure in the PDB. If specifying multiple chains, separate the chain sequences with a '/'.

## @markdown Generate ProteinMPNN sequences:
#n_proteinmpnn_seqs = 0 # @param {type: "integer"}

##@markdown Sequence sampling temperature:
##seq_temperature = 0.2 # @param {type:"number"}

## @markdown Note the ProteinMPNN model used in this notebook is the v_48_010 model from the official repo: https://github.com/dauparas/ProteinMPNN.

if custom_fasta:
  print('Fasta Upload:')
  uploaded_fasta = files.upload()
  for fn in uploaded_fasta.keys():
    fasta = os.path.basename(fn)
    if not fasta.endswith('.fasta'):
      raise ValueError(f"Uploaded file {fasta} does not end in '.fasta'. Please check and rename file as needed.")
    os.rename(fn, os.path.join("/content/", fasta))
    fasta_file = os.path.join("/content/", fasta)

    with open(fasta_file, 'r') as f:
      lines = f.readlines()
    seqs = [line.strip() for line in lines if line[0] != ">" and line]
else:
 seqs = [sequence]
seqs = [seq.split('/') for seq in seqs]

convert_mse_to_met = True # @param {type: "boolean"}
#@markdown If enabled, any MSE residue will be parsed as MET. If disabled, MSE residues are ignored.


#@markdown ### **MODEL SETTINGS**

#@markdown Model to use:
model = "PIPPack (ensemble)" # @param ["PIPPack (model 1)", "PIPPack (model 2)", "PIPPack (model 3)", "PIPPack (ensemble)"]

#@markdown Side chain sampling temperature:
chi_temperature = 0.0 # @param {type:"number"}
# @markdown Higher temperature results in more diversity of conformations, but also conformations that the model is less confident about. Use T=0.0 for best results.

# @markdown Number of recycles:
recycles = 3 # @param {type: "integer"}

# @markdown Random seed:
use_seed = True # @param {type: "boolean"}
seed = 3445 # @param {type: "integer"}
# @markdown Enabling seed by checking "use_seed" will use the "seed" value for random number generation. Disabling will result in a random seed.

#@markdown Resampling Arguments:
use_resampling = True #@param {type:"boolean"}
sample_temp = 0.1 # @param {type:"number"}
clash_overlap_tolerance = 0.4 # @param {type:"number"}
pro_tolerance_factor = 12 # @param {type:"integer"}
max_iters = 50 # @param {type:"integer"}
metropolis_temp = 0.000005 # @param {type:"number"}
#@markdown Note these arguments pertain to the resampling procedure that is applied after PIPPack prediction to reduce the amount of clashing residues. These defaults are sensible ones used in the results of the paper. Feel free to experiment with other settings, and share if you find anything interesting!

resample_args = {
    "sample_temp": sample_temp,
    "clash_overlap_tolerance": clash_overlap_tolerance,
    "pro_tolerance_factor": pro_tolerance_factor,
    "max_iters": max_iters,
    "metropolis_temp": metropolis_temp,
}

In [None]:
#@title ## **Step 4: Run PIPPack**

# ---------- Model Name Parsing ---------- #
import pickle

if model == "PIPPack (model 1)":
  model_names = ["pippack_model_1"]
elif model == "PIPPack (model 2)":
  model_names = ["pippack_model_2"]
elif model == "PIPPack (model 3)":
  model_names = ["pippack_model_3"]
else:
  model_names = ["pippack_model_1", "pippack_model_2", "pippack_model_3"]

# Parse one of the configs to get n_chi_bins
with open(f'/content/PIPPack/model_weights/{model_names[0]}_config.pickle', 'rb') as f:
  cfg = pickle.load(f)
n_chi_bins = cfg.model.n_chi_bins


# ---------- PDB Parsing and Dataset Creation ---------- #
import sys

sys.path.append('/content/PIPPack')

from data.protein import from_pdb_file
from data.top2018_dataset import transform_structure, collate_fn
from inference import replace_protein_sequence

# Load the appropriate chains from the PDB file
chains = chain.strip().split(',')
if chains == ['']:
  chains = None
protein = vars(from_pdb_file(pdb_file, chain_id=chains, mse_to_met=convert_mse_to_met))

# Generate or swap sequence, if necessary
if n_proteinmpnn_seqs > 0:
  import torch
  from proteinmpnn.model_utils import ProteinMPNN
  from utils.train_utils import load_checkpoint
  import data.residue_constants as rc

  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  # Don't reload sequence design model if rerunning cell
  if 'seq_model' not in locals():
    n_gen_seqs = n_proteinmpnn_seqs
    seq_model_ckpt = '/content/proteinmpnn_ckpt.pt'
    request.urlretrieve('https://github.com/dauparas/ProteinMPNN/raw/main/vanilla_model_weights/v_48_010.pt', seq_model_ckpt)
    seq_model = ProteinMPNN(k_neighbors=48, augment_eps=0.0, use_ipmp=False).to(device)
    seq_model_ckpt = torch.load(seq_model_ckpt, map_location='cpu')
    seq_model.load_state_dict(seq_model_ckpt['model_state_dict'])

  # Reshape protein for sequence design
  seqs = []
  batch = transform_structure(protein)
  batch = collate_fn([batch]).to(device)
  for i in range(n_gen_seqs):
    randn = torch.randn(batch.S.shape, device=device)
    sample_out = seq_model.sample(batch.X, randn, batch.S, torch.ones_like(batch.S), torch.zeros_like(batch.S), batch.residue_index, batch.residue_mask, temperature=seq_temperature)
    seq = sample_out['S'].squeeze(0)
    seqs.append([''.join([rc.restypes_with_x[s] for s in seq])])
    print(f'Generated sequence with ProteinMPNN: {seqs[-1]}')

if seqs != [['']]:
  # Replace any Xs that were generated
  seqs = [[s.replace('X', 'G') for s in seq] for seq in seqs]
  proteins = replace_protein_sequence(protein, os.path.basename(pdb_file)[:-4], seqs)
else:
  proteins = [(os.path.basename(pdb_file)[:-4], protein)]

# Transform proteins
proteins = [(protein[0], transform_structure(protein[1], n_chi_bins, sc_d_mask_from_seq=True)) for protein in proteins]


# ---------- Model Loading ------------- #
import hydra
import torch
from utils.train_utils import load_checkpoint

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Don't reload model if rerunning cell
if 'models' not in locals():
  models = []
  for model_name in model_names:
    cfg_file = f'/content/PIPPack/model_weights/{model_name}_config.pickle'
    ckpt_file = f'/content/PIPPack/model_weights/{model_name}_ckpt.pt'

    with open(cfg_file, 'rb') as f:
      cfg = pickle.load(f)

    m = hydra.utils.instantiate(cfg.model).to(device)
    load_checkpoint(ckpt_file, m)

    models.append(m)

# ---------- Inference -------- #
import time
import lightning
from inference import pdbs_from_prediction
from ensembled_inference import sample_epoch
from model.modules import get_atom14_coords
from model.resampling import resample_loop
import warnings

# Seed for predictions
if not use_seed:
  seed = None
# Clear the global seed set by PL
if "PL_GLOBAL_SEED" in os.environ:
  os.environ.pop("PL_GLOBAL_SEED")
with warnings.catch_warnings():
  warnings.simplefilter('ignore')
  _ = lightning.seed_everything(seed)

predictions = {}
for protein_item in proteins:
  # Unpack batch
  pdb_name = protein_item[0]
  protein = protein_item[1]

  # Collate the batch
  batch = collate_fn([protein])

  # Perform inference
  t0 = time.time()
  sample_results = sample_epoch(models, batch, chi_temperature, device, n_recycle=recycles)

  # Perform resampling, if desired
  if use_resampling:
    for i in range(batch.S.shape[0]):
      # Get the protein components.
      temp_protein = {
          "S": sample_results["S"][i],
          "X": sample_results["X"][i],
          "X_mask": sample_results["X_mask"][i],
          "BB_D": sample_results["BB_D"][i],
          "residue_index": sample_results["residue_index"][i],
          "residue_mask": sample_results["residue_mask"][i],
          "chi_logits": sample_results["chi_logits"][i],
          "chi_bin_offset": sample_results["chi_bin_offset"][i] if "chi_bin_offset" in sample_results else None,
      }
      pred_xyz = sample_results["final_X"][i]

      # Perform resampling
      resample_xyz, _ = resample_loop(temp_protein, pred_xyz, **resample_args)

      # Update the coordinates
      sample_results["final_X"][i] = resample_xyz
  print(f'Packed {"and resampled " if use_resampling else ""}{pdb_name} ({batch.S.numel()} residues) using {model} in {time.time() - t0:.3f} sec.')

  # Store prediction and processed pdb string
  pdb_str = pdbs_from_prediction(sample_results)[0]
  predictions[pdb_name] = {'results': sample_results, 'pdb_str': pdb_str}


In [None]:
# @title ## **Step 5: Visualize Packed Nanobody**
!pip install py3Dmol > /dev/null
import py3Dmol
from pathlib import Path
import math

# Configuration parameters
Output_pdb_name = 'design_15_model'  # @param {type:"string"}
#hotspot_res = None  # @param {type:"string"}  # Inserisci i residui hotspot separati da virgola

# Load PDB file
pdb_file_path = f"../{Output_pdb_name}.pdb"
pdb_content = Path(pdb_file_path).read_text()

# Dizionario per mappare i codici degli amminoacidi da tre lettere a una lettera
AA_THREE_TO_ONE = {
    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C',
    'GLN': 'Q', 'GLU': 'E', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
    'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P',
    'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
}

# Parse interacting residues between chain H (nanobody) and chain T (target)
atoms_H = []  # H chain (nanobody heavy chain)
atoms_T = []  # T chain (target)
sequences = {}  # Per memorizzare le sequenze delle catene

for L in pdb_content.splitlines():
    if L.startswith(("ATOM  ","HETATM")):
        chain = L[21]
        resi  = int(L[22:26])
        resn  = L[17:20].strip()
        x, y, z = map(float, (L[30:38], L[38:46], L[46:54]))

        # Aggiungi alla sequenza
        if chain not in sequences:
            sequences[chain] = {}
        if resi not in sequences[chain]:
            sequences[chain][resi] = resn

        if chain == 'H':  # Nanobody heavy chain
            atoms_H.append((resi, resn, chain, (x, y, z)))
        elif chain == 'T':  # Target chain
            atoms_T.append((resi, resn, chain, (x, y, z)))

# Stampa le sequenze primarie
print("🔵 Primary Sequences:")
for chain in sorted(sequences.keys()):
    print(f"Chain {chain}:")
    residues = sorted(sequences[chain].items())
    seq = ""
    for resi, resn in residues:
        # Converti da codice a tre lettere a una lettera
        one_letter = AA_THREE_TO_ONE.get(resn, 'X')
        seq += one_letter

    # Stampa la sequenza formattata (max 80 caratteri per riga)
    for i in range(0, len(seq), 80):
        print(f"   {seq[i:i+80]}")
    print()

# Find residues in chain H within 6Å of chain T
interacting_residues = set()
for resi_H, resn_H, chain_H, coords_H in atoms_H:
    for resi_T, resn_T, chain_T, coords_T in atoms_T:
        if math.dist(coords_H, coords_T) <= 6.0:
            interacting_residues.add(resi_H)
            break

# Print interacting residues
if interacting_residues:
    print("🟡 Interacting residues on chain H (within 6Å of chain T):")
    # Get residue names for interacting residues
    residue_names = {}
    for resi, resn, chain, coords in atoms_H:
        if resi in interacting_residues:
            residue_names[resi] = resn

    for resi in sorted(interacting_residues):
        one_letter = AA_THREE_TO_ONE.get(residue_names[resi], 'X')
        print(f"  {one_letter}{resi} ({residue_names[resi]}{resi})")
else:
    print("⚠️ No H-chain residues found within 6Å of chain T")

# Create visualization
view = py3Dmol.view(width=800, height=600)
view.addModel(pdb_content, 'pdb')

# Style the visualization
view.setStyle({'chain':['H']}, {'cartoon': {'color': 'steelblue'}})  # Heavy chain
view.setStyle({'chain':['L']}, {'cartoon': {'color': 'forestgreen'}})  # Light chain
view.addSurface(py3Dmol.VDW, {'color': 'lightgrey', 'opacity': 0.75}, {'chain': 'T'})  # Target chain

# Add hotspot residues
if hotspot_res:
    hotspot_residues = [int("".join(filter(str.isdigit, r))) for r in hotspot_res.split(',')]
    view.addStyle({'chain':'T','resi': hotspot_residues},  # Target chain
                  {'sphere': {'color': 'red', 'radius': 1.5}})

# Highlight entire interacting residues on chain H (not just atoms within threshold)
if interacting_residues:
    view.addStyle({'chain': 'H', 'resi': list(interacting_residues)},
                  {'stick': {'color': 'yellow'}})

view.zoomTo()
view.show()

# Print legend
print("\nLegend:")
print("🟦 Antibody (chain H, heavy chain) - Steel blue")
print("🟩 Antibody (chain L, light chain) - Forest green")
print("⬜ Target (chain T) - Light grey surface")
print("🔴 Epitope Hotspots - Red spheres")
print("🟡 Interacting Residues (entire residue) - Yellow sticks")

In [8]:
# @title ## **Step 6: Save Predictions**

# ---------- Save the packed protein ---------- #
# @markdown You can save the packed proteins using the `pdb_name`. That is, if you saw "Packed <u>1PGA</u> (56 residues) using PIPPack (ensemble) in  0.856 sec." as an output from Step 4, then `pdb_name='1PGA'`, as indicated by the underline.

# @markdown The output file will have the name: "{pdb_name}_{model_name}.pdb" and can be accessed by clicking the folder icon on the left.


# @markdown ### **Individual Download**
# Get the input PDB name for output file naming
input_pdb_name = os.path.basename(pdb_file).replace('.pdb', '')
pdb_name = input_pdb_name # @param {type:"string"}
# @markdown Do not be afraid if running this cell results in a `ValueError`. Read the message as it will explain what the valid options for `pdb_name` are.

# @markdown ### **Download all**
# @markdown Enable this option and run this cell to download all of the predictions.

download_all = True # @param {type: "boolean"}

if not download_all and pdb_name not in predictions:
  raise ValueError(f"pdb_name not a valid option. Try one of {list(predictions.keys())}.")
else:
  output_files = [pdb_name]

for pdb in output_files:
  with open(f"/content/{pdb}_{model.replace('(', '').replace(')', '').replace(' ', '_').lower()}.pdb", 'w') as f:
    f.write(predictions[pdb]['pdb_str'])

### **License**
The source code for PIPPack, including licensing information, can be found [here](https://github.com/Kuhlman-Lab/PIPPack).

### **Citation Information**

If you use PIPPack in your own research, please cite [this preprint](https://www.biorxiv.org/content/10.1101/2023.08.03.551328):

Randolph NZ, Kuhlman B. Invariant point message passing for protein side chain packing. BioRxiv. August 3, 2023. doi:10.1101/2023.08.03.551328