In [1]:
#Load modules
import zarr
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import yaml
from pathlib import Path
import allel

from dask.distributed import Client
import dask
dask.config.set(**{'array.slicing.split_large_chunks': False}) # Silence large chunk warnings
import dask.array as da
from dask import delayed, compute
from dask_gateway import Gateway
import functools
import numcodecs
from fsspec.implementations.zip import ZipFileSystem
from collections.abc import Mapping
import gcsfs
import numba
import psutil
from humanize import naturalsize

import pickle
import platform

import traceback
import logging

from pyprojroot import here
from bokeh.plotting import *
import plotly.express as px
import plotly.graph_objects as go
from plotly.validators.scatter.marker import SymbolValidator

### Connect to gcs

In [None]:
gcs = gcsfs.GCSFileSystem()

In [None]:
gcs.ls('vo_afun_release_master_us_central1')[:3]

### Set up data access

In [5]:
production_root = Path('vo_afun_release_master_us_central1')
vo_afun_staging = Path(production_root, 'v1.0')
sampleset_staging_dir = Path(vo_afun_staging, 'snp_genotypes', 'all')

#Decision tree or static filters
genomic_positions_site_filter_dt_data_cloud_zarr_dir = 'vo_afun_release_master_us_central1/v1.0/site_filters/dt_20200416/funestus'
genomic_positions_site_filter_sc_data_cloud_zarr_dir = 'vo_afun_release_master_us_central1/v1.0/site_filters/sc_20220908/funestus'

repo_clone_path = here()
release_config_path = repo_clone_path / 'analysis' / 'config.yml'

with open(release_config_path) as fh:
    config = yaml.load(fh, Loader=yaml.BaseLoader)
    
samplesets = config["sample_sets"]

In [None]:
meta = pd.read_csv("../../metadata/supp1_tab2.csv")
meta.columns

In [2]:
#Access the data from the cloud.
af1 = malariagen_data.Af1()
af1

MalariaGEN Af1 API client,MalariaGEN Af1 API client
"Please note that data are subject to terms of use,  for more information see the MalariaGEN website or contact data@malariagen.net.","Please note that data are subject to terms of use,  for more information see the MalariaGEN website or contact data@malariagen.net..1"
Storage URL,gs://vo_afun_release/
Data releases available,1.0
Results cache,
Cohorts analysis,20230823
Site filters analysis,dt_20200416
Software version,malariagen_data 7.13.0
Client location,unknown


### Connect to the cluster

### Set up dask cluster

In [16]:
gateway = Gateway()
gateway.list_clusters()

[]

In [17]:
#gateway = Gateway()
conda_prefix = os.environ["CONDA_PREFIX"]
current_environment = 'global/'+conda_prefix.split('/')[5]
cluster = gateway.new_cluster(
    profile='standard', 
    conda_environment = current_environment,
)
cluster

