In [1]:
#Load modules
import zarr
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import yaml
from pathlib import Path
import allel

from dask.distributed import Client
import dask
dask.config.set(**{'array.slicing.split_large_chunks': False}) # Silence large chunk warnings
import dask.array as da
from dask import delayed, compute
from dask_gateway import Gateway
import functools
import numcodecs
from fsspec.implementations.zip import ZipFileSystem
from collections.abc import Mapping
import gcsfs
import numba
import psutil
from humanize import naturalsize

import pickle
import platform

import traceback
import logging

from pyprojroot import here
from bokeh.plotting import *
import plotly.express as px
import plotly.graph_objects as go
from plotly.validators.scatter.marker import SymbolValidator

### Connect to gcs

In [2]:
gcs = gcsfs.GCSFileSystem()



In [3]:
gcs.ls('vo_afun_release')[:3]

['vo_afun_release/reference',
 'vo_afun_release/v1.0',
 'vo_afun_release/v1.0-config.json']

### Set up data access

In [4]:
production_root = Path('vo_afun_release')
vo_afun_staging = Path(production_root, 'v1.0')
sampleset_staging_dir = Path(vo_afun_staging, 'snp_genotypes', 'all')

#Decision tree or static filters
genomic_positions_site_filter_dt_data_cloud_zarr_dir = 'vo_afun_release/v1.0/site_filters/dt_20200416/funestus'
genomic_positions_site_filter_sc_data_cloud_zarr_dir = 'vo_afun_release/v1.0/site_filters/sc_20220908/funestus'

repo_clone_path = here()
release_config_path = repo_clone_path / 'tracking' / 'release' / 'v1.0' / 'config.yml'

with open(release_config_path) as fh:
    config = yaml.load(fh, Loader=yaml.BaseLoader)
    
samplesets = config["sample_sets"]
#Remove the sample set with no sample
samplesets.remove("1242-VO-KE-TCHOUASSI-VMF00082")

In [5]:
meta = pd.read_csv("../../metadata/supp1.csv")
meta.columns

Index(['sample_id', 'geographic_cohort', 'geographic_cohort_colour',
       'geographic_cohort_shape', 'PCA_cohort', 'PCA_cohort_colour',
       'mitochondrial_id', 'karyotype_3La', 'karyotype_3Ra', 'karyotype_3Rb',
       'karyotype_2Ra', 'karyotype_2Rh', 'median_coverage'],
      dtype='object')

### Functions

In [6]:
# load a single array from field/chrom/sampleset
# internal path for calldata is chrom/calldata/field
# sampleset_calldata = sampleset_staging_dir / sset
# sampleset is needed to load species spec.
def load_single_field(zarr_path, internal_path, sset, exclude_males=False, samples=None):
      
    inz = zarr.group(is_gcloud(zarr_path), overwrite=False)
    
    oo = da.from_zarr(inz[internal_path])  
    
    if oo.ndim == 1:
        oo = oo.reshape((1, -1))
           
    return oo

In [7]:
## General function to concatenate data.
## Selected chunk size may be more appropriate for some than others.
def concatenate_along_axis(base_dir, internal_path, req_samplesets):
    
    # work out shape
    data = [load_single_field(base_dir / ss, internal_path, ss) for ss in req_samplesets]
    
    return da.concatenate(data, axis=1)

In [8]:
def is_gcloud(path):
    
    try: 
        return gcs.get_mapper(path.as_posix())
    except NameError as e:
        return path.as_posix()

In [9]:
def load_filter(chrom, filter_dir = genomic_positions_site_filter_sc_data_cloud_zarr_dir):
    gcsmap = gcs.get_mapper(filter_dir)
    genomic_positions_site_filter_data = zarr.Group(gcsmap, read_only=True)
    filter_pass = da.from_zarr(
            genomic_positions_site_filter_data[chrom]['variants/filter_pass'])
    return filter_pass

In [10]:
def load_position(chrom):
    store = gcs.get_mapper(
        f'gs://vo_afun_release/v1.0/snp_genotypes/all/sites')
    root = zarr.open(store, mode='r')
    pos = root[chrom]['variants/POS'][:]
    return pos

