In [None]:
odds_ratio = '../out/associations/odds_ratio.tsv'
virulence = '../out/virulence_genes.tsv'
filtered = '../out/associations/filtered_cont_lmm_rtab.tsv'
tnames = '../out/associations/associated_ogs.final.tsv'
phenotypes = '../data/phenotypes/phenotypes.tsv'
tree = '../out/gubbins/tree.nwk'
rtab = '../out/roary/gene_presence_absence.Rtab'
spangenome = '../out/roary/sampled_pangenome.faa'
mappings = '../out/associations/kmer_mappings/'
fold_changes = '../out/rna/fold_changes.tsv'

In [None]:
# plotting imports
%matplotlib inline

import matplotlib.pyplot as plt
import seaborn as sns
from adjustText import adjust_text

sns.set_style('white')

plt.rc('font', size=11)
plt.rc('xtick', labelsize=11)
plt.rc('ytick', labelsize=11)
plt.rc('axes', labelsize=12, titlesize=12)
plt.rc('legend', fontsize=11)

In [None]:
import os
import numpy as np
import pandas as pd
from Bio import Phylo
from Bio import SeqIO

In [None]:
def plot_tree(tree,
              phenotypes,
              mdist,
              rs,
              names=None, vnames=None,
              mdist_offset=0,
              color_strains=None,
              order=None, label_matrix=False):
    if names is None:
        names = {}
    if vnames is None:
        vnames = {}
    p_vector = phenotypes.loc[[x.name
                              for x in tree.get_terminals()]]
    
    grid_length = 13 + 1 + 1 + sum([x[0].shape[0]
                                    for x in rs])

    grid_height = len(tree.get_terminals())
    
    fig = plt.figure(figsize=(grid_length/2.5, 35))

    curr_w = 0
    for r, label, cmap in rs:
        cmap.set_bad('grey', 0.2)
        cmap.set_under('red', 1)
        ax=plt.subplot2grid((grid_height, grid_length),
                            (0, 15 + curr_w),
                            colspan=r.shape[0],
                            rowspan=len(tree.get_terminals()))
        if order is None:
            corder = list(r.T.sum().sort_values(ascending=False).index)
        else:
            corder = order
        a=ax.imshow(r.loc[corder,
                         [x.name
                          for x in tree.get_terminals()]].T,
                    cmap=cmap,
                    vmin=r.min().min(), vmax=r.max().max(),
                    aspect='auto',
                    interpolation='none',
                    )
        ax.set_yticks([])
        ax.set_xticks([])
        for i  in range(r.shape[0]):
            ax.axvline(i+0.5,
                       color='grey',
                       alpha=0.77)
        ax.set_xticks([x for x in range(r.shape[0])])
        ax.set_xticklabels([names.get(x, vnames.get(x, x))
                            for x in 
                            corder],
                           rotation=90,
                           size=12)
        ax.set_xlabel(label,
                      size=12)
        curr_w += r.shape[0]
        
    cmap = plt.cm.Reds
    cmap.set_bad(sns.xkcd_rgb['light grey'], 0.2)

    ax1=plt.subplot2grid((grid_height, grid_length),
                         (0, 13),
                         colspan=1,
                         rowspan=len(tree.get_terminals()))
    a=ax1.imshow([[x] for x in p_vector],
                 cmap=cmap,
                 vmin=p_vector.min(), vmax=p_vector.max(),
                 aspect='auto',
                 interpolation='none',
                )
    if label_matrix:
        ax1.set_yticks([x for x in range(p_vector.shape[0])])
        ax1.set_yticklabels([x.name if x.name in color_strains
                             else ''
                             for x in tree.get_terminals()],
                            rotation=0,
                            size=10)
    else:
        ax1.set_yticks([])
    ax1.set_xticks([0])
    ax1.set_xticklabels(['Phenotype'],
                        rotation=90,
                        size=12)

    ax=plt.subplot2grid((grid_height, grid_length),
                        (0, 0),
                        colspan=10,
                        rowspan=len(tree.get_terminals()))
    
    fig.subplots_adjust(wspace=0, hspace=0)
    
    labels = {}
    for x, i in zip(tree.get_terminals(),
                    range(len(tree.get_terminals()))):
        labels[x.name] = i
    
    if color_strains is None:
        color_strains = [x.name for x in tree.get_terminals()]
    
    def _label(x, strains, label_matrix):
        if x.name in strains and not label_matrix:
            return x.name
        
    plt.rc('font', size=10)
    suptitle = ''
    Phylo.draw(tree, axes=ax, 
               show_confidence=False,
               label_func=lambda x: _label(x, color_strains, label_matrix),
               xticks=([],),
               yticks=([],),
               ylabel=('',), suptitle=(suptitle,),
               xlim=(-0.01, mdist+0.01+mdist_offset),
               axis=('off',),
               do_show=False,)

In [None]:
tree = Phylo.read(tree,
                  'newick')
tree.ladderize()
mdist = max([tree.distance(tree.root, x) for x in tree.get_terminals()])

In [None]:
k = pd.read_table(phenotypes,
                  index_col=0)

In [None]:
m = pd.read_table(odds_ratio,
                  index_col=0)
m['lrt-pvalue'] = [float(x)
                   if x != 'NAN'
                   else np.nan
                   for x in m['lrt-pvalue']]
