# OMA scripting
# ===============================

In [1]:
# system dependencies
import sys
import logging
import os
import time

# library dependencies
import click
import duckdb as ddb
import pandas as pd
import pyhmmer
from sklearn.utils import resample
from tqdm import tqdm


# local dependencies
import pairpro.utils as pp_utils
# blast
import pairpro.user_blast as pp_up
# hmmer
import pairpro.hmmer as pp_hmmer
# structure
import pairpro.structures as pp_structures
# ML
from pairpro.train_val_wrapper import train_val_wrapper

In [2]:
####################
### PATHS & VARS ###
####################
# db Paths
TEST_DB_PATH = '../tmp/oma.db' 

# BLAST Paths
BLAST_OUTPUT_DIR = '../data/protein_pairs/blast_output/'

# HMMER Paths
HMM_PATH = '../data/pfam/Pfam-A.hmm'  # ./Pfam-A.hmm
PRESS_PATH = '../data/pfam/pfam'
HMMER_OUTPUT_DIR = '../data/protein_pairs/'
PARSE_HMMER_OUTPUT_DIR = '../data/protein_pairs/parsed_hmmer_output/'
WORKER_WAKE_UP_TIME = 25  # this is to ensure that if a worker that is about to be shut down due to previous task completetion doesn't actually start running

# Structure Paths
STRUCTURE_DIR = '../data/structures/'
STRUCTURE_OUTPUT_DIR = '../data/protein_pairs/structures/'

# ML Paths
MODEL_PATH = '../data/models/'

In [3]:
##################
# Aux. functions #
##################

def auto_balance_data(dataframe, target_column):
    """
    Automatically balances the dataframe based on the label distribution in the target column.
    Applies under-sampling, over-sampling, or a combination based on the label's distribution.

    Args:
        dataframe (pandas.DataFrame): The training dataframe.
        target_column (str): The column whose labels should be balanced.

    Returns:
        pandas.DataFrame: A new DataFrame with balanced labels.
    """
    # Count the frequency of each class
    class_counts = dataframe[target_column].value_counts()
    max_count = class_counts.max()
    min_count = class_counts.min()
    
    # Determine the ratio of the largest class to the smallest class
    ratio = max_count / min_count
    print(f'ratio of max to min: {ratio}')

    # Decide the strategy based on the ratio
    if ratio < 1.5:
        # If ratio is small (fairly balanced already), over-sample the minority
        over_sampled_dfs = []
        for label in class_counts.index:
            label_df = dataframe[dataframe[target_column] == label]
            resampled_df = resample(label_df, replace=True, n_samples=max_count)
            over_sampled_dfs.append(resampled_df)
        balanced_df = pd.concat(over_sampled_dfs)
    else:
        # If the imbalance is significant, under-sample the majority and over-sample the minority
        under_sampled_dfs = []
        over_sampled_dfs = []
        for label in class_counts.index:
            label_df = dataframe[dataframe[target_column] == label]
            if class_counts[label] == max_count:
                # Under-sample the majority class
                resampled_df = resample(label_df, replace=False, n_samples=min_count)
                under_sampled_dfs.append(resampled_df)
            else:
                # Over-sample the minority class
                resampled_df = resample(label_df, replace=True, n_samples=max_count)
                over_sampled_dfs.append(resampled_df)
        balanced_df = pd.concat(under_sampled_dfs + over_sampled_dfs)

    return balanced_df

## Actual script w/o click stuff

In [4]:
##### database construction #####

con = ddb.connect(TEST_DB_PATH, read_only=False) # create a database. Has to be read_only=False

# create main table
con.execute("""CREATE OR REPLACE TABLE OMA_main AS 
            (
            SELECT query_id, subject_id, pair_id, query, subject 
            FROM
            (
                SELECT protein1_uniprot_id AS query_id, protein2_uniprot_id AS subject_id, pair_id, protein1_sequence AS query, protein2_sequence AS subject
                FROM combined_pairs
            ) 
            );""")

con.commit() # commit the changes. Otherwise, the table will not be created.

