In [43]:
import pandas as pd
import requests
from Bio.PDB import PDBList, MMCIFParser
import pickle
import os
from tqdm import tqdm
from pathlib import Path
import shutil
import numpy as np

def handle_case_xray(json_response, dict_info, chain_dict, uniprot_id, idx, error_dict, file_descriptor):
    for entry in json_response['uniProtKBCrossReferences']:
        if entry['database'] == 'PDB':
            pdb_id = entry['id']
            url_pdb = f'https://data.rcsb.org/rest/v1/core/entry/{pdb_id}'
            pdb_response = requests.get(url_pdb)
            if pdb_response.status_code == 200:
                json_response_pdb = pdb_response.json()
                if 'resolution_combined' in json_response_pdb['rcsb_entry_info']:
                    entities = json_response_pdb['rcsb_entry_container_identifiers']['polymer_entity_ids']
                    for entity in entities:
                        url_pdb_entity = f'https://data.rcsb.org/rest/v1/core/polymer_entity/{pdb_id}/{entity}'
                        pdb_response_entity = requests.get(url_pdb_entity)
                        if pdb_response_entity.status_code == 200:
                            json_response_pdb_entity = pdb_response_entity.json()
                            if ('rcsb_polymer_entity_align' in json_response_pdb_entity and 
                                len(json_response_pdb_entity['rcsb_polymer_entity_align']) == 1 and 
                                json_response_pdb_entity['rcsb_polymer_entity_align'][0]['reference_database_accession'] == uniprot_id):
                                    dict_info[pdb_id] = (json_response_pdb_entity['entity_poly']['rcsb_sample_sequence_length'], 
                                                        json_response_pdb['rcsb_entry_info']['resolution_combined'][0])
                                    chain_dict[pdb_id] = json_response_pdb_entity['entity_poly']['pdbx_strand_id'].split(',')
                                    #print(uniprot_id)
                                    #print(pdb_id)
                                    #print(json_response_pdb_entity['entity_poly']['pdbx_strand_id'])
                        else:
                            error_dict = handle_dict(uniprot_id, idx, error_dict, file_descriptor, 'Requests returned error (RCSB PDB entity)')
            else:
                error_dict = handle_dict(uniprot_id, idx, error_dict, file_descriptor, 'Requests returned error (RCSB PDB)')
    
    return dict_info, chain_dict, error_dict


def handle_case_no_xray(json_response, dict_info, chain_dict, uniprot_id, idx, error_dict, file_descriptor):
    for entry in json_response['uniProtKBCrossReferences']:
        if entry['database'] == 'PDB':
            pdb_id = entry['id']
            url_pdb = f'https://data.rcsb.org/rest/v1/core/entry/{pdb_id}'
            pdb_response = requests.get(url_pdb)
            if pdb_response.status_code == 200:
                json_response_pdb = pdb_response.json()
                entities = json_response_pdb['rcsb_entry_container_identifiers']['polymer_entity_ids']
                for entity in entities:
                    url_pdb_entity = f'https://data.rcsb.org/rest/v1/core/polymer_entity/{pdb_id}/{entity}'
                    pdb_response_entity = requests.get(url_pdb_entity)
                    if pdb_response_entity.status_code == 200:
                        json_response_pdb_entity = pdb_response_entity.json()
                        if ('rcsb_polymer_entity_align' in json_response_pdb_entity and 
                            len(json_response_pdb_entity['rcsb_polymer_entity_align']) == 1 and 
                            json_response_pdb_entity['rcsb_polymer_entity_align'][0]['reference_database_accession'] == uniprot_id):
                                dict_info[pdb_id] = (json_response_pdb_entity['entity_poly']['rcsb_sample_sequence_length'], 0.0)
                                chain_dict[pdb_id] = json_response_pdb_entity['entity_poly']['pdbx_strand_id'].split(',')
                    else:
                        error_dict = handle_dict(uniprot_id, idx, error_dict, file_descriptor, 'Requests returned error (RCSB PDB entity)')
            else:
                error_dict = handle_dict(uniprot_id, idx, error_dict, file_descriptor, 'Requests returned error (RCSB PDB)')
    
    return dict_info, chain_dict, error_dict


def delete_if_exist(folder_name):
    dirpath = Path(folder_name)
    if dirpath.exists() and dirpath.is_dir():
        shutil.rmtree(dirpath)

def sort_dict(pdb_id, dict_info):
    # Extract the tuple values for sorting
    sequence_length, resolution = dict_info[pdb_id]
    return (sequence_length, -resolution)  # Negate for descending order