In [11]:
def read_in_genotypes_positions(chrom, samples_idx, samplesets, posl, posu, \
                                filter_dir = genomic_positions_site_filter_sc_data_cloud_zarr_dir):

    # load the genotypes and positions
    gt_d = concatenate_along_axis(sampleset_staging_dir, f"{chrom}/calldata/GT", samplesets)
    gt = allel.GenotypeDaskArray(gt_d)
    pos = load_position(chrom)
    
    if posu==-1:
        posu = pos.max()+1
        
    if posl==-1:
        posl = pos.min()
    
    #load the filter
    loc_filt = load_filter(chrom, filter_dir)
    
    #filter by positions
    pos_filt = (pos>=posl) & (pos<posu)
    
    #apply the filter to positions and genotypes
    gt = gt.compress((loc_filt) & (pos_filt), axis=0)
    pos = pos[(loc_filt) & (pos_filt)]
    
    #subset to desired samples 
    gt = da.take(gt, samples_idx, axis=1)
    
    
    return gt, pos

In [12]:
def read_in_genotypes_at_doubleton_positions(chrom, doubleton_pos, samples_idx, samplesets, \
                                filter_dir = genomic_positions_site_filter_sc_data_cloud_zarr_dir):

    # load the genotypes and positions
    gt_d = concatenate_along_axis(sampleset_staging_dir, f"{chrom}/calldata/GT", samplesets)
    gt = allel.GenotypeDaskArray(gt_d)
    pos = load_position(chrom)
    
    #filter by positions
    pos_filt = np.isin(pos, doubleton_pos)
    
    #apply the filter to positions and genotypes
    gt = gt.compress((pos_filt), axis=0)
    
    #subset to desired samples 
    gt = da.take(gt, samples_idx, axis=1)
    
    
    return gt

In [13]:
def identify_doubletons_rf(chrom, samples_idx, posl, posu, samplesets):
    
    gt, pos = read_in_genotypes_positions(chrom, samples_idx, samplesets, posl, posu)
    
    #count allles
    ac = gt.count_alleles(max_allele=3)
    
    #identify at which sites there are doubletons
    doubleton_sites = (ac==2).any(axis=1)
    
    gt_doub = gt.compress(doubleton_sites, axis=0)
    
    return gt_doub
    
    

In [14]:
def identify_doubletons_chrom(chrom, chromname, samples_idx, samplesets, posmin, posmax, reading_frame_size):
    if posmax == -1:
        posmax = load_position(chrom).max()
        
    rf_start = posmin
    rf_end = min(rf_start + reading_frame_size, posmax)
    gt_doub = identify_doubletons_rf(chrom, samples_idx, rf_start, rf_end, samplesets)
    rf_start += reading_frame_size
    while rf_end < posmax:
        rf_end = min(rf_start + reading_frame_size, posmax)
        gt_doub_rf = identify_doubletons_rf(chrom, samples_idx, rf_start, rf_end, \
                                    samplesets=samplesets)
        gt_doub = gt_doub.concatenate([gt_doub_rf], axis=0)
        rf_start += reading_frame_size
        #print(f'identified {gt_doub.shape[0]} doubleton sites up to position {rf_end}')
        
    return  gt_doub


In [15]:
def count_shared_doubletons(chrom, gt_doub, popdict, cohorts):
    
    #count alleles at doubleton sites
    ac = gt_doub.count_alleles(max_allele=3)
    print('finished counting alleles at doubleton sites')
    
    doubleton_counts = np.zeros((len(cohorts), len(cohorts)))
    
    #loop through alleles
    for allele in np.arange(4):
        doub = (ac[:,allele]==2)
        #count alleles per cohort for sites with a doubleton at appropriate allele
        gt_doub_allele = gt_doub.compress(doub, axis=0)
        print(f'determined {gt_doub_allele.shape[0]} sites where allele {allele} is doubleton')
        ac_subpop = gt_doub_allele.count_alleles_subpops(popdict, max_allele=3)
        #count doubletons for each pair of cohorts
        for r, cohort1 in enumerate(cohorts[:-1]):
            #count doubletons shared within the cohort 
            doubleton_counts[r,r] += (ac_subpop[cohort1][:,allele]==2).sum().compute()
            #count doubletons shared with every other cohort
            for c, cohort2 in enumerate(cohorts[r+1:]):
                doubleton_counts[r,r+1+c] += ((ac_subpop[cohort1][:,allele]==1) & 
                                         (ac_subpop[cohort2][:,allele]==1)).sum().compute()
        r+=1
        doubleton_counts[r,r] += (ac_subpop[cohorts[-1]][:,allele]==2).sum().compute()
        print(f'allele {allele} done')
        
    return doubleton_counts
    
    

