In [10]:
import pandas as pd
import numpy as np
import dask.dataframe as dd
import dask.array as da
import os
import os.path as osp
import sgkit_plink 
from dotenv import load_dotenv; load_dotenv();

In [2]:
path = osp.join(os.environ['GWAS_TUTORIAL_DIR'], 'HapMap_3_r3_1')
path

'/home/jovyan/work/data/gwas/tutorial/1_QC_GWAS/HapMap_3_r3_1'

In [88]:
%%time
ds = sgkp.read_plink(path, bim_sep='\t', fam_sep=' ', bim_int_contig=False)
ds

CPU times: user 1.33 s, sys: 127 ms, total: 1.46 s
Wall time: 1.45 s


In [89]:
x = ds['call/genotype']

In [157]:
import xarray as xr
from xarray import Dataset
from typing_extensions import Literal

Dimension = Literal['samples', 'variants']

def _swap(dim: Dimension) -> Dimension:
    return 'samples' if dim == 'variants' else 'variants'

def call_rate(ds: Dataset, dim: Dimension) -> Dataset:
    odim = _swap(dim)[:-1]
    n_called = (~ds['call/genotype_mask'].any(dim='ploidy')).sum(dim=dim)
    return xr.Dataset({
        f'{odim}/n_called': n_called,
        f'{odim}/call_rate': n_called / ds.dims[dim]
    })

def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
    odim = _swap(dim)[:-1]
    mask, gt = ds['call/genotype_mask'].any(dim='ploidy'), ds['call/genotype']
    non_ref = (gt > 0).any(dim='ploidy')
    hom_alt = ((gt > 0) & (gt[..., 0] == gt)).all(dim='ploidy')
    hom_ref = (gt == 0).all(dim='ploidy')
    het = ~(hom_alt | hom_ref)
    agg = lambda x: xr.where(mask, False, x).sum(dim=dim)
    return xr.Dataset({
        f'{odim}/n_het': agg(het),
        f'{odim}/n_hom_ref': agg(hom_ref),
        f'{odim}/n_hom_alt': agg(hom_alt),
        f'{odim}/n_non_ref': agg(non_ref)
    })

def allele_count(ds: Dataset) -> Dataset:
    # Collapse 3D calls into 2D array where calls are flattened into columns
    gt = ds['call/genotype'].stack(calls=('samples', 'ploidy'))
    mask = ds['call/genotype_mask'].stack(calls=('samples', 'ploidy'))
    
    # Count number of non-missing alleles (works with partial calls)
    an = (~mask).sum(dim='calls')
    # Count number of each individual allele 
    ac = xr.concat([
        xr.where(mask, 0, gt == i).sum(dim='calls')
        for i in range(ds.dims['alleles'])
    ], dim='alleles').T
    
    return xr.Dataset({
        'variant/allele_count': ac,
        'variant/allele_total': an,
        'variant/allele_frequency': ac / an
    })

def variant_stats(ds: Dataset) -> Dataset:
    return xr.merge([
        call_rate(ds, dim='samples'),
        genotype_count(ds, dim='samples'),
        allele_count(ds)
    ])

def sample_stats(ds: Dataset) -> Dataset:
    return xr.merge([
        call_rate(ds, dim='variants'),
        genotype_count(ds, dim='variants')
    ])

In [158]:
xr.set_options(display_width=80, display_style='text')

<xarray.core.options.set_options at 0x7f23ba1d9890>

In [159]:
ds = sgkit_plink.read_plink(path, bim_sep='\t', fam_sep=' ', bim_int_contig=False)
xr.merge([variant_stats(ds), sample_stats(ds)]).compute()

In [160]:
dss = variant_stats(ds)
dss

In [161]:
#ds_ac['variant/allele_frequency'].values[:10]

In [162]:
import pytest
import xarray as xr
import numpy as np
from numpy import ndarray
from sgkit.testing import simulate_genotype_call_dataset
from sgkit.stats.aggregation import allele_count

def get_dataset(calls, **kwargs):
    calls = np.asarray(calls)
    ds = simulate_genotype_call_dataset(
        n_variant=calls.shape[0], n_sample=calls.shape[1], **kwargs
    )
    dims = ds["call/genotype"].dims
    ds["call/genotype"] = xr.DataArray(calls, dims=dims)
    ds["call/genotype_mask"] = xr.DataArray(calls < 0, dims=dims)
    return ds

In [163]:
ds = get_dataset([ 
    [[2, 2], [1, 1], [0, 0]],
    [[0, 1], [1, 2], [2, 1]],
    [[-1, 0], [-1, 1], [-1, 2]],
    [[-1, -1], [-1, -1], [-1, -1]],
])

In [165]:
print(genotype_count(ds, dim='samples').to_dataframe().to_markdown())

|   variants |   variant/n_het |   variant/n_hom_ref |   variant/n_hom_alt |   variant/n_non_ref |
|-----------:|----------------:|--------------------:|--------------------:|--------------------:|
|          0 |               0 |                   1 |                   2 |                   2 |
|          1 |               3 |                   0 |                   0 |                   3 |
|          2 |               0 |                   0 |                   0 |                   0 |
|          3 |               0 |                   0 |                   0 |                   0 |


In [85]:
def apply(ds, fn):
    return ds.merge(fn(ds))
ds.pipe(apply, allele_count)