In [None]:
#Load modules
import zarr
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import yaml
from pathlib import Path
import allel

from dask.distributed import Client
import dask
dask.config.set(**{'array.slicing.split_large_chunks': False}) # Silence large chunk warnings
import dask.array as da
from dask import delayed, compute
from dask_gateway import Gateway
import functools
import numcodecs
from fsspec.implementations.zip import ZipFileSystem
from collections.abc import Mapping
import gcsfs
import numba
import psutil
from humanize import naturalsize

import pickle
import platform

import traceback
import logging

from pyprojroot import here
from bokeh.plotting import *
import plotly.express as px
import plotly.graph_objects as go
from plotly.validators.scatter.marker import SymbolValidator

In [3]:
#Access the data from the cloud.
af1 = malariagen_data.Af1()
af1

MalariaGEN Af1 API client,MalariaGEN Af1 API client
"Please note that data are subject to terms of use,  for more information see the MalariaGEN website or contact data@malariagen.net.","Please note that data are subject to terms of use,  for more information see the MalariaGEN website or contact data@malariagen.net..1"
Storage URL,gs://vo_afun_release/
Data releases available,1.0
Results cache,
Cohorts analysis,20221129
Site filters analysis,dt_20200416
Software version,malariagen_data 7.13.0
Client location,unknown


### Connect to GCS

In [None]:
gcs = gcsfs.GCSFileSystem()

In [None]:
gcs.ls('vo_afun_release_master_us_central1')[:3]

### Set up data access

In [5]:
production_root = Path('vo_afun_release_master_us_central1')
vo_afun_staging = Path(production_root, 'v1.0')
sampleset_staging_dir = Path(vo_afun_staging, 'snp_genotypes', 'all')
haplotypes_dir = Path(vo_afun_staging, 'snp_haplotypes')

#Decision tree or static filters
genomic_positions_site_filter_dt_data_cloud_zarr_dir = 'vo_afun_release_master_us_central1/v1.0/site_filters/dt_20200416/funestus'
genomic_positions_site_filter_sc_data_cloud_zarr_dir = 'vo_afun_release_master_us_central1/v1.0/site_filters/sc_20220908/funestus'

repo_clone_path = here()
release_config_path = repo_clone_path / 'analysis' / 'config.yml'

with open(release_config_path) as fh:
    config = yaml.load(fh, Loader=yaml.BaseLoader)


In [None]:
metadata_path = "../../metadata/supp1_tab2.csv"
metadata = pd.read_csv(metadata_path)
metadata.columns

### settings

In [14]:

analysis = 'funestus'
chromosomes = ['2RL', '3RL', 'X']

In [16]:
# extract the sample sets
sample_sets = af1.sample_sets(release=f"{release}")['sample_set'].tolist()
sample_sets

['1229-VO-GH-DADZIE-VMF00095',
 '1230-VO-GA-CF-AYALA-VMF00045',
 '1231-VO-MULTI-WONDJI-VMF00043',
 '1232-VO-KE-OCHOMO-VMF00044',
 '1235-VO-MZ-PAAIJMANS-VMF00094',
 '1236-VO-TZ-OKUMU-VMF00090',
 '1240-VO-CD-KOEKEMOER-VMF00099',
 '1240-VO-MZ-KOEKEMOER-VMF00101']

In [18]:
window_sizes=(100, 200, 300, 500, 700, 1000, 2000, 3000, 4000)
geographic_cohorts = list(metadata.geographic_cohort.unique())

### Connect to the cluster

In [21]:
gateway = Gateway()

In [None]:
#check if any cluster is currently running
gateway.list_clusters()

In [None]:
#gateway = Gateway()
conda_prefix = os.environ["CONDA_PREFIX"]
current_environment = 'global/'+conda_prefix.split('/')[5]
cluster = gateway.new_cluster(
    profile='standard', 
    conda_environment = current_environment,
)
cluster

In [27]:
cluster.scale(50)
client = cluster.get_client()

### Functions

