In [None]:
odds_ratio = '../out/associations/odds_ratio.tsv'
virulence = '../out/virulence_genes.tsv'
filtered = '../out/associations/summary_cont_lmm_kmer.tsv'
names = '../out/associations/associated_ogs.final.tsv'
phenotypes = '../data/phenotypes/phenotypes.tsv'
tree = '../out/gubbins/tree.nwk'
rtab = '../out/roary/gene_presence_absence.Rtab'

In [None]:
# plotting imports
%matplotlib inline

import matplotlib.pyplot as plt
import seaborn as sns
from adjustText import adjust_text

sns.set_style('white')

plt.rc('font', size=11)
plt.rc('xtick', labelsize=11)
plt.rc('ytick', labelsize=11)
plt.rc('axes', labelsize=12, titlesize=12)
plt.rc('legend', fontsize=11)

In [None]:
import numpy as np
import pandas as pd
from Bio import Phylo

In [None]:
def plot_condition(tree,
                   phenotypes,
                   mdist,
                   r1, r2,
                   mdist_offset=0):
    p_vector = phenotypes.loc[[x.name
                              for x in tree.get_terminals()]]
    
    grid_length = 13 + 1 + 1 + r1.shape[0] + r2.shape[0]

    grid_height = len(tree.get_terminals())
    
    fig = plt.figure(figsize=(grid_length/2.5, 35))

    cmap = plt.cm.Reds
    cmap.set_bad('grey', 0.2)
    cmap.set_under('red', 1)

    ax3=plt.subplot2grid((grid_height, grid_length),
                         (0, 15),
                         colspan=r1.shape[0],
                         rowspan=len(tree.get_terminals()))
    a=ax3.imshow(r1.loc[r1.T.sum().sort_values(ascending=False).index,
                       [x.name
                        for x in tree.get_terminals()]].T,
                 cmap=cmap,
                 vmin=0, vmax=1,
                 aspect='auto',
                 interpolation='none',
                )
    ax3.set_yticks([])
    ax3.set_xticks([])
    for i  in range(r1.shape[0]):
        ax3.axvline(i+0.5,
                    color='grey',
                    alpha=0.77)
    ax3.set_xticks([x for x in range(r1.shape[0])])
    ax3.set_xticklabels(r1.T.sum().sort_values(ascending=False).index,
                        rotation=90,
                        size=12)
    ax3.set_xlabel('associated OGs',
                   size=12)
    
    cmap = plt.cm.Blues
    cmap.set_bad('grey', 0.2)
    cmap.set_under('red', 1)
    
    ax2=plt.subplot2grid((grid_height, grid_length),
                         (0, 14 + r1.shape[0] + 1),
                         colspan=r2.shape[0],
                         rowspan=len(tree.get_terminals()))
    a=ax2.imshow(r2[[x.name
                     for x in tree.get_terminals()]].T,
                 cmap=cmap,
                 vmin=0, vmax=1,
                 aspect='auto',
                 interpolation='none',
                )
    ax2.set_yticks([])
    ax2.set_xticks([])
    for i in range(r2.shape[0]):
        ax2.axvline(i+0.5,
                    color='grey',
                    alpha=0.77)
    ax2.set_xticks([x for x in range(r2.shape[0])])
    ax2.set_xticklabels(r2.index,
                        rotation=90,
                        size=12)
    ax2.set_xlabel('other virulence OGs',
                   size=12)
    
    cmap = plt.cm.Reds
    cmap.set_bad('grey', 0.2)

    ax1=plt.subplot2grid((grid_height, grid_length),
                         (0, 13),
                         colspan=1,
                         rowspan=len(tree.get_terminals()))
    a=ax1.imshow([[x] for x in p_vector],
                 cmap=cmap,
                 vmin=0, vmax=p_vector.max(),
                 aspect='auto',
                 interpolation='none',
                )
    ax1.set_yticks([])
    ax1.set_xticks([0])
    ax1.set_xticklabels(['Phenotype'],
                        rotation=90,
                        size=12)

    ax=plt.subplot2grid((grid_height, grid_length),
                        (0, 0),
                        colspan=10,
                        rowspan=len(tree.get_terminals()))
    
    fig.subplots_adjust(wspace=0, hspace=0)
    
    labels = {}
    for x, i in zip(tree.get_terminals(),
                    range(len(tree.get_terminals()))):
        labels[x.name] = i
    
    color_strains = [x.name for x in tree.get_terminals()]
    
    def _label(x, strains):
        if x.name in strains:
            return x.name
        
    plt.rc('font', size=10)
    suptitle = ''
    Phylo.draw(tree, axes=ax, 
               show_confidence=False,
               label_func=lambda x: _label(x, color_strains),
               xticks=([],),
               yticks=([],),
               ylabel=('',), suptitle=(suptitle,),
               xlim=(-0.01, mdist+0.01+mdist_offset),
               axis=('off',),
               do_show=False,)

