In [1]:
#Load modules
import zarr
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import yaml
from pathlib import Path
import allel

from dask.distributed import Client
import dask
dask.config.set(**{'array.slicing.split_large_chunks': False}) # Silence large chunk warnings
import dask.array as da
from dask import delayed, compute
from dask_gateway import Gateway
import functools
import numcodecs
from fsspec.implementations.zip import ZipFileSystem
from collections.abc import Mapping
import gcsfs
import numba
import psutil
from humanize import naturalsize

import pickle
import platform

import traceback
import logging

from pyprojroot import here
from bokeh.plotting import *
import plotly.express as px
import plotly.graph_objects as go
from plotly.validators.scatter.marker import SymbolValidator

### Connect to gcs

In [None]:
gcs = gcsfs.GCSFileSystem()

In [None]:
gcs.ls('vo_afun_release_master_us_central1')[:3]

### Set up data access

In [4]:
production_root = Path('vo_afun_release_master_us_central1')
vo_afun_staging = Path(production_root, 'v1.0')
sampleset_staging_dir = Path(vo_afun_staging, 'snp_genotypes', 'all')

#Decision tree or static filters
genomic_positions_site_filter_dt_data_cloud_zarr_dir = 'vo_afun_release_master_us_central1/v1.0/site_filters/dt_20200416/funestus'
genomic_positions_site_filter_sc_data_cloud_zarr_dir = 'vo_afun_release_master_us_central1/v1.0/site_filters/sc_20220908/funestus'

repo_clone_path = here()
release_config_path = repo_clone_path / 'tracking' / 'release' / 'v1.0' / 'config.yml'

with open(release_config_path) as fh:
    config = yaml.load(fh, Loader=yaml.BaseLoader)
    
samplesets = config["sample_sets"]

In [5]:
meta = pd.read_csv("../../metadata/supp1_tab2.csv")
meta.columns

Index(['sample_id', 'geographic_cohort', 'geographic_cohort_colour',
       'geographic_cohort_shape', 'PCA_cohort', 'PCA_cohort_colour',
       'mitochondrial_id', 'karyotype_3La', 'karyotype_3Ra', 'karyotype_3Rb',
       'karyotype_2Ra', 'karyotype_2Rh', 'median_coverage'],
      dtype='object')

### Functions

In [6]:
def a1(n):
    return np.sum(1/np.arange(1,n))

In [7]:
def a2(n):
    return np.sum(1/(np.arange(1,n)**2))

In [8]:
def b1(n):
    return (n+1)/(3*(n-1))

In [9]:
def b2(n):
    return 2*(n**2+n+3)/(9*n*(n-1))

In [10]:
def c1(n):
    return b1(n) - (1/a1(n))

In [11]:
def c2(n):
    return b2(n) - ((n + 2) / (a1(n) * n)) + (a2(n) / (a1(n)**2))

In [12]:
def e1(n):
    return c1(n)/a1(n)

In [13]:
def e2(n):
    return c2(n)/(a1(n)**2+a2(n))

In [14]:
# load a single array from field/chrom/sampleset
# internal path for calldata is chrom/calldata/field
# sampleset_calldata = sampleset_staging_dir / sset
# sampleset is needed to load species spec.
def load_single_field(zarr_path, internal_path, sset, exclude_males=False, samples=None):
      
    inz = zarr.group(is_gcloud(zarr_path), overwrite=False)
    
    oo = da.from_zarr(inz[internal_path])  
    
    if oo.ndim == 1:
        oo = oo.reshape((1, -1))
           
    return oo

In [15]:
## General function to concatenate data.
## Selected chunk size may be more appropriate for some than others.
def concatenate_along_axis(base_dir, internal_path, req_samplesets):
    
    # work out shape
    data = [load_single_field(base_dir / ss, internal_path, ss) for ss in req_samplesets]
    
    return da.concatenate(data, axis=1)

In [16]:
def is_gcloud(path):
    
    try: 
        return gcs.get_mapper(path.as_posix())
    except NameError as e:
        return path.as_posix()

In [17]:
def load_filter(chrom, filter_dir = genomic_positions_site_filter_sc_data_cloud_zarr_dir):
    gcsmap = gcs.get_mapper(filter_dir)
    genomic_positions_site_filter_data = zarr.Group(gcsmap, read_only=True)
    filter_pass = da.from_zarr(
            genomic_positions_site_filter_data[chrom]['variants/filter_pass'])
    return filter_pass

In [18]:
def load_position(chrom):
    store = gcs.get_mapper(
        f'gs://vo_afun_release_master_us_central1/v1.0/snp_genotypes/all/sites')
    root = zarr.open(store, mode='r')
    pos = root[chrom]['variants/POS'][:]
    return pos

