## Setup

Sets up the environment for analyses of the phase 1 selection paper.

In [None]:
# python standard library
import sys
import os
import operator
import itertools
import collections
import functools
import glob
import csv
import datetime
import bisect
import sqlite3
import subprocess
import random
import gc
import shutil
import shelve
import contextlib
import tempfile
import math

In [None]:
# plotting setup
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.gridspec import GridSpec
import seaborn as sns
sns.set_context('paper')
sns.set_style('white')
sns.set_style('ticks')
rcParams = plt.rcParams
# N.B., reduced font size
rcParams['font.size'] = 8
rcParams['axes.labelsize'] = 10
rcParams['xtick.labelsize'] = 10
rcParams['ytick.labelsize'] = 10
rcParams['legend.fontsize'] = 10
rcParams['axes.linewidth'] = .5
rcParams['lines.linewidth'] = .5
rcParams['patch.linewidth'] = .5
rcParams['ytick.direction'] = 'out'
rcParams['xtick.direction'] = 'out'
rcParams['savefig.jpeg_quality'] = 100
rcParams['lines.markeredgewidth'] = .5

In [None]:
%matplotlib inline

In [None]:
# general purpose third party packages
import numpy as np
nnz = np.count_nonzero
import scipy
import scipy.stats
import scipy.spatial.distance
import numexpr
import h5py
import tables
import bcolz
import dask
import dask.array as da
import pandas as pd
import IPython
from IPython.display import clear_output, display, HTML
import sklearn
import sklearn.decomposition
import sklearn.manifold
import petl as etl
etl.config.display_index_header = True
import humanize
from humanize import naturalsize, intcomma, intword
import zarr
from scipy.stats import entropy
import lmfit

In [None]:
import allel

In [None]:
sys.path.insert(0, '../agam-report-base/src/python')
from util import *
import zcache
import veff
import hapclust

ag1k_dir = '../../data/release'
from ag1k import phase1_ar3
phase1_ar3.init(os.path.join(ag1k_dir, 'phase1.AR3'))

from ag1k import phase1_ar31
phase1_ar31.init(os.path.join(ag1k_dir, 'phase1.AR3.1'))

In [None]:
from ag1k import phase1_selection
phase1_selection.init(os.path.join(ag1k_dir, 'phase1.selection.1.RC2'))

In [None]:
tbl_genes = etl.cat(*[get_geneset_features(phase1_ar3.geneset_agamp42_fn, chrom).eq('type', 'gene').unpackdict('attributes', ['ID']) 
                      for chrom in chromosomes])
tbl_genes

In [None]:
lkp_gene = tbl_genes.recordlookupone('ID')
lkp_gene['AGAP004707']

In [None]:
def plot_list_genes_track(namespace, chrom, ax, gene_labels, plot=True, x_loc=None, **kwargs):
    if chrom == '2R':
        sns.despine(ax=ax, left=True, offset=5)
    else:
        sns.despine(ax=ax, left=True, offset=5)
    if plot:
        plot_genes(namespace.genome, namespace.geneset_agamp42_fn, 
                   chrom, ax=ax, height=.2, label=False, 
                   barh_kwargs=dict(lw=0.1, alpha=.2))
    if chrom == '2R':
