# Running IF-SitePred in Jupyter NB

## Imports

In [1]:
import os
import json
import esm
import torch
import pickle
import numpy as np 

## Utils functions

In [2]:
def save_to_pickle(variable, file_path):
    with open(file_path, 'wb') as file:
        pickle.dump(variable, file)

def read_from_pickle(file_path):
    with open(file_path, 'rb') as file:
        return pickle.load(file)

## Predicting residues functions

In [3]:
def get_chains(pdb_path):
    chains = [i[21:22] for i in open(pdb_path, 'r').readlines() if i.startswith('ATOM')]
    return list(set(chains))

def get_index_coords(pdb_path, chain):
    coords_resi = {}
    atoms = [i for i in open(pdb_path, 'r').readlines() if i.startswith('ATOM') and i[21:22] == chain]
    for atom in atoms:
        coords = [atom[30:38].strip(), atom[38:46].strip(), atom[46:55].strip()]
        coords_resi[','.join(coords)] = int(atom[22:26].strip())
    return coords_resi


def load_models(num_models):
    models = []
    for i in range(num_models):
        models.append(pickle.load(open(f'/Users/2394007/Documents/FROM_CLUSTER/CLUSTER/summer_project/binding-sites/models/lgbm_if1_{i}.pkl', 'rb')))
    return models

def write_chimerax_attr_file(data, attr_name, file_name, model_id='1', chain_id='A'):
    with open(file_name, 'w') as file:
        # Write the header
        file.write(f"attribute: {attr_name}\n")
        file.write("match mode: 1-to-1\n")
        file.write("recipient: residues\n")
        file.write("\n")  # Blank line for readability
        
        # Write each residue's attribute
        for res_num, attr_value in data.items():
            file.write(f"\t#{model_id}/{chain_id}:{res_num}\t{attr_value}\n")

## Load ESM-IF model

In [4]:
model, alphabet = esm.pretrained.load_model_and_alphabet_local('/Users/2394007/Documents/FROM_CLUSTER/CLUSTER/summer_project/binding-sites/ESM_models/esm_if1_gvp4_t16_142M_UR50.pt')
model.eval()



