In [2]:
#Load modules
import zarr
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import yaml
from pathlib import Path
import allel

from dask.distributed import Client
import dask
dask.config.set(**{'array.slicing.split_large_chunks': False}) # Silence large chunk warnings
import dask.array as da
from dask import delayed, compute
from dask_gateway import Gateway
import functools
import numcodecs
from fsspec.implementations.zip import ZipFileSystem
from collections.abc import Mapping
import gcsfs
import numba
import psutil
from humanize import naturalsize

import pickle
import platform

import traceback
import logging

from pyprojroot import here
from bokeh.plotting import *
import plotly.express as px
import plotly.graph_objects as go
from plotly.validators.scatter.marker import SymbolValidator

### Connect to gcs

In [None]:
gcs = gcsfs.GCSFileSystem()

In [None]:
gcs.ls('vo_afun_release_master_us_central1')[:3]

### Set up data access

In [5]:
production_root = Path('vo_afun_release_master_us_central1')
vo_afun_staging = Path(production_root, 'v1.0')
sampleset_staging_dir = Path(vo_afun_staging, 'snp_genotypes', 'all')

#Decision tree or static filters
genomic_positions_site_filter_dt_data_cloud_zarr_dir = 'vo_afun_release_master_us_central1/v1.0/site_filters/dt_20200416/funestus'
genomic_positions_site_filter_sc_data_cloud_zarr_dir = 'vo_afun_release_master_us_central1/v1.0/site_filters/sc_20220908/funestus'

repo_clone_path = here()
release_config_path = repo_clone_path / 'analysis' / 'config.yml'

with open(release_config_path) as fh:
    config = yaml.load(fh, Loader=yaml.BaseLoader)
    
samplesets = config["sample_sets"]

In [None]:
meta = pd.read_csv("../../metadata/supp1_tab2.csv")
meta.columns

### Functions

In [7]:
# load a single array from field/chrom/sampleset
# internal path for calldata is chrom/calldata/field
# sampleset_calldata = sampleset_staging_dir / sset
# sampleset is needed to load species spec.
def load_single_field(zarr_path, internal_path, sset, exclude_males=False, samples=None):
      
    inz = zarr.group(is_gcloud(zarr_path), overwrite=False)
    
    oo = da.from_zarr(inz[internal_path])  
    
    if oo.ndim == 1:
        oo = oo.reshape((1, -1))
           
    return oo

In [8]:
## General function to concatenate data.
## Selected chunk size may be more appropriate for some than others.
def concatenate_along_axis(base_dir, internal_path, req_samplesets):
    
    # work out shape
    data = [load_single_field(base_dir / ss, internal_path, ss) for ss in req_samplesets]
    
    return da.concatenate(data, axis=1)

In [9]:
def is_gcloud(path):
    
    try: 
        return gcs.get_mapper(path.as_posix())
    except NameError as e:
        return path.as_posix()

In [10]:
def load_filter(chrom, filter_dir = genomic_positions_site_filter_sc_data_cloud_zarr_dir):
    gcsmap = gcs.get_mapper(filter_dir)
    genomic_positions_site_filter_data = zarr.Group(gcsmap, read_only=True)
    filter_pass = da.from_zarr(
            genomic_positions_site_filter_data[chrom]['variants/filter_pass'])
    return filter_pass

In [11]:
def load_position(chrom):
    store = gcs.get_mapper(
        f'gs://vo_afun_release/v1.0/snp_genotypes/all/sites')
    root = zarr.open(store, mode='r')
    pos = root[chrom]['variants/POS'][:]
    return pos

In [12]:
def read_in_genotypes_positions(chrom, samples1_idx, samples2_idx, samplesets, posl, posu, \
                                filter_dir = genomic_positions_site_filter_sc_data_cloud_zarr_dir):

    # load the genotypes and positions
    gt_d = concatenate_along_axis(sampleset_staging_dir, f"{chrom}/calldata/GT", samplesets)
    gt = allel.GenotypeDaskArray(gt_d)
    pos = load_position(chrom)
    
    if posu==-1:
        posu = pos.max()+1
        
    if posl==-1:
        posl = pos.min()
    
    #load the filter
    loc_filt = load_filter(chrom, filter_dir)
    
    #filter by positions
    pos_filt = (pos>=posl) & (pos<posu)
    
    #apply the filter to positions and genotypes
    gt = gt.compress((loc_filt) & (pos_filt), axis=0)
    pos = pos[(loc_filt) & (pos_filt)]
    
    #subset to desired samples 
    gt1 = da.take(gt, samples1_idx, axis=1)
    gt2 = da.take(gt, samples2_idx, axis=1)
    
    
    return gt1, gt2, pos

