In [1]:
#Load modules
import zarr
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import yaml
from pathlib import Path
import allel

from dask.distributed import Client
import dask
dask.config.set(**{'array.slicing.split_large_chunks': False}) # Silence large chunk warnings
import dask.array as da
from dask import delayed, compute
from dask_gateway import Gateway
import functools
import numcodecs
from fsspec.implementations.zip import ZipFileSystem
from collections.abc import Mapping
import gcsfs
import numba
import psutil
from humanize import naturalsize

import pickle
import platform

import traceback
import logging

from pyprojroot import here
from bokeh.plotting import *
import plotly.express as px
import plotly.graph_objects as go
from plotly.validators.scatter.marker import SymbolValidator

### Connect to gcs

In [None]:
gcs = gcsfs.GCSFileSystem()

In [None]:
gcs.ls('vo_afun_release_master_us_central1')[:3]

### Set up data access

In [4]:
production_root = Path('vo_afun_release_master_us_central1')
vo_afun_staging = Path(production_root, 'v1.0')
sampleset_staging_dir = Path(vo_afun_staging, 'snp_genotypes', 'all')

#Decision tree or static filters
genomic_positions_site_filter_dt_data_cloud_zarr_dir = 'vo_afun_release_master_us_central1/v1.0/site_filters/dt_20200416/funestus'
genomic_positions_site_filter_sc_data_cloud_zarr_dir = 'vo_afun_release_master_us_central1/v1.0/site_filters/sc_20220908/funestus'

repo_clone_path = here()
release_config_path = repo_clone_path / 'analysis' / 'config.yml'

with open(release_config_path) as fh:
    config = yaml.load(fh, Loader=yaml.BaseLoader)
    
samplesets = config["sample_sets"]

In [None]:
meta = pd.read_csv("../../metadata/supp1_tab2.csv")
meta.columns

### Functions

In [6]:
# load a single array from field/chrom/sampleset
# internal path for calldata is chrom/calldata/field
# sampleset_calldata = sampleset_staging_dir / sset
# sampleset is needed to load species spec.
def load_single_field(zarr_path, internal_path, sset, exclude_males=False, samples=None):
      
    inz = zarr.group(is_gcloud(zarr_path), overwrite=False)
    
    oo = da.from_zarr(inz[internal_path])  
    
    if oo.ndim == 1:
        oo = oo.reshape((1, -1))
           
    return oo

In [7]:
## General function to concatenate data.
## Selected chunk size may be more appropriate for some than others.
def concatenate_along_axis(base_dir, internal_path, req_samplesets):
    
    # work out shape
    data = [load_single_field(base_dir / ss, internal_path, ss) for ss in req_samplesets]
    
    return da.concatenate(data, axis=1)

In [8]:
def is_gcloud(path):
    
    try: 
        return gcs.get_mapper(path.as_posix())
    except NameError as e:
        return path.as_posix()

In [9]:
def load_filter(chrom, filter_dir = genomic_positions_site_filter_sc_data_cloud_zarr_dir):
    gcsmap = gcs.get_mapper(filter_dir)
    genomic_positions_site_filter_data = zarr.Group(gcsmap, read_only=True)
    filter_pass = da.from_zarr(
            genomic_positions_site_filter_data[chrom]['variants/filter_pass'])
    return filter_pass

In [10]:
def load_position(chrom):
    store = gcs.get_mapper(
        f'gs://vo_afun_release/v1.0/snp_genotypes/all/sites')
    root = zarr.open(store, mode='r')
    pos = root[chrom]['variants/POS'][:]
    return pos

In [11]:
def read_in_genotypes(chrom, pos_min=None, pos_max=None, \
                      samplesets = samplesets, \
                      filter_dir = genomic_positions_site_filter_sc_data_cloud_zarr_dir):

    # load the genotypes and positions
    gt_d = concatenate_along_axis(sampleset_staging_dir, f"{chrom}/calldata/GT", samplesets)
    gt = allel.GenotypeDaskArray(gt_d)
    
    #load the filter
    is_accessible = load_filter(chrom, filter_dir)
    
    #load positions
    pos = load_position(chrom)
    if pos_min is None:
        pos_min=pos.min()
    if pos_max is None:
        pos_max=pos.max()
    is_in_pos = (pos>=pos_min) & (pos<=pos_max)
    
    #return accessible genotypes
    
    return gt.compress((is_accessible) & (is_in_pos), axis=0)

In [12]:
def compute_number_of_hets(meta, chrom, arm, pos_min=None, pos_max=None):
    
    gt = read_in_genotypes(chrom, pos_min, pos_max)
    print(f"read in genotypes on {arm}")
    n_het = gt.count_het(axis=0)
    print(f'computed number of hets on {arm}')
    n_called = gt.count_called(axis=0)
    print(f'computed number of called sites on {arm}')
    
    meta[f'n_het_{arm}'] = n_het.compute()
    meta[f'n_called_{arm}'] = n_called.compute()
    
    return meta
    

In [13]:
def compute_number_of_hets_additional(meta, chrom, arm, pos_min=None, pos_max=None):
    
    gt = read_in_genotypes(chrom, pos_min, pos_max)
    print(f"read in genotypes on {arm}")
    n_het = gt.count_het(axis=0)
    print(f'computed number of hets on {arm}')
    n_called = gt.count_called(axis=0)
    print(f'computed number of called sites on {arm}')
    
    meta[f'n_het_{arm}'] += n_het.compute()
    meta[f'n_called_{arm}'] += n_called.compute()
    
    return meta

