In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import os
import sys
from scipy import stats
from scipy.stats import spearmanr
import seaborn as sns

In [6]:
import importlib
from terminator.models import load_model
importlib.reload(load_model)
from terminator.data import load_data
importlib.reload(load_data)

<module 'terminator.data.load_data' from '/data1/groups/keatinglab/COORDinator_shared/terminator/data/load_data.py'>

In [3]:
class model_args():
    def __init__(self, dataset, model_dir, dev='cpu', mpnn_dihedrals=False, subset = False, subset_list=[],
                sc_mask_rate=0, sc_mask=[], mask_neighbors=False, checkpoint='net_best_checkpoint.pt', use_sc=True,
                no_mask=False, sc_screen=False, sc_screen_range=[], data_only=False, verbose=False, use_esm=False,
                return_esm=False, zero_node_features=False, zero_edge_features=False, name_cluster=False, batch_size=1,
                per_protein_loss=False, chain_handle='', mask_interface=True, random_interface=False, esm=None, batch_converter=None, use_struct_predict=False, inter_cutoff=16,
                noise_level=0, bond_length_noise_level=0, bond_angle_noise_level=0, flex_type='', from_wds=False, pdb_list=[], add_gap=False, sep_complex=False,
                ener_data=None, ener_col='ener', adj_index=True, alphabetize_data=True, center_node_ablation=False, center_node=False):
        self.dataset = dataset
        self.model_dir = model_dir
        self.dev = dev
        self.subset = subset
        self.subset_list = subset_list
        self.sc_mask = sc_mask
        self.sc_mask_rate = sc_mask_rate
        self.base_sc_mask = 0.0
        self.mask_neighbors = mask_neighbors
        self.use_sc = use_sc
        self.mpnn_dihedrals = mpnn_dihedrals
        self.no_mask = no_mask
        self.sc_info = 'all'
        self.sc_screen = sc_screen
        self.sc_screen_range = sc_screen_range
        self.data_only = data_only
        self.verbose = verbose
        self.checkpoint = checkpoint
        self.use_esm = use_esm
        self.return_esm = return_esm
        self.zero_node_features = zero_node_features
        self.zero_edge_features = zero_edge_features
        self.name_cluster = name_cluster
        self.batch_size = batch_size
        self.per_protein_loss = per_protein_loss
        self.chain_handle = chain_handle
        self.mask_interface = mask_interface
        self.random_interface = random_interface
        self.esm = esm
        self.batch_converter = batch_converter
        self.use_struct_predict = use_struct_predict
        self.inter_cutoff = inter_cutoff
        self.noise_level = noise_level
        self.bond_length_noise_level = bond_length_noise_level
        self.bond_angle_noise_level = bond_angle_noise_level
        self.flex_type = flex_type
        self.from_wds = from_wds
        self.pdb_list = pdb_list
        self.add_gap = add_gap
        self.sep_complex = sep_complex
        self.ener_data = ener_data
        self.ener_col = ener_col
        self.adj_index = adj_index
        self.alphabetize_data = alphabetize_data
        self.center_node_ablation = center_node_ablation
        self.center_node= center_node


In [5]:
def get_model_and_data(args):
    
    if args.return_esm:
        model, dataloader, target_dataloader, binder_dataloader, sampler, esm, batch_converter, esm_options, run_hparams, model_hparams = load_model.load_model(args)
    else:
        esm = None
        batch_converter = None
        esm_options = None
        model, dataloader, target_dataloader, binder_dataloader, sampler, run_hparams, model_hparams = load_model.load_model(args)
        
    return model, dataloader, target_dataloader, binder_dataloader, sampler, esm, batch_converter, esm_options, run_hparams, model_hparams

