In [None]:
import sqlite3
import os
from collections import Counter, defaultdict
import matplotlib.pyplot as plt
from matplotlib import axes
import matplotlib.colors
import numpy as np
from ete3 import NCBITaxa
ncbi = NCBITaxa()
import pandas as pd
import statistics
import random
import scipy.stats as stats
import plotly.graph_objs as go
from plotly.offline import iplot, init_notebook_mode
init_notebook_mode()

# <center>Streamlining CRISPR spacer-based bacterial host predictions to decipher the viral dark matter (Dion et al. 2020)</center>

## Goal: optimize phage host predictions using CRISPR spacers, determine sensitivity and specificity.

## Step 1: Build spacers database

### 1.1 Download all bacterial genomes on NCBI

#### Each bacterial genome has a ftp address which corresponds to the location of the file on the NCBI server. The list of all ftp addresses for all bacterial genomes is found in this summary document in the column ftp_path:

#### ftp.ncbi.nlm.nih.gov/genomes/genbank/bacteria/assembly_summary.txt

#### As of March 23rd 2020, there were 580 384 bacterial genomes deposited on NCBI.

### 1.2 Identify CRISPR loci with CRISPRDetect

#### We used CRISPRDetect to identify CRISPR loci as it is the most sensitive and accurate software. More information on CRISPRDetect is available in the related publication and Github page of the authors:

#### https://pubmed.ncbi.nlm.nih.gov/27184979/

#### https://github.com/ambarishbiswas/CRISPRDetect_2.2

#### We added the option '-array_quality_score_cutoff 3' as recommended by the authors because we used fasta files. We also added the option '-tmp_dir $SLURM_TMPDIR' because CRISPRDetect creates temporary files and does not delete them which increases very quickly the number of files. Three output files are generated by CRISPRDetect and we used the gff file to parse the information. 

### 1.3 Composition and diversity of the spacer database

In [None]:
from SpacersDB import CrisprOpenDB

In [None]:
conn = sqlite3.connect('SpacersDB/CrisprOpenDB.sqlite')
c = conn.cursor()

In [None]:
# number of crispr positive bacteria
c.execute('''select DISTINCT ORG.ASSEMBLY_ACCESSION \
from SPACER_TABLE ST, ORGANISM ORG \
where ST.GENEBANK_ID=ORG.GENEBANK_ID''')
len(c.fetchall())

In [None]:
# number of crispr positive bacteria
c.execute('''select DISTINCT ORG.ASSEMBLY_ACCESSION \
from SPACER_TABLE ST, ORGANISM ORG \
where ST.GENEBANK_ID=ORG.GENEBANK_ID''')
accession = [x[0] for x in c.fetchall()]

In [None]:
# number of spacers
c.execute('''select COUNT(*) from SPACER_TABLE''')
c.fetchall()

In [None]:
# shorter spacer
c.execute('''select MIN(SPACER_LENGTH) from SPACER_TABLE''')
c.fetchall()

In [None]:
#longer spacer
c.execute('''select MAX(SPACER_LENGTH) from SPACER_TABLE''')
c.fetchall()

In [None]:
# number of spacers between 28 and 43 nt
c.execute('''select COUNT(*) from SPACER_TABLE where SPACER_LENGTH BETWEEN 28 AND 43''')
c.fetchall()

In [None]:
11674395/11767782 * 100

In [None]:
db_explorer = CrisprOpenDB.CrisprOpenDB(os.path.join("SpacersDB", "CrisprOpenDB.sqlite"))

df = pd.read_sql_query("select ST.SPACER, ORG.GENUS \
from SPACER_TABLE ST, ORGANISM ORG \
where ST.GENEBANK_ID=ORG.GENEBANK_ID", db_explorer._connection)

In [None]:
spacers_diversity = df[df.GENUS != 'Unknown']

In [None]:
spacers_diversity

In [None]:
total_spacers = Counter(spacers_diversity.GENUS)

In [None]:
total_spacers.most_common(10)

In [None]:
8325687/588364

In [None]:
y_total = ['$\it{Salmonella}$',
          '$\it{Listeria}$',
          '$\it{Escherichia}$',
          '$\it{Clostridioides}$',
          '$\it{Mycobacterium}$',
          '$\it{Campylobacter}$',
          '$\it{Pseudomonas}$',
          '$\it{Klebsiella}$',
          '$\it{Acinetobacter}$',
          '$\it{Streptococcus}$']

In [None]:
# Total number of spacers for Salmonella is 14 times greater than the second most abundant
# genus, so we will change the figure a little bit to have more readable numbers and add
# an axis break manually on Illustrator.
x_total = [800000, 588364, 368069, 271784, 199856, 147421, 121125, 79980, 78313, 58885]

In [None]:
fig, ax = plt.subplots(figsize = (6, 5.5))

y_total_pos = np.arange(len(y_total))

barlist = ax.barh(y_total_pos, x_total, align = 'center')
barlist[0].set_color('#d55e00')
barlist[1].set_color('#949494')
barlist[2].set_color('#949494')
barlist[3].set_color('#949494')
barlist[4].set_color('#949494')
barlist[5].set_color('#cc78bc')
barlist[6].set_color('#0173b2')
barlist[7].set_color('#949494')
barlist[8].set_color('#029e73')
barlist[9].set_color('#de8f05')
ax.set_yticks(y_total_pos)
ax.set_yticklabels(y_total, size = 20)
ax.set_xticklabels(['0', '100', '200', '300', '400', '500','600', '', '8000'], size = 18)

ax.invert_yaxis()
ax.set_xlabel('Total spacer sequences (thousands)', size = 20)

In [None]:
unique_spacers = Counter(spacers_diversity.drop_duplicates(subset=['SPACER','GENUS']).GENUS)

In [None]:
len(spacers_diversity.SPACER.unique())

In [None]:
unique_spacers.most_common(10)

In [None]:
y_unique = ['$\it{Salmonella}$',
          '$\it{Clostridium}$',
          '$\it{Streptomyces}$',
          '$\it{Lactobacillus}$',
          '$\it{Acinetobacter}$',
          '$\it{Streptococcus}$',
          '$\it{Bifidobacterium}$',
          '$\it{Pseudomonas}$',
          '$\it{Campylobacter}$',
          '$\it{Corynebacterium}$']

