Script for clustering simulated sequences using the TMRCA metric

# Import modules

In [6]:
import numpy as np
import zarr
import allel
import scipy.cluster.hierarchy as sch
import scipy.spatial
import matplotlib
import matplotlib.pyplot as plt
import scipy.signal
from scipy.ndimage import gaussian_filter1d
from numpy.lib.stride_tricks import sliding_window_view
import dask
from dask.delayed import delayed
from dask.base import compute
from itertools import combinations
import time
import seaborn as sns
import sys

ModuleNotFoundError: No module named 'zarr'

In [7]:
seed = sys.argv[1]
#Like Hamming distance code, this was also taken from Anushka Thawani. Adaptations were made to this on the 
#high-performance computer using shell script, but this could not be represented.
def convert(file, genome):
    '''
    This function extracts haplotypes sequences from a vcf file 
    Adapted from: http://alimanfoo.github.io/2018/04/09/selecting-variants.html
    
    Arguments:
        file: name of vcf file (from SLiM soft sweep simulation)
        genome: length of genome used in SLiM simulation 
        
    Returns:
        ht: haplotype sequences for 200 individuals
        samp_freq: frequency of sweep mutation in sample
        cols: used to color dendrogram

    '''
    
    v = file + '.vcf'
    z = file + '.zarr'
    slim_sim_data = allel.read_vcf(v, fields='*')
    allel.vcf_to_zarr(v, z, fields='*', overwrite=True)
    data = zarr.open_group(z, mode='r')
    
 
    pos = allel.SortedIndex(data['variants/POS']) # Stores a sorted index of the ID and genomic position of each variant
    
    # Extract genotypes for the first 200 individuals and convert to haplotypes
    gt = data['calldata/GT'][:,0:200] 
    ht = allel.GenotypeArray(gt).to_haplotypes()
    
    mutation = int((genome+1)/2) + 1  # position of sweep mutation
    
    
    # Output the frequency of the sweep mutation in the sample
    contains_sweep = pos.locate_range(mutation,mutation)
    sweep = ht[contains_sweep]
    sweep = np.sum(sweep, axis =0)
    
    samp_freq = np.sum(sweep)/400  # 400 haplotypes
    
    
    # This dictionary is used later to color the dendrogram branches according to whether or not the 
    # corresponding sequence contains the sweep mutation
    cols = {}
    for i in range(400):
        if sweep[i]:
            cols[i] = 'r'  
        else:
            cols[i] = "#808080"
    
    return ht, pos, samp_freq, cols, sweep

NameError: name 'sys' is not defined

In [1]:
def sliding_distance(L, mutation, genome, haplotypes, pos, gts): 
    '''
    This function calculates the sliding window homozygosity for all pairs of haplotype sequences
    
    Arguments:
        L: length of sliding window
        mutation: position of sweep mutation in genome
        genome: length of genomes
        haplotypes: haplotype sequences
        pos: positions of variants
        gts: number of haplotype pairs
        
    Returns:
        hom: sliding homozygosity for all pairs of sequences
    '''
    
    # Initialise empty vector 
    hom = np.empty(shape=(genome,gts),dtype=np.float32)  

    reg =  slice(-100, -100, None)
    
    
    for x in range(0,genome):
        start  = x
        end = x + L
        try:    
            region = pos.locate_range(start,end)  
            # Check if the current window (region) is different from the previous window (reg) - makes code faster
            if region != reg:
                htx = haplotypes[region]
                d = allel.pairwise_distance(htx, metric="hamming")
                hom[x,:] = 1-d  # 1-hamming distance = homozygosity
                reg = region
    
            else:
                hom[x,:] = 1-d

        except KeyError:
            pass
    
    return hom

In [None]:
def find_troughs(smooth, mutation_pos):
    '''
    This function calculates the shared haplotype lengths (SHLs) for all pairs of haplotype sequences
    
    Arguments:
        smooth: smoothed sliding window homogosity for all pairs of sequences
        mutation_pos: position of sweep mutation in genome
        
    Returns:
        lower: position of breakpoint left of the sweep site
        upper: position of breakpoint right of the sweep site
        SHL: shared haplotype length
    '''
    
    troughs = scipy.signal.find_peaks(-smooth)
    troughs = troughs[0]     # Indexes of all troughs
    troughs = troughs[smooth[troughs] < thresh]   # Extract troughs where homozygosity<0.87
    
    peaks = scipy.signal.find_peaks(smooth)
    peaks = peaks[0] 
    
    bp = np.searchsorted(troughs,mutation_pos)   # Find positions of troughs flanking sweep site
    lower = troughs[bp - 1]
    upper = troughs[bp]
    
    # Find the average peak position around the sweep site
    highest = peaks[(peaks >= lower) & (peaks <= upper)]
    if highest.size != 0:
        highest = np.mean(highest)
    else: 
        highest = (lower+upper)/2
    
    
    lower = (lower+highest)/2
    upper = (upper+highest)/2

    SHL = upper - lower
    
    return int(lower), int(upper), SHL