def handle_dict(k, v, d, file_descriptor, message):
    file_descriptor.write(f"{k} : {message}\n")
    if k in d:
        d[k].append(v)
    else:
        d[k] = [v]
    return d

def download_pdb_files(uniprot_id_list, file_descriptor):

    error_dict = {}

    for idx, uniprot_id in enumerate(tqdm(uniprot_id_list[20:30])):
        url = f'https://www.uniprot.org/uniprotkb/{uniprot_id}.json'
        response = requests.get(url)
        if response.status_code == 200:
            json_response = response.json()
            dict_info = {}
            chain_dict = {}
            if 'uniProtKBCrossReferences' in json_response:
                dict_info, chain_dict, error_dict = handle_case_xray(json_response=json_response,
                                                                     dict_info=dict_info,
                                                                     chain_dict=chain_dict,
                                                                     uniprot_id=uniprot_id,
                                                                     idx=idx,
                                                                     error_dict=error_dict,
                                                                     file_descriptor=file_descriptor)
                if not dict_info:
                    dict_info, chain_dict, error_dict = handle_case_no_xray(json_response=json_response,
                                                                            dict_info=dict_info,
                                                                            chain_dict=chain_dict,
                                                                            uniprot_id=uniprot_id,
                                                                            idx=idx,
                                                                            error_dict=error_dict,
                                                                            file_descriptor=file_descriptor)
                    
            else:
                error_dict = handle_dict(uniprot_id, idx, error_dict, file_descriptor, 'ID no longer present in Uniprot')
                continue
            
            if not dict_info:
                error_dict = handle_dict(uniprot_id, idx, error_dict, file_descriptor, 'Only alphafold or no structure info')
            else:
                #s = sorted(s, key = lambda x: (x[1], x[2]))
                right_pdb_id = max(dict_info, key=lambda x: sort_dict(x, dict_info))
                # Create an instance of the PDBList class
                pdb_list = PDBList()
                # Download the MMCIF file using the retrieve_pdb_file method
                pdb_filename = pdb_list.retrieve_pdb_file(right_pdb_id, pdir="data/PDB_files", file_format="mmCif")

                parser = MMCIFParser()
                structure = parser.get_structure('protein_structure', os.path.join('data', 'PDB_files', right_pdb_id + '.cif'))

                print(chain_dict[right_pdb_id])

                selected_chains = []
                for chain in structure[0]:
                    if chain.get_id() in chain_dict[right_pdb_id]:
                        selected_chains.append(chain)
        
        else:
            error_dict = handle_dict(uniprot_id, idx, error_dict, file_descriptor, 'Requests returned error (Uniprot)')
    
    return error_dict

column_names = [
    'protein_xref_1',
    'protein_xref_2',
    'alternative_identifiers_1',
    'alternative_identifiers_2',
    'protein_alias_1',
    'protein_alias_2',
    'detection_method',
    'author_name',
    'pmid',
    'protein_taxid_1',
    'protein_taxid_2',
    'interaction_type',
    'source_database_id',
    'database_identifier',
    'confidence',
]

pdb_path = os.path.join('data', 'PDB_files')
delete_if_exist(pdb_path)
os.makedirs(pdb_path)

pkl_path = os.path.join('data', 'PKL_files')
delete_if_exist(pkl_path)
os.makedirs(pkl_path)

# Load the PSI-MITAB file into a pandas DataFrame
df = pd.read_csv('hpidb2.mitab.txt', sep='\t', header=0, encoding='ISO-8859-1')
df2 = pd.read_csv('species_human.txt', sep='\t', header=None, encoding='ISO-8859-1', names=column_names)

df = pd.concat([df, df2], axis=0)

# Show the original number of rows
print(len(df))

# Filter the DataFrame
filtered_df = df[df['protein_xref_1'].str.contains('uniprot', na=False) & 
                  df['protein_xref_2'].str.contains('uniprot', na=False)]

# Show the filtered number of rows
print(len(filtered_df))

filtered_df = filtered_df[filtered_df['interaction_type'].str.contains('direct interaction', na=False)]

print(len(filtered_df))
# Use .loc to set values in a slice of the DataFrame
filtered_df.loc[:, 'protein_xref_1'] = filtered_df['protein_xref_1'].apply(lambda x: x.split('uniprotkb:')[-1])
filtered_df.loc[:, 'protein_xref_1'] = filtered_df['protein_xref_1'].apply(lambda x: x.split('-')[0])

# Optionally apply the same operation to 'protein_xref_2' if needed
filtered_df.loc[:, 'protein_xref_2'] = filtered_df['protein_xref_2'].apply(lambda x: x.split('uniprotkb:')[-1])
filtered_df.loc[:, 'protein_xref_2'] = filtered_df['protein_xref_2'].apply(lambda x: x.split('-')[0])