In [None]:
fig, ax = plt.subplots(figsize = (6, 5.5))

y_unique_pos = np.arange(len(y_unique))
x_unique = [x[1] for x in unique_spacers.most_common(10)]

barlist = ax.barh(y_unique_pos, x_unique, align = 'center')
barlist[0].set_color('#d55e00')
barlist[1].set_color('#949494')
barlist[2].set_color('#949494')
barlist[3].set_color('#949494')
barlist[4].set_color('#029e73')
barlist[5].set_color('#de8f05')
barlist[6].set_color('#949494')
barlist[7].set_color('#0173b2')
barlist[8].set_color('#cc78bc')
barlist[9].set_color('#949494')
ax.set_yticks(y_unique_pos)
ax.set_yticklabels(y_unique, size = 20)
ax.set_xticklabels(['0', '10', '20', '30', '40', '50','60', '70'], size = 18)


ax.invert_yaxis()
ax.set_xlabel('Unique spacer sequences (thousands)', size = 20)

In [None]:
under95_x = []
under95_y = []
above95_x = []
above95_y = []
ratio_unique_total_distribution = []

for genus, unique in unique_spacers.items() :
    ratio = unique/total_spacers[genus]
    ratio_unique_total_distribution.append(ratio)
    if ratio >= 0.95 :
        above95_x.append(total_spacers[genus])
        above95_y.append(unique)
    else :
        under95_x.append(total_spacers[genus])
        under95_y.append(unique)

In [None]:
plt.figure(figsize = (8,9))

plt.scatter(x = under95_x, 
            y = under95_y, 
            alpha = 0.5, s = 12, label = '< 0.95', c = 'gainsboro')

plt.scatter(x = above95_x, 
            y = above95_y, 
            alpha = 0.5, s = 12, label = '≥ 0.95', c = 'mediumvioletred')

legend = plt.legend(title = 'unique/total\n   spacers', fontsize = 22, numpoints = 1, bbox_to_anchor=(0.4,1))
legend.legendHandles[0]._sizes = [80]
legend.legendHandles[1]._sizes = [80]
plt.setp(legend.get_title(),fontsize=22)
plt.xscale('log')
plt.yscale('log')
plt.xlabel('Total number of spacers', size = 22)
plt.ylabel('Number of unique spacers', size = 22)
plt.ylim(1,10000000)
plt.xticks(size = 22)
plt.yticks(size = 22)

In [None]:
min(ratio_unique_total_distribution)

In [None]:
unique_spacers['Salmonella']/total_spacers['Salmonella']

In [None]:
sum(i >= 0.95 for i in ratio_unique_total_distribution)/len(ratio_unique_total_distribution)

In [None]:
plt.figure(figsize = (8,3))
plt.hist(ratio_unique_total_distribution, bins = 20, color = 'k')
plt.xlabel('Ratio unique/total spacers', size = 20)
plt.ylabel('Frequency', size = 20)
plt.xticks(size = 20)
plt.yticks(size = 20)

## Step 2: Evaluate recall and precision using phages with known hosts

### 2.1 Download phage genomes

#### We used the NCBIVirus webtool to obtain the list of all phages and download their genomes.

#### https://www.ncbi.nlm.nih.gov/labs/virus/vssi/#/virus?SeqType_s=Nucleotide

#### In the left column, we selected Virus > Bacteriophages, all taxids and Nucleotide Completeness > complete.  

#### We then separeted the results by Sequence Type (Genbank or RefSeq) because almost all RefSeq are duplicates from Genbank and we wanted to avoid duplicates as much as possible. We downloaded the two resulting tables in csv format, keeping "Accession", "Species", "Genus", "Family", "Length", "Host" and "GenBank_Title". We also downloaded all the genomes in FASTA format.

### 2.2 Clean up dataset

#### There are some phages in this list for which the bacterial host is unknown. This is the case mostly for phages sequenced from metagenomics. These phages will be identified as "organism from which the phage was sequenced" + phage (e.g. bees, lynx, chimpanzees, etc.). These phages need to be removed from our dataset because we need to know the host of each phage that will be used to test our prediction method.

In [None]:
phage_table = pd.read_csv('genbank_20200423.csv')

# For some phages, the column 'host' is empty. For the majority of them, 
# it's because the host is really unknown, but for some others it looks
# like a mistake because in the phage name the host name is clearly there.

phage_table = phage_table.fillna(0)

for index, row in phage_table.iterrows():
    if row['Host'] == 0 :
        deduced_host = row['Species'].split(' ')[0]
        phage_table.loc[index:index, 'Host'] = deduced_host

all_host = phage_table.Host.unique().tolist()
host_taxid = ncbi.get_name_translator(all_host)
bacteria_host = {}

for i in all_host :
    if i not in host_taxid.keys() :
        bacteria_host[i] = 0
    else :
        if 2 in ncbi.get_lineage(host_taxid[i][0]) :
            bacteria_host[i] = 1
        else :
            bacteria_host[i] = 0

for index, row in phage_table.iterrows():
    if bacteria_host[row['Host']] == 0 :
        phage_table = phage_table.drop([index])
        
# Next thing to do is remove "biological" duplicates, i.e. phages that have a different
# Genbank_Title but are the result of laboratory evolution experiments. We need to remove
# these duplicates because they would cause an unjustified overrepresentation of some sequences.
# Most of these "biological" duplicates have either "mutant" or "clone" in their name, so
# we'll simply filter these out.

duplicates = ['mutant', 'Mutant', 'clone','isolate','evolved']

for i in duplicates :
    phage_table = phage_table[~phage_table.GenBank_Title.str.contains(i)]
            
# Last step is to add phages categorized in RefSeq and remove duplicates. We'll do the same
# cleaning up steps then compare their "Genbank_Title" and keep only RefSeq phages that
# are NOT already in our table (this should be a very small number)

refseq_table = pd.read_csv('refseq_20200423.csv')

refseq_table = refseq_table.fillna(0)

for index, row in refseq_table.iterrows():
    if row['Host'] == 0 :
        deduced_host = row['Species'].split(' ')[0]
        refseq_table.loc[index:index, 'Host'] = deduced_host

all_host = refseq_table.Host.unique().tolist()
host_taxid = ncbi.get_name_translator(all_host)
bacteria_host = {}

