In [None]:
species = ["bambooshark",
           "homo",
           "gallus",
           "xenopus",
           "lepisosteus",
           "bonytongue",
           "ictalurus",
           "danio",
           "esox",
           "gadus",
           "takifugu"]

In [None]:
# Load modules 

import phylopandas as phy
from phylogenetics import tools
import phylogenetics

%matplotlib inline 
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

import re, pickle, glob, os, random, string

In [None]:
# ---------------------------------------------------
# Functions for plotting genomic context 
# ---------------------------------------------------

def _smooth_hits(hits_df,window_size,global_bounds,
                 weight_by_evalue=True,scale_min_feature_to=3):
    """
    Generate running average of hits along the block.
    
    hits_df: dataframe with all hits on this block
    window_size: size for smoothing window
    global_bounds: tuple holding total size of block: (start,end)
    weight_by_evalue: whether or not to weight average by the -log(evalue) of
                      the hits.
    scale_min_feature_to: to make faster, take the smallest feature in the data
                          and rescale so this is scale_min_feature_to units long.
    """
    
    # copy df
    df = hits_df.copy()
    
    # Go through every gene, looking for minimum gene length
    for i in range(len(df.subject_start)):
                
        # Flip all genes so they go s -> e
        if df.subject_start.iloc[i] > df.subject_end.iloc[i]:
            s = df.subject_start.iloc[i]
            e = df.subject_end.iloc[i]
            
            df.subject_start.iloc[i] = e
            df.subject_end.iloc[i] = s

    # See if we've seen a smaller gene
    min_feature_size = window_size
    min_seen = np.min(df.subject_end - df.subject_start)
    if min_seen < min_feature_size:
        min_feature_size = min_seen
        
    # Make minimum feature size scale_min_feature_to units long
    scale = 10**np.floor(np.log10(min_feature_size/scale_min_feature_to))
    
    # Rescale indexes to be smaller and more managable
    df.subject_start = np.array(np.round(df.subject_start/scale),dtype=np.int)
    df.subject_end = np.array(np.round(df.subject_end/scale),dtype=np.int)
    window_size = np.int(np.round(window_size/scale))
    global_bounds = list(global_bounds[:])
    global_bounds[0] = np.int(np.round(global_bounds[0]/scale))
    global_bounds[1] = np.int(np.round(global_bounds[1]/scale))
    
    # Sort values from small to large
    df = df.sort_values(by="subject_start")
    
    # Edit e_values to be -log(e_value
    e_values = np.copy(df.e_value)
    e_values[np.array([e == 0 for e in e_values])] = 5.5e-126
    df.e_value = -np.log(e_values)
    
    N = np.max(df.subject_end) - np.min(df.subject_start)
    
    # Window size can't be bigger than the total length of the fragment
    if N < window_size:
        window_size = N

    """
    offset = df.subject_start.iloc[0]
    to_convolve = np.zeros(N,dtype=np.float)
    for i in range(len(df.subject_start)):
        s = df.subject_start.iloc[i] - offset
        e = df.subject_end.iloc[i] - offset
        to_convolve[s:(e + 1)] += df.e_value[i]
    
    indexes = np.arange(len(to_convolve)) + offset
    smoothed = np.convolve(to_convolve,
                           np.ones((window_size,))/window_size,
                           mode='same')
    """

    # Hit sets are hits that need to be convolved given hit size.  Stick the
    # first hit into the first hit set
    hit_sets = [[]]
    hit_sets[-1].append([df.subject_start.iloc[i],
                         df.subject_end.iloc[i],
                         df.e_value.iloc[i]])
    
    # Go over remaining hits
    for i in range(1,len(df.subject_start)):
        
        # No overlap, so create new set in list
        if df.subject_start.iloc[i] > (df.subject_end.iloc[i-1] + window_size):
            hit_sets.append([])
        
        # Append hit to last hit set in the list
        hit_sets[-1].append([df.subject_start.iloc[i],
                             df.subject_end.iloc[i],
                             df.e_value.iloc[i]])
    
    all_indexes = []
    all_smoothed = []
    for hit_set in hit_sets:
        
        # All hit boundaries in this set
        all_positions = [hit[0] for hit in hit_set]
        all_positions.extend([hit[1] for hit in hit_set])
        
        # Get outer edges of hit set
        total_min = np.min(all_positions) - window_size
        total_max = np.max(all_positions) + window_size
            
        # Populate array for doing convolution
        to_convolve = np.zeros(total_max - total_min,dtype=np.float)
        for hit in hit_set:
            s = hit[0] - total_min
            e = hit[1] - total_min
            
            if weight_by_evalue:
                v = hit[2]
            else:
                v = 1.0
                
            to_convolve[s:(e+1)] += v
        
        # Do convolution
        smoothed = np.convolve(to_convolve, 
                               np.ones((window_size,))/window_size,
                               mode='same')
        
        indexes = np.arange(len(smoothed),dtype=np.float) + total_min
        
        all_indexes.append(np.copy(indexes))
        all_smoothed.append(smoothed)

    if all_indexes[0][-1] > global_bounds[0]:
        all_indexes.insert(0,np.array([global_bounds[0]],dtype=np.float))
        all_smoothed.insert(0,np.zeros(1,dtype=np.float))

    if all_indexes[-1][-1] < global_bounds[0]:
        all_indexes.append(np.array([global_bounds[1]],dtype=np.float))
        all_smoothed.append(np.zeros(1,dtype=np.float))
        
    indexes = np.concatenate(all_indexes)
    smoothed = np.concatenate(all_smoothed)
    
    outside_mask = np.logical_and(indexes >= global_bounds[0],
                                  indexes <= global_bounds[1])

    return indexes*scale, smoothed


