# Python in Life Sciences


The aim o the tutorial is to show usage of the main Python tools used in bioinformatics contexts and to present a workflow of preparing a simple CLI application.


Outline:
1. Prototyping
    * processing SAM files (pysam)
    * processing TSV files (pandas, pyarrow)
    * adding multiprocessing 
2. CLI application counting gRNAs from alignment files
    
Problem:
Let's assume we have a data from the CRISPR screen experimemt. For the purpose of the tutorial we are going to focus on one step of the data processing - counting genes indentified by guide RNA sequences aligned to the library. The real analysis requires to start with demultiplexing, trimming the reads and performing and alignment to a library. Let's assume we have these steps performed already.

Task: process the SAM file in order to get counts of genes indentified by the gRNAs. Write results to a TSV file and visualize them.

In [0]:
!curl -Lk -o data.zip 'https://github.com/barbarakalinowska/tutorial-python/blob/master/data.zip?raw=true'
!unzip data.zip

## Install required packages


In [0]:
! pip install pysam biopython pandas pyarrow matplotlib seaborn

Upload the data and locate them in /content/data/ folder.

## Processing SAM files

#### Approach 1: parse SAM file with pysam module read by read.

Firstly, let's import some basic packages.

In [0]:
import csv
import glob
import os
import matplotlib.pyplot as plt
import multiprocessing as mp
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pysam
import seaborn as sns

from Bio import SeqIO
from collections import defaultdict

import warnings
warnings.filterwarnings('ignore')

Assuming someone would like to parse the library in order to obtain all gRNAs identifiers, the easiest way is to parse FASTA with with the library with the Biopython module.

Convert the content of the `screen_library.fasta` to dictionary and count gRNA sequences:

In [0]:
library = "data/screen_library.fasta"

gRNA_dict = SeqIO.to_dict(SeqIO.parse(library, "fasta"))


Count gRNA identifiers for each unique gene:

In [0]:
genes_count = defaultdict(int)
with open(library, "r") as handle:
    for record in SeqIO.parse(handle, "fasta"):
        genes_count[record.id.split("_")[0]] += 1

Parse the SAM file and print some information about each read:

In [0]:
sam = pysam.AlignmentFile("data/sam_files/sample1.sam", "r")
for read in sam.fetch():
    print(read.reference_name, read.reference_name, read.is_unmapped)
    print(read.query_name, read.query_sequence, read.query_length, read.get_tags())
sam.close()

Count reads mapped to gRNA sequences per gene.

In [0]:
genes_count = defaultdict(int)
sam_file = "data/sam_files/sample1.sam"
sam = pysam.AlignmentFile(sam_file, "r")
for read in sam.fetch():
    if not read.is_unmapped:
        genes_count[read.reference_name.split("_", 1)[0]] += 1
sam.close()

Add filters for reads - maximum read length - 20 and the maximum number of mismatches - 1:

In [0]:
genes_count = defaultdict(int)
sam_file = "data/sam_files/sample1.sam"
sam = pysam.AlignmentFile(sam_file, "r")
for read in sam.fetch():
    if not read.is_unmapped and read.qlen <= 20 and read.get_tag('NM') <= 1:
        genes_count[read.reference_name.split("_", 1)[0]] += 1
sam.close()

Now, we save the results in a TSV file.

In [0]:
counts_report = 'report_test.tsv'
header = ["gene", "count"]
with open(counts_report, 'w') as csvfile:
    gene_count_csv = csv.writer(csvfile, delimiter='\t')
    gene_count_csv.writerow(header)
    for gene in sorted(genes_count.keys()):
        gene_count_csv.writerow([gene, genes_count[gene]])   

Let's gather all the steps into functions:

In [0]:
def write_report(genes_count, report):
    header = ["gene", "count"]
    with open(report, 'w') as csvfile:
        gene_count_csv = csv.writer(csvfile, delimiter='\t')
        gene_count_csv.writerow(header)
        for gene in sorted(genes_count.keys()):
            gene_count_csv.writerow([gene, genes_count[gene]])  

            
def count_genes_pysam(sam_aln, report):
    sam = pysam.AlignmentFile(sam_aln, "rb")
    genes_count = defaultdict(int)
    for read in sam.fetch():
        if not read.is_unmapped and read.qlen <= 20 and read.get_tag('NM') <= 1:
            genes_count[read.reference_name.split("_", 1)[0]] += 1
    sam.close()
    write_report(genes_count, report)

And check the function's performance:

In [0]:
%timeit count_genes_pysam("data/sam_files/sample1.sam", "counts_report1.tsv")

## Processing TSV files

#### Approach 2: treat the SAM file as a TSV file.