In [19]:
def read_in_genotypes_positions(chrom, samples_idx, samplesets, posl, posu, \
                                filter_dir = genomic_positions_site_filter_sc_data_cloud_zarr_dir):

    # load the genotypes and positions
    gt_d = concatenate_along_axis(sampleset_staging_dir, f"{chrom}/calldata/GT", samplesets)
    gt = allel.GenotypeDaskArray(gt_d)
    pos = load_position(chrom)
    
    if posu==-1:
        posu = pos.max()+1
        
    if posl==-1:
        posl = pos.min()
    
    #load the filter
    loc_filt = load_filter(chrom, filter_dir)
    
    #filter by positions
    pos_filt = (pos>=posl) & (pos<posu)
    
    #apply the filter to positions and genotypes
    gt = gt.compress((loc_filt) & (pos_filt), axis=0)
    pos = pos[(loc_filt) & (pos_filt)]
    
    #subset to desired samples 
    gt = da.take(gt, samples_idx, axis=1)
    
    return gt, pos

In [21]:
def mean_per_window(values, window_size):
    n_windows = len(values)//window_size
    vals = np.zeros(n_windows)
    for w in np.arange(n_windows):
        vals[w] = np.nanmean(values[w*window_size:(w+1)*window_size])
    return vals

In [22]:
def sum_per_window(values, window_size):
    n_windows = len(values)//window_size
    vals = np.zeros(n_windows)
    for w in np.arange(n_windows):
        vals[w] = np.nansum(values[w*window_size:(w+1)*window_size])
    return vals

In [23]:
def compute_tajimas_d(chrom, samples_idx, posl, posu, window_size, missing_frac,\
                     samplesets):
    
    #read in genotypes
    gt, pos = read_in_genotypes_positions(chrom, samples_idx, samplesets, posl=posl, posu=posu)
    
    #count alleles 
    ac = gt.count_alleles(max_allele=3)
    
    #filters for missingness, biallelism and maf
    missing_filter = ac.sum(axis=1) >= (1-missing_frac)*2*len(samples_idx)
    
    #get filtered allele counts
    ac_f = ac.compress(missing_filter, axis=0)
    pos_f = pos[missing_filter]
    
    #assess which sites are segregating
    seg_f = ac_f.is_segregating()
    
    #compute pi
    mpd = allel.mean_pairwise_difference(ac_f)
    
    #sum pi per window
    windowed_mpd = sum_per_window(mpd, window_size)
    window_centers = mean_per_window(pos_f, window_size)
    
    #compute number of segregating sites per window
    windowed_S = count_segregating_per_window(seg_f, window_size)
    
    #compute constants
    #this assumes no missing data - actual values will be slightly off 
    #but scikit allele does the same
    n_alleles = len(samples_idx) * 2
    ch_a1 = a1(n_alleles)
    ch_e1 = e1(n_alleles)
    ch_e2 = e2(n_alleles)
    
    #compute tajima's D
    d = windowed_mpd - windowed_S/ch_a1
    std = np.sqrt((ch_e1 * windowed_S) + (ch_e2 * windowed_S * (windowed_S - 1)))
    D = d/std
    
    #get starting position new reading frame
    end_pos = pos_f[len(D)*window_size]
    
    return D, windowed_mpd, window_centers, end_pos

In [24]:
def count_segregating_per_window(seg_f, window_size):
    
    #get the positions of segregating sites
    pos_s = np.arange(seg_f.shape[0])[seg_f.compute()]
    
    n_windows = seg_f.shape[0]//window_size
    start_locs = np.searchsorted(pos_s, np.arange(n_windows)*window_size+1)
    stop_locs = np.searchsorted(pos_s, np.arange(1,n_windows+1)*window_size, side='right')
    locs = np.column_stack((start_locs, stop_locs))
    counts = np.diff(locs, axis=1).reshape(-1)
    
    return counts

In [26]:
def loop_through_reading_frames(chrom, samples_idx, samplesets, outdir, cohortname,\
                                reading_frame_size, window_size, missing_frac):
    
    chrom_size = load_position(chrom).max()
    
    taj_d, mpd, windows, rf_start = compute_tajimas_d(chrom, samples_idx, 0, reading_frame_size, window_size=window_size, \
                                    missing_frac=missing_frac, samplesets=samplesets)
    rf_end = reading_frame_size
    while rf_end < chrom_size:
        rf_end = rf_start + reading_frame_size
        taj_d_rf, mpd_rf, windows_rf, rf_start = compute_tajimas_d_downsampled(chrom, samples_idx, rf_start, rf_end, window_size, missing_frac,\
                     samplesets)
        
        taj_d = np.concatenate((taj_d, taj_d_rf))
        mpd = np.concatenate((mpd, mpd_rf))
        windows = np.concatenate((windows, windows_rf))
        np.save(f'{outdir}/{chrom}_{cohortname}_tajima_d.npy', taj_d)
        np.save(f'{outdir}/{chrom}_{cohortname}_pi.npy', mpd)
        np.save(f'{outdir}/{chrom}_{cohortname}_windows.npy', windows)
        