def plot_hits_on_assembly(hit_df,
                          subjects_to_plot=None,
                          window_size=0.0001,
                          focal_query=None,
                          focal_locations=None,
                          block_lengths=None,
                          use_evalue_hit_height=True,
                          min_evalue=5.5e-126,
                          num_blocks_to_plot=5,
                          min_hits_cutoff=4,
                          figsize=(10,5),
                          dpi=90,
                          only_plot_focal=False,
                          manual_title=None):
    """
    hit_df
    subjects_to_plot: either string for single subject species to make
                      plot for or a list of strings containing subject
                      species.  if None, plot all subject species in 
                      the dataframe.
    window_size: size of window for making bar plots of hits/per genomic
                 region. if < 1, use as a fraction of total size.  if > 1
                 use as an actual window. 
    focal_query: focal query used to generate the regions of interest. 
                 If defined, the focal_query is *not* included in the
                 bar height, and will be highlighted with its own color.
                 If None, ignore. 
    focal_locations: dictionary containing locations of known focal genes.
                     key should be subject_species, value should be a 
                     tuple like (block,location_on_block).  If None,
                     ignore. 
    block_lengths: dictionary of block lengths for this species
    use_evalue_hit_height: weight the height of bars by sum(-log(evalue)) 
                           for hits in the window
    min_evalue: assign any evalue below min_evalue to min_evalue. 
                (needed to deal with evalue = 0)
    num_blocks_to_plot: number of blocks to plot per subject
    min_hits_cutoff: don't plot a block if it has fewer than
                     min_hits_cutoff hits
    figsize: figure size (passed directly to matplotlib.pyplot.figure)
    dpi: figure dpi (passed directly to matplotlib.pyplot.figure)
    only_plot_focal: only plot block if it has a focal hit
    """
    
    plot_y_min = -0.2
    plot_y_max =  1.0
    focal_color = "green"
    gene_vertical = -0.1
    species_colors = {"homo":"indigo","danio":"peru"}
    
    # -------------------------------------------------------------- 
    # Sanity checking
    # --------------------------------------------------------------

    if focal_locations is not None and focal_query is None:
        err = "focal_locations can only be specified if a focal_query"
        err += " is specified.\n"
        raise ValueError(err)
    
    # -------------------------------------------------------------- 
    # Figure out what subject species we are going to make plots for
    # --------------------------------------------------------------
        
    if subjects_to_plot is None:
        subjects_to_plot = list(set(hit_df.subject_species))
        subjects_to_plot.sort()
    else:
        if type(subjects_to_plot) is str:
            subjects_to_plot = [subjects_to_plot]

    # --------------------------------------------------------------
    # X-axis and which blocks to plot
    #
    # Find the longest chunk of genome we'll have to plot.  Put
    # every assembly/chromosome etc. on this x-length scale so plots
    # are comprable.  Main output is "plot_bounds" dictionary, which
    # gives us the bounds for any plot that involves a subject species
    # and block pair. 
    #
    # Also record the number of hits on each block, creating a 
    # dictionary that keys each subject species to a list that holds
    # the blocks ranked from most to fewest hits. 
    # --------------------------------------------------------------
    
    # Do pass through all pairs, finding maximum span of each block,
    # as well as the maxium span seen overall. 
    biggest_span = -1
    all_min_max = {}
    block_hit_numbers = {}
    
    subject_species = set(hit_df.subject_species)
    for ss in subject_species:
        
        all_min_max[ss] = {}
        block_hit_numbers[ss] = []
        
        ss_df = hit_df[hit_df.subject_species == ss]
        blocks = set(ss_df.block)
        for b in blocks:
            
            block_df = ss_df[ss_df.block == b]
            all_positions = block_df.subject_start.tolist()
            all_positions.extend(block_df.subject_end.tolist())
            min_seen = np.min(all_positions)
            max_seen = np.max(all_positions)
    
            all_min_max[ss][b] = (min_seen,max_seen)
        
            if (max_seen - min_seen) > biggest_span:
                biggest_span = (max_seen - min_seen)
                
            num_hits_on_block = len(block_df)
            block_hit_numbers[ss].append((num_hits_on_block,b))
        
        # Sort blocks so they go from highest number of hits to 
        # lowest number, keyed to each species
        block_hit_numbers[ss].sort(reverse=True)
    
    biggest_span = biggest_span*1.1
    
    # Go back through pairs, creating plot bounds that add offsets to
    # left and right so region of interest is always centered. 
    x_plot_bounds = {}
    for ss in all_min_max.keys():
        x_plot_bounds[ss] = {}
        for b in all_min_max[ss].keys():
            min_seen = all_min_max[ss][b][0]
            max_seen = all_min_max[ss][b][1]
            span = (max_seen - min_seen)
            
            leftover = (biggest_span - span)/2. 
            
            x_plot_bounds[ss][b] = (min_seen,max_seen,leftover)
    
    # Figure out window size to use
    if window_size < 1:
        window_size = np.int(np.round(window_size*biggest_span))
    
    # --------------------------------------------------------------
    # Y-axis
    #
    # Find the maximum height we could get on our output plot
    # if all of the query species 
    # --------------------------------------------------------------
    
    # Get the maximum number of query genes that come from a single 
    # query species
    num_queries = []
    query_species = set(hit_df.query_species)
    for qs in query_species:
        qs_df = hit_df[hit_df.query_species == qs]
        query_names = set(qs_df.query_name)
        num_queries.append(len(query_names))
    
    # Maximum possible height is total number of queries less the focal
    # sequence, which we ignore in the score. 
    height_norm_factor = np.max(num_queries)
    if focal_query is not None:
        height_norm_factor = height_norm_factor - 1
        
    if use_evalue_hit_height:
        height_norm_factor *= -np.log(min_evalue)
    
    # HACK HACK HACK
    height_norm_factor = 53
    
    # --------------------------------------------------------------
    # Get down to plotting business
    # --------------------------------------------------------------
    
    for ss in subjects_to_plot:
        
        df = all_hits[all_hits.subject_species == ss]
        
        # Construct dictionary of focal locations, with start keyed
        # to block name.
        if focal_locations is not None:
            
            try:
                tmp = focal_locations[ss]
                focal_for_species = dict([(t[0],t[1]) for t in tmp])
            except KeyError:
                focal_for_species = {}
        else:
            focal_for_species = {}
    
        num_blocks = len(block_hit_numbers[ss])
        if num_blocks > num_blocks_to_plot:
            num_blocks = num_blocks_to_plot
        
        plt.tight_layout()
        fig, ax = plt.subplots(num_blocks,1,
                               figsize=(figsize[0],figsize[1]*(num_blocks)),
                               dpi=dpi,frameon=False)
        
        for i in range(num_blocks):

            
            num_hits = block_hit_numbers[ss][i][0] 
            block = block_hit_numbers[ss][i][1]
            
            # If too few hits, we're done plotting for this subject
            # because the block hits are ordered from most to fewest
            # hits. 
            if num_hits < min_hits_cutoff:
                break
            
            # Find the focal position on this block if it exists
            focal_position = None
            try:
                focal_position = focal_for_species[block]
            except KeyError:
                pass
            
            if only_plot_focal and focal_position is None:
                continue
            
            x_bounds = x_plot_bounds[ss][block][:2]
            x_offset = x_plot_bounds[ss][block][2]
            
            #fig = figure(1, figsize=figsize, dpi=dpi)
            #ax[i] = fig.add_subplot(111,frameon=False)
            ax[i].set_xlim((x_bounds[0]-x_offset,x_bounds[1]+x_offset))
            ax[i].set_ylim((plot_y_min,plot_y_max))
            
            # Try to add line for focal position, if this is in this
            # species on this block.
            if focal_position is not None:
                ax[i].plot((focal_position,focal_position),
                             (plot_y_min,gene_vertical),"--",color=focal_color)

            block_df = df[df.block == block]
            
            for qs in query_species:
                
                qs_df = block_df[block_df.query_species == qs]
                
                ax[i].plot(qs_df.subject_start,
                             np.ones(len(qs_df.subject_start))*gene_vertical,
                             "^",color=species_colors[qs])
                
                tmp_df = qs_df.copy()
                tmp_df = tmp_df[tmp_df.query_name != focal_query]
                
                if len(tmp_df.e_value) > 0:
                    indexes, smoothed = _smooth_hits(tmp_df,
                                                     window_size=window_size,
                                                     global_bounds=x_bounds,
                                                     weight_by_evalue=use_evalue_hit_height)
                
                    smoothed = smoothed/height_norm_factor
                    ax[i].fill_between(indexes,smoothed,color=species_colors[qs],alpha=0.5)
                    ax[i].plot(indexes,smoothed,'-',color=species_colors[qs],linewidth=0.5)
                
                num_hits_drawn = len(qs_df)
                for j in range(num_hits_drawn):
                    if qs == "homo":
                        x_value = x_bounds[0] - x_offset*0.9
                    else:
                        x_value = x_bounds[0] + x_offset*0.9
                        
                    
                    #ax[i].text(qs_df.iloc[j].subject_start,
                    #             0.8,qs_df.iloc[j].query_name,
                    #             color=species_colors[qs],ha="right",
                    #             rotation=90)
                    
                    
                    num = qs_df.iloc[j].subject_start
                    name = qs_df.iloc[j].query_name
                    out = "{}: {}".format(num,name)
                    ax[i].text(x_value,
                             (num_hits_drawn - j )/18, #num_hits_drawn,
                             out,
                             color=species_colors[qs],ha="left")
                            
                               
                

            # Plot genome line
            if block_lengths is not None:
                
                block_length = block_lengths[block]

                block_start_location = 0
                block_end_location = block_length 

                if block_start_location < (x_bounds[0]-x_offset):
                    block_start_location = x_bounds[0]-x_offset
                    ax[i].plot((block_start_location,block_start_location),(-0.2,1),"b--")

                if block_end_location > (x_bounds[1]+x_offset):
                    block_end_location = x_bounds[1] + x_offset
                    ax[i].plot((block_end_location,block_end_location),(-0.2,1),"b--")
                    
            else:
                block_start_location = x_bounds[0]
                block_end_location = x_bounds[1]
            
            # Plot query hits on top
            focal_query_df = block_df[block_df.query_name == focal_query]
            ax[i].plot(focal_query_df.subject_start,
                         np.ones(len(focal_query_df.subject_start))*gene_vertical,
                          "^",color=focal_color)

            # Plot genome
            ax[i].plot((block_start_location,block_end_location),(0,0),"k-",linewidth=2) #,zorder=1000000)
            
                        
            if i == 0 and manual_title is not None:
                title = "{}\n\n{}".format(manual_title,block)
                ax[i].set_title(title) #manual_title,pad=-5)
            else:
                ax[i].set_title(" \n \n{}".format(block)) #"{}, {}".format(ss,block))
                
            ax[i].set_xlim((x_bounds[0]-x_offset,x_bounds[1]+x_offset))
            ax[i].set_ylim((plot_y_min,plot_y_max))
            ax[i].spines['top'].set_visible(False)
            ax[i].spines['right'].set_visible(False)
            #ax[i].spines['bottom'].set_visible(False)
            #ax[i].spines['left'].set_visible(False)
            
            ax[i].set_xlabel("genome position (Mb)")
            ax[i].set_ylabel("relative evalue/bp")
    
        plt.tight_layout()
            #plt.savefig("new_figs/{}-{}-{}.pdf".format(ss,i,block))
            #plt.show()

    