for i in all_host :
    if i not in host_taxid.keys() :
        bacteria_host[i] = 0
    else :
        if 2 in ncbi.get_lineage(host_taxid[i][0]) :
            bacteria_host[i] = 1
        else :
            bacteria_host[i] = 0

for index, row in refseq_table.iterrows():
    if bacteria_host[row['Host']] == 0 :
        refseq_table = refseq_table.drop([index])
        

title_genbank = phage_table.GenBank_Title.tolist()

for index, row in refseq_table.iterrows():
    if row['GenBank_Title'] in title_genbank :
        refseq_table = refseq_table.drop([index])
        
phage_table = pd.concat([phage_table, refseq_table])
phage_table.set_index('Accession', inplace=True)

### 2.2 Search for homology between phage genomes and the spacers database

### 2.3 Perform predictions

In [None]:
def load_spacer_table():
    db_explorer = CrisprOpenDB.CrisprOpenDB(os.path.join("SpacersDB", "CrisprOpenDB.sqlite"))
    df = pd.read_sql_query("select ST.SPACER_ID, ST.GENEBANK_ID, ORG.GENUS, ORG.FAMILY, ORG.TORDER, ST.SPACER_LENGTH, SAL.COUNT_SPACER, ST.POSITION_INSIDE_LOCUS  \
        from ORGANISM ORG, SPACER_TABLE ST, SPACER_ARRAY_LENGTH SAL \
        where ST.GENEBANK_ID=ORG.GENEBANK_ID and ST.GENEBANK_ID=SAL.GENEBANK_ID and ST.NUMERO_LOCUS=SAL.NUMERO_LOCUS", db_explorer._connection)
    return df

In [None]:
spacer_table = load_spacer_table()

In [None]:
def load_blastn_results(blastn_file):
    columns = ["Query", "SPACER_ID", "identity", "alignement_length", 
    "mismatch", "gap", "q_start", "q_end", "s_start", "s_end", "e_value", "score"]
    blastn_result_table = pd.read_csv(blastn_file, names=columns, sep="\t")
    return blastn_result_table

In [None]:
# blastn alignment results between phage genomes and spacer database
# file available upon request
# enter path to file here
path_to_alignment_result = ''
alignement_results = load_blastn_results(path_to_alignment_result)

In [None]:
summary_table = pd.merge(alignement_results, spacer_table, on="SPACER_ID", how="left")

In [None]:
summary_table

In [None]:
def findHost(summary_table, criteria, limit):
    
    result_table = summary_table.copy()
    result_table = result_table.dropna()
    result_table = result_table[(result_table.gap <= 0) & (result_table.GENUS != 'Unknown')]

    if criteria == 'mismatch' :
        result_table.loc[:,"true_num_mismatch"] = (result_table.loc[:,"SPACER_LENGTH"] - result_table.loc[:,"alignement_length"]) + result_table.loc[:,"mismatch"]
        result_table  = result_table[result_table.true_num_mismatch <= limit]
    
    elif criteria == 'e_value' :
        result_table  = result_table[result_table.e_value <= limit]
        
    else :
        print('wrong criteria')
        return
    
    if result_table.empty:
        return

    result_table["mean_position"] = np.array((result_table["q_start"] + result_table["q_end"]) / 2, dtype=int)


    #Criteria 1: If only one genus, it is the host.

    if len(set(result_table["GENUS"])) == 1:
        prediction = {"PREDICTION": result_table["GENUS"].iloc[0]}
        return(prediction)
    
    else:
        no_pred = {"level_1": result_table["GENUS"].unique().tolist()}


    #Criteria 2: Number of different position on phage genome (use mean(start, end) to check positions) (ex: MF153391)

    genus = np.array(result_table["GENUS"])
    position = np.array(result_table["mean_position"])
    sets_to_count = list(set([(j, position[i]) for i,j in enumerate(genus)]))
    counted_genus = Counter([i[0] for i in sets_to_count])

    most_commons_genus = counted_genus.most_common()

    if most_commons_genus[0][1] != most_commons_genus[1][1]: # If count is not equal, we found host
        prediction = {"PREDICTION": most_commons_genus[0][0]}
        return({**no_pred, **prediction})
    genuses_to_keep = []
    for i, g in enumerate(most_commons_genus):
        if most_commons_genus[i][1] == most_commons_genus[0][1]:
            genuses_to_keep.append(g[0])

    result_table = result_table[result_table["GENUS"].isin(genuses_to_keep)]
    no_pred['level_2'] = genuses_to_keep

    #Criteria 3: If 2 is equal, find relative position most in 5' (MF158036)
    
    five_prime_relative_position = (np.array(result_table["POSITION_INSIDE_LOCUS"], dtype=float) -1) / (np.array(result_table["COUNT_SPACER"]) - 1)
    result_table["five_prime_relative_position"] = five_prime_relative_position
    result_table.sort_values(by="five_prime_relative_position", inplace=True)
    result_table.reset_index(inplace=True, drop=True)

    sub_section = result_table[result_table["five_prime_relative_position"] == result_table["five_prime_relative_position"][0]]

    if len(sub_section["GENUS"].unique().tolist()) == 1:
        prediction = {"PREDICTION": sub_section["GENUS"].iloc[0]}
        return({**no_pred, **prediction})
    else:
        result_table = sub_section
        no_pred['level_3'] = sub_section["GENUS"].unique().tolist()

    #Criteria 4: Last common ancester (does not return a genus)

    if len(set(result_table["FAMILY"])) == 1:
        prediction = {"PREDICTION": result_table["FAMILY"].iloc[0]}
        return({**no_pred, **prediction})
    elif len(set(result_table["TORDER"])) == 1:
        prediction = {"PREDICTION": result_table["TORDER"].iloc[0]}
        return({**no_pred, **prediction})
    else:
        return

In [None]:
def predict(criteria, value) :
    pred_dict = {}
    for phage in phage_table.index:
        df = summary_table[summary_table["Query"] == phage]
        pred = findHost(df, criteria, value)
        if pred != None :
            pred_dict[phage] = pred
    return pred_dict

In [None]:
def find_lca_rank(host, pred_host):
    
    translator = ncbi.get_name_translator([host, pred_host])
    
    lineage_host = ncbi.get_lineage(translator[host][0])
    lineage_pred_host = ncbi.get_lineage(translator[pred_host][0])
    
    lca = [i for i, j in zip(lineage_host, lineage_pred_host) if i == j]
    rank = ncbi.get_rank([lca[-1]])[lca[-1]]
    return(rank)

