In [None]:
import seaborn as sns
import glob
import os
import re
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
sns.set_context("talk")
%matplotlib inline

In [None]:
import pysam

nt_pal = sns.xkcd_palette(["medium green", "sky blue", "goldenrod", "red",
                           "light purple"])
nt2i = dict((k,n) for n,k in enumerate("ACGTN"))


def target_from_ref(seqname):
    '''Return the gene, strand and exon from guide fasta sequence name'''
    try:
        chrom, start, end, gene, strand, x, exon = seqname.split('_')
    except ValueError:
        #controls
        try:
            chrom, exon, end, gene, strand, x = seqname.split('_')
        except ValueError:
            raise RuntimeError("Could not parse sequence name: {}".format(
                seqname))
    strand = strand.replace('~', '-')
    return (gene, strand, exon)


def nucleotide_counts_from_fasta(ref_fasta):
    ''' Assumes all FASTA records are the same length'''
    nt_counts = [] #2-d - 1st dimension is Cycle, 2nd is Nucleotide (A,C,G,T,N)
    with open (ref_fasta, 'rt') as fa:
        c = 0
        for line in fa:
            if line.startswith(">"):
                c = 0
                continue
            if not nt_counts:
                for i,n in enumerate(line.rstrip()):
                    nt_counts.append([0] * 5)
            for i,n in enumerate(line.rstrip()):
                nt_counts[i][nt2i[n]] += 1
    ref_fa_nt_counts = defaultdict(list)
    for i,row in enumerate(nt_counts):
        for j,n in enumerate("ACGTN"):
            ref_fa_nt_counts["Cycle"].append(i+1)
            ref_fa_nt_counts["Nucleotide"].append(n)
            ref_fa_nt_counts["Count"].append(row[j])
    return pd.DataFrame.from_dict(ref_fa_nt_counts)


def plot_fasta_nucleotide_counts(ref_fasta,
                                 title="Guide Sequences Nucleotide Counts",
                                 palette=None):
    '''
        Create a stacked bar plot of nucleotide frequencies from a
        reference FASTA file of equal lengthed sequences.

        Args:
            ref_fasta:  Path to FASTA file

            title:      Title for plot. Default = "Guide Sequences
                        Nucleotide Counts"

            palette:    Name of palette to use if the default colorscheme
                        is not suitable for your purposes. Must available
                        via seaborn's color_palette method (see
                        https://seaborn.pydata.org/tutorial/color_palettes.html)

    '''
    if palette is None:
        pal = nt_pal
    else:
        pal = sns.color_palette(palette, 5)
    ref_fa_nt_counts = nucleotide_counts_from_fasta(ref_fasta)
    plt.figure(figsize=(12,9))
    margin_bottom = np.zeros(len(ref_fa_nt_counts.Cycle.drop_duplicates()))
    for nt,i in nt2i.items():
        color = pal[i]
        tmp_df = ref_fa_nt_counts[(ref_fa_nt_counts.Nucleotide == nt)]
        plt.bar(tmp_df.Cycle.values, tmp_df.Count.values, color=color,
                edgecolor='white', width=1, label=nt, bottom=margin_bottom)
        margin_bottom += tmp_df.Count.values
    plt.xlabel("Cycle")
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.title(title)
    ylim = plt.ylim()
    new_ylim = [ylim[0], ylim[1] * 1.05]
    plt.ylim(new_ylim)
    plt.show()


