In [15]:
from tqdm import tqdm
import numpy as np
import pandas as pd

import pyrosetta
from pyrosetta.rosetta.core.scoring import get_score_function
from pyrosetta.rosetta.protocols.analysis import InterfaceAnalyzerMover
from pyrosetta.toolbox import mutants

pyrosetta.init("-mute core basic protocols ")
# pyrosetta.init("-mute core basic protocols -corrections::beta_nov16")

pickle_6YLA = '/home/zengyun1/AbRFC/data/test/CMAB0/6YLA_scores.pickle'

data = pd.read_pickle(pickle_6YLA)


from SIN import compute_sin_scores
import json 
aif = json.load(open('aif_matrix.json','r'))
from pose_operations import MutateResidue


┌──────────────────────────────────────────────────────────────────────────────┐
│                                 PyRosetta-4                                  │
│              Created in JHU by Sergey Lyskov and PyRosetta Team              │
│              (C) Copyright Rosetta Commons Member Institutions               │
│                                                                              │
│ NOTE: USE OF PyRosetta FOR COMMERCIAL PURPOSES REQUIRE PURCHASE OF A LICENSE │
│         See LICENSE.PyRosetta.md or email license@uw.edu for details         │
└──────────────────────────────────────────────────────────────────────────────┘
PyRosetta-4 2024 [Rosetta PyRosetta4.Release.python38.ubuntu 2024.42+release.3366cf78a3df04339d1982e94531b77b098ddb99 2024-10-11T08:24:04] retrieved from: http://www.pyrosetta.org


In [17]:
def calculate_neighbors(residue_number, pose_residue_index, inter_energy):
    neighbors_1 = []
    neighbors_2 = []

    for resid in range(residue_number):
        inter_energy_i = np.array(list(inter_energy[pose_residue_index, :][resid]))[1:]
        inter_energy_i = np.abs(inter_energy_i)

        if np.max(inter_energy_i) > 0.15:
            neighbors_2.append(resid)

        elif np.max(inter_energy_i) > 0.05:
            neighbors_1.append(resid)
    
    return neighbors_1, neighbors_2

def get_interface_data(inter_resid_data, pose_residue_index, neighbors_1, neighbors_2):
    attributes = [
        'complexed_sasa', 'dG', 'dSASA', 'dSASA_fraction', 'dSASA_sc',
        'dhSASA', 'dhSASA_rel_by_charge', 'dhSASA_sc', 'interface_residues'
    ]
    interface_data = {}
    
    indices = {'_0': pose_residue_index, '_1': neighbors_1, '_2': neighbors_2}

    for attr in attributes:
        data = np.nan_to_num(np.array(getattr(inter_resid_data, attr)), nan=0.0)
        for suffix, idx in indices.items():
            interface_data[f'i{attr}{suffix}'] = data[idx].sum()

    return interface_data

def calculate_interface_properties(pose,interface_chain,pose_residue_index,neighbors_1,neighbors_2):

    interface_analyzer = InterfaceAnalyzerMover()
    interface_analyzer.set_interface(interface_chain)
    interface_analyzer.set_calc_dSASA(True)
    interface_analyzer.set_compute_interface_energy(True)
    interface_analyzer.set_compute_packstat(True)
    interface_analyzer.apply(pose)

    inter_resid_data = interface_analyzer.get_all_per_residue_data()
    interface_data = get_interface_data(inter_resid_data, pose_residue_index, neighbors_1, neighbors_2)


    idelta_unsat_hbonds = interface_analyzer.get_all_data().delta_unsat_hbonds
    iinterface_hbonds = interface_analyzer.get_all_data().interface_hbonds
    ipackstat = interface_analyzer.get_all_data().packstat
    isc_value = interface_analyzer.get_all_data().sc_value
    
    interface_data['idelta_unsat_hbonds'] = idelta_unsat_hbonds
    interface_data['iinterface_hbonds'] = iinterface_hbonds
    interface_data['ipackstat'] = ipackstat
    interface_data['isc_value'] = isc_value

    interface_resid = [i-1 for i in interface_analyzer.get_interface_set()]
    
    return interface_data,interface_resid