In [None]:
def specificity_sensitivity(pred_host_dict, total_number_of_phages):
    
    specificity_by_rank = defaultdict(int)

    true_positive = 0
    false_positive = 0
    
    for phage in pred_host_dict.keys():
        prediction = pred_host_dict[phage]['PREDICTION']
        real_host = phage_table.loc[phage]['Host']              
        lca_rank = find_lca_rank(real_host, prediction)
        specificity_by_rank[lca_rank] += 1

    number_of_predictions = len(pred_host_dict)
    sensitivity = number_of_predictions/total_number_of_phages * 100
    
    specificity_by_rank['family'] = specificity_by_rank['family'] + specificity_by_rank['genus']
    specificity_by_rank['order'] = specificity_by_rank['order'] + specificity_by_rank['family']
    
    specificity_by_rank['genus'] = specificity_by_rank['genus'] / number_of_predictions * 100
    specificity_by_rank['family'] = specificity_by_rank['family'] / number_of_predictions * 100
    specificity_by_rank['order'] = specificity_by_rank['order'] / number_of_predictions * 100

    
    return specificity_by_rank, sensitivity

In [None]:
pred_0_mismatch = predict('mismatch',0)
pred_1_mismatch = predict('mismatch',1)
pred_2_mismatch = predict('mismatch',2)
pred_3_mismatch = predict('mismatch',3)
pred_4_mismatch = predict('mismatch',4)
pred_5_mismatch = predict('mismatch',5)
pred_6_mismatch = predict('mismatch',6)
pred_7_mismatch = predict('mismatch',7)
pred_8_mismatch = predict('mismatch',8)
pred_9_mismatch = predict('mismatch',9)
pred_10_mismatch = predict('mismatch',10)

In [None]:
specificity_by_rank_0 = specificity_sensitivity(pred_0_mismatch, 9484)
specificity_by_rank_1 = specificity_sensitivity(pred_1_mismatch, 9484)
specificity_by_rank_2 = specificity_sensitivity(pred_2_mismatch, 9484)
specificity_by_rank_3 = specificity_sensitivity(pred_3_mismatch, 9484)
specificity_by_rank_4 = specificity_sensitivity(pred_4_mismatch, 9484)
specificity_by_rank_5 = specificity_sensitivity(pred_5_mismatch, 9484)
specificity_by_rank_6 = specificity_sensitivity(pred_6_mismatch, 9484)
specificity_by_rank_7 = specificity_sensitivity(pred_7_mismatch, 9484)
specificity_by_rank_8 = specificity_sensitivity(pred_8_mismatch, 9484)
specificity_by_rank_9 = specificity_sensitivity(pred_9_mismatch, 9484)
specificity_by_rank_10 = specificity_sensitivity(pred_10_mismatch, 9484)

In [None]:
plt.style.use('seaborn-deep')
genus = []
family = []
order = []

sensitivity_recap = []

for i in [specificity_by_rank_0, specificity_by_rank_1, specificity_by_rank_2,
         specificity_by_rank_3, specificity_by_rank_4, specificity_by_rank_5,
         specificity_by_rank_6, specificity_by_rank_7, specificity_by_rank_8,
         specificity_by_rank_9, specificity_by_rank_10] :
    genus.append(i[0]['genus'])
    family.append(i[0]['family'])
    order.append(i[0]['order'])

    sensitivity_recap.append(i[1])
    
x_axis = [0,1,2,3,4,5,6,7,8,9,10]

plt.figure(figsize=(6,6))

plt.plot(x_axis, sensitivity_recap, linestyle = '--', color = 'black', marker = 'o')
plt.plot(x_axis, sensitivity_recap, alpha = 0)
plt.plot(x_axis, order, marker='v')
plt.plot(x_axis, family, marker='o')
plt.plot(x_axis, genus, marker='D')

plt.xticks(x_axis, fontsize=20)
plt.xlabel("Maximal number of mismatches", fontsize=20)
plt.yticks(fontsize=20)
plt.ylim([0,100])
plt.ylabel("%", fontsize=20)


In [None]:
pred_000000001_evalue = predict('e_value', 0.000000001)
pred_00000001_evalue = predict('e_value', 0.00000001)
pred_0000001_evalue = predict('e_value', 0.0000001)
pred_000001_evalue = predict('e_value', 0.000001)
pred_00001_evalue = predict('e_value', 0.00001)
pred_0001_evalue = predict('e_value', 0.0001)
pred_001_evalue = predict('e_value', 0.001)
pred_01_evalue = predict('e_value', 0.01)
pred_1_evalue = predict('e_value', 0.1)

In [None]:
specificity_by_rank_evalue_000000001 = specificity_sensitivity(pred_000000001_evalue, 9484)
specificity_by_rank_evalue_00000001 = specificity_sensitivity(pred_00000001_evalue, 9484)
specificity_by_rank_evalue_0000001 = specificity_sensitivity(pred_0000001_evalue, 9484)
specificity_by_rank_evalue_000001 = specificity_sensitivity(pred_000001_evalue, 9484)
specificity_by_rank_evalue_00001 = specificity_sensitivity(pred_00001_evalue, 9484)
specificity_by_rank_evalue_0001 = specificity_sensitivity(pred_0001_evalue, 9484)
specificity_by_rank_evalue_001 = specificity_sensitivity(pred_001_evalue, 9484)
specificity_by_rank_evalue_01 = specificity_sensitivity(pred_01_evalue, 9484)
specificity_by_rank_evalue_1 = specificity_sensitivity(pred_1_evalue, 9484)

In [None]:
plt.style.use('seaborn-deep')
genus = []
family = []
order = []

sensitivity_recap = []

for i in [specificity_by_rank_evalue_000000001, specificity_by_rank_evalue_00000001, specificity_by_rank_evalue_0000001,
         specificity_by_rank_evalue_000001, specificity_by_rank_evalue_00001, specificity_by_rank_evalue_0001,
         specificity_by_rank_evalue_001, specificity_by_rank_evalue_01, specificity_by_rank_evalue_1] :
    genus.append(i[0]['genus'])
    family.append(i[0]['family'])
    order.append(i[0]['order'])

    sensitivity_recap.append(i[1])
    
