In [None]:
import numpy as np
import pysam
import matplotlib.pyplot as plt
import time

In [None]:
%matplotlib inline

In [None]:
plt.rcParams['figure.figsize'] = [8, 4]
plt.rcParams['font.size'] = 12

In [None]:
def filter_reads(bam, include=0, exclude=0):
    for read in bam:
        flag = read.flag
        if flag & include == include and flag & exclude == 0:
            yield read


def gen_read_pairs(bam):
    pairs = dict()
    for read in bam:
        qname = read.qname
        if qname in pairs:
            if read.is_read1:
                yield read, pairs.pop(qname)
            elif read.is_read2:
                yield pairs.pop(qname), read
        else:
            pairs[qname] = read


def generate_bins(bam_header, bin_size):
    for ref in bam_header.references:
        ref_len = bam_header.get_reference_length(ref)
        bin_bounds = [
            (i, i + bin_size)
            for i in range(0, ref_len, bin_size)
        ]
        if bin_bounds[-1][1] != ref_len:
            bin_bounds.append((bin_bounds[-1][1], ref_len))
        yield ref, bin_bounds


def first(it):
    i = iter(it)
    try:
        return next(i)
    except StopIteration:
        raise Exception(f'Expected at least one item, got {it}')


def head(it, n=5):
    for i, val in enumerate(it, 1):
        if i > n:
            break
        yield val


def one(it):
    i = iter(it)
    err_msg = f'Expected one item, got {it}'
    try:
        ret_val = next(i)
    except StopIteration:
        raise Exception(err_msg)
    try:
        next(i)
    except StopIteration:
        return ret_val
    raise Exception(err_msg)


def count_reads(bam_filename, binsize):
    with pysam.AlignmentFile(bam_filename) as bam:
        bins = dict(generate_bins(bam.header, binsize))
        counts = {
            contig: np.zeros(len(bins[contig]), dtype=np.uint64)
            for contig in bins
        }
        rcg = gen_read_pairs(filter_reads(bam, include=2, exclude=1804))
        f_read, r_read = sorted(next(rcg), key=lambda read: read.is_reverse)
        start, stop = f_read.pos, r_read.reference_end
        # always pick first bin
        curr_bin = start // 50_000
        # proper pair reads, no need to validate
        curr_contig = bam.get_reference_name(f_read.reference_id)
        count = counts[curr_contig][curr_bin]

        for read_pair in rcg:
            # TODO: same as above
            f_read, r_read = sorted(read_pair, key=lambda read: read.is_reverse)
            start, stop = f_read.pos, r_read.reference_end
            count_bin = start // 50_000
            count_contig = bam.get_reference_name(f_read.reference_id)

            if count_bin != curr_bin or count_contig != curr_contig:
                counts[curr_contig][curr_bin] = count
                count = counts[count_contig][count_bin]
                curr_bin = count_bin
                curr_contig = count_contig
            count += 1

        return counts

In [None]:
in_bam = (
    '../../data/ftp.1000genomes.ebi.ac.uk/vol1/ftp/phase3/data/'
    'HG00096/alignment/HG00096.mapped.ILLUMINA.bwa.GBR.low_coverage.20120522.bam')

In [None]:
# with pysam.AlignmentFile(in_bam) as bam:
#     for r1, r2 in gen_read_pairs(filter_reads(bam, include=2, exclude=1804)):
#         pass

In [None]:
# plt.bar(list(bins.keys()), np.array([len(bins[contig]) for contig in bins]))
# plt.xticks(rotation=90)
# plt.show()

In [None]:
def first(it):
    i = iter(it)
    try:
        return next(i)
    except StopIteration:
        raise Exception(f'Expected at least one item, got {it}')


def head(it, n=5):
    for i, val in enumerate(it, 1):
        if i > n:
            break
        yield val


def one(it):
    i = iter(it)
    err_msg = f'Expected one item, got {it}'
    try:
        ret_val = next(i)
    except StopIteration:
        raise Exception(err_msg)
    try:
        next(i)
    except StopIteration:
        return ret_val
    raise Exception(err_msg)

In [None]:
start_time = time.time()
readcounts = count_reads(in_bam, 50_000)
end_time = time.time()
print(f'time in seconds: {end_time - start_time}')

In [None]:
counts = readcounts['1']

In [None]:
plt.hist(counts, bins=100)
plt.show()

In [None]:
for contig, counts in readcounts.items():
#     plt.plot(counts / counts.sum(), 'C0.', alpha=.4)
    normed_counts = counts / counts.sum()
    baseline = normed_counts.mean()
    ax = plt.subplot()
    ax.plot(normed_counts, 'C0.', alpha=.4)
    ax.plot(-(baseline - normed_counts), 'r', alpha=.4)
    ax.set_title(contig)
    plt.show()

http://xarray.pydata.org/en/stable/indexing.html

https://samtools.github.io/hts-specs/SAMv1.pdf

https://github.com/pydata/xarray/issues/1603

In [None]:
normed_counts = counts / counts.sum()

In [None]:
normed_counts.mean()

In [None]:
plt.hist(normed_counts, bins=100)
plt.axvline(normed_counts.mean(), c='r')
plt.axvline(np.median(normed_counts), c='y')
plt.show()