In [1]:
from perses.analysis.analysis import Analysis
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pymbar
%matplotlib inline
import os
import itertools
from tqdm import notebook as tqdm_notebook
import pandas as pd


from simtk.openmm import unit
from openmmtools.constants import kB
KT_KCALMOL = kB * 300 * unit.kelvin / unit.kilocalories_per_mole

In [None]:
def subtract_offset(forward_work, reverse_work):

    print("--> subtracting offset")
    
    forward_work_offset = []
    for cycle in forward_work:
        forward_work_offset.append(np.array([val - cycle[0] for val in cycle[1:]]))
    forward_work_offset = np.array(forward_work_offset)

    reverse_work_offset = []
    for cycle in reverse_work:
        reverse_work_offset.append(np.array([val - cycle[0] for val in cycle[1:]]))
    reverse_work_offset = np.array(reverse_work_offset)
    
    return forward_work_offset, reverse_work_offset


def analyse(forward_accumulated, reverse_accumulated):
    
    print("--> computing dg, ddg")
    dg, ddg = pymbar.bar.BAR(forward_accumulated, reverse_accumulated)
    
    return dg, ddg


def plot_works(forward_work_offset,
               reverse_work_offset,
               dg,
               ddg,
               phase,
               mutation,
               title,
               save=False,
               output_dir=None):
    
    CB_color_cycle = ['#377eb8', '#ff7f00', '#4daf4a',
                  '#f781bf', '#a65628', '#984ea3',
                  '#999999', '#e41a1c', '#dede00']
    
    # Plot work trajectories
    # TODO: automatically determine the x axis -> this is a bit of a hack at the moment
    print("--> plotting work trajs")
    
    for i, cycle in enumerate(forward_work_offset):
        
        x = [(j+1)*12.1e-4 for j in range(len(list(cycle)))]
        y = cycle
        if i==0:
            plt.plot(x, y, color=CB_color_cycle[0], label='forward')
        else:
            plt.plot(x, y, color=CB_color_cycle[0])
        
    for i, cycle in enumerate(reverse_work_offset):
        
        x = [(j+1)*12.1e-4 for j in range(len(list(cycle)))]
        y = -cycle
        if i==0:
            plt.plot(x, y, color=CB_color_cycle[1], label='reverse')
        else:
            plt.plot(x, y, color=CB_color_cycle[1])
        
    plt.xlabel("$t_{neq}$ (ns)")
    plt.ylabel("work (kT)")
    plt.title(f"{title} {phase}")
    plt.legend(loc='best')
    if save:
        if output_dir is not None:
            plt.savefig(os.path.join(output_dir, f"{mutation}_{phase}_work_traj.png"), dpi=500)
            print(f"--> saved to: {os.path.join(output_dir, f'{mutation}_{phase}_work_traj.png')}")
        else:
            print("--> No output_dir specified!")
    else:
        plt.show()
    plt.clf()
    
    # Plot work distributions
    print("--> plotting work distrib")
    
    accumulated_forward = [cycle[-1] for cycle in forward_work_offset]
    accumulated_reverse = [-cycle[-1] for cycle in reverse_work_offset]
    sns.distplot(accumulated_forward, color=CB_color_cycle[0], label='forward')
    sns.distplot(accumulated_reverse, color=CB_color_cycle[1], label='reverse')
    plt.axvline(dg)
    plt.axvline(dg + ddg, linestyle='dashed')
    plt.axvline(dg - ddg, linestyle='dashed')
    plt.xlabel("work (kT)")
    plt.ylabel("p(w)")
    plt.title(f"{title} {phase}")
    plt.legend(loc='best')
    if save:
        if output_dir is not None:
            plt.savefig(os.path.join(output_dir, f"{mutation}_{phase}_work_dist.png"), dpi=500)
            print(f"--> saved to: {os.path.join(output_dir, f'{mutation}_{phase}_work_dist.png')}")
        else:
            print("--> No output_dir specified!")
    else:
        plt.show()
    plt.clf()

In [None]:
# ntrk1 mutations
ntrk1_mutations = {
    'larotrectinib': ['G595R', 'G667C'],
    'selitrectinib': ['G595R', 'G667C'],
    'repotrectinib': ['G595R']
    }

In [None]:
tki = 'dasatinib'
pdb_code = '4xey'
sim_df_dasatinib = {}

