## Replicate QC

This notebook takes a merged sampleset, ie all GT arrays from a set that have been merged by the `combine-zarr-callset` pipeline, and computes pairwise distances between them.

Each contig is handled separately, and the dimensions of the resulting outputs are: contigs x npairs.

NB: We restrict to bialleleic positions in phase 2. 

## Use

0. Log onto `datalab.malariagen.net` and clone the repo with submodules.

1. Copy this notebook and the replicate-qc-analysis one to the tracking directory of the sampleset you wish to perform replicate QC on.

2. Ensure the `merged.zarr` is in the expected location on the datalab server.

3. Run this notebook to compute distances

4. Run the other notebook to discard samples that fail QC.

In [None]:
import zarr
import allel
import pandas as pd
from pathlib import Path

In [None]:
sampleset = Path(".").absolute().name
sampleset

In [None]:
storage_path = 'ag1000g-release/observatory/{sampleset}/callset.zarr'.format(sampleset=sampleset)

In [None]:
manifest_fn = "/gcs/observatory/{sampleset}/manifest".format(sampleset=sampleset)

In [None]:
!pip install dask-distance
import dask_distance as dadist
import scipy.spatial.distance as dist
import os

In [None]:
import dask.array as da
import numpy as np

In [None]:
chunksize = 30000

In [None]:
# wrapper function to reshape for map_blocks
def trans_d(block, metric="euclidean"):
    return dist.pdist(block, metric=metric).reshape((-1, 1))

In [None]:
# pruning missing count
def count_nmissing(X1, X2):
    
    X1 = np.array(X1)
    X2 = np.array(X2)
    
    # compress by non missing
    ok = (X1 >= 0) & (X2 >= 0)
    
    # compute on array
    return np.sum(ok)

In [None]:
# cityblock distance after pruning missings
def cib_dist_nmissing(X1, X2):
    
    X1 = np.array(X1)
    X2 = np.array(X2)
    
    # compress by non missing
    ok = (X1 >= 0) & (X2 >= 0)
    
    # compute on array
    return dist.cityblock(
        np.compress(ok, X1),
        np.compress(ok, X2))

In [None]:
# GCS configuration
import gcsfs

gcs_bucket_fs = gcsfs.GCSFileSystem(
    project='malariagen-jupyterhub', token='anon', access='read_only')

store = gcsfs.mapping.GCSMap(
    storage_path, gcs=gcs_bucket_fs, check=False, create=False)

In [None]:
calldata = zarr.Group(store)

In [None]:
df = pd.read_csv(manifest_fn)

In [None]:
# assume this is ok for now. Normally use the manifest
samples = df["sample_name"].tolist()

In [None]:
# kubernetes cluster setup
from dask_kubernetes import KubeCluster
cluster = KubeCluster(n_workers=40)
cluster

In [None]:
# dask client setup
from dask.distributed import Client, progress
client = Client(cluster)
client

In [None]:
# time taken scales exponentially with size of sampleset
len(samples)

In [None]:
phase2_callset = zarr.open_group("/gcs/phase2/AR1/variation/main/zarr2/ag1000g.phase2.ar1")
called_sites = zarr.open_group("/gcs/observatory/ag.allsites.nonN.zarr.zip", mode="r")

In [None]:
# find biallelic sites
def find_phase2_bialleleic_sites(chrom):

    g = allel.GenotypeDaskArray(phase2_callset[chrom]["calldata"]["genotype"])
    
    # TO DO PASS ONLY
    
    biallelic = (g.max(axis=[1,2]) <= 1).compute()
                 
    d = {}
    for x in "POS", "REF", "ALT":
        v = phase2_callset[chrom]["variants"][x]
        dav = da.from_zarr(v, chunksize=v.chunks)
        d[x] = da.compress(biallelic, dav, axis=0)
        
    return d["POS"], d["ALT"], d["REF"]

In [None]:
config = {"pwd_contigs": ["3L", "3R", "2L", "2R", "X"]}
contigs = config["pwd_contigs"]

In [None]:
from itertools import combinations

In [None]:
pairs = list(combinations(range(len(samples)), 2))
npairs = len(pairs)

In [None]:
h = np.zeros((len(contigs), npairs))
denom = np.zeros((len(contigs), npairs))

In [None]:
alt_list = []
for cix, contig in enumerate(contigs):

    sites_pos = allel.SortedIndex(called_sites[contig]["variants/POS"])
    bial_pos, bial_alt, bial_ref = find_phase2_bialleleic_sites(contig)
    loc = sites_pos.locate_keys(bial_pos)

    alleles=da.hstack((bial_ref.reshape((-1, 1)), bial_alt))

    # reduce to biallelic sites all samples still
    print(contig, "compressing")
    gt_a = allel.GenotypeDaskArray(calldata[contig]["calldata/GT"]).compress(loc)

    print(contig, "remapping")
    mapping = allel.create_allele_mapping(
        ref=np.compress(loc, called_sites[contig]["variants/REF"]),
        alt=np.compress(loc, called_sites[contig]["variants/ALT"]),
        alleles=alleles)

    count_alts = gt_a.map_alleles(mapping).to_n_alt(fill=-1)
    
    alt_list.append(count_alts)

    # transpose and rechunk for scipy dist object
    ca = count_alts.T
    ratio = ca.shape[0] / ca.chunksize[0]
    newchunks = (ca.shape[0], int(ca.chunksize[1] / ratio))
    ca = ca.rechunk(chunks=newchunks)
    nchunks = len(ca.chunks[1])

    D = ca.map_blocks(
        trans_d, 
        metric=cib_dist_nmissing,
        chunks=((1,), tuple(np.repeat(1, nchunks))), 
        dtype=float, 
        drop_axis=(0, ), 
        new_axis=(0, ))
    
    X = ca.map_blocks(
        trans_d, 
        metric=count_nmissing,
        chunks=((1,), tuple(np.repeat(1, nchunks))), 
        dtype=float, 
        drop_axis=(0, ), 
        new_axis=(0, ))
    
    h[cix] = D.compute().sum(axis=1)
    denom[cix] = X.compute().sum(axis=1)

In [None]:
np.savez_compressed(
    "replicate-qc-{sset}".format(sset=sampleset), 
    cityblock=h, 
    nsites=denom)