# create a table for proteins in pairs
con.execute("""CREATE OR REPLACE TABLE processed_proteins AS 
    (
        SELECT DISTINCT pid, protein_seq
        FROM 
        (
            SELECT protein1_uniprot_id AS pid, protein2_sequence as protein_seq
            FROM combined_pairs
            UNION ALL
            SELECT protein2_uniprot_id AS pid, protein2_sequence as protein_seq
            FROM combined_pairs
        )   
    );""")

con.commit() # commit the changes. Otherwise, the table will not be created.

<duckdb.DuckDBPyConnection at 0x112370f30>

In [14]:
con.close()

**Quick comment**:
This works as exepcted. Let's assume. We want to BLAST via Click. We will think about synergy b/w modules later, i.e., ml_feature_list as a way to keep track of choices.

### BLAST

In [13]:
print('Starting to run BLAST')
dataframe_for_blast = con.execute("SELECT * FROM OMA_main LIMIT 2000").df()
print(f"DataFrame shape before BLAST processing: {dataframe_for_blast.shape}")

# run blast
s_time = time.time()
print('Starting to run BLAST')
blast_df = pp_up.blast_pairs(dataframe_for_blast, cpus=4)
print(f'BLAST completed in {time.time()-s_time} seconds')

# save blast results to csv
blast_df.to_csv(f'{BLAST_OUTPUT_DIR}blast_output.csv', index=False)

Starting to run BLAST
DataFrame shape before BLAST processing: (2000, 5)
Starting to run BLAST
Found and skipped 0 invalid row(s) containing invalid amino acid sequences.


BLAST completed in 4.8534321784973145 seconds


### Testing DB manuplations and BLAST

In [17]:
con.execute("""CREATE OR REPLACE TEMP TABLE blast_results AS 
                    SELECT * FROM read_csv_auto('../data/protein_pairs/blast_output/blast_output.csv', HEADER=TRUE)""")

<duckdb.DuckDBPyConnection at 0x10af7a5f0>

In [18]:
# sanity check
con.execute("DESCRIBE blast_results").df()

Unnamed: 0,column_name,column_type,null,key,default,extra
0,pair_id,VARCHAR,YES,,,
1,query_id,VARCHAR,YES,,,
2,subject_id,VARCHAR,YES,,,
3,bit_score,DOUBLE,YES,,,
4,local_gap_compressed_percent_id,DOUBLE,YES,,,
5,scaled_local_query_percent_id,DOUBLE,YES,,,
6,scaled_local_symmetric_percent_id,DOUBLE,YES,,,
7,query_align_len,BIGINT,YES,,,
8,query_align_cov,DOUBLE,YES,,,
9,subject_align_len,BIGINT,YES,,,


In [19]:
# sanity check
con.execute("SELECT * FROM OMA_main LIMIT 5").df()

Unnamed: 0,query_id,subject_id,pair_id,query,subject
0,Q6GG31,A0A0L9Z481,clean_1,MFKFNEDEENLKCSFCGKDQDQVKKLVAGSGVYICNECIELCSEIV...,MSKLDEKKQLKCSFCGKTQDQVRRLIAGPGVYICDECIELCSEIIN...
1,Q6GG31,A0A2T7B9D1,clean_2,MFKFNEDEENLKCSFCGKDQDQVKKLVAGSGVYICNECIELCSEIV...,MTDKRKDSSGKLLYCSFCGKSQHEVRKLIAGPSVYICDECVDLCND...
2,A0A0M9XI34,A0A6L8P192,clean_3,MADTVKTTRETAGTPAATHWHQRADRRGGRGTRTLRVRTSAVLVAA...,MLANPEKQTEVIHYEKIPSGFSIMWREFRKDKLAMFSLFFLALILI...
3,Q4L904,A0A643CKU9,clean_4,MFKIGNLELQSRLLLGTGKFENEDVQTEAIKASETNVLTFAVRRMN...,MARRGNVWNVYGAELNSRLLLGSALYPSPEVLKQAILNSGTEVVTV...
4,Q4L8Y1,A0A427NXW7,clean_5,MTELNGRVAIITGASSGIGAATAKALEKQGVKVVLAGRSHDKLNTL...,MTAPLEGQVAIVTGGARGIGRGIALTLAGAGADILLADLLDDALDA...


