In [1]:
import traceback
import logging

logger = logging.getLogger('produce-allele-counts')
logger.setLevel(logging.DEBUG)

# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)

# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
# add the handlers to logger
logger.addHandler(ch)

In [2]:
from dask_kubernetes import KubeCluster
from dask.distributed import Client, progress
import dask.array as da
import numpy as np
import zarr
import allel
import sys
import ag3
import psutil
from humanize import naturalsize
import numba

### parameters

In [3]:
species_group = ["gamb_colu", "arab", "gamb_colu_arab"]
n_downsample = 100_000
max_allele = 3
random_seed = 42
n_workers = 10

#regions
region_3L_free =  '3L', 15000000, 41000000
region_3R_free =  '3R', 1, 37000000

regions = {"gamb_colu" : [region_3L_free, region_3R_free],
           "arab" : [region_3L_free,],
           "gamb_colu_arab"  : [region_3L_free,]
          }


sample_query = {"gamb_colu" : "species_gambcolu_arabiensis == 'gamb_colu'",
                "arab" : "species_gambcolu_arabiensis == 'arabiensis'",
                "gamb_colu_arab" : "species_gambcolu_arabiensis in ('gamb_colu', 'arabiensis', 'intermediate')"
               }

### cloud storage

In [4]:
# Data storage, uses about 34 MB
output_cloud_zarr_path_template = 'vo_agam_production/ag3_data_paper/{}.pca_umap_input_alleles.zarr'
# Writing the PCA data to the cloud will require the appropriate authentication and authorization.

import gcsfs
# UNCOMMENT THIS TO AUTHENTICATE. YOU ONLY NEED TO RUN THIS ONCE.
# After running this once, your authentication token should then be cached in `~/.gcs_tokens`
# Once you have authenticated, you should comment this out again to avoid re-authenticating.
# gcs_browser = gcsfs.GCSFileSystem(project='malariagen-jupyterhub', token='browser')

# Use `cache_timeout=0` to prevent object list cache, to avoid recreating map for Zarr consolidated metadata
auth_fs = gcsfs.GCSFileSystem(project='malariagen-jupyterhub', token='cache', cache_timeout=0)

In [6]:
for sp in species_group:

    # Check that the output's Zarr metadata file is not already on the cloud.
    # We don't want to accidentally overwrite or delete existing data, which might have been used in downstream analysis.
    # We don't simply check for the existence of the Zarr file here (i.e. output_cloud_zarr_path),
    # We might want to re-run the first parts of this notebook again,
    # so the Zarr store might legitimately exist but be incomplete.
    # The Zarr store is not considered complete until the Zarr metadata file (.zmetadata) is present and correct.
    # The final part of this notebook includes steps to create the Zarr metadata file and then validate it.

    output_cloud_zarr_metadata_path = f'{output_cloud_zarr_path_template.format(sp)}/.zmetadata'
    print(f'Checking for {output_cloud_zarr_metadata_path}')
    assert not auth_fs.exists(output_cloud_zarr_metadata_path)

Checking for vo_agam_production/ag3_data_paper/gamb_colu.pca_umap_input_alleles.zarr/.zmetadata
Checking for vo_agam_production/ag3_data_paper/arab.pca_umap_input_alleles.zarr/.zmetadata
Checking for vo_agam_production/ag3_data_paper/gamb_colu_arab.pca_umap_input_alleles.zarr/.zmetadata


### functions

In [7]:
def show_memory():
      vm = psutil.virtual_memory()
      print(f"{naturalsize(vm.used)} used, {naturalsize(vm.available)} available, {naturalsize(vm.total)} total")
show_memory()

1.9 GB used, 13.6 GB available, 15.8 GB total


In [8]:
@numba.njit(numba.int8[:, :](numba.int8[:, :, :], numba.int8), nogil=True)
def numpy_genotype_tensor_to_allele_counts_melt(gt, max_allele):
    # Create an array of zeros (for defaults) with the same number of colums (variants) as the genotype array but a row for each allele, for each sample
    out = np.zeros((gt.shape[0] * (max_allele + 1), gt.shape[1]), dtype=np.int8)
    # For each row (sample) in the genotype array
    for i in range(gt.shape[0]):
        # For each column (variant) in the genotype array
        for j in range(gt.shape[1]):
            # For each allele in the genotype array 
            for k in range(gt.shape[2]):
                allele = gt[i, j, k]
                # If the value in the genotype array at this row and colum and 3rd dimension (i.e. the allele value) is between 0 and max_allele  
                if 0 <= allele <= max_allele:
                    # Increment the value of the `out` array at the row corresponding to this allele for this sample, at the corresponding variant column
                    out[(i * (max_allele + 1)) + allele, j] += 1
    return out