In [None]:
def find_breakpoint(haplotype_pair):
    '''
    For a pair of sequences, this function smoothes the sliding homozygosity and returns the SHL
    Arguments:
        haplotype_pair: a pair of haplotype sequences
        
    Returns:
        lower: position of breakpoint left of the sweep site
        upper: position of breakpoint right of the sweep site
        SHL: shared haplotype length
    '''
    
    mutation_pos = mutation 
    smooth = gaussian_filter1d(haplotype_pair, points_g)
    try:
        lower, upper, SHL = find_troughs(smooth, mutation_pos)
    except IndexError:
        lower = -1.3
        upper = -1.3
        SHL = -1.3
        
    return lower, upper, SHL

In [None]:
def find_snp(n,gts,ht,results_computed_1,pos):
    '''
    This function finds the number of SNPs over the shared haplotype length for all pairs of haplotype sequences
    
    Arguments:
        n: number of haplotype sequences
        gts: number of haplotype pairs
        ht: haplotype sequnces
        results_computed_1: output from find_breakpoint function
        
    Returns:
        diffs: number of SNPs for all pairs of haplotype sequences
        
    '''
    pairwise = []
    for combo in combinations(list(range(0,n)), 2): 
        pairwise.append(combo)

    diffs = np.empty(shape=(gts),dtype=np.float32)
    for i in range(gts):
        pair = ht[:,pairwise[i]]
        try:
            start = results_computed_1[i,1]
            stop = results_computed_1[i,2]

            window_pos = pos.locate_range(start, stop)
            window = pair[window_pos]

            d = allel.pairwise_distance(window, metric = "hamming")

            diffs[i]=d 

        except KeyError:
            diffs[i]=-1.3 
    
    return diffs

In [1]:
def analysis(file,genome,pop,window,threshold,points,r=1,u=1): 
    '''
    This function clusters the sequences stored in a .vcf file.
    
    Arguments:
        file: name of vcf file
        genome: length of genome (in SLiM simulation)
        pop: effective population size (in SLiM simulation)
        window: length of sliding window
        threshold: threshold above which troughs are ignored
        points: number of points to use for 1D-gaussian filter (see scipy documentation)
        r: recombination rate
        u: mutation rate 
    '''
    
    print(file)

    global mutation
    mutation = int((genome+1)/2) 
    global thresh
    thresh=threshold
    global points_g
    points_g = points
    
    # Extract haplotype sequences from .vcf file
    ht, pos, samp_freq, cols, sweep = convert(file, genome)

    
    # Calculate sliding homozygosity for all pairs of haplotype sequences
    L=window
    n = 400 #number of haplotypes 
    gts = int((n*(n-1))/2)
    hom = sliding_distance(L, mutation, genome, ht, pos, gts)

    
    # Find SHL for all pairs of haplotype sequences 
    hom_dask = dask.array.from_array(hom, chunks=(genome,1)) 
    hom = []
    results = dask.array.apply_along_axis(find_breakpoint,0,hom_dask) 
    results_computed = results.compute()

    # Manipulating the dataframe to make it easier to process
    results_computed = np.transpose(results_computed)
    index = np.asarray(range(0,gts))
    index = np.expand_dims(index, axis=0)
    results_computed_1 = np.concatenate((index.T, results_computed), axis=1)
    
    
    # Calculate the TMRCA from the SHLs and number of SNPs
    r = r/(2*pop) #erm wtf?
    u = u/(100*pop) #erm wtf?
    shls = results_computed_1[:,3]   # SHLs for all pairs of haplotype sequences 
    shls[shls<=0] = genome
    diffs = find_snp(n,gts,ht,results_computed_1,pos)  # SNPs for all pairs of haplotype sequences 
    snp = (1+(diffs*shls))/(2*shls*(r + u)) # TMRCA metric for all pairs of haplotype sequences 

    
    # Remove negative and non-integer TMRCA values
    impute = np.nanmean(snp)
    x = np.isfinite(snp)
    for i in np.where(x == 0)[0]:
        snp[i] = impute
    snp[snp<=0] = impute  
    
    
    # Clustering 
    Z = sch.linkage(snp, method = 'complete')
    
    



    
    # Plot dendrogram without colouring branches
    matplotlib.rcParams.update({'font.size': 24})
    fig = plt.figure(figsize=(30, 12))
    gs = matplotlib.gridspec.GridSpec(2, 1, hspace=0.1, wspace=1, height_ratios=(1,1))

    ax_dend = fig.add_subplot(gs[0, 0])
    sns.despine(ax=ax_dend, offset=5, bottom=True, top=True)
    dd = sch.dendrogram(Z,color_threshold=0,above_threshold_color='#808080',ax=ax_dend)

    ls = []
    for leaf, leaf_color in zip(plt.gca().get_xticklabels(), dd["leaves_color_list"]):
        leaf.set_color(cols[int(leaf.get_text())])
        ls.append(int(leaf.get_text()))

    ax_dend.set_ylabel('Haplotype age/generations',fontsize=24)
    ax_dend.set_title('Haplotype clusters',fontsize=24)
    
    
    # Plot dendrogram and colour branches
    
    ax_dend_2 = fig.add_subplot(gs[1, 0])
    
    dflt_col = "#808080"
    
    link_cols = {}
    for i, i12 in enumerate(Z[:,:2].astype(int)):
        c1, c2 = (link_cols[x] if x > len(Z) else cols[x] for x in i12)
        link_cols[i+1+len(Z)] = c1 if c1 == c2 else dflt_col

    sns.despine(ax=ax_dend_2, offset=5, bottom=True, top=True)
    dd = sch.dendrogram(Z,color_threshold=None,link_color_func=lambda x: link_cols[x],ax=ax_dend_2)

    ls = []
    for leaf, leaf_color in zip(plt.gca().get_xticklabels(), dd["leaves_color_list"]):
        leaf.set_color(cols[int(leaf.get_text())])
        ls.append(int(leaf.get_text()))

    ax_dend_2.set_ylabel('Haplotype age/generations',fontsize=24)
    
    
    # Save dendrogram
    output = 'accurate_' + file + '.pdf'
    plt.savefig(output)  
        
    return 