In [28]:
def da_from_zarr(z, inline_array, chunks="auto"):
    """Utility function for turning a zarr array into a dask array.
    """
    if chunks == "native" or z.dtype == object:
        # N.B., dask does not support "auto" chunks for arrays with object dtype
        chunks = z.chunks
    kwargs = dict(chunks=chunks, fancy=False, lock=False, inline_array=inline_array)
    try:
        d = da.from_array(z, **kwargs)
    except TypeError:
        # only later versions of dask support inline_array argument
        del kwargs["inline_array"]
        d = da.from_array(z, **kwargs)
    return d

In [29]:
def xarray_concat(datasets, dim, data_vars="minimal", coords="minimal", compat="override", join="override", **kwargs):
    if len(datasets) == 1:
        return datasets[0]
    else:
        return xr.concat(datasets, dim=dim, data_vars=data_vars, coords=coords, compat=compat, join=join, **kwargs)

In [30]:
def load_position(chrom):
    store = gcs.get_mapper(
        f'gs://vo_afun_release_master_us_central1/v1.0/snp_genotypes/all/sites')
    root = zarr.open(store, mode='r')
    pos = root[chrom]['variants/POS'][:]
    return pos

In [31]:
def open_haplotypes(sample_set, analysis=analysis):  
    zarr_path = f"{haplotypes_dir}/{sample_set}/{analysis}/zarr"
    store = gcs.get_mapper(zarr_path) 
    root = zarr.open_consolidated(store = store,)    
    return root

In [32]:
def open_haplotype_sites(analysis=analysis):       
    zarr_path = f"{haplotypes_dir}/sites/{analysis}/zarr"
    store = gcs.get_mapper(zarr_path)    
    root = zarr.open_consolidated(store = store,)
    return root

In [33]:
def haplotypes_dataset(*, chrom, sample_set, analysis, inline_array, chunks):
    
    #open the zarr files
    root = open_haplotypes(sample_set=sample_set, analysis=analysis)
    sites = open_haplotype_sites(analysis=analysis)

    # some sample sets have no data for a given analysis, handle this
    if root is None:
        return None

    coords = dict()
    data_vars = dict()

    #load variant position
    pos = sites[f"{chrom}/variants/POS"]
    coords["variant_position"] = (["variants"], da_from_zarr(pos, inline_array=inline_array, chunks=chunks),)

    #load variant_contig
    chrom_index = chromosomes.index(chrom)
    coords["variant_contig"] = (["variants"], da.full_like(pos, fill_value=chrom_index, dtype="u1"),)

    #load variant allele
    ref = da_from_zarr(sites[f"{chrom}/variants/REF"], inline_array=inline_array, chunks=chunks)
    alt = da_from_zarr(sites[f"{chrom}/variants/ALT"], inline_array=inline_array, chunks=chunks)
    variant_allele = da.hstack([ref[:, None], alt[:, None]])
    data_vars["variant_allele"] = ["variants", "alleles"], variant_allele

    #load call_genotype
    data_vars["call_genotype"] = (["variants", "samples", "ploidy"],
        da_from_zarr(root[f"{chrom}/calldata/GT"], inline_array=inline_array, chunks=chunks),)

    #load sample arrays
    coords["sample_id"] = (["samples"], da_from_zarr(root["samples"], inline_array=inline_array, chunks=chunks),)

    #set up attributes
    attrs = {"contigs": chrom}

    #create a dataset")
    ds = xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs)

    return ds

