In [1]:
from perses.analysis.analysis import Analysis
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pymbar
%matplotlib inline
import os
import itertools
from tqdm import notebook as tqdm_notebook
import pandas as pd


from simtk.openmm import unit
from openmmtools.constants import kB
KT_KCALMOL = kB * 300 * unit.kelvin / unit.kilocalories_per_mole

In [2]:
def subtract_offset(forward_work, reverse_work):

    print("--> subtracting offset")
    
    forward_work_offset = []
    for cycle in forward_work:
        forward_work_offset.append(np.array([val - cycle[0] for val in cycle[1:]]))
    forward_work_offset = np.array(forward_work_offset)

    reverse_work_offset = []
    for cycle in reverse_work:
        reverse_work_offset.append(np.array([val - cycle[0] for val in cycle[1:]]))
    reverse_work_offset = np.array(reverse_work_offset)
    
    return forward_work_offset, reverse_work_offset


def analyse(forward_accumulated, reverse_accumulated):
    
    print("--> computing dg, ddg")
    dg, ddg = pymbar.bar.BAR(forward_accumulated, reverse_accumulated)
    
    return dg, ddg


def plot_works(forward_work_offset,
               reverse_work_offset,
               dg,
               ddg,
               phase,
               mutation,
               title,
               save=False,
               output_dir=None):
    
    CB_color_cycle = ['#377eb8', '#ff7f00', '#4daf4a',
                  '#f781bf', '#a65628', '#984ea3',
                  '#999999', '#e41a1c', '#dede00']
    
    # Plot work trajectories
    # TODO: automatically determine the x axis -> this is a bit of a hack at the moment
    print("--> plotting work trajs")
    
    for i, cycle in enumerate(forward_work_offset):
        
        x = [(j+1)*12.1e-4 for j in range(len(list(cycle)))]
        y = cycle
        if i==0:
            plt.plot(x, y, color=CB_color_cycle[0], label='forward')
        else:
            plt.plot(x, y, color=CB_color_cycle[0])
        
    for i, cycle in enumerate(reverse_work_offset):
        
        x = [(j+1)*12.1e-4 for j in range(len(list(cycle)))]
        y = -cycle
        if i==0:
            plt.plot(x, y, color=CB_color_cycle[1], label='reverse')
        else:
            plt.plot(x, y, color=CB_color_cycle[1])
        
    plt.xlabel("$t_{neq}$ (ns)")
    plt.ylabel("work (kT)")
    plt.title(f"{title} {phase}")
    plt.legend(loc='best')
    if save:
        if output_dir is not None:
            plt.savefig(os.path.join(output_dir, f"{mutation}_{phase}_work_traj.png"), dpi=500)
            print(f"--> saved to: {os.path.join(output_dir, f'{mutation}_{phase}_work_traj.png')}")
        else:
            print("--> No output_dir specified!")
    else:
        plt.show()
    plt.clf()
    
    # Plot work distributions
    print("--> plotting work distrib")
    
    accumulated_forward = [cycle[-1] for cycle in forward_work_offset]
    accumulated_reverse = [-cycle[-1] for cycle in reverse_work_offset]
    sns.distplot(accumulated_forward, color=CB_color_cycle[0], label='forward')
    sns.distplot(accumulated_reverse, color=CB_color_cycle[1], label='reverse')
    plt.axvline(dg)
    plt.axvline(dg + ddg, linestyle='dashed')
    plt.axvline(dg - ddg, linestyle='dashed')
    plt.xlabel("work (kT)")
    plt.ylabel("p(w)")
    plt.title(f"{title} {phase}")
    plt.legend(loc='best')
    if save:
        if output_dir is not None:
            plt.savefig(os.path.join(output_dir, f"{mutation}_{phase}_work_dist.png"), dpi=500)
            print(f"--> saved to: {os.path.join(output_dir, f'{mutation}_{phase}_work_dist.png')}")
        else:
            print("--> No output_dir specified!")
    else:
        plt.show()
    plt.clf()

In [3]:
# ntrk1 mutations
ntrk1_mutations = {
    'larotrectinib': ['G613V', 'R780Q'],
    'entrectinib': ['G613V', 'R780Q'],
    }

In [4]:
base_data_path = '/data/chodera/glassw/miame/bayer_mutations_2021_06_17/1_run_neq/ntrk1/'
base_output_dir = '/home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/'
ntrk = 'NTRK1'

ntrk1_df = {
    'larotrectinib': {},
    'entrectinib': {},
    }