m = m.dropna()

In [None]:
f = pd.read_table(filtered,
                  index_col=0)

In [None]:
v = pd.read_table(virulence)

In [None]:
names = pd.read_table(tnames, index_col=1)['preferred_og_name'].to_dict()

In [None]:
vnames = v.set_index('og')['gene'].to_dict()

In [None]:
oglen = pd.Series({x.id: len(x)
                   for x in SeqIO.parse(spangenome, 'fasta')})

In [None]:
fold = pd.read_table(fold_changes,
                     index_col=0)

In [None]:
plt.figure(figsize=(4, 4))

plt.plot(-np.log10(m['lrt-pvalue']),
         m['odds-ratio'],
         'ko',
         alpha=0.1,
         label='_')

plt.plot(-np.log10(m.loc[m.index.intersection(f.index)]['lrt-pvalue']),
         m.loc[m.index.intersection(f.index)]['odds-ratio'],
         'ro',
         label='associated OGs')
plt.plot(-np.log10(m.loc[m.index.intersection(v['og'])]['lrt-pvalue']),
         m.loc[m.index.intersection(v['og'])]['odds-ratio'],
         'bo',
         label='other virulence OGs')

plt.legend(loc='center left',
           bbox_to_anchor=(1, 0.5),
           frameon=True)

plt.xlabel('OG association $-log_{10}(pvalue)$')
plt.ylabel('OG odds ratio');

In [None]:
plt.figure(figsize=(7, 7))

plt.plot(-np.log10(m['lrt-pvalue']),
         m['odds-ratio'],
         'ko',
         alpha=0.03,
         label='_')

plt.plot(-np.log10(m.loc[m.index.intersection(f.index)]['lrt-pvalue']),
         m.loc[m.index.intersection(f.index)]['odds-ratio'],
         'ro',
         label='associated OGs')
text1 = [plt.text(-np.log10(x), y, names.get(t, t),
                  ha='center', va='center')
         for t, (y, x) in m.loc[m.index.intersection(f.index)].iterrows()]

plt.plot(-np.log10(m.loc[m.index.intersection(v['og'])]['lrt-pvalue']),
         m.loc[m.index.intersection(v['og'])]['odds-ratio'],
         'bo',
         label='other virulence OGs')
text2 = [plt.text(-np.log10(x), y, vnames.get(t, t),
                  ha='center', va='center')
         for t, (y, x) in m.loc[m.index.intersection(v['og'])].iterrows()]

adjust_text(text1 + text2,
            arrowprops=dict(arrowstyle='->', color='k'),
            force_points=15)

plt.legend(loc='center left',
           bbox_to_anchor=(1, 0.5),
           frameon=True)

plt.xlabel('OG association $-log_{10}(pvalue)$')
plt.ylabel('OG odds ratio');

In [None]:
r = pd.read_table(rtab, index_col=0)
r1 = r.loc[f.index].copy(deep=True)
r2 = r.loc[v['og']].drop_duplicates().copy(deep=True)
r2 = r2.sort_index()

In [None]:
plot_tree(tree,
          k['killed'],
          mdist,
          [(r1, 'associated OGs', plt.cm.Reds),
           (r2, 'other pathogenic OGs', plt.cm.Blues)],
          names, vnames,
          mdist_offset=0)

In [None]:
p = []
for f in os.listdir(mappings):
    o = pd.read_table(os.path.join(mappings, f),
                      header=None)
    o.columns = ['strain', 'kmer', 'dna',
                 'start', 'end', 'strand',
                 'up', 'in', 'down']
    p.append(o)
p = pd.concat(p)
p['size'] = [len(x) for x in p['kmer'].values]

In [None]:
t = p[p['size'] >= 30].groupby(['strain', 'in']).count()['kmer'].unstack().copy(deep=True)
t = t.reindex([x.name for x in tree.get_terminals()])
t[np.isnan(t)] = 0
t = t.T

In [None]:
plot_tree(tree,
          k['killed'],
          mdist,
          [(t, 'associated kmers', plt.cm.Reds)],
          names, vnames,
          mdist_offset=0)

In [None]:
t1 = t.apply(lambda x: x / oglen.loc[t.index])

In [None]:
plot_tree(tree,
          k['killed'],
          mdist,
          [(t1, 'associated kmers', plt.cm.Reds)],
          names, vnames,
          mdist_offset=0)

In [None]:
t2 = t.copy(deep=True)
t2[t2 <= 10] = 0
t2[t2 > 10] = 1
t2 = t2.loc[t2.T.max()[t2.T.max() > 0].index]

In [None]:
plot_tree(tree,
          k['killed'],
          mdist,
          [(t2, 'associated kmers', plt.cm.Reds)],
          names, vnames,
          mdist_offset=0)

In [None]:
idx = fold.index.intersection(t1.index).intersection(r1.index)

In [None]:
plot_tree(tree,
          k['killed'],
          mdist,
          [(t1.loc[idx], 'associated kmers', plt.cm.Blues),
           (r1.loc[idx], 'gene presence/absence', plt.cm.Reds),
           (fold.loc[idx], 'transcription\n(fold change)', sns.cm.vlag),],
          names, vnames,
          mdist_offset=0,
          color_strains=fold.columns,
          order=t1.loc[idx].T.sum().sort_values(ascending=False).index,
          label_matrix=True)