for tki in ntrk1_mutations:

    for mutation in ntrk1_mutations[tki]:

        # Load and combine arrays
        forward_complex_arrays = []
        reverse_complex_arrays = []
        forward_apo_arrays = []
        reverse_apo_arrays = []

        for j in tqdm_notebook.tqdm(range(100)):

            #
            forward_complex_path = f'/data/chodera/glassw/kinoml/NTRK/run_neq_WithTraj/NTRK1/{mutation}/{pdb_code}_{tki}_{mutation}_complex_{j}_forward.npy'
            reverse_complex_path = f'/data/chodera/glassw/kinoml/NTRK/run_neq_WithTraj/NTRK1/{mutation}/{pdb_code}_{tki}_{mutation}_complex_{j}_reverse.npy'
            forward_apo_path = f'/data/chodera/glassw/kinoml/NTRK/run_neq_WithTraj/NTRK1/{mutation}/{pdb_code}_{tki}_{mutation}_apo_{j}_forward.npy'
            reverse_apo_path = f'/data/chodera/glassw/kinoml/NTRK/run_neq_WithTraj/NTRK1/{mutation}/{pdb_code}_{tki}_{mutation}_apo_{j}_reverse.npy'

            if os.path.exists(forward_complex_path):
                with open(forward_complex_path, 'rb') as f:
                    forward_complex_arrays.append(np.load(f))

            if os.path.exists(reverse_complex_path):
                with open(reverse_complex_path, 'rb') as f:
                    reverse_complex_arrays.append(np.load(f))

            if os.path.exists(forward_apo_path):
                with open(forward_apo_path, 'rb') as f:
                    forward_apo_arrays.append(np.load(f))

            if os.path.exists(reverse_apo_path):
                with open(reverse_apo_path, 'rb') as f:
                    reverse_apo_arrays.append(np.load(f))

        if forward_complex_arrays and reverse_complex_arrays and forward_apo_arrays and reverse_apo_arrays:

            forward_complex_combined = np.concatenate(forward_complex_arrays)
            forward_complex_accumulated = np.array([cycle[-1] - cycle[0] for cycle in forward_complex_combined]) # compute this separately bc the last value of the subsampled array is diff than the actual last sample
            forward_complex_combined = np.array([cycle for cycle in forward_complex_combined])
            print(forward_complex_combined.shape)

            reverse_complex_combined = np.concatenate(reverse_complex_arrays)
            reverse_complex_accumulated = np.array([cycle[-1] - cycle[0] for cycle in reverse_complex_combined])
            reverse_complex_combined = np.array([cycle for cycle in reverse_complex_combined])

            forward_apo_combined = np.concatenate(forward_apo_arrays)
            forward_apo_accumulated = np.array([cycle[-1] - cycle[0] for cycle in forward_apo_combined])
            forward_apo_combined = np.array([cycle for cycle in forward_apo_combined])
            print(forward_apo_combined.shape)

            reverse_apo_combined = np.concatenate(reverse_apo_arrays)
            reverse_apo_accumulated = np.array([cycle[-1] - cycle[0] for cycle in reverse_apo_combined])
            reverse_apo_combined = np.array([cycle for cycle in reverse_apo_combined])


            # Analyse

            ## complex
            forward_complex_work_offset, reverse_complex_work_offset = subtract_offset(forward_complex_combined,
                                                                                    reverse_complex_combined)

            complex_dg, complex_ddg = analyse(forward_complex_accumulated,
                                            reverse_complex_accumulated)

            ## apo
            forward_apo_work_offset, reverse_apo_work_offset = subtract_offset(forward_apo_combined,
                                                                            reverse_apo_combined)

            apo_dg, apo_ddg = analyse(forward_apo_accumulated, reverse_apo_accumulated)

            ## plot the work trajectories and distibutions
            complex_plot = plot_works(forward_complex_work_offset,
                                    reverse_complex_work_offset,
                                    complex_dg,
                                    complex_ddg,
                                    phase='complex',
                                    mutation=mutation,
                                    title=f'{pdb_code.upper()}-{tki} {mutation}',
                                    save=True,
                                    output_dir=f'/lila/home/glassw/GITHUB/study-abl-resistance/notebooks/perses_analysis_output/{tki}')

            apo_plot = plot_works(forward_apo_work_offset,
                                reverse_apo_work_offset,
                                apo_dg,
                                apo_ddg,
                                phase='apo',
                                mutation=mutation,
                                title=f'{pdb_code.upper()}-{tki} {mutation}',
                                save=True,
                                output_dir=f'/lila/home/glassw/GITHUB/study-abl-resistance/notebooks/perses_analysis_output/{tki}')

            ## Get binding dg and ddg
            binding_dg = complex_dg - apo_dg
            binding_ddg = (apo_ddg**2 + complex_ddg**2)**0.5
            sim_df_dasatinib[mutation] = [binding_dg, binding_ddg]
            print(f"--> complex_dg: {complex_dg}")
            print(f"--> apo dg: {apo_dg}")

        else:
            print(f"--> dir {mutation} has at least one phase without data" )