GVPTransformerModel(
  (encoder): GVPTransformerEncoder(
    (dropout_module): Dropout(p=0.1, inplace=False)
    (embed_tokens): Embedding(35, 512, padding_idx=1)
    (embed_positions): SinusoidalPositionalEmbedding()
    (embed_gvp_input_features): Linear(in_features=15, out_features=512, bias=True)
    (embed_confidence): Linear(in_features=16, out_features=512, bias=True)
    (embed_dihedrals): DihedralFeatures(
      (node_embedding): Linear(in_features=6, out_features=512, bias=True)
      (norm_nodes): Normalize()
    )
    (gvp_encoder): GVPEncoder(
      (embed_graph): GVPGraphEmbedding(
        (embed_node): Sequential(
          (0): GVP(
            (wh): Linear(in_features=3, out_features=256, bias=False)
            (ws): Linear(in_features=263, out_features=1024, bias=True)
            (wv): Linear(in_features=256, out_features=256, bias=False)
          )
          (1): LayerNorm(
            (scalar_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
         

In [5]:
models = load_models(40)
model_num = len(models)

## Target selection

In [18]:
target_dir = './../../clean_rep_chains'

In [19]:
targets = read_from_pickle("./results/PDB_rep_chains_files_V2.pkl")

In [20]:
#targets = [os.path.join(target_dir, t) for t in target_names]

In [21]:
targets[:5]

['6cph_D.pdb', '7sol_C.pdb', '4npj_A.pdb', '6y7f_B.pdb', '4iqy_B.pdb']

In [23]:
print(len(targets))

4037


In [9]:
for t in targets:
    if os.path.isfile(os.path.join(target_dir, t)):
        continue
    else:
        print(f'{t} does not exist')

In [10]:
preds_dir = "./results/IFSP_preds"

In [12]:
preds = os.listdir(preds_dir)

In [13]:
preds

['5xir_A',
 '8f16_A',
 '3bpt_A',
 '8e98_C',
 '4y72_A',
 '3iwl_A',
 '2q7r_F',
 '6b7h_A',
 '2n14_A',
 '3dk9_A',
 '2wng_A',
 '3gdh_B',
 '6kqy_A',
 '3s4y_A',
 '6uez_A',
 '3q4u_B',
 '1m8y_B',
 '7sxf_A',
 '2g4a_A',
 '5w8h_B',
 '1zs9_A',
 '1w73_B',
 '6rlr_A',
 '1d1j_A',
 '2jbh_B',
 '3p1m_B',
 '3p0y_A',
 '1kqn_D',
 '7okv_A',
 '6ohq_B',
 '2aex_A',
 '3fit_A',
 '1now_B',
 '5l6p_A',
 '6nt9_B',
 '6g0i_A',
 '2o52_B',
 '3nmq_A',
 '6km7_D',
 '8p50_B',
 '3n7o_A',
 '1su3_A',
 '6chd_A',
 '3ctz_A',
 '6i0z_A',
 '6gn5_A',
 '5om2_B',
 '3elb_A',
 '3n6s_A',
 '7mjd_A',
 '6j08_A',
 '8fp1_A',
 '4apc_B',
 '8gt5_A',
 '6bcu_R',
 '3bn3_A',
 '7ugk_B',
 '1cf4_A',
 '7kpu_D',
 '2vcy_A',
 '1psr_B',
 '4dch_A',
 '4egl_A',
 '4xdp_A',
 '1isi_A',
 '2obv_A',
 '2ci8_A',
 '4qhu_C',
 '1x05_A',
 '3vcm_B',
 '4orw_A',
 '7qbo_P',
 '3sa0_A',
 '2c47_C',
 '3zli_A',
 '8rbx_W',
 '4z2b_A',
 '6rj5_A',
 '4rqk_A',
 '2wwy_B',
 '6cvo_A',
 '1siq_A',
 '1p8t_A',
 '7sow_A',
 '7qns_AA',
 '7eoq_C',
 '6bth_B',
 '2hrb_A',
 '8fjx_A',
 '2qtz_A',
 '4zgg_A'

In [14]:
N_targets = len(targets)
print(N_targets)

4037


It is now 19:33 on 04/04/2024

## Running IF-SitePred on Human AF

In [17]:
errors = []
done = 0
for i, target in enumerate(targets):

    if i % 50 == 0:
        print(i)

    target_id = target.split(".")[0]
    target_path = os.path.join(target_dir, target)

    prediction_dir = os.path.join(preds_dir, target_id)

    if not os.path.isdir(prediction_dir):
        os.mkdir(prediction_dir)
        
    lig_scores_out = os.path.join(prediction_dir, "{}_lig_scores.pkl".format(target_id))
    lig_label_out = os.path.join(prediction_dir, "{}_lig_labels.pkl".format(target_id))
    binding_ress_out = os.path.join(prediction_dir, "{}_binding_ress.pkl".format(target_id))
    ligandability_attr_out = os.path.join(prediction_dir, "{}_ligandability.defattr".format(target_id))
    lig_label_attr_out = os.path.join(prediction_dir, "{}_ligand_binding.defattr".format(target_id))

    if os.path.isfile(lig_scores_out) and os.path.isfile(lig_label_out) and os.path.isfile(binding_ress_out) and os.path.isfile(ligandability_attr_out) and os.path.isfile(lig_label_attr_out):
        #print("Results exist for {}".format(target_id))
        done += 1
        continue
    
    chains = sorted(get_chains(target_path))
    
    for chain in chains[:1]: # looping throuch chains (only using first)
    
        structure = esm.inverse_folding.util.load_structure(target_path, chain) # loading structure chain
        coords_resi = get_index_coords(target_path, chain) # dictionary where key is coordinates on a string and value is resNum
        coords, seq = esm.inverse_folding.util.extract_coords_from_structure(structure)	# coordinates and sequence in string
        if1 = esm.inverse_folding.util.get_encoder_output(model, alphabet, coords).detach().numpy() # generate embeddings from structure coordinates
    
        raw_preds = []
        round_preds = [] 
        for model_flaml in models:
            raw_p = model_flaml.predict(if1) # raw probs to 0/1
            raw_preds.append(raw_p)
            
            round_p = np.rint(raw_p)
            round_preds.append(round_p) # rounded to 0/1
    
        pred_intersection = []
        for i in range(len(round_preds[0])):
            pred_intersection.append(1 if [round_preds[a][i] for a in range(model_num)].count(1) == model_num else 0)
        
        res_nums = []
        #print(coords)
        for i in range(len(coords)):
            coord_str = ','.join([str("{:.3f}".format(x)) for x in coords[i][0]])
            if pred_intersection[i] == 1:
                try:
                    res_nums.append(str(coords_resi[coord_str]))
                except Exception as e:
                    print("ERROR with {}".format(target_id)) # it could be that alt locs make this crash
                    errors.append(target)
                    continue
    
        lig_scores = np.mean(raw_preds, axis = 0)
        res_score_dict = {} # key is resNum, value is ligandability score
        bin_lab_dict = {} # key is resNum, value is binary 0/1 label
        for i in range(len(coords)):
            coord_str = ','.join([str("{:.3f}".format(x)) for x in coords[i][0]])
            try:
                resnum = str(coords_resi[coord_str])
                res_score_dict[resnum] = round(lig_scores[i], 2)
                bin_lab_dict[resnum] = pred_intersection[i]
            except Exception as e:
                print("ERROR with {}".format(target_id)) # it could be that alt locs make this crash
                errors.append(target)
                continue
                
    
        
        save_to_pickle(
            res_score_dict,
            lig_scores_out
        )
        
        save_to_pickle(
            bin_lab_dict,
            lig_label_out
        )
        
        save_to_pickle(
            {chain: res_nums},
            binding_ress_out
        )
        
        write_chimerax_attr_file(
            res_score_dict, "ligandability",
            ligandability_attr_out
        )

        write_chimerax_attr_file(
            bin_lab_dict, "ligand_binding",
            lig_label_attr_out
        )
        print("There are {} ligand binding residues in chain {} of {}".format(str(len(res_nums)), chain, target_id))

0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950
2000
2050
2100
2150
2200
2250
2300
2350
2400
2450
2500
2550
2600
2650
2700
2750
2800
2850
2900
2950
3000
3050
3100
3150
3200
3250
3300
3350
3400
3450
3500
3550
3600
3650
3700
3750
3800
3850
3900
3950
4000


It finished at 22:35 on 04/04/2024 182'

In [16]:
errors_unique = list(set(errors))
print(len(errors_unique))

17


## ChimeraX colouring commands

To colour by ligandability score:
    
    color byattribute r:ligandability #!1 target scab palette 0,white:0.5,#febe55:1,#de2d26

To colour by binary ligand-binding label:

    color byattribute r:ligand_binding #!1 target scab palette 0,white:0.5,white:1,red