x_axis = [0.000000001, 0.00000001, 0.0000001, 0.000001, 0.00001,
          0.0001, 0.001, 0.01, 0.1]

plt.figure(figsize=(6,6))
plt.plot(x_axis, sensitivity_recap, linestyle = '--', color = 'black', marker = 'o')
plt.plot(x_axis, sensitivity_recap, alpha = 0)
plt.plot(x_axis, order, marker='v')
plt.plot(x_axis, family, marker='o')
plt.plot(x_axis, genus, marker='D')

plt.xticks(x_axis, fontsize=20)
plt.xscale('log')
plt.xlabel("Maximum e-value", fontsize=20)
plt.yticks(fontsize=20)
plt.ylim([0,100])
plt.legend(["Recall", 'Precision:',"Order","Family","Genus"], bbox_to_anchor=(1.05, 1), prop={'size': 20})


### 2.4 Data exploration

In [None]:
final_pred_2_mismatch = []

for phage in pred_2_mismatch.keys():
    final_pred_2_mismatch.append(pred_2_mismatch[phage]['PREDICTION'])

In [None]:
# Contribution of each filter in identifying one final predicted host

no_filter = []
crit_1 = []
crit_2 = []
crit_3 = []
crit_4 = []

for phage in pred_2_mismatch.keys():
    len_no_filter = len(summary_table[summary_table["Query"] == phage]['GENUS'].unique().tolist())
    if len_no_filter > 10 :
        len_no_filter = 10
    no_filter.append(len_no_filter)

    if len(pred_2_mismatch[phage].keys()) == 1 :
        crit_1.append(1)
        crit_2.append(1)
        crit_3.append(1)
        crit_4.append(1)
    if len(pred_2_mismatch[phage].keys()) == 2 :
        len_1 = len(pred_2_mismatch[phage]['level_1'])
        if len_1 > 10 :
            len_1 = 10
        crit_1.append(len_1)
        crit_2.append(1)
        crit_3.append(1)
        crit_4.append(1)
    if len(pred_2_mismatch[phage].keys()) == 3 :
        len_1 = len(pred_2_mismatch[phage]['level_1'])
        if len_1 > 10 :
            len_1 = 10
        len_2 = len(pred_2_mismatch[phage]['level_2'])
        if len_2 > 10 :
            len_2 = 10
        crit_1.append(len_1)
        crit_2.append(len_2)
        crit_3.append(1)
        crit_4.append(1)
    if len(pred_2_mismatch[phage].keys()) == 4 :
        len_1 = len(pred_2_mismatch[phage]['level_1'])
        if len_1 > 10 :
            len_1 = 10
        len_2 = len(pred_2_mismatch[phage]['level_2'])
        if len_2 > 10 :
            len_2 = 10
        len_3 = len(pred_2_mismatch[phage]['level_3'])
        if len_3 > 10 :
            len_3 = 10
        crit_1.append(len_1)
        crit_2.append(len_2)
        crit_3.append(len_3)
        crit_4.append(1)


In [None]:
crit_2.count(1)/len(pred_2_mismatch)*100

In [None]:
(crit_3.count(1)-crit_2.count(1))/len(pred_2_mismatch)*100

In [None]:
(crit_4.count(1)-crit_3.count(1))/len(pred_2_mismatch)*100

In [None]:
plt.style.use('seaborn-deep')
plt.figure(figsize = (8,6))
plt.hist([no_filter,crit_1,crit_2, crit_3, crit_4], 
         bins = range(1, 11 + 1, 1))
plt.legend(labels = ['No filter',
                     'After 2 mismatches',
                     'After highest number of targets',
                     'After 5\' end neighbouring spacers',
                    'After last common ancestor'],
          prop={'size': 14})
plt.xlabel('Number of predicted hosts (genus level)', size = 20)
plt.ylabel('Number of phages', size = 20)
plt.xticks(np.arange(1+1/2, 11, 1), ["1","2","3","4","5","6","7","8","9","10+"], size = 20)
plt.yticks(size = 20)
#plt.savefig("contributution_per_filter_20200812.png", dpi =300, bbox_inches='tight')

In [None]:
# Is there a difference in the accuracy of prediction depending on the filter used to identify the predicted host?
accuracy_by_filter = defaultdict(list)

for phage, pred in pred_2_mismatch.items() :
    level = len(pred)
    pred_host = pred_2_mismatch[phage]['PREDICTION']
    real_host = phage_table.loc[phage]['Host']
    rank = find_lca_rank(real_host, pred_host)
    if rank == 'genus' :
        accuracy_by_filter[level].append('Y')
    else :
        accuracy_by_filter[level].append('N')


In [None]:
accuracy_by_filter[1].count('Y')/len(accuracy_by_filter[1])*100

In [None]:
accuracy_by_filter[2].count('Y')/len(accuracy_by_filter[2])*100

In [None]:
accuracy_by_filter[3].count('Y')/len(accuracy_by_filter[3])*100

In [None]:
plt.style.use('seaborn-deep')
plt.figure(figsize=(8,6))
N = len(accuracy_by_filter)
yes = [accuracy_by_filter[1].count('Y'), 
       accuracy_by_filter[2].count('Y'), 
       accuracy_by_filter[3].count('Y'), 
       accuracy_by_filter[4].count('Y')]
no = [accuracy_by_filter[1].count('N'), 
      accuracy_by_filter[2].count('N'), 
      accuracy_by_filter[3].count('N'), 
      accuracy_by_filter[4].count('N')]
ind = np.arange(N)
width = 0.35

p1 = plt.bar(ind, yes, width, color = 'mediumseagreen')
p2 = plt.bar(ind, no, width, bottom=yes, color = 'indianred')

plt.ylabel('Number of phages', size = 20)
plt.yticks(size = 20)
plt.xticks(ind, ('Filter 1', 'Filter 2', 'Filter 3', 'Filter 4'), size = 20)
legend = plt.legend(title = 'Prediction', labels = ['Accurate', 'Wrong'], prop={'size': 14})
legend.get_title().set_fontsize('14')