In [None]:
analysis(file='001_C',genome=39999,pop=1000,r=0.1,window=1000,threshold=0.87,points=280)

In [None]:
# analysis(['001_A',39999,1000],r=0.1,window=1000,threshold=0.87,points=280)
# analysis(['001_B',39999,1000],r=0.1,window=1000,threshold=0.87,points=280)
# analysis(['001_C',39999,1000],r=0.1,window=1000,threshold=0.87,points=280)
# analysis(['001_D',39999,1000],r=0.1,window=1000,threshold=0.87,points=280)
# analysis(['001_E',39999,1000],r=0.1,window=1000,threshold=0.87,points=280)

# analysis(['002_A',19999,1000],r=1,window=600,threshold=0.87,points=250)
# analysis(['002_B',19999,1000],r=1,window=600,threshold=0.87,points=250)
# analysis(['002_C',19999,1000],r=1,window=600,threshold=0.87,points=250)
# analysis(['002_D',19999,1000],r=1,window=600,threshold=0.87,points=250)
# analysis(['002_E',19999,1000],r=1,window=600,threshold=0.87,points=250)

# analysis(['003_A',9999,1000],r=10,window=500,threshold=0.99,points=210)
# analysis(['003_B',9999,1000],r=10,window=500,threshold=0.99,points=210)
# analysis(['003_C',9999,1000],r=10,window=500,threshold=0.99,points=210)
# analysis(['003_D',9999,1000],r=10,window=500,threshold=0.99,points=210)
# analysis(['003_E',9999,1000],r=10,window=500,threshold=0.99,points=210)

# analysis(['004_A',99999,1000],r=0.01,window=1000,threshold=0.87,points=280)
# analysis(['004_B',99999,1000],r=0.01,window=1000,threshold=0.87,points=280)
# analysis(['004_C',99999,1000],r=0.01,window=1000,threshold=0.87,points=280)
# analysis(['004_D',99999,1000],r=0.01,window=1000,threshold=0.87,points=280)
# analysis(['004_E',99999,1000],r=0.01,window=1000,threshold=0.87,points=280)

# analysis(['005_A',99999,1000],r=0.01,window=1000,threshold=0.87,points=280)
# analysis(['005_B',99999,1000],r=0.01,window=1000,threshold=0.87,points=280)
# analysis(['005_C',99999,1000],r=0.01,window=1000,threshold=0.87,points=280)
# analysis(['005_D',99999,1000],r=0.01,window=1000,threshold=0.87,points=280)
# analysis(['005_E',99999,1000],r=0.01,window=1000,threshold=0.87,points=280)

# analysis(['006_A',99999,1000],r=0.01,window=1000,threshold=0.87,points=280)
# analysis(['006_B',99999,1000],r=0.01,window=1000,threshold=0.87,points=280)
# analysis(['006_C',99999,1000],r=0.01,window=1000,threshold=0.87,points=280)
# analysis(['006_D',99999,1000],r=0.01,window=1000,threshold=0.87,points=280)
# analysis(['006_E',99999,1000],r=0.01,window=1000,threshold=0.87,points=280)