#         ax.set_ylabel('genes', ha='left', va='center', rotation=0)
#         ax.yaxis.set_label_coords(0, 1, transform=ax.transAxes)
        ax.set_yticks([])
        ax.set_ylabel('')
    else:
        ax.set_yticks([])
        ax.set_ylabel('')
    if chrom == 'X':
        ax.set_xlabel('Position (Mbp)', ha='left')
        ax.xaxis.set_label_coords(1.1, -.87, transform=ax.transAxes)
        ax.set_ylabel('Genes', ha='left', va='center', rotation=0)
        ax.yaxis.set_label_coords(1.1, 0.5, transform=ax.transAxes)
    
    for gid in gene_labels:
        rec = lkp_gene[gid]
        if rec.seqid == chrom:
            x = (rec.start + rec.end) / 2
            if rec.strand == '+':
                y = .9
                marker = 'v'
                yt = -1
            else:
                y = .1
                marker = '^'
                yt = -4
            ax.plot([x], [y], marker=marker, mfc='w', mec='k')
            ax.annotate(gene_labels[gid], xy=(x, y), xytext=(5, yt), textcoords='offset points',
                        fontsize=7, fontstyle='italic')
            
            
    if x_loc is not None:
        for _chrom, _pos in x_loc:

            if _chrom == chrom:
                y = .1
                marker = '^'
                yt = -4
                ax.plot([_pos], [y], marker=marker, mfc='k', mec='k')

    xticks = np.arange(0, len(namespace.genome[chrom]), 10000000)
    if chrom in {'3R', '2R'}:
        xticks = xticks[:-1]
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticks//1000000)
    ax.set_xlim(0, len(namespace.genome[chrom]))
    ax.set_ylim(0, 1)

In [None]:
def fig_linear_genome(plotf, genome, chromosomes=None, fig=None, 
                      bottom=0, height=1, width_factor=1.08, chrom_pad=0.035, 
                      clip_patch_kwargs=None, **kwargs):
    if chromosomes is None:
        chromosomes = ['2R', '2L', '3R', '3L', 'X']
    genome_size = sum(len(genome[chrom]) for chrom in chromosomes)

    from matplotlib.path import Path

    if fig is None:
        fig = plt.figure(figsize=(8, 1))

    left = 0

    if clip_patch_kwargs is None:
        clip_patch_kwargs = dict()
    clip_patch_kwargs.setdefault('edgecolor', 'k')
    clip_patch_kwargs.setdefault('facecolor', 'none')
    clip_patch_kwargs.setdefault('lw', 1)

    axs = dict()
    for chrom in chromosomes:

        # calculate width needed for this chrom
        width = len(genome[chrom]) / (genome_size * width_factor)

        # create axes
        ax = fig.add_axes([left, bottom, width, height])
        ax.set_axis_bgcolor((1, 1, 1, 0));
        axs[chrom] = ax

        # construct clip path
        if chrom in {'2R', '3R'}:
            verts = [(0.01, 0.02), (0.9, 0.02), (1.01, 0.3), (1.01, 0.7), (0.9, .98), (0.01, .98), (0.01, 0.02)]
            codes = [Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY]
        elif chrom == "X":
            verts = [(0.01, 0.02), (0.9, 0.02), (0.99, 0.3), (0.99, 0.7), (0.9, .98), (0.01, .98), (0.01, 0.02)]
            codes = [Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY]
        else:
            verts = [(0.1, 0.02), (.99, 0.02), (.99, .98), (.1, .98), (-0.01, .7), (-0.01, .3), (0.1, 0.02)]
            codes = [Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY]
        path = Path(verts, codes)
        clip_patch = mpl.patches.PathPatch(path, transform=ax.transAxes, **clip_patch_kwargs)

        # do the plotting
        plotf(chrom=chrom, ax=ax, clip_patch=clip_patch, **kwargs)

        # increment left coordinate
        left += len(genome[chrom]) / (genome_size * width_factor)
        if chrom in {'2L', '3L'}:
            left += chrom_pad

    return axs

In [None]:
autosomes = '2R', '2L', '3R', '3L'
chromosomes = autosomes + ('X',)


