In [None]:
import pysam
import xarray as xr
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline

In [None]:
plt.rcParams['figure.figsize'] = [8, 4]
plt.rcParams['font.size'] = 12

In [None]:
def filter_reads(bam, include=0, exclude=0):
    for read in bam:
        flag = read.flag
        if flag & include == include and flag & exclude == 0:
            yield read


def gen_read_pairs(bam):
    pairs = dict()
    for read in bam:
        qname = read.qname
        if qname in pairs:
            if read.is_read1:
                yield read, pairs[qname]
                del[qname]
            elif read.is_read2:
                yield pairs[qname], read
                del[qname]
        else:
            pairs[qname] = read
            if read.is_read1:
                pairs[qname] = read
            elif read.is_read2:
                pairs[qname] = read

In [None]:
in_bam = (
    '../../data/ftp.1000genomes.ebi.ac.uk/vol1/ftp/phase3/data/'
    'HG00096/alignment/HG00096.mapped.ILLUMINA.bwa.GBR.low_coverage.20120522.bam')

with pysam.AlignmentFile(in_bam) as bam:
    print(bam.references)
    header = bam.header
    for read1, read2 in gen_read_pairs(filter_reads(bam, include=2, exclude=1804)):
        break

In [None]:
def generate_bins(bam_header, bin_size):
    for ref in bam_header.references:
        ref_len = bam_header.get_reference_length(ref)
        bin_bounds = [
            (i, i + bin_size)
            for i in range(0, ref_len, bin_size)
        ]
        if bin_bounds[-1][1] != ref_len:
            bin_bounds.append((bin_bounds[-1][1], ref_len))
        yield ref, bin_bounds

In [None]:
def make_index(bins):
    refs, starts, stops = [], [], []
    for ref, bounds in bins:
        for start, stop in bounds:
            refs.append(ref)
            starts.append(start)
            stops.append(stop)
    return {
        'contig': refs,
        'start': starts,
        'stop': stops
    }

In [None]:
idx = make_index(generate_bins(header, 50_000))

In [None]:
bins = dict(generate_bins(header, 50_000))

In [None]:
lengths = xr.DataArray([len(bins[contig]) for contig in bins], coords=idx, dims=list(idx.keys()))

In [None]:
plt.bar(list(bins.keys()), lengths.data)
plt.xticks(rotation=90)
plt.show()

In [None]:
print(*read1.to_string().split('\t')[:9], sep='\t')

In [None]:
print(*read2.to_string().split('\t')[:9], sep='\t')

In [None]:
def count_reads(bam, binsize):
    with pysam.AlignmentFile(in_bam) as bam:
        for read_pair in gen_read_pairs(filter_reads(bam, include=2, exclude=1804)):
            f_read, r_read = sorted(read_pair, key=lambda read: read.is_reverse)
#             f_read, r_read = read_pair
            yield f_read, r_read

In [None]:
cri = count_reads(in_bam, 50_000)
counts = np.zeros(len(bins['1']), dtype=np.uint64)
count = 0
curr_bin = 0
for read1, read2 in cri:
    start, stop = read1.pos, read2.reference_end
    # if spanning, pick first bin
    _bin = sorted({start // 50_000, stop // 50_000})[0]
    if _bin != curr_bin:
        counts[_bin] = count
        count = 0
        curr_bin = _bin
    count += 1
    if read1.rname != 0:
        break

In [None]:
counts.sum()

In [None]:
read1.pos // 50_000

In [None]:
read2.reference_end - read1.pos

In [None]:
read1.pos + read1.tlen

In [None]:
plt.hist(counts, bins=100)
plt.show()

In [None]:
import pandas as pd

In [None]:
# df = pd.DataFrame([(contig, start, stop) for contig in bins for start, stop in bins[contig] if contig == '1'], columns=['contig', 'start', 'stop'])

In [None]:
xr.DataArray(counts, [('intervals', list(df.index))])

In [None]:
# xr.DataArray(counts, [('intervals', [(contig, start, stop) for contig in bins for start, stop in bins[contig] if contig == '1'])])

In [None]:
plt.plot(counts / counts.sum(), 'C0.', alpha=.4)
plt.show()

http://xarray.pydata.org/en/stable/indexing.html

https://samtools.github.io/hts-specs/SAMv1.pdf

https://github.com/pydata/xarray/issues/1603

In [None]:
normed_counts = counts / counts.sum()

In [None]:
normed_counts.mean()

In [None]:
plt.hist(normed_counts, bins=100)
plt.axvline(normed_counts.mean(), c='r')
plt.axvline(np.median(normed_counts), c='y')
plt.show()

In [None]:
ds = xr.Dataset(
    coords=dict(
        pos=list(range(len(counts))),
        sample=['HG00096'],
    ),
    data_vars=dict(
        contig=('pos', np.array(contigs, dtype=np.str)),
        start=('pos', np.array(starts, dtype=np.uint32)),
        end=('pos', np.array(ends, dtype=np.uint32)),
        depth=(
            ('pos', 'sample'),
            np.array([counts], dtype=np.uint64).transpose()),
    )
)
ds

In [None]:
np.append()

In [None]:
bins['1'][-3:]

In [None]:
np.

In [None]:
def generate_bins(bam_header, bin_size):
    starts, stops, contigs = [], [], []
    for ref in bam_header.references:
        ref_len = bam_header.get_reference_length(ref)
        start = np.arange(0, ref_len, bin_size)
        stop = np.append(np.arange(bin_size, ref_len, bin_size), ref_len)
        contig = np.array([ref] * start.shape[0])
        starts.append(start)
        stops.append(stops)
        contigs.append(contig)
    return np.hstack(contigs), np.hstack(starts), np.hstack(stops)

In [None]:
# generate_bins(header, 50_000)