We are going to use pandas to process the TSV file quickly.

Because the SAM file do not have a constant number of columns, we need to define the columns before we read in the file.

In [0]:
column_names = ["read_id", "flags_sum", "ref", "pos", "quality", "cigar", 
        "ref_aln", "aln_pos", "insert", "read_seq", "aln", 
        "opt1", "opt2", "opt3", "opt4", "opt5", "opt6", "opt7" ]

In [0]:
aln = pd.read_csv("data/sam_files/sample1.sam", delimiter="\t", names=column_names, comment="@", 
                  index_col=False, compression='infer')

In [0]:
aln.to_csv("data/tsv_files/sample1.tsv.gz", sep="\t", index=False, compression="gzip")

Let's quickly investigate the dataframe:

In [0]:
aln.head(10)

In [0]:
aln.shape

We would like to keep only aligned reads, so we can filter out the unmapped reads:

In [0]:
aln['cigar'].unique()

The unmapped reads should be removed, so filter out the rows with '\*' in the 'cigar' field and print the new number of rows: 

In [0]:
aln_filtered = aln[aln['cigar'] != "*"]

Add columne "gene" containing genes identifiers:

In [0]:
aln_filtered['gene'] = aln_filtered['ref'].str.split("_", 1, expand=True)[0]

In [0]:
aln_filtered.head()

Add filtering by the maximum number of mimatches - 1:

In [0]:
# Firstly, the rows with NaN values in the 'opt3' columns are removed
aln_filtered = aln_filtered[~aln_filtered.opt3.isna()]
# The column 'MN' is added with a value provided by the 'MN' tag
aln_filtered['MN'] = aln_filtered['opt3'].str.split(":", expand=True)[2].astype(int)
# The rows with more than 1 mismatch are removed
aln_filtered = aln_filtered[aln_filtered.MN <= 1]

Now, let's use aggregation function to count occurence of genes:

In [0]:
genes_count = aln_filtered.groupby('gene').size()

In [0]:
genes_count.to_csv("counts_report2.tsv", sep="\t", header=["count"])

All the processing will be now gathered as a single function:

In [0]:
def count_genes_pandas(sam_aln, report, cols):
    aln = pd.read_csv(sam_aln, delimiter="\t", names=cols, comment="@", 
                  index_col=False, compression='infer')
    aln = aln[aln['cigar'] != "*"]
    aln['gene'] = aln['ref'].str.split("_", 1, expand=True)[0]
    genes_count = aln.groupby('gene').size()
    genes_count.to_csv(report, sep="\t", header=["count"])

In [0]:
%time count_genes_pandas("data/sam_files/sample1.sam", "count_report_pd.tsv", column_names)

## Parquet files

#### Approach 3: converting the file into parquet files.
In this approach we are testing what may be a benefit of storing the data in a parquet files instead of TSV files.

In [0]:
pd_file = pd.read_table('data/sam_files/sample1.sam', comment="@", 
                       names=column_names, index_col=False, compression='infer')
pq.write_table(pa.Table.from_pandas(pd_file), 'data/sample1.pq', compression='snappy')

In [0]:
def count_genes_pq(pq_file, report):
    aln = pd.read_parquet(pq_file, use_threads=True)
    aln = aln[aln['cigar'] != "*"]
    aln['gene'] = aln['ref'].str.split("_", 1, expand=True)[0]
    genes_count = aln.groupby('gene').size().to_frame(name="count")
    genes_count.to_csv(report, sep="\t", header=["count"])

In [0]:
%time count_genes_pq("data/sample1.pq", "data/count_report.tsv")

## Multiprocessing

#### Approach 4: dividing a file into chunks

This approach may be especially useful if large files will be processed and when someone would like to avoid reading the whole file into memory.  Also, if someone would like to process the large file and keep intermediate results.
In this approach we are going to use pyarrow module, which allows us to write data into parquet files. One of very important advantage of this operation is improving the peformance.

Firstly, devide the sam file into chunks:

In [0]:
reader = pd.read_table('data/sam_files/sample1.sam', chunksize=1e3, comment="@", 
                       names=column_names, index_col=False, compression='infer')

for chunk_no, chunk in enumerate(reader):
    pq.write_table(pa.Table.from_pandas(chunk),
    os.path.join('data/pq_files_in', 'aln-{:04d}.parquet'.format(chunk_no)), compression='snappy')

Prepare a function which returns a list of genes from gRNA IDs to which reads were mapped:

In [0]:
def extract_genes(df):
    df = df[df['cigar'] != "*"]
    df['gene'] = df['ref'].str.split("_", 1, expand=True)[0]
    return df[['gene']]