In [34]:
def haplotypes(chrom, analysis=analysis, sample_sets=None, sample_query=None, inline_array=True,
    chunks="native", cohort_size=None, random_seed=42):
    
    """Access haplotype data.
    Returns
    -------
    ds : xarray.Dataset
        A dataset of haplotypes and associated data.
    """ 
        
    ##build datasets
    lx = []
    for r in chrom:
        ly = []

        for s in sample_sets:
            y = haplotypes_dataset(chrom=chrom, sample_set=s, analysis=analysis,
                inline_array=inline_array, chunks=chunks,)
            
            if y is not None:
                ly.append(y)

        if len(ly) == 0:
            # debug("early out, no data for given sample sets and analysis")
            return None

        #concatenate data from multiple sample sets
        x = xarray_concat(ly, dim="samples")

        lx.append(x)

    #concatenate data from multiple regions
    ds = xarray_concat(lx, dim="variants")

    #handle sample query
    if sample_query is not None:

        #load sample metadata
        df_samples = af1.sample_metadata(sample_sets=sample_sets)

        #align sample metadata with haplotypes
        phased_samples = ds["sample_id"].values.tolist()
        df_samples_phased = (
            df_samples.set_index("sample_id").loc[phased_samples].reset_index()
        )

        #apply the query
        loc_samples = df_samples_phased.eval(sample_query).values
        if np.count_nonzero(loc_samples) == 0:
            raise ValueError(f"No samples found for query {sample_query!r}")
        ds = ds.isel(samples=loc_samples)

    #handle cohort size")
    if cohort_size is not None:
        n_samples = ds.dims["samples"]
        if n_samples < cohort_size:
            raise ValueError(
                f"not enough samples ({n_samples}) for cohort size ({cohort_size})"
            )
        rng = np.random.default_rng(seed=random_seed)
        loc_downsample = rng.choice(n_samples, size=cohort_size, replace=False)
        loc_downsample.sort()
        ds = ds.isel(samples=loc_downsample)

    return ds

In [35]:
#import functions for plotting. Modified to allow concat of chromosome arms.
# @functools.lru_cache(maxsize=None)
def load_haplotypes(chrom, analysis, cohort_query, cohort_set, sample_sets=sample_sets, downsample=30, seed=42):
    
    if chrom in af1.contigs:
        print("accessing single chromosome arm for cohort " + cohort_query)
        # access haplotypes
        ds_haps = haplotypes(chrom=chrom, sample_sets=sample_sets, analysis=analysis)
        pos = ds_haps["variant_position"].values
        gt = allel.GenotypeDaskArray(ds_haps['call_genotype'].data)   

        # access sample metadata and align to haplotypes
        cohorts_meta_df = pd.read_csv(metadata_path)
        samples_phased = ds_haps['sample_id'].values  # dask computation happens here
        cohorts_meta_df_phased = cohorts_meta_df.set_index("VBS_sample_id").loc[samples_phased].reset_index()

        # apply cohort query #find samples that match a cohort in the admin query
        samples = cohorts_meta_df_phased.loc[cohorts_meta_df_phased[cohort_set]==cohort_query]

        #get the index of those values
        cohort_index = samples.index.values
        
        #downsample
        if len(cohort_index) >= downsample:
            np.random.seed(seed)
            cohort_index = np.random.choice(cohort_index, size=downsample, replace=False)
            cohort_index.sort()

        gt_cohort = gt.take(cohort_index, axis=1)
        ht_cohort = gt_cohort.to_haplotypes().compute()  # dask computation happens here

    return pos, ht_cohort

In [39]:
# @functools.lru_cache(maxsize=None)
def h12_calibration(chrom, analysis, cohort_query, cohort_set, window_sizes, 
                    sample_sets, downsample=30, seed=42):
        
        print("load haplotypes")
        pos, ht_cohort = load_haplotypes(chrom=chrom, analysis=analysis,cohort_query=cohort_query,
            cohort_set=cohort_set, sample_sets=sample_sets, downsample=downsample, seed=seed)
        
        print("starting calibration runs")
        calibration_runs = list()
        for window_size in window_sizes:
            print(f"compute H12 at window size {window_size}")
            h1, h12, h123, h2_h1 = allel.moving_garud_h(ht_cohort, size=window_size)
            calibration_runs.append(h12)

        return calibration_runs