In [16]:
def compute_shared_doubletons(chrom, chromname, cohorts, meta, outdir, posmin=-1, posmax=-1, samplesets=samplesets, \
                          reading_frame_size=15_000_000):
    
    samples_idx = meta.loc[meta.subset_3=='Y'].index
    meta = meta.loc[samples_idx].reset_index(drop=True)
    popdict = dict()
    for cohort in cohorts:
        popdict[cohort] = meta.loc[meta.geographic_cohort==cohort].index
    
    gt_doub = identify_doubletons_chrom(chrom, chromname, samples_idx, samplesets, \
                                              posmin, posmax, reading_frame_size) 
    print(f'identified {gt_doub.shape[0]} doubleton sites on {chromname}')
    doubleton_counts = count_shared_doubletons(chrom, gt_doub, popdict, cohorts)
    pd.DataFrame(doubleton_counts, index=cohorts, columns=cohorts).to_csv(f'{outdir}/doubletons_{chromname}.tsv', sep='\t')
    
    return doubleton_counts
        
    

In [17]:
meta.columns

Index(['sample_id', 'geographic_cohort', 'geographic_cohort_colour',
       'geographic_cohort_shape', 'PCA_cohort', 'PCA_cohort_colour',
       'mitochondrial_id', 'karyotype_3La', 'karyotype_3Ra', 'karyotype_3Rb',
       'karyotype_2Ra', 'karyotype_2Rh', 'median_coverage'],
      dtype='object')

In [18]:
hic30 = pd.read_csv('../../meta/hicov_30_samples.csv', header=None)
hic30 = hic30[0].values
hic30