In [None]:
def cycle_counts_from_fq(fq,
                         progress_interval=1000000):
    '''
        Returns a dataframe of nucleotide counts per cycle for FASTQ in fq_dir.
        Assumes all FASTQs end '.fastq.gz'. Specify a different value for
        'extension' if your FASTQs do not end with this extension.

        Args:
            fq_dir:    Directory containing FASTQs to analyze

    '''
    fq_nt_counts = dict()
    with pysam.FastxFile(fq) as fqfile:
        nt_counts = []  #2-d - 1st dimension is Cycle, 2nd is NT (A,C,G,T,N)
        first = next(fqfile)
        for i, n in enumerate(first.sequence):
            nt_counts.append([0] * 5)
            nt_counts[i][nt2i[n]] += 1
        records = 1
        for entry in fqfile:
            if (progress_interval is not None
                    and records % progress_interval == 0):
                print("Processed {:,} records for {}".format(
                    records, bn))
            for i, n in enumerate(
                    entry.sequence):  #assume all same length(?)
                try:
                    nt_counts[i][nt2i[n]] += 1
                except IndexError:
                    nt_counts.append([0] * 5)
                    nt_counts[i][nt2i[n]] += 1
            records += 1
        print("Finished processing {:,} records for {}".format(
              records, fq))
        return nt_counts
    


def plot_cycle_counts(cycle_counts, skip_samples=[], palette=None):
    '''
        Create a stacked bar plot of nucleotide frequencies from a
        set of FASTQ files of equal lengthed sequences.

        Args:
            cycle_counts:
                        Dataframe of cycle counts as produced by the
                        cycle_counts_from_fq function.

            skip_samples:
                        List of filenames to ignore.

            palette:    Name of palette to use if the default colorscheme
                        is not suitable for your purposes. Must available
                        via seaborn's color_palette method (see
                        https://seaborn.pydata.org/tutorial/color_palettes.html)

    '''

    if palette is None:
        pal = nt_pal
    else:
        pal = sns.color_palette(palette, 5)
    for fq in [
            x for x in sorted(np.unique(cycle_counts.File))
            if x not in skip_samples
    ]:
        plt.figure(figsize=(12, 9))
        margin_bottom = np.zeros(len(cycle_counts.Cycle.drop_duplicates()))
        for nt, i in nt2i.items():
            color = pal[i]
            tmp_df = cycle_counts[(cycle_counts.File == fq)
                                  & (cycle_counts.Nucleotide == nt)]
            plt.bar(tmp_df.Cycle.values,
                    tmp_df.Count.values,
                    color=color,
                    edgecolor='white',
                    width=1,
                    label=nt,
                    bottom=margin_bottom)
            margin_bottom += tmp_df.Count.values
        plt.xlabel("Cycle")
        plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        plt.title("Fastq {}".format(fq))
        ylim = plt.ylim()
        new_ylim = [ylim[0], ylim[1] * 1.05]
        plt.ylim(new_ylim)
        plt.show()


In [None]:
def get_guide_coverage(bam, ref, minimum_aligned_length=0,
                       minimum_fraction_aligned=0.0):
    '''
        Determine the number of reads aligned to sequences in FASTA reference
        file in BAM alignment file. Returns a pandas dataframe of Target names
        and read counts.

        Args:
            bam:    Alignment file in BAM/SAM format

            ref:    FASTA reference file (same as used for creating BAM/SAM
                    alignment)

            minimum_aligned_length:
                    Minimum number of bases that must have been aligned in
                    order to count a record. Default=0

            minimum_fraction_aligned:
                    Minimum fraction of bases that must have been aligned in
                    order to count a record. Default=0.0

    '''
    bamfile = pysam.AlignmentFile(bam)
    fasta = pysam.FastaFile(ref)
    total_reads = 0
    unmapped = 0
    counts = defaultdict(int)
    mapping = {'mapped': [0], 'unmapped': [0], 'filtered': [0], 'passed': [0]}
    for read in bamfile.fetch(until_eof=True):
        total_reads += 1
        if read.is_unmapped:
            mapping['unmapped'][0] += 1
        else:
            mapping['mapped'][0] += 1
            if minimum_aligned_length:
                if read.query_alignment_length < minimum_aligned_length:
                    mapping['filtered'][0] += 1
                    continue
            if minimum_fraction_aligned:
                if (read.query_alignment_length/read.query_length <
                        minimum_fraction_aligned):
                    mapping['filtered'][0] += 1
                    continue
            counts[read.reference_name] += 1
            mapping['passed'][0] += 1
    print("{:,}/{:,} unmapped reads ({:g})".format(
        mapping['unmapped'][0], total_reads,
        mapping['unmapped'][0]/total_reads))
    count_rows = {'Target': [], 'Coverage': []}
    for seq in fasta.references:
        count_rows['Target'].append(seq)
        if seq in counts:
            count_rows['Coverage'].append(counts[seq])
        else:
            count_rows['Coverage'].append(0)
    df = pd.DataFrame.from_dict(count_rows)
    df.Target = df.Target.apply(lambda x: x.replace('~', '-'))
    df['Gene'] = df.Target.apply(lambda x: target_from_ref(x)[0])
    mapping_df = pd.DataFrame.from_dict(mapping)
    return df, mapping_df