In [24]:
def compute_fst(chrom, samples1_idx, samples2_idx, posl, posu, missing_frac,\
                     samplesets):
    
    #read in genotypes
    gt1, gt2, pos = read_in_genotypes_positions(chrom, samples1_idx, samples2_idx, samplesets, posl=posl, posu=posu)
    
    #count alleles 
    ac1 = gt1.count_alleles(max_allele=3)
    ac2 = gt2.count_alleles(max_allele=3)
    
    #filters for missingness, biallelism and maf
    missing_filter = ((ac1.sum(axis=1) >= (1-missing_frac)*2*len(samples1_idx)) & \
                      (ac2.sum(axis=1) >= (1-missing_frac)*2*len(samples2_idx)))
    
    #get filtered allele counts
    ac1_f = ac1.compress(missing_filter, axis=0)
    ac2_f = ac2.compress(missing_filter, axis=0)
    pos_f = pos[missing_filter]
    
    #compute Fst
    num, den = allel.hudson_fst(ac1_f, ac2_f)
    
    #compute mean of ratios
    mor = np.nanmean(np.concatenate([num[den>0]/den[den>0], np.zeros(np.sum(den==0))]))
    
    
    return np.nansum(num), np.nansum(den), mor, len(pos_f)

In [27]:
def loop_through_reading_frames(chrom, samples1_idx, samples2_idx, posmax, posmin, samplesets, \
                                reading_frame_size, missing_frac):
    if posmax == -1:
        posmax = load_position(chrom).max()

    
    num, den, mor, nsites = compute_fst(chrom, samples1_idx, samples2_idx, posmin, min(posmin+reading_frame_size, posmax), \
                                        missing_frac=missing_frac, samplesets=samplesets)
    rf_start = posmin+reading_frame_size
    rf_end = posmin+reading_frame_size
    while rf_end < posmax:
        rf_end = rf_start + reading_frame_size
        num_rf, den_rf, mor_rf, nsites_rf = compute_fst(chrom, samples1_idx, samples2_idx, rf_start, rf_end, \
                                    missing_frac=missing_frac, samplesets=samplesets)
        rf_start += reading_frame_size
        num += num_rf
        den += den_rf
        mor = (mor*nsites+mor_rf*nsites_rf)/(nsites+nsites_rf)
        nsites += nsites_rf
        
    return num/den, mor
        
        

In [29]:
def compute_fst_all_cohorts(chrom, chromname, cohorts, meta, outdir, posmin=-1, posmax=-1, samplesets=samplesets, \
                          reading_frame_size=10_000_000, missing_frac=0.05):
    df = pd.DataFrame(index=cohorts, columns=cohorts)
    
    for n, cohort1 in enumerate(cohorts[:-1]):
        for cohort2 in cohorts[n+1:]:
            samples1_idx = meta.loc[(meta.geographic_cohort==cohort1) & (meta.subset_2=='Y')].index
            samples2_idx = meta.loc[(meta.geographic_cohort==cohort2) & (meta.subset_2=='Y')].index
            df.loc[cohort1, cohort2], df.loc[cohort2, cohort1] = loop_through_reading_frames(chrom, \
                                    samples1_idx, samples2_idx, posmin, posmax, samplesets,\
                                    reading_frame_size, missing_frac)
            print(f"On {chromname} cohort {cohort1} vs {cohort2} done")
            df.to_csv(f'{outdir}/Fst_{chromname}.tsv', sep='\t')
    
    return df
        
    

In [16]:
#put meta in order as dask is stored
sample_order = concatenate_along_axis(sampleset_staging_dir, "samples", samplesets).compute()
sample_order = (sample_order[0]).astype(str)
meta.set_index('sample_id', inplace=True)
meta = meta.loc[sample_order]
meta.reset_index(inplace=True)
meta.head()

Index(['sample_id', 'geographic_cohort', 'geographic_cohort_colour',
       'geographic_cohort_shape', 'PCA_cohort', 'PCA_cohort_colour',
       'mitochondrial_id', 'karyotype_3La', 'karyotype_3Ra', 'karyotype_3Rb',
       'karyotype_2Ra', 'karyotype_2Rh', 'median_coverage'],
      dtype='object')