In [21]:
columns_to_add = [("local_gap_compressed_percent_id", "DOUBLE"),
                          ("scaled_local_query_percent_id", "DOUBLE"),
                          ("scaled_local_symmetric_percent_id", "DOUBLE"),
                          ("query_align_len", "DOUBLE"),
                          ("query_align_cov", "DOUBLE"),
                          ("subject_align_len", "DOUBLE"),
                          ("subject_align_cov", "DOUBLE"),
                          ("bit_score", "DOUBLE")]

for column_name, column_type in columns_to_add:
    con.execute(f"""
        ALTER TABLE OMA_main
        ADD COLUMN {column_name} {column_type}
    """)

In [22]:
con.execute("SELECT * FROM OMA_main LIMIT 5").df()

Unnamed: 0,query_id,subject_id,pair_id,query,subject,local_gap_compressed_percent_id,scaled_local_query_percent_id,scaled_local_symmetric_percent_id,query_align_len,query_align_cov,subject_align_len,subject_align_cov,bit_score
0,Q6GG31,A0A0L9Z481,clean_1,MFKFNEDEENLKCSFCGKDQDQVKKLVAGSGVYICNECIELCSEIV...,MSKLDEKKQLKCSFCGKTQDQVRRLIAGPGVYICDECIELCSEIIN...,,,,,,,,
1,Q6GG31,A0A2T7B9D1,clean_2,MFKFNEDEENLKCSFCGKDQDQVKKLVAGSGVYICNECIELCSEIV...,MTDKRKDSSGKLLYCSFCGKSQHEVRKLIAGPSVYICDECVDLCND...,,,,,,,,
2,A0A0M9XI34,A0A6L8P192,clean_3,MADTVKTTRETAGTPAATHWHQRADRRGGRGTRTLRVRTSAVLVAA...,MLANPEKQTEVIHYEKIPSGFSIMWREFRKDKLAMFSLFFLALILI...,,,,,,,,
3,Q4L904,A0A643CKU9,clean_4,MFKIGNLELQSRLLLGTGKFENEDVQTEAIKASETNVLTFAVRRMN...,MARRGNVWNVYGAELNSRLLLGSALYPSPEVLKQAILNSGTEVVTV...,,,,,,,,
4,Q4L8Y1,A0A427NXW7,clean_5,MTELNGRVAIITGASSGIGAATAKALEKQGVKVVLAGRSHDKLNTL...,MTAPLEGQVAIVTGGARGIGRGIALTLAGAGADILLADLLDDALDA...,,,,,,,,


Cool! Now it works. for-loop of tuples ftw!

In [23]:
update_columns = ["local_gap_compressed_percent_id",
                          "scaled_local_query_percent_id",
                          "scaled_local_symmetric_percent_id",
                          "query_align_len",
                          "query_align_cov",
                          "subject_align_len",
                          "subject_align_cov",
                          "bit_score"]
        
for column in update_columns:
    con.execute(f"""
                UPDATE OMA_main
                SET {column} = (
                    SELECT b.{column}
                    FROM blast_results AS b
                    WHERE b.query_id = OMA_main.query_id
                    AND b.subject_id = OMA_main.subject_id
                    AND b.pair_id = OMA_main.pair_id
                    )""")

In [24]:
con.execute("SELECT * FROM OMA_main LIMIT 5").df()