In [None]:
# Load data written out from 01_blast-genomes.ipynby
block_lengths = {}
for s in species:
    block_lengths[s] = pickle.load(open("{}_length.pickle".format(s),"rb"))
all_hits = pd.read_csv("all-hits.csv")

### Construct ortholog locations by reverse BLAST

In [None]:
"""
# Reverse blast hits from TLR4 against human and danio
# This is slow.  It saves out homo_rblast_hits.pickle and
# danio_rblast_hits.pickle. These can be loaded in the next
# cell, meaning you only need to run this cell once, even if 
# you tweak the downstream analysis. 

homo_rblast_hits = {}
danio_rblast_hits = {}
for s in species:
    s_df = all_hits[np.logical_and(all_hits.subject_species == s,all_hits.query_name=="TLR4")]

    print(s)
    
    homo_rblast_hits[s] = {}
    danio_rblast_hits[s] = {}
    for i in range(len(s_df.sequence)):
        uid = s_df.iloc[i].uid
        homo_rblast_hits[s][uid] = tools.blast.local_blast(s_df.iloc[i],db="reference/homo",hitlist_size=10)
        danio_rblast_hits[s][uid] = tools.blast.local_blast(s_df.iloc[i],db="reference/danio",hitlist_size=10)

pickle.dump(homo_rblast_hits,open("homo_rblast_hits.pickle","wb"))
pickle.dump(danio_rblast_hits,open("danio_rblast_hits.pickle","wb"))
"""