In [28]:
def compute_tajima_d_all_cohorts(chrom, cohorts, meta, outdir, size=30, samplesets=samplesets, \
                          reading_frame_size=10_000_000, window_size=20_000, missing_frac=0.05):
    
    for cohort in cohorts:
        samples_idx = meta.loc[(meta.geographic_cohort==cohort) & (meta.subset_3=='Y')].index
        loop_through_reading_frames(chrom, samples_idx, samplesets, outdir, cohort,
                                    reading_frame_size, window_size, missing_frac)
        print(f"On chromosome {chrom} cohort {cohort} done")

In [None]:
#put meta in order as dask is stored
sample_order = concatenate_along_axis(sampleset_staging_dir, "samples", samplesets).compute()
sample_order = (sample_order[0]).astype(str)
meta.set_index('VBS_sample_id', inplace=True)
meta = meta.loc[sample_order]
meta.reset_index(inplace=True)
meta.head()

### Set up dask cluster

In [30]:
gateway = Gateway()
gateway.list_clusters()

[]

In [42]:
#gateway = Gateway()
conda_prefix = os.environ["CONDA_PREFIX"]
current_environment = 'global/'+conda_prefix.split('/')[5]
cluster = gateway.new_cluster(
    profile='standard', 
    conda_environment = current_environment,
)
cluster

VBox(children=(HTML(value='<h2>GatewayCluster</h2>'), HBox(children=(HTML(value='\n<div>\n<style scoped>\n    …

In [43]:
client=cluster.get_client()

In [44]:
cluster.scale(50)

In [34]:
cohorts = meta.loc[meta.subset_3=='Y', 'geographic_cohort'].unique()
cohorts

array(['Ghana_Northern-Region', 'Gabon_Haut-Ogooue', 'CAR_Ombella-MPoko',
       'Cameroon_Adamawa', 'Ghana_Ashanti-Region',
       'Malawi_Southern-Region', 'Mozambique_Maputo',
       'Uganda_Eastern-Region', 'Benin_Atlantique-Dept', 'DRC_Kinshasa',
       'Nigeria_Ogun-State', 'Zambia_Eastern-Prov', 'Kenya_Nyanza-Prov',
       'Kenya_Western-Prov', 'Tanzania_Morogoro-Region', 'DRC_Haut-Uele',
       'Mozambique_Cabo-Delgado'], dtype=object)

In [35]:
compute_tajima_d_all_cohorts('X', cohorts, meta, 'tajima_d/')

On chromosome X cohort Ghana_Northern-Region done
On chromosome X cohort Gabon_Haut-Ogooue done
On chromosome X cohort Cameroon_Adamawa done
On chromosome X cohort Ghana_Ashanti-Region done
On chromosome X cohort Mozambique_Maputo done
On chromosome X cohort Uganda_Eastern-Region done
On chromosome X cohort Benin_Atlantique-Dept done
On chromosome X cohort DRC_Kinshasa done
On chromosome X cohort Nigeria_Ogun-State done
On chromosome X cohort Zambia_Eastern-Prov done
On chromosome X cohort Kenya_Nyanza-Prov done
On chromosome X cohort DRC_Haut-Uele done
On chromosome X cohort Mozambique_Cabo-Delgado done


In [36]:
compute_tajima_d_all_cohorts('3RL', cohorts, meta, 'tajima_d/')

On chromosome 3RL cohort Ghana_Northern-Region done
On chromosome 3RL cohort Gabon_Haut-Ogooue done
On chromosome 3RL cohort Cameroon_Adamawa done
On chromosome 3RL cohort Ghana_Ashanti-Region done
On chromosome 3RL cohort Mozambique_Maputo done
On chromosome 3RL cohort Uganda_Eastern-Region done
On chromosome 3RL cohort Benin_Atlantique-Dept done
On chromosome 3RL cohort DRC_Kinshasa done
On chromosome 3RL cohort Nigeria_Ogun-State done
On chromosome 3RL cohort Zambia_Eastern-Prov done
On chromosome 3RL cohort Kenya_Nyanza-Prov done
On chromosome 3RL cohort DRC_Haut-Uele done
On chromosome 3RL cohort Mozambique_Cabo-Delgado done


In [37]:
compute_tajima_d_all_cohorts('2RL', cohorts, meta, 'tajima_d/')

On chromosome 2RL cohort Ghana_Northern-Region done
On chromosome 2RL cohort Gabon_Haut-Ogooue done
On chromosome 2RL cohort Cameroon_Adamawa done
On chromosome 2RL cohort Ghana_Ashanti-Region done
On chromosome 2RL cohort Mozambique_Maputo done
On chromosome 2RL cohort Uganda_Eastern-Region done
On chromosome 2RL cohort Benin_Atlantique-Dept done
On chromosome 2RL cohort DRC_Kinshasa done
On chromosome 2RL cohort Nigeria_Ogun-State done
On chromosome 2RL cohort Zambia_Eastern-Prov done
On chromosome 2RL cohort Kenya_Nyanza-Prov done
On chromosome 2RL cohort DRC_Haut-Uele done
On chromosome 2RL cohort Mozambique_Cabo-Delgado done


In [51]:
cluster.shutdown()

In [16]:
for report in gateway.list_clusters():
    gateway.connect(report.name).shutdown()