In [None]:
tree = Phylo.read(tree,
                  'newick')
tree.ladderize()
mdist = max([tree.distance(tree.root, x) for x in tree.get_terminals()])

In [None]:
k = pd.read_table(phenotypes,
                  index_col=0)

In [None]:
m = pd.read_table(odds_ratio,
                  index_col=0)
m['lrt-pvalue'] = [float(x)
                   if x != 'NAN'
                   else np.nan
                   for x in m['lrt-pvalue']]
m = m.dropna()

In [None]:
f = pd.read_table(filtered,
                  index_col=0)
f = f[f['specific_hits'] > 10]

In [None]:
v = pd.read_table(virulence)

In [None]:
names = pd.read_table(names, index_col=1)['preferred_og_name'].to_dict()

In [None]:
vnames = v.set_index('og')['gene'].to_dict()

In [None]:
plt.figure(figsize=(4, 4))

plt.plot(-np.log10(m['lrt-pvalue']),
         m['odds-ratio'],
         'ko',
         alpha=0.1,
         label='_')

plt.plot(-np.log10(m.loc[m.index.intersection(f.index)]['lrt-pvalue']),
         m.loc[m.index.intersection(f.index)]['odds-ratio'],
         'ro',
         label='associated OGs')
plt.plot(-np.log10(m.loc[m.index.intersection(v['og'])]['lrt-pvalue']),
         m.loc[m.index.intersection(v['og'])]['odds-ratio'],
         'bo',
         label='other virulence OGs')

plt.legend(loc='center left',
           bbox_to_anchor=(1, 0.5),
           frameon=True)

plt.xlabel('OG association $-log_{10}(pvalue)$')
plt.ylabel('OG odds ratio');

In [None]:
plt.figure(figsize=(7, 7))

plt.plot(-np.log10(m['lrt-pvalue']),
         m['odds-ratio'],
         'ko',
         alpha=0.03,
         label='_')

plt.plot(-np.log10(m.loc[m.index.intersection(f.index)]['lrt-pvalue']),
         m.loc[m.index.intersection(f.index)]['odds-ratio'],
         'ro',
         label='associated OGs')
text1 = [plt.text(-np.log10(x), y, names.get(t, t),
                  ha='center', va='center')
         for t, (y, x) in m.loc[m.index.intersection(f.index)].iterrows()]
adjust_text(text1,
            arrowprops=dict(arrowstyle='->', color='k'),
            force_points=5)

plt.plot(-np.log10(m.loc[m.index.intersection(v['og'])]['lrt-pvalue']),
         m.loc[m.index.intersection(v['og'])]['odds-ratio'],
         'bo',
         label='other virulence OGs')
text2 = [plt.text(-np.log10(x), y, vnames.get(t, t),
                  ha='center', va='center')
         for t, (y, x) in m.loc[m.index.intersection(v['og'])].iterrows()]
adjust_text(text2,
            arrowprops=dict(arrowstyle='->', color='k'),
            force_points=15)

plt.legend(loc='center left',
           bbox_to_anchor=(1, 0.5),
           frameon=True)

plt.xlabel('OG association $-log_{10}(pvalue)$')
plt.ylabel('OG odds ratio');

In [None]:
r = pd.read_table(rtab, index_col=0)
r1 = r.loc[f.index].copy(deep=True)
r1.index = [names.get(x, vnames.get(x, x))
           for x in r1.index]
r2 = r.loc[v['og']].copy(deep=True)
r2.index = [names.get(x, vnames.get(x, x))
           for x in r2.index]
r2 = r2.sort_index()

In [None]:
plot_condition(tree,
               k['killed'],
               mdist,
               r1, r2,
               mdist_offset=0)