class GenomeFigure(object):
    
    def __init__(self, genome, *args, **kwargs):
        self.chromosomes = kwargs.pop('chromosomes', ['2R', '2L', '3R', '3L', 'X'])
        maxchrsize = max(np.array(genome[chrom]).size for chrom in self.chromosomes)
        fig = plt.figure(*args, **kwargs)
        self.fig = fig
        self.ax = dict()
        for i, chrom in enumerate(self.chromosomes):
            ax = fig.add_subplot(3, 2, i+1)
            self.ax[chrom] = ax
            S = np.array(genome[chrom])
            if i % 2 == 1:
                sns.despine(ax=ax, offset=10, top=True, left=True, right=False)
                ax.set_xlim(0, maxchrsize)
                ax.yaxis.tick_right()
                ax.yaxis.set_label_position('right')
            else:
                ax.set_xlim((S.size)-(maxchrsize), S.size)
                ax.yaxis.tick_left()
                sns.despine(ax=ax, offset=10, top=True, left=False, right=True)
            ax.set_xticks(range(0, S.size, int(5e6)))
            ax.set_xticklabels(range(0, int(S.size/1e6), 5))
            ax.set_title(chrom, fontweight='bold')
            ax.xaxis.tick_bottom()
        fig.tight_layout()
        
    def apply(self, f, **kwargs):
        chromosomes = kwargs.pop('chromosomes', self.chromosomes)
        for chrom in chromosomes:
            ax = self.ax[chrom]
            f(chrom, ax, **kwargs)
        
        
def subplots(*args, **kwargs):
    fig, ax = plt.subplots(*args, **kwargs)
    sns.despine(ax=ax, offset=10)
    return fig, ax

In [None]:
def filter_zarr(data, start, stop, field, pos_field="POS"):
    
    pos = allel.SortedIndex(data[pos_field])
    ix = pos.locate_range(start, stop)
    
    return pos[ix], data[field][ix]

In [None]:
def compute_gain(target_attr, pos, haps, plots=False):

    n_classes = np.bincount(target_attr).size
    n_haps = haps.shape[1]

    # compute entropy for the target attribute
    target_freqs = np.bincount(target_attr, minlength=n_classes) / target_attr.shape[0]
    target_entropy = entropy(target_freqs)

    # setup output array
    gain = np.zeros(pos.shape[0])

    # work through the variants one by one
    for i in range(pos.shape[0]):

        # pull out the attribute data
        attr = haps[i]

        # split on attribute value and compute entropies for each split
        split_entropy = 0
        for v in 0, 1, 2:
            split = target_attr[attr == v]
            if split.shape[0] == 0:
                continue
            split_freqs = np.bincount(split, minlength=n_classes) / split.shape[0]
            split_entropy += (split.shape[0] / n_haps) * scipy.stats.entropy(split_freqs)

        # compute and store gain
        gain[i] = target_entropy - split_entropy
        
    if plots:
        make_gain_plots(gain, pos, haps)
        
        
    return gain

In [None]:
def make_gain_plots(gain, pos, haps):
    
    
    fig, axes = plt.subplots(nrows=2, sharex=False, sharey=True, figsize=(9, 6))
    
    for ax in axes:
        sns.despine(ax=ax, offset=5)
    
    ax1, ax2 = axes
    ax1.scatter(pos/1e6, gain, marker="o", alpha=0.1, color="k")
    ax1.set_ylabel("Information gain")
    ax1.set_xlabel("position (Mbp)")
    
    ac = haps.count_alleles()
    freq = ac.to_frequencies()[:, 0]
    ax2.scatter(freq, gain, marker="o", alpha=0.1, color="k")
    ax2.set_ylabel("Information gain")
    ax2.set_xlabel("Alt freq.")

In [None]:
# assume access to callset pass.
def retrieve_allele_freqs(chrom, pos):
    
    af = os.path.join(ag1k_dir, "phase1.AR3/extras/allele_frequencies.h5")
    affh = h5py.File(af, "r")
    
    callset_pass = phase1_ar3.callset_pass
    
    reference = callset_pass[chrom]["variants/REF"][:].astype("<U1")
    alternate = callset_pass[chrom]["variants/ALT"][:].astype("<U1")
    positions = allel.SortedIndex(callset_pass[chrom]["variants/POS"][:])
    
    ix = positions.locate_keys(pos)
    
    afq_df = pd.DataFrame(index=pos, columns=["REF", "ALT"] + sorted(affh[chrom].keys()))
    afq_df["REF"] = np.compress(ix, reference, 0)    
    afq_df["ALT"] = np.compress(ix, alternate, 0)

    for grp in affh[chrom].keys():
        afq_df[grp] = np.compress(ix, affh[chrom][grp], 0)
    
    return afq_df