In [2]:
def plot_dms(plot_data, sequence, title='COORDinator Protein Predictions', clabel='Predicted $\Delta$$\Delta$G$_{bound-unbound}$ (a.u.)', row_norm=False, all_norm=False):

    if all_norm:
        plot_data = (plot_data - np.nanmin(plot_data)) / np.nanmax(plot_data - np.nanmin(plot_data))
    elif row_norm:
        min_vals = np.nanmin(plot_data, axis=1, keepdims=True)
        max_vals = np.nanmax(plot_data, axis=1, keepdims=True)
        plot_data = (plot_data - min_vals) / (max_vals - min_vals)
    
    amino_acids = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 
                   'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
    
    # Define custom colors for the colormap
    blue = (0.0, 0.0, 1.0)    # RGB for blue
    gray90 = (0.9, 0.9, 0.9)  # RGB for gray90
    red = (1.0, 0.0, 0.0)     # RGB for red

    # Create a diverging colormap
    cmap = mcolors.LinearSegmentedColormap.from_list("Blue_Gray90_Red", [blue, gray90, red])

    # plt.figure(figsize=(24, 6))
    # sns.heatmap(plot_data.transpose(1,0), cmap='viridis', yticklabels=amino_acids, xticklabels=list(sequence[:plot_data.shape[0]]))
    # plt.xticks(rotation=0)
    # plt.xlabel('WT Residue')
    # plt.ylabel('Mut Residue')
    # plt.title('COORDinator Stat5B DMS Predictions ($\Delta$$\Delta$G$_{bound-unbound}$)')


    fig, ax = plt.subplots(figsize=(4, 4))
    sns.set(font_scale=0.6)
    sns.heatmap(
        plot_data, 
        cmap=cmap, 
        center=0,
        xticklabels=amino_acids, 
        yticklabels=list(sequence),
        cbar_kws={'shrink': 0.8, 'pad': 0.02, 'label': clabel}  # Colorbar options
    )


    for tick in ax.get_yticklabels():
        tick.set_rotation(0)
        tick.set_ha('left')
        tick.set_position((-0.03, tick.get_position()[1]))  # Move labels to the left

    # Customize x-axis
    plt.xticks(rotation=0)
    plt.xlabel('Mut Residue')
    plt.ylabel('WT Residue')

    # Add a title
    plt.title(title)

    # plt.savefig('/mnt/shared/fosterb/Jen_data/COORDinator_DMS_A.png', dpi=300, bbox_inches='tight')
    plt.show()

1

The history saving thread hit an unexpected error (OperationalError('database is locked')).History will not be written to the database.


In [None]:
def plot_scatter_torch(real_list, pred_list, title):
    plt.scatter(real_list.cpu().numpy(), pred_list.cpu().numpy(), s=1)
    plt.xlabel('Real E')
    plt.ylabel('Predicted E')
    
    # Calculate the best-fit line
    slope, intercept, r_value, p_value, std_err = stats.linregress(real_list.cpu().numpy(), pred_list.cpu().numpy())
    line = slope * real_list.cpu().numpy() + intercept
    plt.plot(real_list.cpu().numpy(), line, color='red', label=f'Best Fit Line (r={r_value:.3f})')
    plt.legend()
    plt.title(title)

In [None]:
def plot_scatter_list(real_list, pred_list, title='COORDinator predictions', xlabel='Real E', ylabel='Predited E', midline=False):
    real_list = np.array(real_list)
    pred_list = np.array(pred_list)
    plt.scatter(real_list, pred_list, s=1)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    
    # Calculate the best-fit line
    slope, intercept, r_value, p_value, std_err = stats.linregress(real_list, pred_list)
    line = slope * real_list + intercept
    plt.plot(real_list, line, color='red', label=f'Best Fit Line (r={r_value:.3f})')
    
    # Plot midline if needed
    if midline:
        plt.xlim([min([min(real_list), min(pred_list)]), max([max(real_list), max(pred_list)])])
        plt.ylim([min([min(real_list), min(pred_list)]), max([max(real_list), max(pred_list)])])
        midline_points = np.arange(min([min(real_list), min(pred_list)]), max([max(real_list), max(pred_list)]), 0.1)
        plt.plot(midline_points, midline_points, color='black')
    
    plt.legend()
    plt.title(title)

In [None]:
def save_preds(preds, mut_seqs, wt_seq, pdb):
    pdb_stats = {'pos': [], 'wildtype': [], 'mutant': [], 'mutation': [], 'pred_ener': [], 'protein': []}

    for pred, mut_seq in zip(preds, mut_seqs):
        pos_list = []
        wt_list = []
        mut_list = []
        mutation_list =[]
        for pos, (mc, wc) in enumerate(zip(mut_seq, wt_seq)):
            found_mut = False
            if mc != wc:
                pos = str(pos)
                wc = ints_to_seq_torch(wc.unsqueeze(0))
                mc = ints_to_seq_torch(mc.unsqueeze(0))
                pos_list.append(pos)
                wt_list.append(wc)
                mut_list.append(mc)
                mutation_list.append(wc + pos + mc)
                found_mut = True
            if not found_mut:
                continue
        pos = ";".join(pos_list)
        wt = ";".join(wt_list)
        mt = ";".join(mut_list)
        mutation = ";".join(mutation_list)
        
        
        pdb_stats['protein'].append(pdb)
        pdb_stats['pos'].append(pos)
        pdb_stats['wildtype'].append(wt)
        pdb_stats['mutant'].append(mt)
        pdb_stats['mutation'].append(mutation)
        pdb_stats['pred_ener'].append(pred.cpu().item())
        
    return pdb_stats