# Correlations between multiple targets for symbols

In [None]:
from itertools import combinations, chain

from math import ceil

from collections import OrderedDict as odict

import textwrap

from os import environ

import seaborn as sns

### Configuration

In [None]:
# Notebook parameters...

species = environ.get('species') or 'Human'

species

In [None]:
# pActivity column to use...

pactivity_col = 'pchembl_value'

In [None]:
# Minimum number of compounds in common for a pair of targets...

pair_threshold = 3

In [None]:
# ChEMBL connection...

engine = create_engine(open('database.txt').read().strip())

In [None]:
# Seaborn config...

sns.set()

### Load targets

In [None]:
# Load the ChEMBL target info...

targets = pd.read_pickle('chembl_targets.pkl').query("exclude == 0")

targets.shape

In [None]:
HTML(targets.query("species == @species").head().to_html())

### Get pairs of targets for gene symbols

Pairs of targets for each symbol having more than minimm number of distinct parent compounds with pActivities associated with them.

In [None]:
# %%javascript

# IPython.notebook.kernel.execute("notebook_name = " + "'" + window.document.getElementById("notebook_name").innerHTML + "'");

notebook_name = 'Correlation_Targets' # JavaScript magic above doesn't work with runipy

In [None]:
cache_file = notebook_name + '_' + species + '.pkl'

cache_file

In [None]:
# Retrieve activity data for a list of targets...

def get_data_for_targets(target_chemblids):

    sql = """
    select
        *   
    from
      tt_curve_data_v1 a
    where
      a.target_chemblid in ({})
    """.format(', '.join(":{}".format(n+1) for n in range(len(target_chemblids))))
    
    return pd.read_sql_query(sql, engine, params=target_chemblids)

In [None]:
if os.path.exists(cache_file): os.remove(cache_file)

In [None]:
# %%cache $cache_file data_by_symbol means_by_symbol pairs_by_symbol

data_by_symbol, means_by_symbol, pairs_by_symbol = odict(), odict(), odict()
    
symbols = targets.query("species == @species")[['symbol', 'chembl_id']].groupby('symbol').count().query('chembl_id > 1').reset_index()['symbol'].values.tolist()

for symbol in symbols:
    
    logging.info("Starting '{}'...".format(symbol))

    # Get list of ChEMBL targets for the gene symbol...
    
    symbol_targets = targets.query("(species == @species) & (symbol == @symbol)")

    target_chemblids = symbol_targets.chembl_id.values.tolist()

    # Get activity data for these ChEMBL targets, calculate means and unstack so each target is represented by a single column...

    data = get_data_for_targets(target_chemblids)
    
    if not data.shape[0]:
        
        logging.warn("> No data for '{}'.".format(symbol))
        
        continue

    means = data[['target_chemblid', 'parent_cmpd_chemblid', pactivity_col]].groupby(['target_chemblid', 'parent_cmpd_chemblid']).mean().unstack(level=0)

    means.columns = means.columns.droplevel()

    if means.shape[1] == 1:

        logging.warn("> Only one target with data for '{}'.".format(symbol))
        
        continue
        
    # Get all pairs of targets with a number of compounds in common greater than some threshold...
    
    try:
    
        target_pairs = pd.DataFrame(
            ((x, means[x].count(), y, means[y].count(), n, d.corr().iloc[0, 1]) for x, y, n, d in
                ((x, y, d.shape[0], d) for x, y, d in
                    ((x, y, means[[x, y]].dropna(how='any')) for x, y in
                        combinations(means.columns.values, 2)
                    )
                )
            if n > pair_threshold), 
            columns=['target_1', 'n_1', 'target_2', 'n_2', 'n', 'r']).sort(['r', 'n'], ascending=False).reset_index(drop=True)

    except ValueError as e:

        if not e.args[0].startswith('Shape of passed values is (0, 0)'): raise

        logging.warn("> No pairs of target with sufficient compounds in common for '{}'.".format(symbol))
        
        continue

    # Add full names of targets (NB renaming and reordering targets)...

    pref_names = symbol_targets[['chembl_id', 'pref_name']].reset_index(drop=True).set_index('chembl_id')

    target_pairs = target_pairs.merge(pref_names, left_on='target_1', right_index=True).merge(pref_names, left_on='target_2', right_index=True).sort('r', ascending=False)
    
    target_pairs.columns = [x.replace('_x', '_1').replace('_y', '_2') for x in target_pairs.columns.values]
    
    target_pairs = target_pairs[['target_1', 'pref_name_1', 'n_1', 'target_2', 'pref_name_2', 'n_2', 'n', 'r']]
    
    # Done...
    
    data_by_symbol[symbol], means_by_symbol[symbol], pairs_by_symbol[symbol] = data, means, target_pairs
    
    logging.info("...OK.")
    
logging.info("Finished.")

In [None]:
# Symbols having multiple targets with sufficient data for comparison...

[(x, y.shape[0]) for x, y in pairs_by_symbol.items()]

In [None]:
# Show data for each symbol...

# HTML('\n'.join("<h3>{}</h3>\n{}".format(x, data_by_symbol[x].to_html()) for x in sorted(data_by_symbol.keys())))

In [None]:
# Show pair summary info for each symbol...

HTML('\n'.join("<h3>{}</h3>\n{}".format(x, pairs_by_symbol[x].to_html()) for x in sorted(pairs_by_symbol.keys())))

### Plot pActivity data for pairs of assays

In [None]:
# Function to plot pChEMBL values for all pairs of ChEMBL targets for a symbol...

ncol, size = 4, 12

min_xc50, max_xc50 = 3.0, 10.0

def plots_for_symbol(symbol):
    
    target_pairs, target_means = pairs_by_symbol[symbol], means_by_symbol[symbol]

    n_pairs = target_pairs.shape[0]

    nrow = int(ceil(n_pairs / ncol))

    fig, axes = plt.subplots(nrow, ncol, figsize=(size*ncol, size*nrow))
    
    fig.suptitle(symbol)
    
    if nrow > 1: axes = list(chain.from_iterable(axes))

    for ax in axes[n_pairs:]: ax.axis('off')
        
    for i, (_, rec) in enumerate(target_pairs.iterrows()):
        
        axis = axes[i]

        pair_means = target_means[[rec.target_1, rec.target_2]].dropna(how='any')

        axis.scatter(pair_means[rec.target_1], pair_means[rec.target_2])
        
        axis.set_xlim(min_xc50, max_xc50)
        axis.set_ylim(min_xc50, max_xc50)
        axis.set_aspect(1)
        
        axis.set_title("{}  ({}/{})  r = {:.2f}  [n = {}]".format(symbol, i+1, n_pairs, rec.r, rec.n))
        axis.set_xlabel("{}  [n = {}]\n{}".format(rec.target_1, rec.n_1, rec.pref_name_1))
        axis.set_ylabel("{}  [n = {}]\n{}".format(rec.target_2, rec.n_2, rec.pref_name_2))

        axis.plot((min_xc50, max_xc50), (min_xc50, max_xc50), color='r', linestyle='-', linewidth=1)
        axis.plot((5, 5), (min_xc50, max_xc50), color='m', linestyle='--', linewidth=2)
        axis.plot((min_xc50, max_xc50), (5, 5), color='m', linestyle='--', linewidth=2)

In [None]:
# Generate plots for all symbols...

for symbol in sorted(data_by_symbol.keys()):
                         
    plots_for_symbol(symbol)