array(['VBS24195', 'VBS24199', 'VBS24225', 'VBS24201', 'VBS24218',
       'VBS24240', 'VBS24202', 'VBS24233', 'VBS24196', 'VBS24242',
       'VBS24213', 'VBS24236', 'VBS24227', 'VBS24216', 'VBS24230',
       'VBS24232', 'VBS24234', 'VBS24231', 'VBS24239', 'VBS24200',
       'VBS24197', 'VBS24203', 'VBS24222', 'VBS24235', 'VBS24226',
       'VBS24204', 'VBS24243', 'VBS24238', 'VBS24228', 'VBS24229',
       'VBS17757', 'VBS17756', 'VBS17759', 'VBS17761', 'VBS17755',
       'VBS17735', 'VBS17733', 'VBS17766', 'VBS17762', 'VBS17753',
       'VBS17740', 'VBS17750', 'VBS17746', 'VBS17742', 'VBS17747',
       'VBS17749', 'VBS17764', 'VBS17765', 'VBS17763', 'VBS17737',
       'VBS17754', 'VBS17758', 'VBS17768', 'VBS17734', 'VBS17760',
       'VBS17752', 'VBS17732', 'VBS17743', 'VBS17745', 'VBS17751',
       'VBS17196', 'VBS17206', 'VBS17239', 'VBS17222', 'VBS17236',
       'VBS17235', 'VBS17214', 'VBS17205', 'VBS17220', 'VBS17203',
       'VBS17197', 'VBS17192', 'VBS17215', 'VBS17212', 'VBS172

### Set up dask cluster

In [25]:
gateway = Gateway()
gateway.list_clusters()

[ClusterReport<name=dev.cc530628979a4bd7950435ac155f99f2, status=RUNNING>]

In [20]:
#gateway = Gateway()
conda_prefix = os.environ["CONDA_PREFIX"]
current_environment = 'global/'+conda_prefix.split('/')[5]
cluster = gateway.new_cluster(
    profile='standard', 
    conda_environment = current_environment,
)
cluster

VBox(children=(HTML(value='<h2>GatewayCluster</h2>'), HBox(children=(HTML(value='\n<div>\n<style scoped>\n    …

In [21]:
client=cluster.get_client()

In [22]:
cluster.scale(50)

In [23]:
cohorts = ['Ghana_Northern-Region', 'Benin_Atlantique-Dept', 'Ghana_Ashanti-Region', 
                   'Nigeria_Ogun-State', 'Cameroon_Adamawa', 'CAR_Ombella-MPoko',
                   'DRC_Haut-Uele', 'Uganda_Eastern-Region', 'Kenya_Western-Prov', 
                   'Kenya_Nyanza-Prov', 'Gabon_Haut-Ogooue', 'DRC_Kinshasa',
                   'Tanzania_Morogoro-Region', 'Mozambique_Cabo-Delgado', 
                   'Zambia_Eastern-Prov', 'Malawi_Southern-Region', 'Mozambique_Maputo']

In [26]:
df2r = compute_shared_doubletons('2RL', '2R2', cohorts, meta, 'doubletons/hic30/', posmin = 29_000_000, posmax=57_350_000)

identified 1079777 doubleton sites on 2R2
finished counting alleles at doubleton sites
determined 1121 sites where allele 0 is doubleton
allele 0 done
determined 398065 sites where allele 1 is doubleton
allele 1 done
determined 424807 sites where allele 2 is doubleton
allele 2 done
determined 277136 sites where allele 3 is doubleton
allele 3 done


In [27]:
df2r.sum()

1101129.0

In [28]:
df2r1 = compute_shared_doubletons('2RL', '2R1', cohorts, meta, 'doubletons/hic30/', posmax = 29_000_000)

identified 1624993 doubleton sites on 2R1
finished counting alleles at doubleton sites
determined 1345 sites where allele 0 is doubleton
allele 0 done
determined 603563 sites where allele 1 is doubleton
allele 1 done
determined 642134 sites where allele 2 is doubleton
allele 2 done
determined 425578 sites where allele 3 is doubleton
allele 3 done


In [29]:
df2r1.sum()

1672620.0

In [None]:
df2l = compute_shared_doubletons('2RL', '2L', cohorts, meta, 'doubletons/hic30/', posmin = 57_350_000)

identified 2997380 doubleton sites on 2L
finished counting alleles at doubleton sites
determined 2321 sites where allele 0 is doubleton


In [None]:
df2l.sum()

In [26]:
df3r = compute_shared_doubletons('3RL', '3R', cohorts, meta, 'doubletons/hic30/', posmax = 44_700_000)

identified 2067676 doubleton sites on 3R
finished counting alleles at doubleton sites
determined 1645 sites where allele 0 is doubleton
allele 0 done
determined 765002 sites where allele 1 is doubleton
allele 1 done
determined 814761 sites where allele 2 is doubleton
allele 2 done
determined 537484 sites where allele 3 is doubleton
allele 3 done


In [27]:
df3r.sum()

2118892.0

In [28]:
df3l = compute_shared_doubletons('3RL', '3L', cohorts, meta, 'doubletons/hic30/', posmin = 44_700_000)

identified 1763527 doubleton sites on 3L
finished counting alleles at doubleton sites
determined 1654 sites where allele 0 is doubleton
allele 0 done
determined 653703 sites where allele 1 is doubleton
allele 1 done
determined 694537 sites where allele 2 is doubleton
allele 2 done
determined 459720 sites where allele 3 is doubleton
allele 3 done


In [29]:
df3l.sum()

1809614.0

In [30]:
dfx = compute_shared_doubletons('X', 'X', cohorts, meta, 'doubletons/hic30/')

identified 1034913 doubleton sites on X
finished counting alleles at doubleton sites
determined 539 sites where allele 0 is doubleton
allele 0 done
determined 388542 sites where allele 1 is doubleton
allele 1 done
determined 414850 sites where allele 2 is doubleton
allele 2 done
determined 267477 sites where allele 3 is doubleton
allele 3 done


In [31]:
dfx.sum()

1071408.0

In [32]:
cluster.shutdown()

In [19]:
for report in gateway.list_clusters():
    gateway.connect(report.name).shutdown()