In [None]:
# Create dictionary keying species to TLR4 ortholog locations 

homo_rblast_hits = pickle.load(open("homo_rblast_hits.pickle","rb"))
danio_rblast_hits = pickle.load(open("danio_rblast_hits.pickle","rb"))
pattern = re.compile("TLR4|toll-like receptor 4|toll like receptor 4",flags=re.IGNORECASE)
worst_allowed = 0

ortholog_locations = {}
for s in species:
    s_df = all_hits[np.logical_and(all_hits.subject_species == s,all_hits.query_name=="TLR4")]

    for i in range(len(s_df.sequence)):
        uid = s_df.iloc[i].uid
        homo_hits = homo_rblast_hits[s][uid]
        for j in range(len(homo_hits.hit_def)):
            
            if pattern.search(homo_hits.hit_def.iloc[j]):
                if j <= worst_allowed:
                    
                    block = s_df[s_df.uid == uid].block.iloc[0]
                    start = s_df[s_df.uid == uid].subject_start.iloc[0]
                    end = s_df[s_df.uid == uid].subject_end.iloc[0]
                    out_tuple = (block,start,end)
                    try:
                        ortholog_locations[s].append(out_tuple)
                    except KeyError:
                        ortholog_locations[s] = [out_tuple]
                    
tlr4_ortholog_locations = ortholog_locations

### Generate graphs plotting hit e-value/bp along genomes of interest

In [None]:
for i, s in enumerate(species):

    print(s)
    plot_hits_on_assembly(all_hits,
                          subjects_to_plot=s,
                          window_size=0.0001,
                          focal_query="TLR4",
                          focal_locations=tlr4_ortholog_locations,
                          block_lengths=block_lengths[s],
                          num_blocks_to_plot=10,
                          min_hits_cutoff=1,
                          figsize=(10,4),
                          use_evalue_hit_height=True,
                          manual_title="Fig S{}: hits on {} chromosomes".format(i+2,s),
                          dpi=300)
    
    plt.savefig("Fig-S{}.pdf".format(i+2))