In [40]:
def plot_h12_calibration(chrom, analysis, cohort_query, cohort_set, window_sizes, sample_sets, downsample=30, seed=42,title=None):
    # get H12 values
    # with ProgressBar():
    calibration_runs = h12_calibration(chrom=chrom, analysis=analysis, cohort_query=cohort_query, cohort_set=cohort_set,
        sample_sets=sample_sets, window_sizes=window_sizes, downsample=downsample, seed=seed)
    
    # compute summaries
    q50 = [np.median(h12) for h12 in calibration_runs]
    q25 = [np.percentile(h12, 25) for h12 in calibration_runs] 
    q75 = [np.percentile(h12, 75) for h12 in calibration_runs] 
    q05 = [np.percentile(h12, 5) for h12 in calibration_runs] 
    q95 = [np.percentile(h12, 95) for h12 in calibration_runs]
    
    # make a plot
    fig, ax = plt.subplots()
    x = window_sizes
    y = q50
    ax.grid()
    ax.fill_between(x, q05, q95, color='#bbbbff', label="5-95%") 
    ax.fill_between(x, q25, q75, color='#7777ff', label="25-75%")
    ax.plot(x, q50, color='k', lw=2, linestyle="-", marker="o", label="median") 
    ax.set_xscale("log")
    ax.set_xticks(window_sizes) 
    ax.set_xticklabels(window_sizes) 
    ax.set_xlabel("Window size (no. SNPs)") 
    ax.set_ylabel("H12")
    ax.legend() 
    if title:
        ax.set_title(title)
    if cohort_set=='karyotype_3Ra':
        sanitized_cohort_query = re.sub(r'\W+', '_', cohort_query)
        plt.savefig(f"plots/{chrom}_{sanitized_cohort_query}_h12_calibration")
    else:
        plt.savefig(f"plots/{chrom}_{cohort_query}_h12_calibration")
    plt.close()

In [41]:
@functools.lru_cache(maxsize=None)
def h12_gwss(chrom, analysis, cohort_query, cohort_set, window_size, sample_sets, downsample=30, seed=42):

    print("load haplotypes") 
    pos, ht = load_haplotypes(chrom=chrom, analysis=analysis, cohort_query=cohort_query, cohort_set=cohort_set, sample_sets=sample_sets, downsample=downsample, seed=seed)
    
    print(f"compute H12 at window size {window_size}")
    h1, h12, h123, h2_h1 = allel.moving_garud_h(ht, size=window_size,)
                                     
    print("compute window coordinates")
    x = allel.moving_statistic(pos, statistic=np.mean, size=window_size,)
    return x, h12

### Run H12 calibration

##### X

In [None]:
#get optimal window sizes for Funestus cohorts
chrom = 'X'
for cohort_query in geographical_cohorts:

    plot_h12_calibration(chrom=chrom, analysis=analysis, cohort_query=cohort_query, sample_sets=tuple(sample_sets),
                         cohort_set="geographic_cohort", downsample=20, window_sizes=window_sizes, title=cohort_query)

##### 2RL

In [None]:
#get optimal window sizes for Funestus cohorts
chrom = '2RL'
for cohort_query in geographic_cohorts:

    plot_h12_calibration(chrom=chrom, analysis=analysis, cohort_query=cohort_query, sample_sets=tuple(sample_sets),
                         cohort_set="geographic_cohort", downsample=20, window_sizes=window_sizes, title=cohort_query)

##### 3RL

In [None]:
#get optimal window sizes for Funestus cohorts
chrom = '3RL'
for cohort_query in geographical_cohorts:

    plot_h12_calibration(chrom=chrom, analysis=analysis, cohort_query=cohort_query, sample_sets=tuple(sample_sets),
                         cohort_set="geographic_cohort", downsample=30, window_sizes=window_sizes, title=cohort_query)

In [28]:
# set up dictionaries for the windows to use for computing the H12