def guide_coverage_from_bams(bam_dir, ref, extensions='.bam',
                             skip_bams=[], minimum_aligned_length=0,
                             minimum_fraction_aligned=0.0):
    '''Returns dataframe of alignment counts for all BAMs in bam_dir.'''
    bam2counts = pd.DataFrame()
    bam2mapped = pd.DataFrame()
    skip_bams = [os.path.join(bam_dir, x) for x in skip_bams]
    for bam in (x for x in glob.glob('{}/*bam'.format(bam_dir)) if x not in
                skip_bams):
        print("Processing {}".format(bam))
        tmp_df, tmp_map_df = get_guide_coverage(bam, ref,
                                                minimum_aligned_length,
                                                minimum_fraction_aligned)
        sample = os.path.basename(bam).rsplit(extensions, 1)[0]
        tmp_df["Sample"] = sample
        tmp_map_df["Sample"] = sample
        bam2counts = pd.concat([bam2counts, tmp_df])
        bam2mapped = pd.concat([bam2mapped, tmp_map_df])
    return bam2counts,bam2mapped


## Visualise the nucleotide content for all our guide sequences

In [None]:
for ref_fasta in glob.glob("../../data/*.fa"):
    plot_fasta_nucleotide_counts(ref_fasta,
                                 title = os.path.basename(ref_fasta))

## Calculate our cycle counts for each of our fastq files in our fastq directory

You can explore the dataframe generated as desired. We will plot nucleotide content per cycle.

You may also want to use FASTQC to generate plots of other, more detailed metrics 

In [None]:
# units_file = os.path.join("../../", snakemake.config["units"])
# samples_file = os.path.join("../../", snakemake.config["samples"])
units_file = "../../config/units.tsv"
samples_file = "../../config/samples.tsv"
fq_df = pd.read_csv(units_file, sep='\t')
unit2sample = dict(zip(fq_df.sample_name + "-" + fq_df.unit_name,
                       fq_df.sample_name))

In [None]:
os.makedirs("../../results", exist_ok=True)
fq_counts = dict()
for fq in fq_df.fq1:
    fq_counts[fq] = cycle_counts_from_fq(fq,
                                            progress_interval=5e6)
print("Finished processing FASTQs. Creating dataframe.")
cycle_counts = defaultdict(list)
for fq, counts in fq_counts.items():
    fn = os.path.basename(fq).rstrip('.fastq.gz')
    for i, row in enumerate(counts):
        for j, n in enumerate("ACGTN"):
            cycle_counts["File"].append(fn)
            cycle_counts["Cycle"].append(i + 1)
            cycle_counts["Nucleotide"].append(n)
            cycle_counts["Count"].append(row[j])
cycle_counts = pd.DataFrame.from_dict(cycle_counts)
cycle_counts.to_csv("../../results/cycle_counts.csv", index=False)
cycle_counts.head()

In [None]:
plot_cycle_counts(cycle_counts)

## Retrieve Read Counts for each Guide

