In [1]:
from dask_kubernetes import KubeCluster
from dask.distributed import Client, progress
import dask.array as da
import numpy as np
import zarr
import allel
import sys
import ag3
import psutil
from humanize import naturalsize
import numba

### parameters

In [2]:
species_group = ["gamb_colu", "arab", "gamb_colu_arab"]
module_path = "../../notebooks/"
n_downsample = 100_000
max_allele = 3
random_seed = 42
n_workers = 10

#regions
region_3L_free =  '3L', 15000000, 41000000
region_3R_free =  '3R', 1, 37000000

#I've chosen to limit gamb_colu to just 3L for ease of comparison across masks - up for discussion
regions = {"gamb_colu" : region_3L_free,
            "arab" : region_3L_free,
            "gamb_colu_arab"  : region_3L_free
            }


#This is correct - gamb_colu_arab includes ALL samples
sample_query = {"gamb_colu" : "species_gambcolu_arabiensis == 'gamb_colu'",
                "arab" : "species_gambcolu_arabiensis == 'arabiensis'",
                "gamb_colu_arab" : "species_gambcolu_arabiensis != 'NA'"
                }

In [3]:
# Data storage option, uses about 34 MB
# Either 'local' or 'cloud'
storage_option = 'cloud' 

# Specify the storage paths
# The string "{}" will be replaced with each species_group
output_local_zarr_path_template = '{}.pca_umap_input_alleles.zarr'
output_cloud_zarr_path_template = 'vo_agam_production/ag3_data_paper/{}.pca_umap_input_alleles.zarr'

### Check the output doesn't already exist

In [4]:
# Writing the PCA data to the cloud will require the appropriate authentication and authorization.
if storage_option == 'cloud':
    import gcsfs
    # UNCOMMENT THIS TO AUTHENTICATE. YOU ONLY NEED TO RUN THIS ONCE.
    # After running this once, your authentication token should then be cached in `~/.gcs_tokens`
    # Once you have authenticated, you should comment this out again to avoid re-authenticating.
    # gcs_browser = gcsfs.GCSFileSystem(project='malariagen-jupyterhub', token='browser')
    
    # Use `cache_timeout=0` to prevent object list cache, to avoid recreating map for Zarr consolidated metadata
    auth_fs = gcsfs.GCSFileSystem(project='malariagen-jupyterhub', token='cache', cache_timeout=0)
    
    for sp in species_group:

        # Check that the output's Zarr metadata file is not already on the cloud.
        # We don't want to accidentally overwrite or delete existing data, which might have been used in downstream analysis.
        # We don't simply check for the existence of the Zarr file here (i.e. output_cloud_zarr_path),
        # We might want to re-run the first parts of this notebook again,
        # so the Zarr store might legitimately exist but be incomplete.
        # The Zarr store is not considered complete until the Zarr metadata file (.zmetadata) is present and correct.
        # The final part of this notebook includes steps to create the Zarr metadata file and then validate it.

        output_cloud_zarr_metadata_path = f'{output_cloud_zarr_path_template.format(sp)}/.zmetadata'
        print(f'Checking for {output_cloud_zarr_metadata_path}')
        assert not auth_fs.exists(output_cloud_zarr_metadata_path)

Checking for vo_agam_production/ag3_data_paper/gamb_colu.pca_umap_input_alleles.zarr/.zmetadata
Checking for vo_agam_production/ag3_data_paper/arab.pca_umap_input_alleles.zarr/.zmetadata
Checking for vo_agam_production/ag3_data_paper/gamb_colu_arab.pca_umap_input_alleles.zarr/.zmetadata


In [5]:
if storage_option == 'local':
    from os import path
    
    for sp in species_group:
        output_local_zarr_metadata_path = f'{output_local_zarr_path_template.format(sp)}/.zmetadata'
        print(f'Checking for {output_local_zarr_metadata_path}')
        assert not path.isfile(output_local_zarr_metadata_path)

### functions