Prepare function which extract genes from parquet file and saves it as an intermediate output:

In [0]:
def process_aln(filename, out_dir):
    chunk = pq.read_table(filename, use_threads=True).to_pandas()
    chunk_genes = extract_genes(chunk)
    pq.write_table(pa.Table.from_pandas(chunk_genes), 
                   os.path.join(out_dir, os.path.basename(filename)), 
                   compression='snappy')

Extract genes from all parquet files using multiprocessing:

In [0]:
%%timeit
pool = mp.Pool()
out_dir = "data/pq_files_out"
in_dir = "data/pq_files_in"
for filename in glob.glob(os.path.join(in_dir, '*.parquet')):
    pool.apply_async(process_aln, args=(filename, out_dir, ))
pool.close()
pool.join()

Add merging the outputs to obtain the final report:

In [0]:
%%time
df = pq.read_table('data/pq_files_out/', use_threads=True).to_pandas()
genes = df.groupby('gene').size().to_frame(name="count")
genes.to_csv('counts_report_pq.tsv', sep="\t")

Finally, measure the time of processing the file:

In [0]:
%%time
pool = mp.Pool()
out_dir = "data/pq_files_out"
in_dir = "data/pq_files_in"
for filename in glob.glob(os.path.join(in_dir, '*.parquet')):
    pool.apply_async(process_aln, args=(filename, out_dir, ))
pool.close()
pool.join()

df = pq.read_table('data/pq_files_out/', use_threads=True).to_pandas()
genes = df.groupby('gene').size().to_frame(name="count")
genes.to_csv('counts_report_pq.tsv', sep="\t")

## Plotting

Now, we are going to process all four samples and then the results will be plotted.

In [0]:
sam_dir = "./data/sam_files/"
sam_files = glob.glob(os.path.join(sam_dir, '*.sam'))

In [0]:
print(sam_files)

In [0]:
def count_genes_pandas(sam_aln, report, cols):
    aln = pd.read_csv(sam_aln, delimiter="\t", names=cols, comment="@", 
                  index_col=False, compression='infer')
    aln = aln[aln['cigar'] != "*"]
    aln['gene'] = aln['ref'].str.split("_", 1, expand=True)[0]
    genes_count = aln.groupby('gene').size()
    genes_count.to_csv(report, sep="\t", header=["count"])

In [0]:
for sam_file in sam_files:
    sample = os.path.basename(sam_file).split(".")[0]
    print(sample)
    count_genes_pandas(sam_file, os.path.join("data/tsv_files/", sample+"_report.tsv"), column_names)

In [0]:
gene_counts_1 = pd.read_csv("data/tsv_files/sample1_report.tsv", sep="\t")
gene_counts_1 = gene_counts_1.set_index('gene')
gene_counts_2 = pd.read_csv("data/tsv_files/sample2_report.tsv", sep="\t")
gene_counts_2 = gene_counts_2.set_index("gene")
gene_counts_3 = pd.read_csv("data/tsv_files/sample3_report.tsv", sep="\t")
gene_counts_3 = gene_counts_3.set_index("gene")
gene_counts_4 = pd.read_csv("data/tsv_files/sample4_report.tsv", sep="\t")
gene_counts_4 = gene_counts_4.set_index("gene")

In [0]:
gene_counts_all =  pd.merge(gene_counts_1, gene_counts_2,
                            on="gene", how="outer", suffixes=['_1', '_2'])

In [0]:
gene_counts_all_p2 =  pd.merge(gene_counts_3, gene_counts_4,
                            on="gene", how="outer", suffixes=['_3', '_4'])

In [0]:
gene_counts_all = pd.merge(gene_counts_all, gene_counts_all_p2, on="gene")

In [0]:
gene_counts_all.head()

In [0]:
gene_counts_all.columns = ["sample1", "sample2", "sample3", "sample4"]

In [0]:
sns.set(style="whitegrid")
data = pd.melt(gene_counts_all)
ax = sns.boxplot(x="variable", y="value", data=data)
ax.set(xlabel='sample', ylabel='counts', title="Genes abundance")
plt.show()

In [0]:
for sample in gene_counts_all.columns:
    ax = sns.distplot(gene_counts_all[sample].dropna(), kde=False, hist=True, label=sample)
ax.set(xlabel='counts', title="Genes abundance")
ax.legend()
plt.show()

In [0]:
for sample in gene_counts_all.columns:
     ax = sns.distplot(gene_counts_all[sample].dropna(), kde=True, kde_kws = {'shade': True, 'linewidth': 3}, hist=True, label=sample)
ax.set(xlabel='counts', title="Genes abundance")
ax.legend()
plt.show()