In [None]:
# function that computes frequencies in each class
def compute_class_frequencies(gain, pos, haps, target_attr, names, plot=True):
    
    assert pos.size == haps.shape[0] == gain.size
    assert np.bincount(target_attr).size == len(names) + 1

    subpops = {"susceptible": np.where(Y == 0)[0]}
    
    for ix, na in enumerate(names):
        subpops[na] = np.where(target_attr == (ix + 1))[0]

    ac = haps.count_alleles_subpops(subpops=subpops, max_allele=1)
    af_ = pd.DataFrame(index=pos, columns=["susceptible"] + names)
    
    for k in subpops.keys():
        af_[k] = ac[k].to_frequencies()[:, 1]
        
    af_["variance"] = af_.var(1)
    af_["gain"] = gain
    
    if plot:
        fig, ax = plt.subplots()
        ax.scatter(af_.variance, af_.gain)
        ax.set_xlabel("Variance")
        ax.set_ylabel("Information gain")
                
    return af_

In [None]:
from Bio.Data import IUPACData
# create reverse dictionary
iupac = IUPACData.ambiguous_dna_values.copy()
iupac.pop("X")
iupac_rev_iub = {"".join(sorted(v)): k for k, v in iupac.items()}

# function that generates string for assay design.
def generate_assay_design_string(chrom, pos_list, buf=1):
    
    # position start, stop. This is *inclusive*
    start, stop = pos_list[0] - buf, pos_list[-1] + buf
    
    # get reference loci: This is 0 based, so we subtract 1. Then +1 for inclusivity.
    reference_seq = np.array(list(phase1_ar3.genome[chrom][(start - 1):stop].upper()))
    
    # find other variants in region
    all_variant_pos = allel.SortedIndex(phase1_ar3.callset[chrom]["variants/POS"])

    # this is inclusive and 0 based 
    ix = all_variant_pos.locate_range(start, stop)
    
    reg_gt = allel.GenotypeArray(phase1_ar3.callset[chrom]["calldata/genotype"][ix])
    reg_pos = all_variant_pos[ix]
    reg_ref = phase1_ar3.callset[chrom]["variants/REF"][ix]
    reg_alt = phase1_ar3.callset[chrom]["variants/ALT"][ix]
    
    span_pos = allel.SortedIndex(range(start, stop + 1))
        
    # loop through and apply upac codes.
    for p, gt, rref, ralt in zip(reg_pos, reg_gt, reg_ref, reg_alt):
        
        # determine if singleton
        gv = allel.GenotypeVector(gt)
        x = gv.to_allele_counts(max_allele=3).sum(0)
        not_singleton = np.sort(x)[:-1].sum() > 1
        
        # check fasta agrees with variant data.
        qq = phase1_ar3.genome[chrom][int(p - 1)]        
        assert rref.decode() == qq.upper(), (rref, "vs", qq)
        
        if not_singleton:
            # work out upac code... and replace
            qix = span_pos.locate_key(p)
            c = rref + b''.join(ralt) 
            iupacstring = "".join(sorted(list(c.decode())))
            reference_seq[qix] = iupac_rev_iub[iupacstring]
            
     # determine accessibilty
    is_accessible = phase1_ar3.accessibility[chrom]["is_accessible"][start - 1:stop]    
    for v in np.where(~is_accessible)[0]:
        reference_seq[v] = reference_seq[v].lower()
    
    qlist = []
    for p in pos_list:
        
        pix = reg_pos.locate_key(p)
        
        fref = reg_ref[pix].astype("<U1")
        falt = reg_alt[pix].astype("<U1")
        
        # assert biallelic
        assert falt[1] == falt[2] == ""
        
        centre = "[{0}|{1}]".format(fref, falt[0])
        
        ix_span = span_pos.locate_key(p)                            
                                    
        upstream = "".join(reference_seq[(ix_span - buf):ix_span])
        dnstream = "".join(reference_seq[(ix_span + 1):(ix_span + 1 + buf)])
        
        qlist.append(upstream + centre + dnstream)
    
    return pd.Series(qlist, index=pos_list, name="assay_string")