In [6]:
def show_memory():
      vm = psutil.virtual_memory()
      print(f"{naturalsize(vm.used)} used, {naturalsize(vm.available)} available, {naturalsize(vm.total)} total")
show_memory()

1.0 GB used, 14.4 GB available, 15.8 GB total


In [7]:
# Define a function to return the genotype data for a given chromosome region, for a given sample set.
# Use the specified intake catalog.
def load_genotype_calldata(cat, sample_set, chrom_arm, region_slice_obj):
    print('- load_genotype_calldata', sample_set, chrom_arm, region_slice_obj)
    zarr_data = cat.ag3.snp_genotypes(sample_set=sample_set).to_zarr()
    return da.from_zarr(zarr_data[chrom_arm]["calldata"]["GT"])[region_slice_obj]

In [8]:
class Util:
# Define a function to count the presence of each allele in a given array of genotypes (samples * variants * alleles)
    # and return an array of allele counts with a row for each possible allele (limited by max_allele) for each sample, and column for each variant
    @staticmethod
    @numba.njit(numba.int8[:, :](numba.int8[:, :, :], numba.int8), nogil=True)
    def numpy_genotype_tensor_to_allele_counts_melt(gt, max_allele):
        # Create an array of zeros (for defaults) with the same number of colums (variants) as the genotype array but a row for each allele, for each sample
        out = np.zeros((gt.shape[0] * (max_allele + 1), gt.shape[1]), dtype=np.int8)
        # For each row (sample) in the genotype array
        for i in range(gt.shape[0]):
            # For each column (variant) in the genotype array
            for j in range(gt.shape[1]):
                # For each allele in the genotype array 
                for k in range(gt.shape[2]):
                    allele = gt[i, j, k]
                    # If the value in the genotype array at this row and colum and 3rd dimension (i.e. the allele value) is between 0 and max_allele  
                    if 0 <= allele <= max_allele:
                        # Increment the value of the `out` array at the row corresponding to this allele for this sample, at the corresponding variant column
                        out[(i * (max_allele + 1)) + allele, j] += 1
        return out

## the only way I could get the code to work is to define this ^^ as Util otherwise the map blocks code doesn't run - why is this?

In [9]:
# Define a function to apply the above function chunk-wise
def dask_genotype_tensor_to_allele_counts_melt(gt, max_allele):
    # Determine output chunks - change axis 0; preserve axis 1; drop axis 2.
    dim0_chunks = tuple(np.array(gt.chunks[0]) * (max_allele + 1))
    chunks = (dim0_chunks, gt.chunks[1])
    
    return gt.map_blocks(
        Util.numpy_genotype_tensor_to_allele_counts_melt,
        max_allele=max_allele,
        chunks=chunks,
        dtype="i1",
        drop_axis=2,
    )

In [10]:
#can this be improved with dask?
def ld_prune(gn, size=500, step=200, threshold=.1, n_iter=1):
    for i in range(n_iter):
        loc_unlinked = allel.locate_unlinked(gn, size=size, step=step, threshold=threshold)
        n = np.count_nonzero(loc_unlinked)
        n_remove = gn.shape[0] - n
        print('iteration', i+1, 'retaining', n, 'removing', n_remove, 'variants')
        gn = gn.compress(loc_unlinked, axis=0)
    return gn

## run

In [11]:
# cluster setup
cluster = KubeCluster()
cluster.scale_up(n_workers)

# dask client setup
client = Client(cluster)
print(cluster.dashboard_link)

# grab data from release
v3 = ag3.release_data()

sample_sets = v3.all_wild_sample_sets

# load_metadata
metadata = v3.load_sample_set_metadata(sample_sets)

distributed.scheduler - INFO - Clear task state
distributed.scheduler - INFO -   Scheduler at:   tcp://10.32.55.11:33257
distributed.scheduler - INFO -   dashboard at:                     :8787
distributed.scheduler - INFO - Receive client connection: Client-2f7418b6-0679-11eb-b59b-327940786d26
distributed.core - INFO - Starting established connection