def calculate_residue_energy(pose, pose_residue_index, neighbors_1):

    resid_energy = pose.energies().residue_total_energies_array()

    attributes = [
        'total_score', 'fa_atr', 'fa_rep', 'fa_sol', 'fa_intra_rep',
        'fa_intra_sol_xover4', 'lk_ball_wtd', 'fa_elec', 'pro_close',
        'hbond_sr_bb', 'hbond_lr_bb', 'hbond_bb_sc', 'hbond_sc'
    ]

    energy_data = {}
    
    indices = {'_0': pose_residue_index, '_1': neighbors_1}

    for attr in attributes:
        data = resid_energy[attr]
        for suffix, idx in indices.items():
            energy_data[f'i{attr}{suffix}'] = data[idx].sum()

    return energy_data



In [18]:
def extract_features(pose, pose_residue_index, interface_chain):
    
    residue_number = pose.total_residue()
    pose_residue_index = pose_residue_index - 1

    # Initialize score function and interface analyzer
    sfxn = pyrosetta.create_score_function('ref2015')
    sfxn(pose)
    
    # Calculate inter energy and neighbors
    inter_energy = pose.energies().residue_pair_energies_array()

    neighbors_1, neighbors_2 = calculate_neighbors(residue_number, pose_residue_index, inter_energy)

    # Get interface features
    interface_data,interface_resid = calculate_interface_properties(pose,interface_chain,pose_residue_index,neighbors_1,neighbors_2)
    # Get residue energy features
    energy_data = calculate_residue_energy(pose, pose_residue_index, neighbors_1)

    # Combine all the features into a single pd.Series
    new_feature = pd.Series({
        'dE2': pose.energies().total_energy(),
        **interface_data,
        **energy_data
    })

    return new_feature,interface_resid

In [12]:
new_feature1_columns = ['dE2', 'idelta_unsat_hbonds', 'iinterface_hbonds', 'ipackstat', 'isc_value', 
 'icomplexed_sasa_0', 'icomplexed_sasa_1', 'icomplexed_sasa_2',
 'idG_0', 'idG_1', 'idG_2', 
 'idSASA_0', 'idSASA_1', 'idSASA_2',
 'idSASA_fraction_0', 'idSASA_fraction_1', 'idSASA_fraction_2',
 'idSASA_sc_0', 'idSASA_sc_1', 'idSASA_sc_2',
 'idhSASA_0', 'idhSASA_1', 'idhSASA_2',
 'idhSASA_rel_by_charge_0', 'idhSASA_rel_by_charge_1', 'idhSASA_rel_by_charge_2',
 'idhSASA_sc_0', 'idhSASA_sc_1', 'idhSASA_sc_2', 
 'iinterface_residues_0', 'iinterface_residues_1', 'iinterface_residues_2',
 'total_score_0', 'total_score_1',
 'fa_atr_0', 'fa_atr_1',
 'fa_rep_0', 'fa_rep_1',
 'fa_sol_0', 'fa_sol_1',
 'fa_intra_rep_0', 'fa_intra_rep_1',
 'fa_intra_sol_xover4_0', 'fa_intra_sol_xover4_1',
 'lk_ball_wtd_0', 'lk_ball_wtd_1',
 'fa_elec_0', 'fa_elec_1',
 'pro_close_0', 'pro_close_1',
 'hbond_sr_bb_0', 'hbond_sr_bb_1',
 'hbond_lr_bb_0', 'hbond_lr_bb_1',
 'hbond_bb_sc_0', 'hbond_bb_sc_1',
 'hbond_sc_0', 'hbond_sc_1']

In [None]:
total_feature=[]