In [9]:
# Define a function to apply the above function chunk-wise
def dask_genotype_tensor_to_allele_counts_melt(gt, max_allele):
    # Determine output chunks - change axis 0; preserve axis 1; drop axis 2.
    dim0_chunks = tuple(np.array(gt.chunks[0]) * (max_allele + 1))
    chunks = (dim0_chunks, gt.chunks[1])
    
    return gt.map_blocks(
        numpy_genotype_tensor_to_allele_counts_melt,
        max_allele=max_allele,
        chunks=chunks,
        dtype="i1",
        drop_axis=2,
    )

In [10]:
# Slowest point- difficult to improve with Dask
def ld_prune(gn, size=500, step=200, threshold=.1, n_iter=1):
    for i in range(n_iter):
        loc_unlinked = allel.locate_unlinked(gn, size=size, step=step, threshold=threshold)
        n = np.count_nonzero(loc_unlinked)
        n_remove = gn.shape[0] - n
        logger.info(f'iteration {i+1}: retaining {n}, removing {n_remove} variants')
        gn = gn.compress(loc_unlinked, axis=0)
    return gn

## run

In [11]:
# cluster setup
cluster = KubeCluster()
cluster.scale_up(n_workers)

distributed.scheduler - INFO - Clear task state
distributed.scheduler - INFO -   Scheduler at: tcp://10.33.164.221:40769
distributed.scheduler - INFO -   dashboard at:                     :8787


In [12]:
# dask client setup
client = Client(cluster)
client

distributed.scheduler - INFO - Receive client connection: Client-a5f7bbc2-095a-11eb-876c-c6bfdf39bfce
distributed.core - INFO - Starting established connection


0,1
Client  Scheduler: tcp://10.33.164.221:40769  Dashboard: /user/nicholasharding/proxy/8787/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


In [13]:
# grab data from release
v3 = ag3.release_data()

In [14]:
sample_sets = v3.all_wild_sample_sets

In [15]:
# load_metadata
metadata = v3.load_sample_set_metadata(sample_sets)

distributed.scheduler - INFO - Register tcp://10.33.75.9:45439
distributed.scheduler - INFO - Starting worker compute stream, tcp://10.33.75.9:45439
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Register tcp://10.33.81.4:34107
distributed.scheduler - INFO - Starting worker compute stream, tcp://10.33.81.4:34107
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Register tcp://10.32.247.4:43017
distributed.scheduler - INFO - Starting worker compute stream, tcp://10.32.247.4:43017
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Register tcp://10.33.74.8:43411
distributed.scheduler - INFO - Starting worker compute stream, tcp://10.33.74.8:43411
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Register tcp://10.32.254.7:34109
distributed.scheduler - INFO - Starting worker compute stream, tcp://10.32.254.7:34109
distributed.core - 