In [None]:
def locate_functional_snps(chrom, haplotypes, positions, callset_pass, clusters, ref_ix):
    
    # larger as includes biallelics
    pos_pass = allel.SortedIndex(phase1_ar3.callset_pass[chrom]['variants/POS'][:])
    pass_loc = pos_pass.locate_range(positions[0], positions[-1])
    
    ann_pass = phase1_ar3.callset_pass[chrom]['variants/ANN'][:][['Annotation', 'HGVS_p', 'Gene_ID']][pass_loc]

    # need to remove missing overlaps ie multiallelics
    ann_region = ann_pass.compress(np.in1d(pos_pass[pass_loc], positions), 0)
    
    # all haplotype positions should be in the pos pass array.
    assert np.all(np.in1d(positions, positions)), "missing positions"
    
    loc_region_missense = (ann_region['Annotation'] == b'missense_variant')        

    # define types of variants to include in EHH analysis - should be mostly neutral
    loc_type_neutral = ((ann_region['Annotation'] == b'intergenic_region') | 
                        (ann_region['Annotation'] == b'intron_variant') |
                        (ann_region['Annotation'] == b'downstream_gene_variant') |
                        (ann_region['Annotation'] == b'upstream_gene_variant') |
                        (ann_region['Annotation'] == b'synonymous_variant') |
                        (ann_region['Annotation'] == b'3_prime_UTR_variant') |
                        (ann_region['Annotation'] == b'5_prime_UTR_variant') 
                        )
    
    expected_size = positions.size
    for d in [haplotypes, ann_region]:
        assert d.shape[0] == expected_size, "wrong shape"
        
    print("{0:.3f} of muations are neutral".format(loc_type_neutral.mean()))
    
    ref_haps = haplotypes.take(ref_ix, axis=1)
    unswept_freq = ref_haps.count_alleles(1).to_frequencies()[:, 1]
    
    swept_f = {}
    for k, v in clusters.items():
        swept_haps = haplotypes.take(v, axis=1)
        swept_f[k] = swept_haps.count_alleles(1).to_frequencies()[:, 1]
    swept_freq = np.column_stack(swept_f.values())
        
    # return swept_freq, unswept_f
    
    # interesting if *all* clusters above 0.9 / below 0.1
    # AND reference is more than x away
    maf_ok = np.any(swept_freq > 0.9, 1) | np.any(swept_freq < 0.1, 1) 
    
    diff = np.abs(swept_freq - np.expand_dims(unswept_freq, axis=1))
    
    
    diff_ok = np.any(diff > 0.25, 1)
    
    pif = maf_ok & diff_ok
    
    
    print("diff", diff[pif][0])
    print("swept", swept_freq[pif][0])
    print("unswept", unswept_freq[pif][0])
    
           
    # position, whether functional
    pos_pif = positions[pif]
    neutral = loc_type_neutral[pif]

    func_changes = pd.DataFrame(index=pos_pif)
    func_changes["ref_af"] = unswept_freq[pif]
    for k, v in swept_f.items():
        print(k, pos_pif.size, v.size)
        func_changes[k + "_af"] = v[pif]
    func_changes["annotation"] = ann_region[pif]["Annotation"].astype("str")
    func_changes["change"] = ann_region[pif]["HGVS_p"].astype("str")
    func_changes["is_neutral"] = neutral

    func_changes["gene"] = ann_region[pif]["Gene_ID"].astype("str")

    return func_changes