Unnamed: 0,query_id,subject_id,pair_id,query,subject,local_gap_compressed_percent_id,scaled_local_query_percent_id,scaled_local_symmetric_percent_id,query_align_len,query_align_cov,subject_align_len,subject_align_cov,bit_score
0,Q6GG31,A0A0L9Z481,clean_1,MFKFNEDEENLKCSFCGKDQDQVKKLVAGSGVYICNECIELCSEIV...,MSKLDEKKQLKCSFCGKTQDQVRRLIAGPGVYICDECIELCSEIIN...,0.644186,0.635321,0.635321,436.0,0.986239,436.0,0.988532,1352.0
1,Q6GG31,A0A2T7B9D1,clean_2,MFKFNEDEENLKCSFCGKDQDQVKKLVAGSGVYICNECIELCSEIV...,MTDKRKDSSGKLLYCSFCGKSQHEVRKLIAGPSVYICDECVDLCND...,0.574118,0.568765,0.568765,429.0,0.983683,429.0,0.993007,1093.0
2,A0A0M9XI34,A0A6L8P192,clean_3,MADTVKTTRETAGTPAATHWHQRADRRGGRGTRTLRVRTSAVLVAA...,MLANPEKQTEVIHYEKIPSGFSIMWREFRKDKLAMFSLFFLALILI...,0.278317,0.275641,0.275641,312.0,0.974359,312.0,1.0,298.0
3,Q4L904,A0A643CKU9,clean_4,MFKIGNLELQSRLLLGTGKFENEDVQTEAIKASETNVLTFAVRRMN...,MARRGNVWNVYGAELNSRLLLGSALYPSPEVLKQAILNSGTEVVTV...,0.375,0.375,0.375,264.0,0.996212,264.0,1.0,375.0
4,Q4L8Y1,A0A427NXW7,clean_5,MTELNGRVAIITGASSGIGAATAKALEKQGVKVVLAGRSHDKLNTL...,MTAPLEGQVAIVTGGARGIGRGIALTLAGAGADILLADLLDDALDA...,0.321569,0.311787,0.311787,263.0,0.969582,263.0,0.980989,261.0


#### Quick tests for global alignment

In [13]:
print('Starting to run BLAST')
dataframe_for_blast = con.execute("SELECT * FROM OMA_main LIMIT 10000").df()
print(f"DataFrame shape before BLAST processing: {dataframe_for_blast.shape}")

# run blast
s_time = time.time()
print('Starting to run BLAST')
blast_df, con2 = pp_up.make_blast_df(dataframe_for_blast, cpus=4, path="../data/protein_pairs/blast_output/blast_db.db")
con2.close()
print(f'BLAST completed in {time.time()-s_time} seconds')

# save blast results to csv
blast_df.to_csv(f'{BLAST_OUTPUT_DIR}global_blast_output.csv', index=False)

Starting to run BLAST
DataFrame shape before BLAST processing: (10000, 5)
Starting to run BLAST
BLAST completed in 40.53030586242676 seconds


### HMMER

### Structure/FATCAT 2.0

#### Issues
* Issue 1: We need a way to get the PDB IDs from the Uniprot IDs.
    * Issue 1b: We need to update the structure module to have the above functionality.
* Issue 2: We need to make sure the module is working as expected as the previous data ingestion was purely meso-thermo protein pair data instead of orthologs.

##### Issue 1 Work

In [5]:
con.execute("SELECT query_id, subject_id FROM OMA_main").df()

Unnamed: 0,query_id,subject_id
0,Q6GG31,A0A0L9Z481
1,Q6GG31,A0A2T7B9D1
2,A0A0M9XI34,A0A6L8P192
3,Q4L904,A0A643CKU9
4,Q4L8Y1,A0A427NXW7
...,...,...
402324,A0A2A1KDX2,Q4AAJ3
402325,A0A4Q9W971,A0A380SBW5
402326,A0A0B6TQ36,A0A045JBZ6
402327,A0A0B8QRQ5,A0A663D6L2


The above are the uniprot IDs. So, these are our input data for the mapping of Uniprot IDs to PDB IDs.

__NOTE__: We need to think about the chain ID not just the PDB ID, i.e., PDB ID + Chain ID. So, we can accurately map sequences to structures.
Therefore, we can have good structure alignments via FATCAT 2.0.

###### Pontential Solutions
1. We can use the Uniprot API/SIFTS to get the mappings.
2. We can use the package `localpdb` to get the mappings.

In [6]:
from localpdb import PDB

In [8]:
# Setting up the local PDB database  (after setting up the local PDB database via the CLI)
pdb = PDB(db_path='../tmp/pdb.db', version='latest')