In [None]:
#put meta in order as dask is stored
sample_order = concatenate_along_axis(sampleset_staging_dir, "samples", samplesets).compute()
sample_order = (sample_order[0]).astype(str)
meta.set_index('sample_id', inplace=True)
meta = meta.loc[sample_order]
meta.reset_index(inplace=True)
meta.head()

### Set up dask cluster

In [36]:
gateway = Gateway()
conda_prefix = os.environ["CONDA_PREFIX"]
current_environment = 'global/'+conda_prefix.split('/')[5]
cluster = gateway.new_cluster(
    profile='standard', 
    conda_environment = current_environment,
)
cluster

VBox(children=(HTML(value='<h2>GatewayCluster</h2>'), HBox(children=(HTML(value='\n<div>\n<style scoped>\n    …

In [37]:
client=cluster.get_client()

In [38]:
cluster.scale(50)

In [None]:
counts = pd.DataFrame(meta.sample_id)

In [16]:
counts = compute_number_of_hets(counts, 'X', 'X')

In [29]:
counts = compute_number_of_hets(counts, '3RL', '3L', pos_min = 44_700_000, pos_max=64_700_000)

read in genotypes on 3L
computed number of hets on 3L
computed number of called sites on 3L


In [34]:
counts = compute_number_of_hets_additional(counts, '3RL', '3L', pos_min = 64_700_001)

read in genotypes on 3L
computed number of hets on 3L
computed number of called sites on 3L


In [42]:
counts = compute_number_of_hets(counts, '3RL', '3R', pos_max=20_000_000)

read in genotypes on 3R
computed number of hets on 3R
computed number of called sites on 3R


In [45]:
counts = compute_number_of_hets_additional(counts, '3RL', '3R', pos_min=20_000_001, pos_max=44_700_000)

read in genotypes on 3R
computed number of hets on 3R
computed number of called sites on 3R


In [48]:
counts = compute_number_of_hets(counts, '2RL', '2R', pos_max=20_000_000)

read in genotypes on 2R
computed number of hets on 2R
computed number of called sites on 2R


In [51]:
counts = compute_number_of_hets_additional(counts, '2RL', '2R', pos_min=20_000_001, pos_max=40_000_000)

read in genotypes on 2R
computed number of hets on 2R
computed number of called sites on 2R


In [54]:
counts = compute_number_of_hets_additional(counts, '2RL', '2R', pos_min=40_000_001, pos_max=57_350_000)

read in genotypes on 2R
computed number of hets on 2R
computed number of called sites on 2R


In [20]:
counts = compute_number_of_hets(counts, '2RL', '2L', pos_min = 57_350_000, pos_max = 78_000_000)

read in genotypes on 2L
computed number of hets on 2L
computed number of called sites on 2L


In [23]:
counts = compute_number_of_hets_additional(counts, '2RL', '2L', pos_min = 78_000_001)

read in genotypes on 2L
computed number of hets on 2L
computed number of called sites on 2L


In [51]:
counts.to_csv("het_counts.csv", index=False)

### Repeat without inversions

In [18]:
counts = compute_number_of_hets(counts, '3RL', '3L_no_inv', pos_min = 44_700_000, pos_max=57_224_763)

read in genotypes on 3L_no_inv
computed number of hets on 3L_no_inv
computed number of called sites on 3L_no_inv


In [22]:
counts = compute_number_of_hets_additional(counts, '3RL', '3L_no_inv', pos_min = 76_848_507)

read in genotypes on 3L_no_inv
computed number of hets on 3L_no_inv
computed number of called sites on 3L_no_inv


In [25]:
counts = compute_number_of_hets(counts, '3RL', '3R_no_inv', pos_max=2_428_547)

read in genotypes on 3R_no_inv
computed number of hets on 3R_no_inv
computed number of called sites on 3R_no_inv


In [28]:
counts = compute_number_of_hets_additional(counts, '3RL', '3R_no_inv', pos_min=12_234_590, pos_max = 21_361_107)

read in genotypes on 3R_no_inv
computed number of hets on 3R_no_inv
computed number of called sites on 3R_no_inv


In [31]:
counts = compute_number_of_hets_additional(counts, '3RL', '3R_no_inv', pos_min=34_095_918, pos_max = 44_700_000)

read in genotypes on 3R_no_inv
computed number of hets on 3R_no_inv
computed number of called sites on 3R_no_inv


In [43]:
counts = compute_number_of_hets(counts, '2RL', '2R_no_inv', pos_max=15_459_000)

read in genotypes on 2R_no_inv
computed number of hets on 2R_no_inv
computed number of called sites on 2R_no_inv


In [46]:
counts = compute_number_of_hets_additional(counts, '2RL', '2R_no_inv', pos_min=15_459_001, pos_max=25_459_000)

read in genotypes on 2R_no_inv
computed number of hets on 2R_no_inv
computed number of called sites on 2R_no_inv


In [49]:
counts = compute_number_of_hets_additional(counts, '2RL', '2R_no_inv', pos_min=39_360_000, pos_max=57_350_000)

read in genotypes on 2R_no_inv
computed number of hets on 2R_no_inv
computed number of called sites on 2R_no_inv


In [None]:
counts.to_csv("het_counts.csv", index=False)

In [52]:
cluster.shutdown()

In [14]:
gateway = Gateway()

In [15]:
for report in gateway.list_clusters():
    gateway.connect(report.name).shutdown()