/user/lee.hart@well.ox.ac.uk/proxy/8787/status


  res_values = method(rvalues)


In [17]:
def build_allele_counts(selection):
    print("building allele counts for", selection)
    
    sample_loc = metadata.eval(sample_query[selection]).values
    assert sample_loc.sum() > 0, "Must select >0 samples"

    genotypes = []
    site_filters = []
    
    chrom, start, stop = regions[selection]
    print(chrom, start, stop)
    #     
    pos = allel.SortedIndex(v3.load_variants(chrom))
    ix = pos.locate_range(start, stop)
    gt = v3.load_sample_set_calldata(chrom, sample_set=sample_sets)[ix]
    mask = v3.load_mask(chrom, selection)[ix]
    g = da.compress(sample_loc, gt, axis=1)

    genotypes.append(g)
    site_filters.append(mask)

    genotype_data = da.concatenate(genotypes, axis=0)
    site_filters_data = da.concatenate(site_filters, axis=0)

    melted_allele_counts = dask_genotype_tensor_to_allele_counts_melt(gt=genotype_data, max_allele=max_allele)
    
    # Get the number of genotyped samples
    number_of_samples = genotype_data.shape[1]
    print("Number of samples", number_of_samples)

    # Sum the allele counts
    allele_count_sums = da.sum(melted_allele_counts, axis=1, dtype='int16')

    # Determine which alleles meet the criteria, and record as a Boolean array.
    loc_midfreq_alleles = (allele_count_sums >= 2) & (allele_count_sums <= ((number_of_samples * 2) - 2))

    # Transform the Boolean site_filter index into the same space as the melted allele counts.
    loc_accessible = da.repeat(site_filters_data, max_allele + 1) # 4 alleles

    # Check that loc_accessible is the same shape as loc_midfreq_alleles
    assert loc_accessible.shape == loc_midfreq_alleles.shape
    
    # Determine the corresponding array indices for all of the mid-frequency alleles that are accessible
    # We use the '&' to choose sites that meet the critera AND are accessible.
    midfreq_alleles_as_indices = da.nonzero(loc_midfreq_alleles & loc_accessible)[0]

    # Compute (and bring into client memory) the midfreq_alleles_as_indices
    ix_select = midfreq_alleles_as_indices.compute()
    
    # Set/reset the random seed used for random variant selection
    # to ensure that we always select the same set of random variants
    np.random.seed(random_seed)

    # Randomly choose `n_downsample_variants` items from the array of accessible mid-frequency allele indices
    downsampled_site_indices = np.random.choice(
        ix_select, 
        size=n_downsample, 
        replace=False)

    # Sort the indices to allow contiguous parsing
    downsampled_site_indices.sort()

    # From the melted_allele_counts array, take the corresponding indices
    downsampled_allele_counts = da.take(melted_allele_counts, downsampled_site_indices, axis=0)   
    computed_downsampled_allele_counts = downsampled_allele_counts.compute()
    
    #prune
    pruned_downsampled_allele_counts = ld_prune(computed_downsampled_allele_counts)
    
    # finally save to zarr...
    print('saving to zarr')
    #z = zarr.ZipStore("../data/allele_counts_for_pca_umap/"+selection+".pca_umap_input_alleles.zarr.zip")
    #zg = zarr.group(z)
    #zg.create_dataset("allele_counts_pca_ready", data=pruned_downsampled_allele_counts)
    
    if storage_option == 'cloud':
    
        output_cloud_zarr_path = output_cloud_zarr_path_template.format(selection)
        print('Will attempt to store at:', output_cloud_zarr_path)

        # Sometimes errors with `overwrite=True`, sometimes errors without, when dir not exist
        # Keep the zarr_store for zarr.consolidate_metadata(zarr_store)
        zarr_store = auth_fs.get_mapper(output_cloud_zarr_path)
        zarr_group = zarr.group(zarr_store)
        
    if storage_option == 'local':

        # Eyeball output path
        output_local_zarr_path =  output_local_zarr_path_template.format(selection)
        print('Will attempt to store at:', output_local_zarr_path)

        # Keep the zarr_store for zarr.consolidate_metadata(zarr_store)
        zarr_store = zarr.DirectoryStore(output_local_zarr_path)
        zarr_group = zarr.open(zarr_store, mode='w')
    
    
    # Check the data type
    print('type(pruned_downsampled_allele_counts):', type(pruned_downsampled_allele_counts))
    
    # overwrite=True, otherwise `ValueError: path 'allele_counts_pca_ready' contains an array`
    zarr_group.create_dataset("allele_counts_pca_ready", data=pruned_downsampled_allele_counts, overwrite=True)
    
    # Check the stored data has all its chunks initialized
    assert zarr_group['allele_counts_pca_ready'].nchunks_initialized == zarr_group['allele_counts_pca_ready'].nchunks
    
    # Check the store contains the expected arrays
    assert 'allele_counts_pca_ready/.zarray' in zarr_store
    
    # Consolidate the Zarr metatdata
    zarr.consolidate_metadata(zarr_store)
    
    # Check the consolidated Zarr metadata
    zarr_consolidated_metadata = zarr.open_consolidated(zarr_store)
    assert list(zarr_consolidated_metadata.keys()) == ['allele_counts_pca_ready']
    