for tki in ntrk1_mutations:

    for mutation in ntrk1_mutations[tki]:

        # Load and combine arrays
        forward_complex_arrays = []
        reverse_complex_arrays = []
        forward_apo_arrays = []
        reverse_apo_arrays = []

        for j in tqdm_notebook.tqdm(range(100)):

            forward_complex_path = f'{base_data_path}{mutation}/NTRK1_{tki}_{mutation}_complex_{j}_forward.npy'
            reverse_complex_path = f'{base_data_path}{mutation}/NTRK1_{tki}_{mutation}_complex_{j}_reverse.npy'
            forward_apo_path = f'{base_data_path}{mutation}/NTRK1_{tki}_{mutation}_apo_{j}_forward.npy'
            reverse_apo_path = f'{base_data_path}{mutation}/NTRK1_{tki}_{mutation}_apo_{j}_reverse.npy'

            if os.path.exists(forward_complex_path):
                with open(forward_complex_path, 'rb') as f:
                    forward_complex_arrays.append(np.load(f))

            if os.path.exists(reverse_complex_path):
                with open(reverse_complex_path, 'rb') as f:
                    reverse_complex_arrays.append(np.load(f))

            if os.path.exists(forward_apo_path):
                with open(forward_apo_path, 'rb') as f:
                    forward_apo_arrays.append(np.load(f))

            if os.path.exists(reverse_apo_path):
                with open(reverse_apo_path, 'rb') as f:
                    reverse_apo_arrays.append(np.load(f))

        if forward_complex_arrays and reverse_complex_arrays and forward_apo_arrays and reverse_apo_arrays:

            forward_complex_combined = np.concatenate(forward_complex_arrays)
            forward_complex_accumulated = np.array([cycle[-1] - cycle[0] for cycle in forward_complex_combined]) # compute this separately bc the last value of the subsampled array is diff than the actual last sample
            forward_complex_combined = np.array([cycle for cycle in forward_complex_combined])
            print(forward_complex_combined.shape)

            reverse_complex_combined = np.concatenate(reverse_complex_arrays)
            reverse_complex_accumulated = np.array([cycle[-1] - cycle[0] for cycle in reverse_complex_combined])
            reverse_complex_combined = np.array([cycle for cycle in reverse_complex_combined])

            forward_apo_combined = np.concatenate(forward_apo_arrays)
            forward_apo_accumulated = np.array([cycle[-1] - cycle[0] for cycle in forward_apo_combined])
            forward_apo_combined = np.array([cycle for cycle in forward_apo_combined])
            print(forward_apo_combined.shape)

            reverse_apo_combined = np.concatenate(reverse_apo_arrays)
            reverse_apo_accumulated = np.array([cycle[-1] - cycle[0] for cycle in reverse_apo_combined])
            reverse_apo_combined = np.array([cycle for cycle in reverse_apo_combined])


            # Analyse

            ## complex
            forward_complex_work_offset, reverse_complex_work_offset = subtract_offset(forward_complex_combined,
                                                                                    reverse_complex_combined)

            complex_dg, complex_ddg = analyse(forward_complex_accumulated,
                                            reverse_complex_accumulated)

            ## apo
            forward_apo_work_offset, reverse_apo_work_offset = subtract_offset(forward_apo_combined,
                                                                            reverse_apo_combined)

            apo_dg, apo_ddg = analyse(forward_apo_accumulated, reverse_apo_accumulated)

            ## make the output directories
            if not os.path.exists(f'{base_output_dir}{ntrk}/{tki}'):
                os.makedirs(f'{base_output_dir}{ntrk}/{tki}')

            ## plot the work trajectories and distibutions
            complex_plot = plot_works(forward_complex_work_offset,
                                    reverse_complex_work_offset,
                                    complex_dg,
                                    complex_ddg,
                                    phase='complex',
                                    mutation=mutation,
                                    title=f'{ntrk.upper()}-{tki} {mutation}',
                                    save=True,
                                    output_dir=f'{base_output_dir}{ntrk}/{tki}')

            apo_plot = plot_works(forward_apo_work_offset,
                                reverse_apo_work_offset,
                                apo_dg,
                                apo_ddg,
                                phase='apo',
                                mutation=mutation,
                                title=f'{ntrk.upper()}-{tki} {mutation}',
                                save=True,
                                output_dir=f'{base_output_dir}{ntrk}/{tki}')

            ## Get binding dg and ddg
            binding_dg = complex_dg - apo_dg
            binding_ddg = (apo_ddg**2 + complex_ddg**2)**0.5
            ntrk1_df[tki][mutation] = [binding_dg, binding_ddg]
            print(f"--> complex_dg: {complex_dg}")
            print(f"--> apo dg: {apo_dg}")

        else:
            print(f"--> dir {mutation} has at least one phase without data" )

HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))


(100, 1251)
(100, 1251)
--> subtracting offset
--> computing dg, ddg
--> subtracting offset
--> computing dg, ddg
--> plotting work trajs
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/larotrectinib/G613V_complex_work_traj.png
--> plotting work distrib
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/larotrectinib/G613V_complex_work_dist.png
--> plotting work trajs
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/larotrectinib/G613V_apo_work_traj.png
--> plotting work distrib
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/larotrectinib/G613V_apo_work_dist.png
--> complex_dg: 3.968279576335074
--> apo dg: 3.988159951607645


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))


(100, 1251)
(100, 1251)
--> subtracting offset
--> computing dg, ddg
--> subtracting offset
--> computing dg, ddg
--> plotting work trajs
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/larotrectinib/R780Q_complex_work_traj.png
--> plotting work distrib
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/larotrectinib/R780Q_complex_work_dist.png
--> plotting work trajs
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/larotrectinib/R780Q_apo_work_traj.png
--> plotting work distrib
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/larotrectinib/R780Q_apo_work_dist.png
--> complex_dg: 181.99052192691738
--> apo dg: 182.55722932595708


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))


(100, 1251)
(97, 1251)
--> subtracting offset
--> computing dg, ddg
--> subtracting offset
--> computing dg, ddg
--> plotting work trajs
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/entrectinib/G613V_complex_work_traj.png
--> plotting work distrib
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/entrectinib/G613V_complex_work_dist.png
--> plotting work trajs
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/entrectinib/G613V_apo_work_traj.png
--> plotting work distrib
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/entrectinib/G613V_apo_work_dist.png
--> complex_dg: -3.447700614218657
--> apo dg: -1.9070631676187133


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))


(100, 1251)
(100, 1251)
--> subtracting offset
--> computing dg, ddg
--> subtracting offset
--> computing dg, ddg
--> plotting work trajs
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/entrectinib/R780Q_complex_work_traj.png
--> plotting work distrib
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/entrectinib/R780Q_complex_work_dist.png
--> plotting work trajs
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/entrectinib/R780Q_apo_work_traj.png
--> plotting work distrib
--> saved to: /home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/entrectinib/R780Q_apo_work_dist.png
--> complex_dg: 182.66244970364653
--> apo dg: 182.95941829989928


<Figure size 432x288 with 0 Axes>

In [5]:
output_path = '/home/glassw/GITHUB/study-ntrk-resistance/notebooks/free_energy_calculations/ntrk1_G613V_R780Q_analysis/NTRK1/'
sim_df = pd.DataFrame(ntrk1_df['entrectinib']).T




In [7]:
ntrk1_df

{'larotrectinib': {'G613V': [-0.019880375272570916, 0.3327860875840932],
  'R780Q': [-0.5667073990397, 0.33242110481102394]},
 'entrectinib': {'G613V': [-1.5406374465999437, 0.31311028941591434],
  'R780Q': [-0.2969685962527535, 0.236678182419231]}}

In [8]:
sim_ntrk1_lar = pd.DataFrame(ntrk1_df['larotrectinib']).T * KT_KCALMOL
sim_ntrk1_ent = pd.DataFrame(ntrk1_df['entrectinib']).T * KT_KCALMOL

sim_ntrk1_lar.columns = ["DDG (kcal / mol)", "dDDG (kcal / mol)"]
sim_ntrk1_ent.columns = ["DDG (kcal / mol)", "dDDG (kcal / mol)"]

In [9]:
sim_ntrk1_lar

Unnamed: 0,DDG (kcal / mol),dDDG (kcal / mol)
G613V,-0.011852,0.198394
R780Q,-0.337849,0.198177


In [10]:
sim_ntrk1_ent

Unnamed: 0,DDG (kcal / mol),dDDG (kcal / mol)
G613V,-0.918469,0.186664
R780Q,-0.177041,0.141099


In [47]:
sim_ntrk1_lar.to_csv(f'{output_path}ntrk1_laro_DDGs.csv')
sim_ntrk1_ent.to_csv(f'{output_path}ntrk1_ent_DDGs.csv')