In [None]:
pred_accuracy = defaultdict(list)
for phage in pred_2_mismatch.keys():
        prediction = pred_2_mismatch[phage]['PREDICTION']
        real_host = phage_table.loc[phage]['Host'].split(' ')[0]      
        lca_rank = find_lca_rank(real_host, prediction)
        if lca_rank in ['genus','family'] :
            if lca_rank == 'genus' :
                accuracy_genus = 'Y'
                accuracy_family = 'Y'
            else :
                accuracy_genus = 'N'
                accuracy_family = 'Y'
        else :
            accuracy_genus = 'N'
            accuracy_family = 'N'
            
        pred_accuracy[phage].extend((prediction, real_host, accuracy_genus, accuracy_family))

In [None]:
pred_accuracy_table = pd.DataFrame.from_dict(pred_accuracy, orient='index', columns = ['prediction','real_host','accuracy_genus','accuracy_family'])

In [None]:
pred_accuracy_table = pred_accuracy_table.merge(phage_table[['Length','Family']], right_index=True, left_index=True, how='left')

In [None]:
pred_accuracy_table = pred_accuracy_table[pred_accuracy_table.Family != 0]

In [None]:
pred_accuracy_table

In [None]:
len(set(phage_table.Host))

In [None]:
len(set(pred_accuracy_table.real_host))

In [None]:
accuracy_by_real_host = defaultdict(list)
host_total = Counter([x.split(' ')[0] for x in phage_table.Host.tolist()])

for i in host_total.keys() :
    number_of_pred = len(pred_accuracy_table[pred_accuracy_table.real_host == i].accuracy_genus)
    sensitivity = round(number_of_pred / host_total[i] * 100,2)
    
    if sensitivity == 0 :
        specificity_genus = 0
        specificity_family = 0
    
    else :  
        try :
            yes_genus = pred_accuracy_table[pred_accuracy_table.real_host == i].accuracy_genus.tolist().count('Y')
            specificity_genus = round(yes_genus/number_of_pred*100,2)
        except :
            specificity_genus = 0

        try :
            yes_family = pred_accuracy_table[pred_accuracy_table.real_host == i].accuracy_family.tolist().count('Y')
            specificity_family = round(yes_family/number_of_pred*100,2)
        except :
            yes_family = 0
    
    accuracy_by_real_host[i].extend((number_of_pred,
                                     host_total[i],
                                     sensitivity,
                                     specificity_genus, 
                                     specificity_family))
    

In [None]:
len(accuracy_by_real_host)

In [None]:
accuracy_by_host_table = pd.DataFrame.from_dict(accuracy_by_real_host, orient = 'index', 
                       columns=['number_pred','number_phages', 'sensitivity','accuracy_genus','accuracy_family'])

In [None]:
number_unique_spacers = []

for index, row in accuracy_by_host_table.iterrows() :
    host = index
    number_unique_spacers.append(unique_spacers[host])
    
accuracy_by_host_table.loc[:,'number_unique_spacers'] = number_unique_spacers

In [None]:
plt.figure(figsize=(4,3))
plt.axvline(x = 50, color = 'grey', linewidth=1)
plt.axhline(y = 50, color = 'grey',linewidth=1)

plt.scatter(x = accuracy_by_host_table[(accuracy_by_host_table.number_phages > 10) ].sensitivity,
            y = accuracy_by_host_table[(accuracy_by_host_table.number_phages > 10)].accuracy_genus,
           c = accuracy_by_host_table[(accuracy_by_host_table.number_phages > 10) ].number_unique_spacers,
           norm = matplotlib.colors.LogNorm(),
           s = 10)
plt.colorbar(label = 'Number of host unique spacers')
plt.xlabel('Recall (%)', size = 10)
plt.ylabel('Precision (%)', size = 10)
plt.xticks(size = 10)
plt.yticks(size = 10)


plt.annotate('3', (95,55), size = 10, color = 'k')
plt.annotate('2', (1,55), size = 10, color = 'k')
plt.annotate('1', (1,40), size = 10, color = 'k')
plt.annotate('4', (95,40), size = 10, color = 'k')

plt.annotate('$\it{Synechococcus}$\n    $\it{Prochlorococcus}$', (-4,12), size = 5, color = 'k')
plt.plot([0,-2],[2,14], color='k', linewidth=0.5)
plt.plot([0,2],[2,10], color='k', linewidth=0.5)

plt.annotate('$\it{Cellulophaga}$', (17,5), size = 5, color = 'k')
plt.plot([19.5,19.5],[2,4], color='k', linewidth=0.5)

plt.annotate('$\it{Clostridioides}$', (78,92), size = 5, color = 'k')
plt.annotate('$\it{Prevotella}$', (90,88.5), size = 5, color = 'k')
plt.plot([100,102],[98,91], color='k', linewidth=0.5)
plt.plot([100,98],[98,94.5], color='k', linewidth=0.5)

plt.annotate('$\it{Enterococcus}$', (62,98), size = 5, color = 'k')
plt.plot([92.5,84],[97.5,98.5], color='k', linewidth=0.5)


#plt.savefig("sensitivity_specificity_by_host_20201007.png", dpi =300, bbox_inches='tight')

In [None]:
stats.spearmanr(accuracy_by_host_table[accuracy_by_host_table.number_phages > 10].sensitivity,
            accuracy_by_host_table[accuracy_by_host_table.number_phages > 10].accuracy_genus)

In [None]:
stats.spearmanr(accuracy_by_host_table[accuracy_by_host_table.number_phages > 10].sensitivity,
            accuracy_by_host_table[accuracy_by_host_table.number_phages > 10].number_unique_spacers)

In [None]:
stats.spearmanr(accuracy_by_host_table[accuracy_by_host_table.number_phages > 10].accuracy_genus,
            accuracy_by_host_table[accuracy_by_host_table.number_phages > 10].number_unique_spacers)

In [None]:
plt.style.use('seaborn-deep')
N = 6
sensitivity = [accuracy_by_real_host['Synechococcus'][2]+0.3, 
               accuracy_by_real_host['Prochlorococcus'][2]+0.3, 
               accuracy_by_real_host['Cellulophaga'][2],
               accuracy_by_real_host['Clostridioides'][2], 
               accuracy_by_real_host['Enterococcus'][2], 
               accuracy_by_real_host['Prevotella'][2]]