In [None]:
from scipy.cluster.hierarchy import _convert_to_double
from scipy.spatial import distance
from scipy.cluster.hierarchy import _hierarchy

def plot_dendrogram(zhier, ax, method='complete', color_threshold=0, above_threshold_color='k'):
    
    # plot dendrogram
    sns.despine(ax=ax, offset=5, bottom=True, top=False)
    r = scipy.cluster.hierarchy.dendrogram(zhier, no_labels=True, count_sort=True, 
                                           color_threshold=color_threshold, 
                                           above_threshold_color=above_threshold_color,
                                           ax=ax)
    xmin, xmax = ax.xaxis.get_data_interval()
    xticklabels = np.array(list(range(0, h.shape[1], 200)) + [h.shape[1]])
    xticks = xticklabels / h.shape[1]
    xticks = (xticks * (xmax - xmin)) + xmin
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabels)
    ax.set_xlabel('Haplotypes')
    ax.xaxis.set_label_position('top')
    ax.set_ylim(bottom=-.0001)
    
    ax.set_ylabel(r'$d_{xy}$')
    ax.autoscale(axis='x', tight=True)
    return z, r

In [None]:
# draw an annotated h12 plot containing 2 populations
def draw_dbl_h12_plot(chrom, start, stop, pop1, pop2, gene_colors, ylim=0.5, path=None):

    gs = GridSpec(2, 1, height_ratios=[1.0, 1.0])
    fig = plt.figure(figsize=(10, 2))

    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1], sharex=ax1)
    sns.despine(ax=ax1)
    sns.despine(ax=ax2)

    max_pos = {}
    for i, pop in enumerate([pop1, pop2]):

        h12_p, h12_v = filter_zarr(
            phase1_selection.hstats_raw[chrom][pop], start, stop, "h12", "pos")
        if i == 1:
            h12_v = 0 - h12_v
        ax1.fill_between(h12_p[::10], h12_v[::10], 
                         label=pop, color=phase1_ar3.pop_colors[pop], alpha=1)
        ax1.set_ylim((-0.51, 0.51))
        
        max_pos[pop] = h12_p[np.nanargmax(h12_v)]

    ax1.xaxis.set_visible(False)
    
    q = np.arange(-ylim, ylim+0.01, 0.5)
    ax1.set_yticks(q)
    ax1.set_yticklabels(np.abs(q))
    ax1.set_ylabel(r"$h_{12}$", rotation=0)
    ax1.yaxis.set_label_coords(0.03, 0.8)
    ax1.tick_params(length=2)
    ax1.legend(ncol=2, loc=(0.75, 0.81))

    plot_genes(phase1_ar3.genome, phase1_ar3.geneset_agamp42_fn,
               chrom, start, stop, label=False, labels=None, label_unnamed=False, ax=ax2)

    for gid in gene_colors:
            rec = lkp_gene[gid]
            color = gene_colors[gid]
            if rec.seqid == chrom:
                x = (rec.start + rec.end) / 2
                if rec.strand == '+':
                    y = .9
                    marker = 'v'
                    yt = -1
                else:
                    y = .1
                    marker = '^'
                    yt = -4
                ax2.plot([x], [y], marker=marker, mfc=color, mec='k')

    ax2.set_xticklabels(["{0:.1f}".format(x*1e-6) for x in  ax2.get_xticks()])
    ax2.tick_params(length=2)
    ax2.yaxis.set_label_coords(0.06, 0.92)
    ax2.set_xlabel("Genomic Position (Mbp)")
    
    if path is not None:
        fig.savefig(path, dpi=500, bbox_inches="tight")
        
    return max_pos

In [None]:
def truspan(cluster, r):
    # get the index of the cluster haps in the dendrogram list of all haps
    cluster_leaves = sorted([r['leaves'].index(i) for i in cluster])
    # are these indices monotonic - they should be!
    x = np.asarray(cluster_leaves)
    dx = np.diff(x)
    mon = np.all(dx == 1)
    assert mon
    return min(cluster_leaves), max(cluster_leaves)