for index, feature in tqdm(data.iterrows(), total=len(data)):

    chain_id = feature['chain']
    origin_aa = feature['refAA']
    target_aa = feature['mutAA']
    pack_radius = 5.0
    interface_chain = feature['sides'][1]+"_"+feature['sides'][0]
    pose = pyrosetta.pose_from_pdb('ref.pdb')
    if feature['res'][-1].isalpha():
        numeric_part = feature['res'][:-1]
        letter_part = feature['res'][-1]
        pose_residue_index = pose.pdb_info().pdb2pose(chain_id, res=int(numeric_part), icode=letter_part)
    else:
        pose_residue_index = pose.pdb_info().pdb2pose(chain_id, res=int(feature['res']))

    ref_pose = pyrosetta.pose_from_pdb('ref.pdb')
    MutateResidue(ref_pose, pose_residue_index, origin_aa)
    ref_pose.dump_pdb("ref/"+feature['label']+".pdb")
    ref_pose.dump_pdb("ref/"+feature['label']+".pdb")

    mut_pose = pyrosetta.pose_from_pdb('ref.pdb')
    mutants.mutate_residue(mut_pose, pose_residue_index, target_aa,pack_radius=pack_radius)
    MutateResidue(mut_pose, pose_residue_index, target_aa)
    mut_pose.dump_pdb("mut/"+feature['label']+".pdb")

    new_ref_feature,ref_interface_resid = extract_features(ref_pose,pose_residue_index,interface_chain)
    new_mut_feature,mut_interface_resid = extract_features(mut_pose,pose_residue_index,interface_chain)

    new_feature1 = new_ref_feature - new_mut_feature

    ref_SIN_matrix = compute_sin_scores(ref_pose)
    ref_AIF_mat = ref_SIN_matrix.copy()
    num_residues = ref_pose.total_residue()
    for i in range(num_residues):
        for j in range(num_residues):
            if ref_SIN_matrix[i][j] > 0:
                res1 = ref_pose.residue(i + 1)
                res2 = ref_pose.residue(j + 1)
                ref_AIF_mat[i][j] = aif[res1.name1()+res2.name1()]
    ref_SIN_scores = ref_SIN_matrix.sum(axis=1)
    max_score = np.max(SIN_scores)
    min_score = np.min(SIN_scores)
    ref_SIN_scores_normalized = np.round((SIN_scores - min_score) / (max_score - min_score), 6)
    ref_sin_if = sum([ref_SIN_scores_normalized[i] for i in ref_interface_resid])
    ref_sin_res	= ref_SIN_scores_normalized[pose_residue_index-1]
    ref_sin_norm = max_score - min_score
    ref_AIF_score = ref_AIF_mat[pose_residue_index-1][:].sum()

    mut_SIN_matrix = compute_sin_scores(mut_pose)
    mut_AIF_mat = mut_SIN_matrix.copy()
    num_residues = mut_pose.total_residue()
    for i in range(num_residues):
        for j in range(num_residues):
            if mut_SIN_matrix[i][j] > 0:
                res1 = mut_pose.residue(i + 1)
                res2 = mut_pose.residue(j + 1)
                mut_AIF_mat[i][j] = aif[res1.name1()+res2.name1()]
    SIN_scores = mut_SIN_matrix.sum(axis=1)
    max_score = np.max(SIN_scores)
    min_score = np.min(SIN_scores)
    mut_SIN_scores_normalized = np.round((SIN_scores - min_score) / (max_score - min_score), 6)
    mut_sin_if = sum([mut_SIN_scores_normalized[i] for i in mut_interface_resid])
    mut_sin_res	= mut_SIN_scores_normalized[pose_residue_index-1]
    mut_sin_norm = max_score - min_score
    mut_AIF_score = mut_AIF_mat[pose_residue_index-1][:].sum()

    new_feature2 = pd.Series({
    'pdb_ref' : feature['pdb_ref'],
    'chain' : feature['chain'],
    'res' : feature['res'],
    'refAA' : "ref/"+feature['label']+".pdb",
    'mutAA' : "mut/"+feature['label']+".pdb",
    'sides' : feature['sides'],
    'pdb_ref_repack' : feature['pdb_ref_repack'],
    'pdb_mut_repack' : feature['pdb_mut_repack'],
    'aif_score' : ref_AIF_score-mut_AIF_score,
    'sin_if' : ref_sin_if-mut_sin_if,
    'sin_res' : ref_sin_res-mut_sin_res,
    'sin_norm' : ref_sin_norm-mut_sin_norm,
    'yclf' : feature['yclf'],
    'yreg' : feature['yreg'],
    'label' : feature['label']})


    new_feature = pd.concat([new_feature1, new_feature2])
    total_feature.append(new_feature)
    
    
total_feature = pd.concat(total_feature, axis=1).transpose()
total_feature.to_pickle("6YLA_scores.pickle")


  1%|          | 9/954 [01:46<3:07:13, 11.89s/it]