<a href="https://colab.research.google.com/github/jtrinquier/SoftAlign/blob/main/Structure_Search_SoftAlign.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# @title Import libraries
! pip install Bio
import jax
!  pip install git+https://github.com/deepmind/dm-haiku
import haiku as hk
import jax.numpy as jnp
from jax import vmap
import numpy as np
import time
import numpy as np
import time
import os
! git clone https://github.com/jtrinquier/SoftAlign.git
import sys
softalign_path = os.path.join(os.getcwd(), 'SoftAlign')

# Add SoftAlign directory to sys.path if it's not already there
if softalign_path not in sys.path:
    sys.path.append(softalign_path)
import ENCODING as enco
import Score_align as score_
import utils
import Input_MPNN as inp

Collecting Bio
  Downloading bio-1.8.0-py3-none-any.whl.metadata (5.7 kB)
Collecting biopython>=1.80 (from Bio)
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting gprofiler-official (from Bio)
  Downloading gprofiler_official-1.0.0-py3-none-any.whl.metadata (11 kB)
Collecting mygene (from Bio)
  Downloading mygene-3.2.2-py2.py3-none-any.whl.metadata (10 kB)
Collecting biothings-client>=0.2.6 (from mygene->Bio)
  Downloading biothings_client-0.4.1-py3-none-any.whl.metadata (10 kB)