In [None]:
def draw_hap_cluster_plot(z, r, cluster_labels, vspans, df_haplotypes=phase1_ar31.df_haplotypes,
                          path=None):

    gs = GridSpec(3, 1, height_ratios=[5.0, 1.0, 0.6])
    fig = plt.figure(figsize=(10, 2))

    ax1 = plt.subplot(gs[0])
    sns.despine(ax=ax1, offset=5, bottom=True, top=False)
    _ = plot_dendrogram(z, ax1)

    ax_pops = fig.add_subplot(gs[1])

    #if hap_pops is None:
    hap_pops = df_haplotypes.population.values

    x = hap_pops.take(r['leaves'])
    hap_clrs = [phase1_ar3.pop_colors[p] for p in x]
    ax_pops.broken_barh(xranges=[(i, 1) for i in range(h.shape[1])], 
                        yrange=(0, 1),
                        color=hap_clrs);
    sns.despine(ax=ax_pops, offset=5, left=True, bottom=True)
    
    ax_pops.set_xticks([])
    ax_pops.set_yticks([])
    ax_pops.set_xlim(0, h.shape[1])
    ax_pops.yaxis.set_label_position('left')
    ax_pops.set_ylabel('Population', rotation=0, ha='right', va='center')

    # cluster brackets
    ax_clu = fig.add_subplot(gs[2])
    sns.despine(ax=ax_clu, bottom=True, left=True)
    ax_clu.set_xlim(0, h.shape[1])
    ax_clu.set_ylim(0, 1)
    for lbl, (xmin, xmax) in zip(cluster_labels, vspans):
        if lbl:
            # hack to get the "fraction" right, which controls length of bracket arms
            fraction = -20 / (xmax - xmin)
            ax_clu.annotate("", ha='left', va='center',
                            xy=(xmin, 1), xycoords='data',
                            xytext=(xmax, 1), textcoords='data',
                            arrowprops=dict(arrowstyle="-",
                                            connectionstyle="bar,fraction=%.4f" % fraction,
                                            ),
                            )
            ax_clu.text((xmax + xmin)/2, 0.1, lbl, va='top', ha='center')
            ax_pops.vlines([xmin, xmax], 0, 1, linestyle=':')

    ax_clu.set_xticks([])
    ax_clu.set_yticks([])

    if path is not None:
        fig.savefig(path, dpi=500, bbox_inches="tight")

In [None]:
prior_loci = "../../data/analysis/20161101-selection-candidate-loci"

In [None]:
func_annot = "type1_metabolic", "type2_target_site", "type3_behavioural", "type4_cuticular"
#func_annot_lab = ["m'bolic", "target s.", "behav'l", "cutic'r"]
func_annot_lab = ["I", "II", "III", "IV"]

In [None]:
# load annotations

gene_classes = {}
gene_table = {}

for an in func_annot:
    gene_classes[an] = pd.read_csv(
        os.path.join(prior_loci, an + ".txt"), 
        sep="\t")["Gene stable ID"].tolist()
    
    df = pd.DataFrame(index=gene_classes[an], columns=["classification"])
    df.classification = an
    df.index.name = "gene_id"
    gene_table[an] = df
    
gene_table = pd.concat(gene_table).reset_index(level=1)

In [None]:
gene_table

## File paths

In [None]:
fstem_scan_plot = "../artwork/case{case}/scan_{analysis_id}.png"
fstem_scan_plot2 = "../artwork/case{case}/scan2_{analysis_id}.png"
fstem_tree_plot = "../artwork/case{case}/pwtree_{analysis_id}.png"
fstem_cluster_plot = "../artwork/case{case}/cluster_{analysis_id}.png"
fstem_switch_plot = "../artwork/case{case}/switch_{analysis_id}.png"

In [None]:
fstem_dis_snps = "../tables/case{case}/dis_snps_{analysis_id}.txt"