In [None]:
bam2counts, bam2mapped = guide_coverage_from_bams("../../alignments/", ref_fasta,
                                                  extensions='.bam',
                                                  minimum_aligned_length=15,
                                                  minimum_fraction_aligned=0.9)

In [None]:
bam2counts

In [None]:
regex = re.compile(r"_S\d+$")
bam2counts['Sample'] = bam2counts.Sample.map(unit2sample)
bam2counts.head()

In [None]:
#write to CSV, can regenerate dataframe using pd.read_csv('read_counts.csv')
bam2counts.to_csv("../../results/read_counts.csv", index=False) 

In [None]:
bam2mapped['Sample'] = bam2mapped.Sample.map(unit2sample)
bam2mapped

In [None]:
#write to CSV, can regenerate dataframe using pd.read_csv('mapping_counts.csv')
bam2mapped.to_csv("../../results/mapping_counts.csv", index=False)

## Plot distribution of coverage of each guide for each sample 

In [None]:
samples = np.sort(np.unique(bam2counts.Sample.values)) 
#alternatively manually enter a list of sample IDs in your preferred order - e.g. samples = ['sample_1', 'sample_2']
pal = sns.color_palette("colorblind", len(samples))
for i in range(len(samples)):
    fig = plt.figure(figsize=(12,6))
    sns.histplot(bam2counts[bam2counts.Sample == samples[i]].Coverage,
                 kde=False, color=pal[i], label=samples[i])
    plt.legend()
    plt.ylabel('Guides')

In [None]:
#number of guides with at least one read mapped for first sample
len(bam2counts[(bam2counts.Sample == samples[0]) &
               (bam2counts.Coverage != 0)])

In [None]:
#number of guides with no reads mapped for first sample
len(bam2counts[(bam2counts.Sample == samples[0]) &
               (bam2counts.Coverage == 0)])

## Calculate and plot distribution of fractions of total mapped reads for each guide

In [None]:
samp2mapped = dict()
for lib in bam2mapped.Sample.unique():
    samp2mapped[lib] = bam2mapped[bam2mapped.Sample == lib]['mapped'].values[0]
    
bam2counts["Fraction Reads Mapped"] = bam2counts.apply(
    lambda x: x.Coverage/samp2mapped[x.Sample],
    axis=1)

In [None]:
g = sns.FacetGrid(bam2counts, 
                  row="Sample", hue="Sample", palette=pal,
                  row_order=samples, hue_order=samples,
                  height=5,aspect=2, sharey=False)
g = g.map(plt.hist, "Fraction Reads Mapped", bins=50, )
for ax in g.axes.flatten():
    ax.tick_params(labelbottom=True, labelleft=True)
    ax.set_ylabel("Guides")
g.fig.tight_layout()
plt.savefig("../../results/t0_guide_representation_hist.png")

## Create pivot table suitable for downstream analysis

In [None]:
bam2counts.head()

In [None]:
bam2counts_pivot = bam2counts.pivot_table(index="Target", columns="Sample")['Coverage'].reset_index()
bam2counts_pivot.columns.name = None
bam2counts_pivot.head()

In [None]:
bam2counts_pivot['Gene'] = bam2counts_pivot.Target.apply(lambda x: target_from_ref(x)[0])
bam2counts_pivot.to_csv("../../results/read_counts_pivot.txt", sep='\t', index=False)
bam2counts_pivot.head()

In [None]:
bam2frac_pivot = bam2counts.pivot_table(index="Target",
                                        columns="Sample")['Fraction Reads Mapped'].reset_index()
bam2frac_pivot.columns.name = None
bam2frac_pivot.head()
bam2frac_pivot['Gene'] = bam2frac_pivot.Target.apply(lambda x: target_from_ref(x)[0])
bam2frac_pivot.to_csv("../../results/fraction_counts_pivot.txt", sep='\t', index=False)
bam2frac_pivot.head()