In [16]:
def build_allele_counts(selection):
    
    logger.info(f"building allele counts for {selection}")
    
    sample_loc = metadata.eval(sample_query[selection]).values
    assert sample_loc.sum() > 0, "Must select >0 samples"

    genotypes = []
    site_filters = []
        
    for reg in (regions[selection]):      
        chrom, start, stop = reg
        logger.info(f"{chrom}: {start}-{stop}")

        pos = allel.SortedIndex(v3.load_variants(chrom))
        ix = pos.locate_range(start, stop)
        gt = v3.load_sample_set_calldata(chrom, sample_set=sample_sets)[ix]
        mask = v3.load_mask(chrom, selection)[ix]
        g = da.compress(sample_loc, gt, axis=1)

        genotypes.append(g)
        site_filters.append(mask)

    genotype_data = da.concatenate(genotypes, axis=0)
    site_filters_data = da.concatenate(site_filters, axis=0)

    melted_allele_counts = dask_genotype_tensor_to_allele_counts_melt(gt=genotype_data, max_allele=max_allele)

    # Get the number of genotyped samples
    number_of_samples = genotype_data.shape[1]
    logger.info(f"Number of samples {number_of_samples}")

    # Sum the allele counts
    allele_count_sums = da.sum(melted_allele_counts, axis=1, dtype='int16')

    # Determine which alleles meet the criteria, and record as a Boolean array.
    loc_midfreq_alleles = (allele_count_sums >= 2) & (allele_count_sums <= ((number_of_samples * 2) - 2))

    # Transform the Boolean site_filter index into the same space as the melted allele counts.
    loc_accessible = da.repeat(site_filters_data, max_allele + 1) # 4 alleles

    # Check that loc_accessible is the same shape as loc_midfreq_alleles
    assert loc_accessible.shape == loc_midfreq_alleles.shape

    # Determine the corresponding array indices for all of the mid-frequency alleles that are accessible
    # We use the '&' to choose sites that meet the critera AND are accessible.
    midfreq_alleles_as_indices = da.nonzero(loc_midfreq_alleles & loc_accessible)[0]

    # Compute (and bring into client memory) the midfreq_alleles_as_indices
    ix_select = midfreq_alleles_as_indices.compute()
    
    # Set/reset the random seed used for random variant selection
    # to ensure that we always select the same set of random variants
    np.random.seed(random_seed)

    logger.info(f"Performing downsampling to {n_downsample}")
    
    # Randomly choose `n_downsample_variants` items from the array of accessible mid-frequency allele indices
    downsampled_site_indices = np.random.choice(
        ix_select, 
        size=n_downsample, 
        replace=False)

    # Sort the indices to allow contiguous parsing
    downsampled_site_indices.sort()

    # From the melted_allele_counts array, take the corresponding indices
    downsampled_allele_counts = da.take(melted_allele_counts, downsampled_site_indices, axis=0)   
    computed_downsampled_allele_counts = downsampled_allele_counts.compute()
    
    
    logger.info("Performing LD pruning.")
    # prune
    pruned_downsampled_allele_counts = ld_prune(computed_downsampled_allele_counts)
    
    # finally save to zarr...
    output_cloud_zarr_path = output_cloud_zarr_path_template.format(selection)
    logger.info(f"Storing at {output_cloud_zarr_path}")

    # Sometimes errors with `overwrite=True`, sometimes errors without, when dir not exist
    # Keep the zarr_store for zarr.consolidate_metadata(zarr_store)
    zarr_store = auth_fs.get_mapper(output_cloud_zarr_path)
    zarr_group = zarr.group(zarr_store)

    # Check the data type
    logger.info(f"type(pruned_downsampled_allele_counts): {type(pruned_downsampled_allele_counts)}")
    
    # overwrite=True, otherwise `ValueError: path 'allele_counts_pca_ready' contains an array`
    zarr_group.create_dataset("allele_counts_pca_ready", data=pruned_downsampled_allele_counts, overwrite=True)
    
    # Check the stored data has all its chunks initialized
    assert zarr_group['allele_counts_pca_ready'].nchunks_initialized == zarr_group['allele_counts_pca_ready'].nchunks
    
    # Check the store contains the expected arrays
    assert 'allele_counts_pca_ready/.zarray' in zarr_store
    
    # Consolidate the Zarr metatdata
    zarr.consolidate_metadata(zarr_store)
    
    # Check the consolidated Zarr metadata
    zarr_consolidated_metadata = zarr.open_consolidated(zarr_store)
    assert list(zarr_consolidated_metadata.keys()) == ['allele_counts_pca_ready']

In [17]:
for s in species_group:
    build_allele_counts(s)

2020-10-08 11:37:36,538 - produce-allele-counts - INFO - building allele counts for gamb_colu
2020-10-08 11:37:36,547 - produce-allele-counts - INFO - 3L: 15000000-41000000
2020-10-08 11:37:44,433 - produce-allele-counts - INFO - 3R: 1-37000000
2020-10-08 11:37:49,683 - produce-allele-counts - INFO - Number of samples 2415
distributed.core - INFO - Event loop was unresponsive in Scheduler for 26.27s.  This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.
distributed.utils_perf - INFO - full garbage collection released 12.75 MB from 0 reference cycles (threshold: 10.00 MB)
2020-10-08 11:44:35,037 - produce-allele-counts - INFO - Performing downsampling to 100000
distributed.core - INFO - Event loop was unresponsive in Scheduler for 27.59s.  This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.
2020-10-08 11:50:22,279 - produce-allele

In [18]:
cluster.close()

distributed.scheduler - INFO - Scheduler closing...
distributed.scheduler - INFO - Scheduler closing all comms
distributed.scheduler - INFO - Remove worker tcp://10.32.247.4:43017
distributed.core - INFO - Removing comms to tcp://10.32.247.4:43017
distributed.scheduler - INFO - Remove worker tcp://10.33.75.9:45439
distributed.core - INFO - Removing comms to tcp://10.33.75.9:45439
distributed.client - ERROR - Failed to reconnect to scheduler after 10.00 seconds, closing client
distributed.utils - ERROR - 
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/distributed/utils.py", line 663, in log_errors
    yield
  File "/opt/conda/lib/python3.7/site-packages/distributed/client.py", line 1296, in _close
    await gen.with_timeout(timedelta(seconds=2), list(coroutines))
concurrent.futures._base.CancelledError
distributed.utils - ERROR - 
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/distributed/utils.py", line 663, in log_e