Downloading bio-1.8.0-py3-none-any.whl (321 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m321.1/321.1 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m76.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gprofiler_official-1.0.0-py3-none-any.whl (9.3

In [2]:
import search

## 🔍 Input Options for Structure Preprocessing

You can choose one of the following two input sources:

1. **SCOPE Database (Precomputed Inputs)**  
   Enable the checkbox to download and load preprocessed structural inputs from the SCOPE database. This is useful for benchmarking or testing.

2. **Custom PDB Folder**  
   If you want to use your own protein structures, disable the SCOPE option and provide the path to your folder containing `.pdb` files. The script will process all PDBs in that folder using `Input_MPNN`. You should also provide a list of chain_ids (default A), format should be pdbname, chain_id
   





In [3]:
# @title Choose input source: SCOPE database or Custom PDB folder
use_scope_database = True  # @param {type:"boolean"}
pdb_folder_path = "pdb_files"  # @param {type:"string"}
chain_ids_file = ""  # @param {type:"string"}

import os
import pickle
import csv
import requests

def load_chain_ids(chain_file_path):
    """Load a dictionary of {pdb_filename: chain_id}"""
    chain_map = {}
    with open(chain_file_path, 'r') as f:
        reader = csv.reader(f, delimiter=",")
        for row in reader:
            if len(row) >= 2:
                pdb_name = row[0].strip()
                chain = row[1].strip()
                chain_map[pdb_name] = chain
    return chain_map

def process_pdb_folder(pdb_folder, chain_file=None, default_chain='A'):
    data = {}
    chain_map = load_chain_ids(chain_file) if chain_file else {}
    for filename in os.listdir(pdb_folder):
        if filename.endswith(".pdb"):
            pdb_path = os.path.join(pdb_folder, filename)
            pdb_key = filename.replace(".pdb", "")
            chain = chain_map.get(pdb_key, default_chain)
            try:
                coords, mask, chain_, res = inp.get_inputs_mpnn(pdb_path, chain=chain)
                data[filename] = (coords, mask, chain_, res)
                print(f"Processed {filename} using chain {chain}")
            except Exception as e:
                print(f"Error processing {pdb_path}: {e}")
                continue

    return data

if use_scope_database:
    try:
        import gdown
    except ImportError:
        !pip install -q gdown
        import gdown

    gdown.download(id="1DFWcUgPukTxWGPUxaeTM1kNEVNCkRgbO", output="dicti_inputs_SCOPE_colab", quiet=False)
    with open("dicti_inputs_SCOPE_colab", 'rb') as f:
        dicti_inputs = pickle.load(f)
    print("Loaded SCOPE database inputs.")
else:
    from Bio.PDB import PDBList

    dicti_inputs = process_pdb_folder(pdb_folder_path, chain_file=chain_ids_file)
    print("Processed custom PDB folder inputs.")



Downloading...
From (original): https://drive.google.com/uc?id=1DFWcUgPukTxWGPUxaeTM1kNEVNCkRgbO
From (redirected): https://drive.google.com/uc?id=1DFWcUgPukTxWGPUxaeTM1kNEVNCkRgbO&confirm=t&uuid=4e421af2-2641-4394-8ef8-f25ec988c264
To: /content/dicti_inputs_SCOPE_colab
100%|██████████| 236M/236M [00:01<00:00, 157MB/s]


Loaded SCOPE database inputs.


## Model Selection: Alignment Strategy

Select one of the two available alignment models for structural comparison:

1. **Smith-Waterman**  


2. **Softmax-Based**  


Use the dropdown menu to select your model, and the corresponding parameters will be loaded automatically.



In [4]:
# @title Choose Model Type
model_type = "Softmax"  # @param ["Smith-Waterman", "Softmax"]
params_path_sw = "./SoftAlign/models/CONT_SW_05_T_3_1"
params_path_sft = "./SoftAlign/models/CONT_SFT_06_T_3_1"

import pickle

if model_type == "Smith-Waterman":
    params_path = params_path_sw
elif model_type == "Softmax":
    params_path = params_path_sft
else:
    raise ValueError("Invalid model type selected.")

params = pickle.load(open(params_path, "rb"))
print(f"Loaded parameters for {model_type} model.")


Loaded parameters for Softmax model.


  params = pickle.load(open(params_path, "rb"))


# ENCODING

##  Encoding Structures Using MPNN

Before performing any search or alignment, we **embed all protein structures** using a Message Passing Neural Network (MPNN).

- This step transforms raw 3D coordinates and masks into **learned feature representations**.
- These embeddings are stored  and are reused during the search process.
- The encoding is **performed once**, which makes future computations faster and more efficient.




In [5]:
# @title Prepare data
# @title Default title text
key = jax.random.PRNGKey(0)

num_layers = 3
num_neighbors = 64
encoding_dim = 64
categorical = False
nb_clusters = 20


def enco_(x1,node_features = encoding_dim,
                 edge_features = encoding_dim, hidden_dim = encoding_dim,
                 num_encoder_layers=num_layers,
                  k_neighbors=num_neighbors,categorical = categorical,nb_clusters = nb_clusters):
  if categorical:
      a = enco.ENCODING_KMEANS_SEQ(node_features,edge_features,hidden_dim,num_encoder_layers,k_neighbors,nb_clusters = nb_clusters)

  else:

    a = enco.ENCODING(node_features,edge_features,hidden_dim,num_encoder_layers,k_neighbors)
  return a(x1)

ENCO = hk.transform(enco_)
@jax.jit
def enco_fast(params,key,input_data):
  return ENCO.apply(params,key,input_data)

X1s = []
mask1s = []
chain1s = []
res1s = []

id1s = []
l1 = []

for k in dicti_inputs.keys():
    pr1 = k
    _X1, _mask1, _chain1, _res1 = dicti_inputs[pr1]
    id1s.append(pr1)
    X1s.append(_X1[0])
    mask1s.append(_mask1[0])
    chain1s.append(_chain1[0])
    res1s.append(_res1[0])
    l1.append(len(_X1[0]))

max_len = max(l1)

# NOTE: If max_size is too large, consider splitting your data into smaller chunks

print(f"max_size set to: {max_len}")





max_size set to: 1419


In [6]:
# @title Create encodings
ENCOO = hk.transform(enco_)
encodings = []
# Set batch size
bs = 10
num_samples = len(X1s)

# Convert lists to numpy arrays
X1s = np.array(X1s,dtype=object)
mask1s = np.array(mask1s,dtype=object)
res1s = np.array(res1s,dtype=object)
chain1s = np.array(chain1s,dtype=object)

key = jax.random.PRNGKey(42)

# Start timer
beg = time.time()

# Loop over batches
for i in range(num_samples // bs):
    # Pad the current batch
    X1, mask1, res1, chain1, X2, mask2, res2, chain2, lens = utils.pad_(
        X1s[i * bs:(i + 1) * bs], mask1s[i * bs:(i + 1) * bs], res1s[i * bs:(i + 1) * bs], chain1s[i * bs:(i + 1) * bs],
        X1s[i * bs:(i + 1) * bs], mask1s[i * bs:(i + 1) * bs], res1s[i * bs:(i + 1) * bs], chain1s[i * bs:(i + 1) * bs],
        max_len
    )
    input_data = X2, mask2, res2, chain2
    encodings_ = enco_fast(params, key, input_data)

    # Directly extend the encodings list with the results
    encodings.extend(encodings_)

# Process the remaining samples if any
if num_samples % bs != 0:
    X1, mask1, res1, chain1, X2, mask2, res2, chain2, lens = utils.pad_(
        X1s[num_samples - num_samples % bs:], mask1s[num_samples - num_samples % bs:], res1s[num_samples - num_samples % bs:],
        chain1s[num_samples - num_samples % bs:], X1s[num_samples - num_samples % bs:], mask1s[num_samples - num_samples % bs:],
        res1s[num_samples - num_samples % bs:], chain1s[num_samples - num_samples % bs:], max_len
    )
    input_data = X2, mask2, res2, chain2
    encodings_ = ENCOO.apply(params, key, input_data)
    encodings.extend(encodings_)

# End timing and print
print(time.time() - beg)



dicti_encodings = {}
for l,k in enumerate(encodings):

  dicti_encodings[id1s[l]] = k[:l1[l],:]

241.62516736984253


#One-VS-all

If you're only interested in evaluating a single query structure against all the others in your dataset, you can run this section.

It will compute the scores and save them in a .csv file for easy analysis.

📄 Output: A CSV file containing the scores for all sequences compared to your query.

In [7]:
thresholds = np.arange(100,max_len+100,100)
print(thresholds)
reusable_target_data = search.setup_target_data(dicti_encodings, dicti_inputs,thresholds)
query_id = "d2dixa1" # @param
enc = dicti_encodings.get(query_id)

if enc is None:
    print(f"Query ID '{query_id}' not found in dicti_encodings.")
else:
    l_query = enc.shape[0]
    l_query_pad = l_query

    print(f"Processing single query: {query_id} (length={l_query}), using padding {l_query_pad}")

    try:
        search.compute_scores_for_query(
            query_id=query_id,
            target_data=reusable_target_data,
            model_type=model_type,
            l_query_pad=l_query_pad
        )
    except Exception as e:
        print(f"Error processing {query_id}: {e}")


[ 100  200  300  400  500  600  700  800  900 1000 1100 1200 1300 1400
 1500]
--- Starting One-Time Target Setup ---
Dispatching pre-processing for all buckets...
Waiting for data to be moved to device...
--- One-Time Setup Finished in 73.36 seconds ---
Processing single query: d2dixa1 (length=73), using padding 73

Processing query: d2dixa1
(100,) (73,)
(73, 100) (73, 100)
(100,) (73,)
(73, 100) (73, 100)
(200,) (73,)
(73, 200) (73, 200)
(200,) (73,)
(73, 200) (73, 200)
(300,) (73,)
(73, 300) (73, 300)
(300,) (73,)
(73, 300) (73, 300)
(400,) (73,)
(73, 400) (73, 400)
(400,) (73,)
(73, 400) (73, 400)
(500,) (73,)
(73, 500) (73, 500)
(500,) (73,)
(73, 500) (73, 500)
(600,) (73,)
(73, 600) (73, 600)
(700,) (73,)
(73, 700) (73, 700)
(800,) (73,)
(73, 800) (73, 800)
(900,) (73,)
(73, 900) (73, 900)
(1000,) (73,)
(73, 1000) (73, 1000)
(1100,) (73,)
(73, 1100) (73, 1100)
(1200,) (73,)
(73, 1200) (73, 1200)
(1300,) (73,)
(73, 1300) (73, 1300)
(1400,) (73,)
(73, 1400) (73, 1400)
(1500,) (73,)


# 🔄 All-vs-All Search

This section performs a full all-vs-all comparison, where each encoded query is evaluated against all others in the dataset using the compute_scores_for_query function.

Each query is treated as a search input, and scores are computed against a shared target set (reusable_target_data).


In [8]:
thresholds = np.arange(100,max_len+100,100)
print(thresholds)
reusable_target_data = search.setup_target_data(dicti_encodings, dicti_inputs,thresholds)
thresholds = np.arange(100, max_len + 100, 100)

for threshold in thresholds:
    min_len = threshold - 100
    max_len = threshold
    l_query_pad = threshold  # ou autre logique si besoin

    print(f"\n=== Queries bt {min_len} and {max_len}  ===")
    compt = 0

    for query_id, enc in dicti_encodings.items():
        l_query = enc.shape[0]

        if min_len <= l_query <= max_len:
            compt += 1
            print(f"Processing query: {query_id} (length={l_query}), count={compt}")
            try:
                search.compute_scores_for_query(
                    query_id=query_id,
                    target_data=reusable_target_data,
                    model_type=model_type,
                    l_query_pad=l_query_pad
                )
            except Exception as e:
                print(f"Error processing {query_id}: {e}")


[ 100  200  300  400  500  600  700  800  900 1000 1100 1200 1300 1400
 1500]
--- Starting One-Time Target Setup ---
Dispatching pre-processing for all buckets...


KeyboardInterrupt: 