VBox(children=(HTML(value='<h2>GatewayCluster</h2>'), HBox(children=(HTML(value='\n<div>\n<style scoped>\n    …

In [18]:
client=cluster.get_client()

In [19]:
cluster.scale(50)

### Functions

In [8]:
# load a single array from field/chrom/sampleset
# internal path for calldata is chrom/calldata/field
# sampleset_calldata = sampleset_staging_dir / sset
# sampleset is needed to load species spec.
def load_single_field(zarr_path, internal_path, sset, exclude_males=False, samples=None):
      
    inz = zarr.group(is_gcloud(zarr_path), overwrite=False)
    
    oo = da.from_zarr(inz[internal_path])  
    
    if oo.ndim == 1:
        oo = oo.reshape((1, -1))
           
    return oo

In [9]:
## General function to concatenate data.
## Selected chunk size may be more appropriate for some than others.
def concatenate_along_axis(base_dir, internal_path, req_samplesets):
    
    # work out shape
    data = [load_single_field(base_dir / ss, internal_path, ss) for ss in req_samplesets]
    
    return da.concatenate(data, axis=1)

In [10]:
def is_gcloud(path):
    
    try: 
        return gcs.get_mapper(path.as_posix())
    except NameError as e:
        return path.as_posix()

In [9]:
def load_filter(chrom, filter_dir = genomic_positions_site_filter_dt_data_cloud_zarr_dir):
    gcsmap = gcs.get_mapper(filter_dir)
    genomic_positions_site_filter_data = zarr.Group(gcsmap, read_only=True)
    filter_pass = da.from_zarr(
            genomic_positions_site_filter_data[chrom]['variants/filter_pass'])
    return filter_pass

In [10]:
def load_position(chrom):
    store = gcs.get_mapper(
        f'gs://vo_afun_release_master_us_central1/v1.0/snp_genotypes/all/sites')
    root = zarr.open(store, mode='r')
    pos = root[chrom]['variants/POS'][:]
    return pos

In [11]:
def read_in_genotypes_positions(chrom, samples_idx, samplesets, posl, posu, \
                                filter_dir = genomic_positions_site_filter_dt_data_cloud_zarr_dir):

    # load the genotypes and positions
    gt_d = concatenate_along_axis(sampleset_staging_dir, f"{chrom}/calldata/GT", samplesets)
    gt = allel.GenotypeDaskArray(gt_d)
    pos = load_position(chrom)
    
    if posu==-1:
        posu = pos.max()+1
        
    if posl==-1:
        posl = pos.min()
    
    #load the filter
    loc_filt = load_filter(chrom, filter_dir)
    
    #filter by positions
    pos_filt = (pos>=posl) & (pos<posu)
    
    #apply the filter to positions and genotypes
    gt = gt.compress((loc_filt) & (pos_filt), axis=0)
    pos = pos[(loc_filt) & (pos_filt)]
    
    #subset to desired samples 
    gt = da.take(gt, samples_idx, axis=1)
    
    return gt, pos

In [34]:
# @functools.lru_cache(maxsize=None)
def compute_roh_per_sample(sample_idx, ssets, chrom, min_roh):

    #load the filtered genotype
    # get_genotype(chrom, ssets, sample=False)
    gt, pos = read_in_genotypes_positions(chrom, sample_idx, ssets, posl=-1, posu=-1,)
    gv = allel.GenotypeDaskVector(gv)
    
    #perform the roh computation
    #note gt, pos is already filtered by accessibility
    df_roh, froh = allel.roh_mhmm(gt, pos, min_roh=min_roh)
    
    #get roh not at the edge of the segment to avoid double counting
    roh_count = len(df_roh[df_roh['is_marginal']==False])
    
    return roh_count, froh, df_roh

In [None]:
def compute_het_per_sample(sample_idx, ssets, chrom, window_size):

    #load the filtered genotype
    # get_genotype(chrom, ssets, sample=False)
    gt, pos = read_in_genotypes_positions(chrom, sample_idx, ssets, posl=-1, posu=-1,)
    gv = allel.GenotypeDaskVector(gv)
    
    het = gv.is_het().compute()
    
    #split to windows
    het_win, pos_win, cnt_win = allel.windowed_statistic(pos, values=het, statistic=np.mean, size=window_size)
    
    return het_win, pos_win

In [None]:
def compute_roh_all(meta, chrom, min_roh, ssets):
    
    for sample_idx in meta.index:
        count, frac, df_roh = compute_roh_per_sample([sample_idx], ssets, chrom, min_roh)
        meta.loc[sample_idx, [f'ROH_count_{chrom}', f'ROH_frac_{chrom}']] = count, frac
        
    return meta

In [None]:
def plot_roh_per_sample(roh_df, het_win, pos_win, sample_name, chrom):
    
    
    # plotting setup
    fig, ax = plt.subplots(figsize=(10, 2.5))
    sns.despine(ax=ax, offset=10)

    # plot heterozygosity
    y = het_win
    x = np.mean(pos_win, axis=1)
    ax.plot(x, y, linewidth=1.5)
    ax.set_ylim(-0.01, 0.04)
    ax.set_yticks(np.arange(0, 0.04, 0.02))
    ax.set_xlim(0, pos_win.max())
    ax.set_xlabel(f'Chromosome {chrom} position (Mbp)', fontsize=7)
    ax.set_ylabel('heterozygosity', fontsize=7)
    ax.xaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda x, p: f"{x//1e6:.1f}"))

    # plot roh
    xranges = np.column_stack([df_roh.start, df_roh.length])
    yrange = (-.008, 0.006)
    ax.broken_barh(xranges, yrange, facecolor='#B8B8B8', linewidth=None)

    
    ax.set_title(f'Heterozygosity and ROHs for {sample_name}')
    fig.tight_layout()
    
    plt.savefig(f'ROH_{sample_name}_{chrom}.svg', dpi=300, bbox_inches='tight')    
        
    return fig

In [None]:
#put meta in order as dask is stored
sample_order = concatenate_along_axis(sampleset_staging_dir, "samples", samplesets).compute()
sample_order = (sample_order[0]).astype(str)
meta.set_index('VBS_sample_id', inplace=True)
meta = meta.loc[sample_order]
meta.reset_index(inplace=True)
meta.head()

### Run the ROH computation for all samples

In [38]:
min_roh = 100_000

for chrom in ['2RL', '3RL', 'X']:
    meta = compute_roh_all(meta, chrom, min_roh, samplesets)
    meta.to_csv("results_roh.tsv", sep='\t', index=False)
    
meta['ROH_count'] = meta['ROH_count_2RL'] + meta['ROH_count_3RL'] + meta['ROH_count_X'] 
meta.to_csv("results_roh.tsv", sep='\t', index=False)

## Compute ROH for a single sample

In [None]:
sample_name = 'VBS24196'
chrom = '2RL'
window_size=100_000
sample_idx = meta.loc[meta.VBS_sample_id==sample_name].index
count, frac, df_roh = compute_roh_per_sample([sample_idx], samplesets, chrom, min_roh)
het_win, pos_win = compute_het_per_sample([sample_idx], samplesets, chrom, window_size):