specificity = [accuracy_by_real_host['Synechococcus'][3]+0.3, 
               accuracy_by_real_host['Prochlorococcus'][3]+0.3, 
               accuracy_by_real_host['Cellulophaga'][3]+0.3,
               accuracy_by_real_host['Clostridioides'][3], 
               accuracy_by_real_host['Enterococcus'][3], 
               accuracy_by_real_host['Prevotella'][3]]

ind = np.arange(N)
width = 0.35

plt.figure(figsize= (8,6))
plt.bar(ind, sensitivity, width, color = 'palevioletred')
plt.bar(ind+width, specificity, width, color = 'cornflowerblue')

plt.xticks(ind,('$\it{Synechococcus}$', '$\it{Prochlorococcus}$', '$\it{Cellulophaga}$',
                '$\it{Clostridioides}$', '$\it{Enterococcus}$', '$\it{Prevotella}$'),
           rotation = 20, size = 16)

plt.yticks(size = 20)
plt.ylim(0,119)

legend = plt.legend(labels = ['Recall', 'Precision'], prop={'size': 14},
                   loc='upper left', bbox_to_anchor=(0.01,0.9))
legend.get_title().set_fontsize('14')

plt.axvline(x = 2.7, color = 'grey')
plt.annotate('Category 1', (0.5, 110), size = 16, color = 'grey')
plt.annotate('Category 3', (3.5, 110), size = 16, color = 'grey')
plt.ylabel('%', size = 20)

#plt.savefig("category_1_and_3_20201009.png", dpi =300, bbox_inches='tight')

### 2.5 Statistical analyses

#### We will determine if our prediction approach is better than picking a random hit from the alignment result.

In [None]:
phage_all_hit_dict = {}
result_table = summary_table.dropna()

for phage in set(np.array(result_table["Query"])) :
    phage_all_hit_dict[phage] = result_table[result_table["Query"] == phage]['GENUS'].tolist()

In [None]:
# Statistical analysis: null model to verify if prediction approach better than random pick among hits

random.seed(12345)
null_prediction_dict = defaultdict(list)
for i in range(1,1001,1):
    for phage, hit in phage_all_hit_dict.items():
        random_host = random.choice(hit)
        real_host = phage_table.loc[phage]['Host']
        rank = find_lca_rank(real_host, random_host)
        if rank == 'genus' :
            null_prediction_dict[i].append('Y')
        else :
            null_prediction_dict[i].append('N')

In [None]:
success_rate_null_model = {}
for i, success_list in null_prediction_dict.items():
    yes = success_list.count('Y')
    success_rate_null_model[i] = yes/9884*100

In [None]:
max(success_rate_null_model.values())

In [None]:
plt.figure(figsize= (8,6))
plt.hist(success_rate_null_model.values())
plt.xlabel('Precision (%)', size = 16)
plt.ylabel('Frequency', size = 16)
plt.xticks(size = 16)
plt.yticks(size = 16)

## Step 3: Compare performance with other methods

### 3.1 Select existing software

In [None]:
# WIsH phage list
# https://github.com/soedinglab/WIsH/blob/master/benchmark/PhageSequences.lst

### 3.2 Confirm the phage list used for the two other methods are present in our dataset

#### WIsH used only phages from RefSeq. In the previous step (2.1) of cleaning up the dataset, we removed most of the RefSeq phage genomes as they are duplicates of Genbank phage genomes. We return to the RefSeq table to retrieve the name for each accession number and search for that name in the final phage dataset.

In [None]:
from Bio import Entrez

def NCBIEntrez(accession):
    Entrez.email = "moira.dion.1@ulaval.ca"
    handle = Entrez.efetch(db="nucleotide", id=accession, rettype="gb", retmode="text")
    for line in handle:
        if line.startswith('  ORGANISM'):
            line = line.strip().split()
            species = ' '.join(line[1:])
            return species

In [None]:
refseq_table = pd.read_csv('refseq_20200423.csv')
refseq_table.set_index('Accession', inplace=True)
refseq_accession = refseq_table.index.tolist()

In [None]:
wish_dataset = []
path_to_wish_list = ''
with open(path_to_wish_list,'r') as f:
    for line in f:
        wish_dataset.append(line.strip())

In [None]:
for i in wish_dataset :
    if i not in refseq_accession and i not in refseq_archaeal_virus_accession :
        print(i + '\t' + NCBIEntrez(i))

In [None]:
# The following lines were manually verified

# NC_003525 is Stx2 converting phage I and host is unknown (removed)
# NC_005808 refers to AY526908
# NC_005809 refers to AY526909
# NC_005949 (Vibrio phage Vf12) no longer exists, but shares exactly the same sequence
## with NC_005948 (Vibrio phage Vf33). They were deposited at the same time from the
## same authors so we will use Vf33 as substitute.
# NC_007457 refers to DQ222851
# NC_008355 refers to NC_002194
# NC_010762 refers to EU568876 
# NC_010763 refers to EU676000
# NC_011105 refers to EU056923
# NC_011284 refers to FJ174694
# NC_011286 refers to FJ168660
# NC_020836 (Synechococcus phage KBS-M-1A) no longer exists. We tried finding a similar cyanophage 
## under a different name but we did not find anything. No other phage in our database has
## the same genome size (removed)
# NC_021296 refers to KC691254
# NC_021303 refers to KC661274
# NC_021778 refers to JN255163
# NC_021859 refers to KC787108
# NC_022326 refers to KF279416
# NC_022327 refers to KF279413
# NC_023596 refers to KF981875
# NC_023718 refers to JN175269
# NC_027338 refers to KJ578763
# NC_027343 refers to KJ578775
# NC_027980 refers to JQ267518

In [None]:
wish_dataset_clean = []

for i in wish_dataset :
    if i in refseq_accession or i in refseq_archaeal_virus_accession :
        wish_dataset_clean.append(i)
        
wish_dataset_clean.extend(['AY526908','AY526909','NC_005948','DQ222851','NC_002194',
                          'EU568876','EU676000','EU056923','FJ174694','FJ168660',
                          'KC691254','KC661274','JN255163','KC787108','KF279416',
                          'KF279413','KF981875','JN175269','KJ578763','KJ578775','JQ267518'])

print(len(wish_dataset), len(wish_dataset_clean))

In [None]:
def predict_wish(criteria, value) :
    pred_dict = {}
    for query in wish_dataset_clean:
        df = summary_table[summary_table["Query"] == query]
        pred = findHost(df, criteria, value)
        if pred != None :
            pred_dict[query] = pred
    return pred_dict