print(len(filtered_df))

temp_df = filtered_df[['protein_xref_1', 'protein_xref_2']]
duplicate_indices = temp_df.duplicated().values
filtered_df = filtered_df.drop(filtered_df[duplicate_indices].index, axis='index')
print(len(filtered_df))

filtered_df = filtered_df.reset_index(drop=True)

filtered_df.to_csv('filtered_ppi.csv')

# Display the modified column to check results
with open('error.txt', 'w+') as file_descriptor:
    error_dict_1 = download_pdb_files(filtered_df['protein_xref_1'].values, file_descriptor)
    error_dict_2 = download_pdb_files(filtered_df['protein_xref_2'].values, file_descriptor)

error_list_1 = [item for sublist in error_dict_1.values() for item in sublist]
error_list_2 = [item for sublist in error_dict_2.values() for item in sublist]

error_indices = np.unique(error_list_1 + error_list_2)
print(error_indices)
filtered_df = filtered_df.drop(error_indices, axis='index')
print(len(filtered_df))

filtered_df.to_csv('filtered_ppi_dropped.csv')

#uniprot_id = 'Q9NPD8'
#url = f'https://www.uniprot.org/uniprotkb/{uniprot_id}.json'
#response = requests.get(url)
#if response.status_code == 200:
#    json_response = response.json()
#    dict_info = {}
#    chain_dict = {}
#    for entry in json_response['uniProtKBCrossReferences']:
#        if entry['database'] == 'PDB':
#            pdb_id = entry['id']
#            url_pdb = f'https://data.rcsb.org/rest/v1/core/entry/{pdb_id}'
#            pdb_response = requests.get(url_pdb)
#            if pdb_response.status_code == 200:
#                json_response_pdb = pdb_response.json()
#                print(pdb_id)
#                entities = json_response_pdb['rcsb_entry_container_identifiers']['polymer_entity_ids']
#                print(entities)
#                for entity in entities:
#                    url_pdb_entity = f'https://data.rcsb.org/rest/v1/core/polymer_entity/{pdb_id}/{entity}'
#                    pdb_response_entity = requests.get(url_pdb_entity)
#                    if pdb_response_entity.status_code == 200:
#                        json_response_pdb_entity = pdb_response_entity.json()
#                        if ('rcsb_polymer_entity_align' in json_response_pdb_entity and 
#                        len(json_response_pdb_entity['rcsb_polymer_entity_align']) == 1 and 
#                        json_response_pdb_entity['rcsb_polymer_entity_align'][0]['reference_database_accession'] == uniprot_id):
#                            dict_info[pdb_id] = (json_response_pdb_entity['entity_poly']['rcsb_sample_sequence_length'], 
#                                                 json_response_pdb['rcsb_entry_info']['resolution_combined'][0])
#                            chain_dict[pdb_id] = json_response_pdb_entity['entity_poly']['pdbx_strand_id'].split(',')
#                            print(json_response_pdb_entity['rcsb_polymer_entity_align'][0]['reference_database_accession'])
#                            print(json_response_pdb_entity['entity_poly']['rcsb_sample_sequence_length'])
#    
#    #s = sorted(s, key = lambda x: (x[1], x[2]))
#    right_pdb_id = max(dict_info, key=sort_dict)
#    # Create an instance of the PDBList class
#    pdb_list = PDBList()
#    # Download the MMCIF file using the retrieve_pdb_file method
#    pdb_filename = pdb_list.retrieve_pdb_file(right_pdb_id, pdir="data/PDB_files", file_format="mmCif")
#
#    with open(os.path.join('data', 'PDB_files', f'{right_pdb_id}.pkl'), 'wb') as fd:
#        pickle.dump(chain_dict[right_pdb_id], fd, protocol=pickle.HIGHEST_PROTOCOL)

146160
136682
7730
7730
4305


  0%|          | 0/10 [00:00<?, ?it/s]

P14859
1CQT
A,B
P14859
1E3O
C
P14859
1GT0
C
P14859
1HF0
A,B
P14859
1OCT
C
Downloading PDB structure '1cqt'...


 10%|█         | 1/10 [00:07<01:06,  7.40s/it]

['A', 'B']
M
N
O
P
A
B
I
J
P06730
1IPB
A
P06730
1IPC
A
P06730
1WKW
A
P06730
2V8W
A,E
P06730
2V8X
A,E
P06730
2V8Y
A,E


 10%|█         | 1/10 [00:12<01:50, 12.24s/it]


KeyboardInterrupt: 