In [None]:
import re
import io
import json
import tempfile

from copy import copy
from collections import Counter
from itertools import groupby, product, combinations
from zipfile import ZipFile

import numpy as np
import matplotlib.pyplot as plt

from Bio import SeqIO, AlignIO, codonalign
from Bio.Align.Applications import MuscleCommandline
from Bio.Seq import Seq, MutableSeq

## Introduction

"Big data era", getting data is a challenge.
NCBI datasets making it easier to get data for addressing a question.
Here's a specific example of a problem that can be addressed with big data:

Bio problem:
- **Our question**: can we detect natural selection by comparing ortholog sequences between species?
- Review: the redundant genetic code, synonymous and non-synonymous substitutions (Not frameshift indels)
- The molecular clock: synonymous mutations accumulate at a constant rate [FIGURE]
- Purifying selection: most non-synonymous mutations are harmful and eliminated by natural selection [FIGURE]
- Positive selection: some non-synonymous mutations may improve fitness. These will fix at a faster-than-neutral rate [FIGURE]
- **Idea**: we can compare the rates of syn/non-syn substitutions to look for signals of purifying or positive selection

We will compute the **dN/dS ratio**, that is the ratio of non-synonymous to synonymous substitutions.
- Low dN/dS indicates strong purifying selection (meaning the gene is important and well-adapted).
- Intermediate -> relaxed purifying selection.
- High -> strong positive selection, rapid adaptation
[DIAGRAM]

We will be comparing Drosophila species [PHYLOGENY] with different levels of divergence across a large number of ortholog families to categorize the orthologs by dN/dS.

## Getting the data

In [None]:
datadir = "/data/orthologs_with_cds"

In [None]:
!tree $datadir

## Computing dN/dS from sequences

We'll pick a gene and then read its FASTA and data report files into variables:

We read the fasta files using BioPython's SeqIO module:

In [None]:
def import_fasta(gene_id, datadir):
    dataset = f"{datadir}/{gene_id}.zip"
    fasta_path = "ncbi_dataset/data/cds.fna"
    with ZipFile(dataset) as zip_file:
        with zip_file.open(fasta_path, "r") as fasta_file:
            records = list(SeqIO.parse(io.TextIOWrapper(fasta_file), "fasta"))
    return(records)

In [None]:
gene_id = 12798080
records = import_fasta(gene_id, datadir)

In [None]:
def get_species(record):
    pattern = re.compile(r"\[organism=([A-Za-z\s]+)\]")
    match = re.search(pattern, record.description)
    if match:
        return match.groups()[0]
    else:
        return None

In [None]:
for record in records:
    print(get_species(record))

In [None]:
def longest_record_per_species(records):
    return {
        species: max(recs, key=lambda r: len(r.seq))
        for species, recs in groupby(records, key=get_species)
    }

In [None]:
dna_records = longest_record_per_species(records)

In [None]:
def translate_record(record):
    new_record = copy(record)
    new_record.seq = record.seq.translate()
    return new_record

In [None]:
protein_records = {spec: translate_record(rec) for spec, rec in dna_records.items()}

In [None]:
print(protein_records)

In [None]:
def align_proteins(protein_records):
    muscle_exe = "muscle"
    with tempfile.NamedTemporaryFile(mode="w+t") as f:
        SeqIO.write(protein_records.values(), f, "fasta")
        f.seek(0)
        muscle_cline = MuscleCommandline(muscle_exe, input=f.name)
        print(muscle_cline)
        stdout, stderr = muscle_cline()
    protein_aln = AlignIO.read(StringIO(stdout), "fasta")
    protein_aln.sort()
    return(protein_aln)

In [None]:
protein_aln = align_proteins(protein_records)
print(protein_aln)

In [None]:
codon_aln = codonalign.build(protein_aln, sorted([rec for (_, rec) in dna_records], key=lambda x: x.id))

In [None]:
print(codon_aln)

In [None]:
def number_of_substitutions(alignment) -> float:
    sub_matrix = alignment.substitutions
    return sub_matrix.sum() - sub_matrix.diagonal().sum()

In [None]:
total_subs = number_of_substitutions(codon_aln)
nonsyn_subs = number_of_substitutions(protein_aln)
syn_subs = total_subs - nonsyn_subs
dnds = nonsyn_subs / syn_subs

In [None]:
print(total_subs)
print(nonsyn_subs)
print(syn_subs)
print(dnds)

In [None]:
def get_all_species(gene_id, datadir):
    records = import_fasta(gene_id, datadir)
    dna_records = longest_record_per_species(records)
    return set(dna_records.keys())

In [None]:
files = !ls {datadir}
species = {}
for i, f in enumerate(files):
    gene_id = f.split(".")[0]
    print(i, gene_id)
    species[gene_id] = get_all_species(gene_id, datadir)
    print(species[gene_id])

In [None]:
species_counts = Counter()
for gid, sp_set in species.items():
    for sp in sp_set:
        species_counts[sp] += 1

In [None]:
print(species_counts)