In [18]:
for s in species_group:
    build_allele_counts(s)

building allele counts for gamb_colu
3L 15000000 41000000
Number of samples 2415


distributed.core - INFO - Event loop was unresponsive in Scheduler for 14.06s.  This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.
distributed.core - INFO - Event loop was unresponsive in Scheduler for 12.15s.  This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.
distributed.scheduler - INFO - Remove worker tcp://10.32.233.2:38611
distributed.core - INFO - Removing comms to tcp://10.32.233.2:38611


iteration 1 retaining 81239 removing 18761 variants
saving to zarr
Will attempt to store at: vo_agam_production/ag3_data_paper/gamb_colu.pca_umap_input_alleles.zarr
type(pruned_downsampled_allele_counts): <class 'numpy.ndarray'>
building allele counts for arab
3L 15000000 41000000
Number of samples 368
iteration 1 retaining 40862 removing 59138 variants
saving to zarr
Will attempt to store at: vo_agam_production/ag3_data_paper/arab.pca_umap_input_alleles.zarr
type(pruned_downsampled_allele_counts): <class 'numpy.ndarray'>
building allele counts for gamb_colu_arab
3L 15000000 41000000
Number of samples 2784


distributed.core - INFO - Event loop was unresponsive in Scheduler for 12.90s.  This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.
distributed.core - INFO - Event loop was unresponsive in Scheduler for 13.42s.  This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.
distributed.utils_perf - INFO - full garbage collection released 149.25 MB from 466 reference cycles (threshold: 10.00 MB)


iteration 1 retaining 81901 removing 18099 variants
saving to zarr
Will attempt to store at: vo_agam_production/ag3_data_paper/gamb_colu_arab.pca_umap_input_alleles.zarr
type(pruned_downsampled_allele_counts): <class 'numpy.ndarray'>


In [19]:
cluster.close()

distributed.scheduler - INFO - Scheduler closing...
distributed.scheduler - INFO - Scheduler closing all comms
distributed.scheduler - INFO - Remove worker tcp://10.32.116.2:39037
distributed.core - INFO - Removing comms to tcp://10.32.116.2:39037
distributed.client - ERROR - Failed to reconnect to scheduler after 10.00 seconds, closing client
distributed.utils - ERROR - 
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/distributed/utils.py", line 663, in log_errors
    yield
  File "/opt/conda/lib/python3.7/site-packages/distributed/client.py", line 1296, in _close
    await gen.with_timeout(timedelta(seconds=2), list(coroutines))
concurrent.futures._base.CancelledError
distributed.utils - ERROR - 
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/distributed/utils.py", line 663, in log_errors
    yield
  File "/opt/conda/lib/python3.7/site-packages/distributed/client.py", line 1025, in _reconnect
    await self._close()