In [10]:
seq_df = con.execute("SELECT query, subject FROM OMA_main").df()
seq_df.head()

Unnamed: 0,query,subject
0,MFKFNEDEENLKCSFCGKDQDQVKKLVAGSGVYICNECIELCSEIV...,MSKLDEKKQLKCSFCGKTQDQVRRLIAGPGVYICDECIELCSEIIN...
1,MFKFNEDEENLKCSFCGKDQDQVKKLVAGSGVYICNECIELCSEIV...,MTDKRKDSSGKLLYCSFCGKSQHEVRKLIAGPSVYICDECVDLCND...
2,MADTVKTTRETAGTPAATHWHQRADRRGGRGTRTLRVRTSAVLVAA...,MLANPEKQTEVIHYEKIPSGFSIMWREFRKDKLAMFSLFFLALILI...
3,MFKIGNLELQSRLLLGTGKFENEDVQTEAIKASETNVLTFAVRRMN...,MARRGNVWNVYGAELNSRLLLGSALYPSPEVLKQAILNSGTEVVTV...
4,MTELNGRVAIITGASSGIGAATAKALEKQGVKVVLAGRSHDKLNTL...,MTAPLEGQVAIVTGGARGIGRGIALTLAGAGADILLADLLDDALDA...


In [12]:
test_sequence  = seq_df['query'].iloc[0]

In [14]:
dir(pdb)

['_PDB__chains',
 '_PDB__chains_copy',
 '_PDB__config',
 '_PDB__entries',
 '_PDB__entries_copy',
 '_PDB__lock',
 '_PDB__registered_attrs',
 '_PDB__rest_api_commands',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_add_col_chains',
 '_add_col_structures',
 '_get_current_indexes',
 '_loaded_plugins',
 '_loaded_plugins_handles',
 '_pdb_bundles_fn',
 '_pdb_entries_fn',
 '_pdb_entries_type_fn',
 '_pdb_res_fn',
 '_pdb_seqres_fn',
 '_pdbv',
 '_register_attr',
 '_remove_attr',
 '_set_filenames',
 '_working_path',
 'auto_filter',
 'bundles',
 'chains',
 'db_path',
 'entries',
 'extract',
 'load_plugin',
 'reset',
 'search',
 'search_seq',
 'search_seq_motif'

In [28]:
# Example: Querying for entries with a specific sequence
# sequence = 'YOUR_PROTEIN_SEQUENCE_HERE'
# query_result = 
pdb.search_seq_motif(test_sequence, type_='simple')

# Print out the PDB IDs and chains that match the query
# for entry in query_result.entries:
#     print(f"PDB ID: {entry.pdb_id}, Chain: {entry.chain_id}")

'Could not find response. Please revise your query.'

Tried every method for the pdb variable. It seems that the `localpdb` package isn't the best for this task. So, we will use the Uniprot API/SIFTS to get the mappings.

In [29]:
import requests

In [30]:
def download_sifts_file(file_url, save_path):
    response = requests.get(file_url, stream=True)
    if response.status_code == 200:
        with open(save_path, 'wb') as f:
            f.write(response.raw.read())
    else:
        print("Failed to download file")

# URL for a specific SIFTS file (e.g., UniProt to PDB mappings)
file_url = 'https://ftp.ebi.ac.uk/pub/databases/msd/sifts/csv/pdb_chain_uniprot.csv'
save_path = '../tmp/pdb_chain_uniprot.csv'

download_sifts_file(file_url, save_path)

In [33]:
pd.read_csv(save_path, skiprows=1).head()

Unnamed: 0,PDB,CHAIN,SP_PRIMARY,RES_BEG,RES_END,PDB_BEG,PDB_END,SP_BEG,SP_END
0,101m,A,P02185,1,154,0,153.0,1,154
1,102l,A,P00720,1,40,1,40.0,1,40
2,102l,A,P00720,42,165,41,,41,164
3,102m,A,P02185,1,154,0,153.0,1,154
4,103l,A,P00720,1,40,1,,1,40