### Set up dask cluster

In [17]:
gateway = Gateway()
gateway.list_clusters()

[]

In [18]:
#gateway = Gateway()
conda_prefix = os.environ["CONDA_PREFIX"]
current_environment = 'global/'+conda_prefix.split('/')[5]
cluster = gateway.new_cluster(
    profile='standard', 
    conda_environment = current_environment,
)
cluster

VBox(children=(HTML(value='<h2>GatewayCluster</h2>'), HBox(children=(HTML(value='\n<div>\n<style scoped>\n    …

In [19]:
client=cluster.get_client()

In [20]:
cluster.scale(50)

In [21]:
cohorts_ordered = ['Ghana_Northern-Region', 'Benin_Atlantique-Dept', 'Ghana_Ashanti-Region', 
                   'Nigeria_Ogun-State', 'Cameroon_Adamawa', 'CAR_Ombella-MPoko',
                   'DRC_Haut-Uele', 'Uganda_Eastern-Region', 'Kenya_Western-Prov', 
                   'Kenya_Nyanza-Prov', 'Gabon_Haut-Ogooue', 'DRC_Kinshasa',
                   'Tanzania_Morogoro-Region', 'Mozambique_Cabo-Delgado', 
                   'Zambia_Eastern-Prov', 'Malawi_Southern-Region', 'Mozambique_Maputo']

In [None]:
! mkdir results

In [30]:
fstx = compute_fst_all_cohorts('X', 'X', cohorts_ordered, meta, 'results/')

On X cohort Ghana_Northern-Region vs Benin_Atlantique-Dept done
On X cohort Ghana_Northern-Region vs Ghana_Ashanti-Region done


Task exception was never retrieved
future: <Task finished name='Task-12684' coro=<Client._gather.<locals>.wait() done, defined at /home/conda/global/60d1d102327c6b963c52da1c248570d56dc2ebf066625a421ba924078e5d2fa7-20230228-170115-620472-19-binder-4.3.0/lib/python3.10/site-packages/distributed/client.py:2119> exception=AllExit()>
Traceback (most recent call last):
  File "/home/conda/global/60d1d102327c6b963c52da1c248570d56dc2ebf066625a421ba924078e5d2fa7-20230228-170115-620472-19-binder-4.3.0/lib/python3.10/site-packages/distributed/client.py", line 2128, in wait
    raise AllExit()
distributed.client.AllExit
Task exception was never retrieved
future: <Task finished name='Task-12685' coro=<Client._gather.<locals>.wait() done, defined at /home/conda/global/60d1d102327c6b963c52da1c248570d56dc2ebf066625a421ba924078e5d2fa7-20230228-170115-620472-19-binder-4.3.0/lib/python3.10/site-packages/distributed/client.py:2119> exception=AllExit()>
Traceback (most recent call last):
  File "/home/cond

On X cohort Ghana_Northern-Region vs Nigeria_Ogun-State done
On X cohort Ghana_Northern-Region vs Cameroon_Adamawa done
On X cohort Ghana_Northern-Region vs CAR_Ombella-MPoko done
On X cohort Ghana_Northern-Region vs DRC_Haut-Uele done
On X cohort Ghana_Northern-Region vs Uganda_Eastern-Region done
On X cohort Ghana_Northern-Region vs Kenya_Western-Prov done
On X cohort Ghana_Northern-Region vs Kenya_Nyanza-Prov done
On X cohort Ghana_Northern-Region vs Gabon_Haut-Ogooue done
On X cohort Ghana_Northern-Region vs DRC_Kinshasa done
On X cohort Ghana_Northern-Region vs Tanzania_Morogoro-Region done
On X cohort Ghana_Northern-Region vs Mozambique_Cabo-Delgado done
On X cohort Ghana_Northern-Region vs Zambia_Eastern-Prov done
On X cohort Ghana_Northern-Region vs Malawi_Southern-Region done
On X cohort Ghana_Northern-Region vs Mozambique_Maputo done
On X cohort Benin_Atlantique-Dept vs Ghana_Ashanti-Region done
On X cohort Benin_Atlantique-Dept vs Nigeria_Ogun-State done
On X cohort Benin_At

In [31]:
fstx

Unnamed: 0,Ghana_Northern-Region,Benin_Atlantique-Dept,Ghana_Ashanti-Region,Nigeria_Ogun-State,Cameroon_Adamawa,CAR_Ombella-MPoko,DRC_Haut-Uele,Uganda_Eastern-Region,Kenya_Western-Prov,Kenya_Nyanza-Prov,Gabon_Haut-Ogooue,DRC_Kinshasa,Tanzania_Morogoro-Region,Mozambique_Cabo-Delgado,Zambia_Eastern-Prov,Malawi_Southern-Region,Mozambique_Maputo
Ghana_Northern-Region,,0.119498,0.09302,0.097453,0.097639,0.096285,0.099779,0.098446,0.100012,0.098088,0.252708,0.245524,0.302989,0.301853,0.299114,0.301987,0.329525
Benin_Atlantique-Dept,0.006401,,0.032116,0.036813,0.037268,0.034815,0.039908,0.03759,0.039199,0.037577,0.193935,0.186667,0.244524,0.243471,0.240809,0.243352,0.271892
Ghana_Ashanti-Region,0.00479,0.002584,,0.008025,0.010164,0.007801,0.013068,0.01054,0.012323,0.010589,0.166469,0.159388,0.217283,0.216249,0.213614,0.216211,0.244891
Nigeria_Ogun-State,0.005381,0.003147,0.000969,,0.014367,0.012047,0.017341,0.014846,0.016596,0.014861,0.170503,0.163438,0.221393,0.220412,0.217737,0.220328,0.249157
Cameroon_Adamawa,0.005514,0.003291,0.001207,0.001736,,0.013358,0.017078,0.015827,0.017214,0.015901,0.170194,0.163437,0.222071,0.220966,0.218118,0.220759,0.249434
CAR_Ombella-MPoko,0.004306,0.002407,0.000641,0.001163,0.00134,,0.007672,0.004719,0.007657,0.005455,0.16921,0.162019,0.218913,0.218466,0.215836,0.218501,0.247373
DRC_Haut-Uele,0.005173,0.00295,0.000846,0.001331,0.001488,0.000743,,0.003976,0.00623,0.00457,0.170728,0.163768,0.219752,0.219694,0.216835,0.219501,0.248259
Uganda_Eastern-Region,0.005187,0.002912,0.000778,0.001261,0.001449,0.000697,0.000271,,0.002359,0.000834,0.166212,0.15881,0.213426,0.213389,0.21065,0.213258,0.242304
Kenya_Western-Prov,0.004904,0.002823,0.000888,0.00139,0.00156,0.000554,0.000517,0.000304,,0.000732,0.148454,0.140531,0.188926,0.188968,0.186143,0.189105,0.219522
Kenya_Nyanza-Prov,0.005266,0.003,0.00087,0.00135,0.001544,0.000797,0.000377,0.000114,0.000175,,0.157574,0.150087,0.202002,0.202021,0.199326,0.201957,0.231556


In [32]:
fst2l = compute_fst_all_cohorts('2RL', '2L', cohorts_ordered, meta, 'results/', posmin=57350000)

On 2L cohort Ghana_Northern-Region vs Benin_Atlantique-Dept done
On 2L cohort Ghana_Northern-Region vs Ghana_Ashanti-Region done
On 2L cohort Ghana_Northern-Region vs Nigeria_Ogun-State done
On 2L cohort Ghana_Northern-Region vs Cameroon_Adamawa done
On 2L cohort Ghana_Northern-Region vs CAR_Ombella-MPoko done
On 2L cohort Ghana_Northern-Region vs DRC_Haut-Uele done
On 2L cohort Ghana_Northern-Region vs Uganda_Eastern-Region done
On 2L cohort Ghana_Northern-Region vs Kenya_Western-Prov done
On 2L cohort Ghana_Northern-Region vs Kenya_Nyanza-Prov done
On 2L cohort Ghana_Northern-Region vs Gabon_Haut-Ogooue done
On 2L cohort Ghana_Northern-Region vs DRC_Kinshasa done
On 2L cohort Ghana_Northern-Region vs Tanzania_Morogoro-Region done
On 2L cohort Ghana_Northern-Region vs Mozambique_Cabo-Delgado done
On 2L cohort Ghana_Northern-Region vs Zambia_Eastern-Prov done
On 2L cohort Ghana_Northern-Region vs Malawi_Southern-Region done
On 2L cohort Ghana_Northern-Region vs Mozambique_Maputo done
O

In [27]:
cluster.shutdown()

In [16]:
for report in gateway.list_clusters():
    gateway.connect(report.name).shutdown()