In [None]:
def count_substitutions(gene_id, datadir, species1, species2):
    records = import_fasta(gene_id, datadir)
    longest_records = longest_record_per_species(records)
    if species1 in longest_records and species2 in longest_records:
        dna_records = {
            species1: longest_records[species1],
            species2: longest_records[species2],
        }
    else:
        return None
    protein_records = {spec: translate_record(rec) for spec, rec in dna_records.items()}
    protein_aln = align_proteins(protein_records)
    try:
        codon_aln = codonalign.build(protein_aln,
                                     sorted(dna_records.values(),
                                            key=lambda x: x.id))
    except RuntimeError as e:
        print(e)
        return None
    total_subs = number_of_substitutions(codon_aln)
    nonsyn_subs = number_of_substitutions(protein_aln)
    return total_subs, nonsyn_subs

In [None]:
substitutions = {}
focal_species = "Drosophila melanogaster"
comparison_species = ["Drosophila pseudoobscura", "Drosophila serrata", "Drosophila simulans"]
for comp in comparison_species:
    print(comp)
    substitutions[comp] = {}
    for i, f in enumerate(files):
        gene_id = f.split(".")[0]
        print(i, gene_id)
        subs = count_substitutions(gene_id, datadir, focal_species, comp)
        if subs:
            substitutions[comp][gene_id] = subs

In [None]:
for total_subs, nonsyn_subs in substitutions["Drosophila simulans"].values():
    plt.loglog(nonsyn_subs, total_subs - nonsyn_subs, '.b', alpha=0.25)
    
for total_subs, nonsyn_subs in substitutions["Drosophila pseudoobscura"].values():
    plt.loglog(nonsyn_subs, total_subs - nonsyn_subs, '.y', alpha=0.25)
   
for total_subs, nonsyn_subs in substitutions["Drosophila serrata"].values():
    plt.loglog(nonsyn_subs, total_subs - nonsyn_subs, '.g', alpha=0.25)

plt.loglog([1, 1000], [3,3000], "--k")

In [None]:
sorted(substitutions["Drosophila pseudoobscura"])

In [None]:
for total_subs, nonsyn_subs in substitutions.values():
    plt.loglog(nonsyn_subs, total_subs - nonsyn_subs, '.k', alpha=0.25)

In [None]:
for total_subs, nonsyn_subs in substitutions.values():
    plt.loglog(nonsyn_subs, total_subs - nonsyn_subs, '.k', alpha=0.25)
plt.loglog([1, 1000], [2,2000], "--k")

## TO-DO:
- Repeat for different genes
- Plots of DN vs DS, where each point is a species comparison
- Scatterplot of DN vs DS, where each point is a gene
- The same, but highlighting a set of genes of interest

In [None]:
bases = set(["A", "C", "G", "T"])
for comb in product(bases, repeat=3):
    s = Seq("".join(comb))
    print(s, "->", s.translate())

In [None]:
codons = (Seq("".join(b)) for b in product(bases, repeat=3))
genetic_code = {
    codon: codon.translate()
    for codon in codons
    if codon.translate() != Seq("*")
}

In [None]:
def count_differences(codon1, codon2):
    return sum(b1 != b2 for b1, b2 in zip(codon1, codon2))

nonsyn_counts = Counter()
syn_counts = Counter()
for codon1, codon2 in combinations(genetic_code, 2):
    if count_differences(codon1, codon2) == 1:
        if genetic_code[codon1] == genetic_code[codon2]:
            syn_counts[codon1] += 1
            syn_counts[codon2] += 1
        else:
            nonsyn_counts[codon1] += 1
            nonsyn_counts[codon2] += 1

In [None]:
for codon in genetic_code:
    print(codon, nonsyn_counts[codon], syn_counts[codon])

In [None]:
nonsyn_total = sum(nonsyn_counts.values())
syn_total = sum(syn_counts.values())
print(nonsyn_total / syn_total)

In [None]:
def expected_dnds(seq, nonsyn_counts, syn_counts):
    nonsyn = 0
    syn = 0
    for i in range(0, len(seq), 3):
        codon = seq[i:i+3]
        try:
            nonsyn += nonsyn_counts[codon]
            syn += syn_counts[codon]
        except KeyError:
            return None
    return nonsyn / syn

In [None]:
gene_ids = [f.split(".")[0] for f in files]
focal_species = "Drosophila melanogaster"
expectations = {}
for gene_id in gene_ids:
    records = import_fasta(gene_id, datadir)
    longest_records = longest_record_per_species(records)
    if focal_species not in longest_records:
        continue
    seq = longest_records[focal_species].seq
    expectations[gene_id] = expected_dnds(seq, nonsyn_counts, syn_counts)

In [None]:
plt.hist(expectations.values())

In [None]:
omega = {}
for gene_id, (total_subs, nonsyn_subs) in substitutions["Drosophila pseudoobscura"].items():
    syn_subs = total_subs - nonsyn_subs
    dnds = nonsyn_subs / syn_subs
    expected_dnds = expectations[gene_id]
    omega[gene_id] = dnds / expected_dnds

In [None]:
plt.hist(omega)