In [None]:
wish_pred_2_mismatch = predict_wish('mismatch',2)

In [None]:
wish_performance = specificity_sensitivity(wish_pred_2_mismatch, len(wish_dataset_clean))

In [None]:
len(wish_pred_2_mismatch)

In [None]:
wish_performance

## Step 4: Proof of concept on already published virome

### We are using the virome dataset by Shkoporov et al (2019) Cell

In [None]:
# Downoladed here: 
# https://figshare.com/articles/The_human_gut_virome_is_highly_diverse_stable_and_individual-specific_/9248864/1

# Retrieved from this folder: Dataset_S1/selected_contigs.fa
# For a total of 57 721 contigs

In [None]:
# blast alignment results between shkoporov virome and spacer database
# file available on request
# enter path to file here
path_to_shkoporov_alignment = ''
shkoporov_results = load_blastn_results(path_to_shkoporov_alignment)

In [None]:
shkoporov_summary_table = pd.merge(shkoporov_results, spacer_table, on="SPACER_ID", how="left")

In [None]:
shkoporov_summary_table

In [None]:
shkoporov_vibrant_list = []
with open('shkoporov_virome.phages_combined.txt','r') as f:
    for line in f:
        line = line.strip()
        shkoporov_vibrant_list.append(line)

In [None]:
def predict_shkoporov(criteria, value) :
    pred_dict = {}
    for phage in shkoporov_vibrant_list:
        df = shkoporov_summary_table[shkoporov_summary_table["Query"] == phage]
        pred = findHost(df, criteria, value)
        if pred != None :
            pred_dict[phage] = pred
    return pred_dict

In [None]:
shkoporov_pred_2_mismatch = predict_shkoporov('mismatch', 2)

In [None]:
shkoporov_final_pred_2_mismatch = {}

for phage, pred in shkoporov_pred_2_mismatch.items():
    shkoporov_final_pred_2_mismatch[phage] = pred['PREDICTION']


In [None]:
# Supplementary table S2 from Shkoporov et al 2019 paper
# enter path to file here
path_to_table_S2 = ''
virome_description = pd.read_excel(path_to_table_S2, skiprows=2)

In [None]:
virome_description

In [None]:
virome_for_comparison = virome_description[(virome_description['Contig name'].isin(shkoporov_final_pred_2_mismatch)) &
                  (virome_description['CRISPR-based host'].notna())]

In [None]:
compare_shkop_spacer_pred = defaultdict(list)

for index, row in virome_for_comparison.iterrows() :
    contig = row['Contig name']
    
    shkop_pred = row['CRISPR-based host']
    compare_shkop_spacer_pred[contig].append(shkop_pred)
    
    spacer_pred = shkoporov_final_pred_2_mismatch[contig]
    compare_shkop_spacer_pred[contig].append(spacer_pred)
    
    try :
        lca = find_lca_rank(shkop_pred, spacer_pred)
    except :
        lca = 'unknown'
    
    compare_shkop_spacer_pred[contig].append(lca)

In [None]:
Counter([x[2] for x in compare_shkop_spacer_pred.values()])

In [None]:
len(shkoporov_vibrant_list)

In [None]:
len(shkoporov_final_pred_2_mismatch)

In [None]:
len(compare_shkop_spacer_pred)

In [None]:
virome_description[(virome_description['Contig name'].isin(shkoporov_vibrant_list)) &
                  (virome_description['CRISPR-based host'].notna())]

In [None]:
len(shkoporov_final_pred_2_mismatch)/len(shkoporov_vibrant_list)*100

In [None]:
1392/2

In [None]:
160+109+105+75+65+64+60+50

In [None]:
top_eight = Counter(shkoporov_final_pred_2_mismatch.values()).most_common(8)

In [None]:
1392-688

In [None]:
labels = ['<i>' + i[0] for i in top_eight]
labels.append('Other')
values = [i[1] for i in top_eight]
values.append(704)
trace_pie_org_names = go.Pie(labels=labels, values=values, sort = False)
layout = go.Layout(title="Most common hosts species")
fig = go.Figure(data=[trace_pie_org_names])
fig.update_layout(legend=go.layout.Legend(x = 1, y = 1, font=dict(size=20)))
fig.update_traces(textinfo='value', textfont_size=20)
iplot(fig)

In [None]:
clostridium_contigs = []
for phage, pred in shkoporov_final_pred_2_mismatch.items() :
    if pred == 'Clostridium' :
        clostridium_contigs.append(phage)

In [None]:
clostridium_table = shkoporov_summary_table[shkoporov_summary_table['Query'].isin(clostridium_contigs)]

In [None]:
clostridium_table.loc[:,'mean_position'] = np.array((clostridium_table.loc[:,"q_start"] + clostridium_table.loc[:,"q_end"]) / 2, dtype=int)

In [None]:
clostridium_table.loc[:,"true_num_mismatch"] = (clostridium_table.loc[:,"SPACER_LENGTH"] - clostridium_table.loc[:,"alignement_length"]) + clostridium_table.loc[:,"mismatch"]

In [None]:
clostridium_targets = {}

for i in clostridium_contigs :
    clostridium_targets[i] = Counter(clostridium_table[(clostridium_table['Query'] == i) &
                  (clostridium_table['true_num_mismatch'] <= 2) & 
                  (clostridium_table['GENUS'] == 'Clostridium')]['mean_position'].tolist()).most_common()

In [None]:
def plot_targets(vOTU, genome_size):

    x_coord = []
    y_coord = []
    for target in clostridium_targets[vOTU] :
        i = 1
        while i <= target[1] :
            x_coord.append(target[0])
            y_coord.append(0)
            i += 1
    width = round(genome_size * 0.0009)
    size = 20
    plt.figure(figsize=(width, 1))
    plt.scatter(x_coord, y_coord, s = 100, color = '#00CC96')
    plt.xlim(0,genome_size+500)
    plt.xticks(size = size)
    plt.ylim(-0.1,0.5)
    plt.yticks([])
    label = 'Position on putative phage contig (bp)'
    plt.xlabel(label, size = size)

In [None]:
plot_targets('metaspades_NG-13376_921T3_lib202033_5478_NODE_37_length_22057_cov_6.42255',22057)