funestus_windows = {'2RL': {'Ghana_Northern-Region': 4000, 'Gabon_Haut-Ogooue': 4000, 'CAR_Ombella-M-Poko': 1000,
 'Cameroon_Adamawa': 1000, 'Ghana_Ashanti-Region': 2000, 'Malawi_Southern-Region': 4000, 'Mozambique_Maputo': 4000,
 'Uganda_Eastern-Region': 2000, 'Benin_Atlantique-Dept': 3000, 'DRC_Kinshasa': 4000, 'Nigeria_Ogun-State': 1000,
 'Zambia_Eastern-Prov': 4000, 'Kenya_Nyanza-Prov': 1000, 'Kenya_Western-Prov': 2000, 'Tanzania_Morogoro-Region': 4000,
 'DRC_Haut-Uele': 1000, 'Mozambique_Cabo-Delgado': 4000},
                    
 '3RL': {'Ghana_Northern-Region': 4000, 'Gabon_Haut-Ogooue': 4000, 'CAR_Ombella-M-Poko': 2000, 'Cameroon_Adamawa': 2000,
 'Ghana_Ashanti-Region': 2000, 'Malawi_Southern-Region': 4000, 'Mozambique_Maputo': 4000, 'Uganda_Eastern-Region': 2000,
 'Benin_Atlantique-Dept': 4000, 'DRC_Kinshasa': 4000, 'Nigeria_Ogun-State': 4000, 'Zambia_Eastern-Prov': 3000,
 'Kenya_Nyanza-Prov': 3000, 'Kenya_Western-Prov': 3000, 'Tanzania_Morogoro-Region': 4000, 'DRC_Haut-Uele': 1000,
 'Mozambique_Cabo-Delgado': 4000},
                    
 'X': {'Ghana_Northern-Region': 4000, 'Gabon_Haut-Ogooue': 4000, 'CAR_Ombella-M-Poko': 1000, 'Cameroon_Adamawa': 1000,
 'Ghana_Ashanti-Region': 1000, 'Malawi_Southern-Region': 4000, 'Mozambique_Maputo': 4000, 'Uganda_Eastern-Region': 1000,
 'Benin_Atlantique-Dept': 3000, 'DRC_Kinshasa': 4000, 'Nigeria_Ogun-State': 1000, 'Zambia_Eastern-Prov': 4000,
 'Kenya_Nyanza-Prov': 1000, 'Kenya_Western-Prov': 1000, 'Tanzania_Morogoro-Region': 4000, 'DRC_Haut-Uele': 1000,
 'Mozambique_Cabo-Delgado': 4000},
}

### Run H12 computation

##### X

In [None]:
chrom = 'X'
windowed_selection = {}
for cohort_query in geographic_cohorts:

    pos, h12 = h12_gwss(chrom=chrom, analysis=analysis, cohort_query=cohort_query, sample_sets=tuple(sample_sets),
                        cohort_set='geographic_cohort', window_size=funestus_windows[chrom][cohort_query])

    windowed_selection[cohort_query] = h12, pos
np.save(f'windowed_H12_geo_cohort_{chrom}.npy', windowed_selection,)

##### 2RL

In [None]:
chrom = '2RL'
windowed_selection = {}
for cohort_query in geographic_cohorts:

    pos, h12 = h12_gwss(chrom=chrom, analysis=analysis, cohort_query=cohort_query, sample_sets=tuple(sample_sets),
                        cohort_set='geographic_cohort', window_size=funestus_windows[chrom][cohort_query])

    windowed_selection[cohort_query] = h12, pos
np.save(f'windowed_H12_geo_cohort_{chrom}.npy', windowed_selection,)

##### 3RL

In [None]:
chrom = '3RL'
windowed_selection = {}
for cohort_query in geographic_cohorts:

    pos, h12 = h12_gwss(chrom=chrom, analysis=analysis, cohort_query=cohort_query, sample_sets=tuple(sample_sets),
                        cohort_set='geographic_cohort', window_size=funestus_windows[chrom][cohort_query])

    windowed_selection[cohort_query] = h12, pos
np.save(f'windowed_H12_geo_cohort_{chrom}.npy', windowed_selection,)