# DeepMicrobes Prediction Results
Deep microbes was trained using MarRef training set, and the resulting model was tested against 855 microbe test tfrecs. Each test tfrec file has 1000s of double-stranded reads. The model makes a prediction on each read, with a confidence score. The initial DeepMicrobes paper predicted any confidence score > 50% would lead to > 98% accuracy. 

### Before running this script:
* Use a trained model to generate 'result.txt' files for each of 855 microbes in the test set.
* Generate 'summary.txt' files for each 855 file at each cutoff (0 % confidence to 100% confidence) These are the decision threshold precentages.
* Use 'process_shell_scripts_accuracy.ipyn in order to generate a concatenated file at each confidence level (0-100) that has each in microbe column, the predicted microbe in 'prediction' column, and a count of the number of reads for that microbe in the 'reads' column. 

**Import required packages**

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from statistics import mean
import math
import numpy as np
import seaborn as sns
import altair as alt

### Import files
* First, import each threshold concatenated file made by the process_shell_scripts_loss.py. We can use each file in this program to generate an ROC curve. The files are named: 'concat_60.csv' etc.
* Also, can upload the 'all_microbes.csv' file. This file has a list of all tested microbes, useful to properly calculate the TP, TN, FP, FN.
* Finally, can upload the labels_MarRef.csv file. This will provide labels to the prediction, so we can get an idea of what microbes are / aren't being predicted correctly.
#### CNTL + M O  to hide output

In [None]:
from google.colab import files

uploaded = files.upload()

## Calculate the TP, FP, TN, FN for each threshold.
*   TP: genome correctly predicted. When 'prediction' and 'name' column are the same, correctly predicting the species.
*   FP: genome not species A but incorrectly predicted to be A. When species in **prediction** column but something else in the name column.
*   TN: genome correctly predicted not be species A. Everything else; easiest to just subtract TP+FP+FN from the total.
*   FN: genome actually species A but incorrectly predicted as not species A. When species in **name** column but something else in the prediction column


###Calculate the True positive rate and false positive rate
- TPR: TP/(TP + FN) 
- FPR: FP/(FP+TN)

In [None]:
# create a dataframe list of file names
files = []

ROC_curve = pd.DataFrame(columns=['threshold', 'TPR', 'FPR', 'Precision', 'Recall'])
#ROC_curve.threshold = [5,10,15,20,25,30,35,40,45,55,60,65,70,75,80,85,90,95]
ROC_curve.threshold = [0,60]

for fn in uploaded.keys():
  files.append(fn)

microbe_list = pd.read_csv('all_microbes.txt', header=None)

for file in files:
  if file != 'all_microbes.txt':
    file_threshold = pd.read_csv(file)
  #  file_threshold.drop('Unnamed: 0', inplace=True, axis=1)
    threshold = int(file.replace('.csv', '').replace('concat_', ''))
    
    # make list of unique microbes and set four measurements to zero
    # unique_microbes = pd.DataFrame(file_threshold.name.drop_duplicates(), columns=['name'])
    TP, TN, FN, FP, TPR, FPR, Precision, Recall = [], [], [], [], [], [], [], [] # list for TP, TN, FN, FP
    total_counts = file_threshold.reads.sum() # sum of all counts in the file

    #calculate the TP, TN, FN, FP for all microbes in the file
    for name in microbe_list[0]:
        # make a dataframe where 'name' column == microbe name
        microbe = file_threshold[file_threshold.name == name]
        # count up all reads where 'name' == microbe name
        total_counts_microbe = microbe.reads.sum()
        # TP where prediction column == microbe name as well
        true_positive = microbe.reads[microbe.prediction == name].sum()
        TP.append(true_positive)
        # FN where species in name but not in prediction in the microbes file
        false_negative = total_counts_microbe - true_positive
        FN.append(false_negative)
        # FP where microbe is in the prediction column but not in the names column
        # obtain by taking sum of every instance where in the prediction column and subtracting names column
        false_positive = (file_threshold[file_threshold.prediction == name].reads.sum() - true_positive)
        FP.append(false_positive)
        # True negative is the total counts minus all counts where microbe exists
        true_negative = total_counts - (true_positive + false_positive + false_negative)
        TN.append(true_negative)
        # append each calculation for each microbe
        TPR.append(true_positive/(true_positive+false_negative))
        FPR.append(false_positive/(false_positive+true_negative))
        Precision.append(true_positive/(true_positive+false_positive))
        Recall.append(true_positive/(true_positive+false_negative))
    TPR = [0 if math.isnan(x) else x for x in TPR]
    FPR = [0 if math.isnan(x) else x for x in FPR]
    Precision = [0 if math.isnan(x) else x for x in Precision]
    Recall = [0 if math.isnan(x) else x for x in Recall]
    print(mean(TPR))
    
    # Make dataframe with microbe, and precision/recall/tpr/fpr for the > 60 threshold
    if threshold == 0:
      list_for_metrics_0 = pd.DataFrame({'microbe':list(microbe_list[0]),
                                      'TP':TP, 'FP':FP, 'TN':TN, 'FN':FN,
                                      'TPR':TPR,
                                      'FPR':FPR,
                                      'Precision':Precision,
                                      'Recall':Recall})
    if threshold == 60:
      list_for_metrics_60 = pd.DataFrame({'microbe':list(microbe_list[0]),
                                      'TP':TP, 'FP':FP, 'TN':TN, 'FN':FN,
                                      'TPR':TPR,
                                      'FPR':FPR,
                                      'Precision':Precision,
                                      'Recall':Recall})
    # Add calculations for each parameter to dataframe, take mean of the above.
    ROC_curve.TPR[ROC_curve.threshold == threshold] = mean(TPR)
    ROC_curve.FPR[ROC_curve.threshold == threshold] = mean(FPR)
    ROC_curve.Precision[ROC_curve.threshold == threshold] = mean(Precision)
    ROC_curve.Recall[ROC_curve.threshold == threshold] = mean(Recall)
print(ROC_curve)  


0.513737884516994


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


0.912169170731006
   threshold       TPR       FPR Precision    Recall
0          0  0.513738  0.000382  0.611692  0.513738
1         60  0.912169  0.000012  0.993885  0.912169


In [None]:
print(ROC_curve.to_markdown(index=False)) 

|   threshold |      TPR |         FPR |   Precision |   Recall |
|------------:|---------:|------------:|------------:|---------:|
|           0 | 0.50755  | 0.000273352 |       0.52  | 0.50755  |
|          60 | 0.937357 | 1.40042e-05 |       0.952 | 0.937357 |


In [None]:
for_table = ROC_curve
for_table.Precision = ROC_curve.Precision.apply(lambda x: '%.3f' % x)
for_table.TPR = ROC_curve.TPR.apply(lambda x: '%.3f' % x)
for_table.Recall = ROC_curve.Recall.apply(lambda x: '%.3f' % x)

In [None]:
for_table

Unnamed: 0,threshold,TPR,FPR,Precision,Recall
0,0,0.514,0.000382,0.612,0.514
1,60,0.912,1.2e-05,0.994,0.912


In [None]:
list_for_metrics_60.head()

In [None]:
list_for_metrics_0.head()

Save results as .csv file, if needed.

In [None]:
list_for_metrics.head()
list_for_metrics.to_csv('/content/sample_data/threshold_55_metrics.csv')

The proper way to view the precision, recall, TPR, and FPR is by looking at each microbes individual metric. But, in order to view the  initial ROC curve, at each threshold (similar to the DeepMicrobes manuscript) we take an average at each confidence threshold. 
*As we will see later, since the data is somewhat skewed (many very high accuracy, and some near 0 accuracy) taking an average might not be the best approach to judge the performance of the model.*
However, to generate an initial ROC curve, taking an average is ok, **merely to get a feeling of the the overall metrics** at each threshold point. After that, we can pick one threshold to take a deeper dive into the model.

In [None]:
ROC_curve.plot(x='FPR', y='TPR')
plt.title("ROC Curve", fontdict={'fontsize':20})
plt.xlabel('FPR')
plt.ylabel('TPR')

- Precision: TP/(TP+FP)
- Recall: TP/(TP+FN)

In [None]:
plt.plot(ROC_curve.threshold,ROC_curve.Precision)
plt.plot(ROC_curve.threshold,ROC_curve.Recall)

plt.gca().legend(('Precision', 'Recall'))
plt.xlabel('decision threshold')
plt.ylabel('Precision/Recall')
plt.title("Precision vs Recall", fontdict={'fontsize':20})
plt.show

In [None]:
list_for_metrics = list_for_metrics_60

There is a **clear up-tick between 55 and 60** for both precision and recall. After threshold of 55, there is not significant improvement. Therefore for the future analysis, I'll use threshold of 60 while exploring the data.

At first, the precision, recall, TPR and FPR graphs look abnormal. After downloading and validating the data for each microbe, however, the calculations look correct. As suggested before, we need to not average the dataset, and instead we will look at each microbe's metrics for 1 threshold. 
With the threshold = 60, we explore what FPR and TPR look like. Below, each dot is a microbe. 

In [None]:
sns.scatterplot(data=list_for_metrics, x="FPR", y="TPR"))
plt.scatter(x=list_for_metrics.FPR, y=list_for_metrics.TPR, marker='o')
plt.title('3 Epochs Threshold 60')
plt.xlabel('FPR')
plt.xticks(rotation = 45)
plt.ylabel('TPR')

In [None]:
#sns.scatterplot(data=list_for_metrics, x="FPR", y="TPR", palette="deep")
alt.Chart(list_for_metrics).mark_circle(size=60).encode(
    x='FPR',
    y='TPR',
    #color='Origin',
    #tooltip=['Name', 'Origin', 'Horsepower', 'Miles_per_Gallon']
).configure_axis(
    grid=False, labelFontSize=16,
    titleFontSize=16).properties(
    title='2 Epochs Genus (60 Threshold)').configure_title(
    fontSize=20,
)

We also took the log of both, to get a better picture of the spread of false and true positive rates. 

In [None]:
plt.scatter(x=np.log2(list_for_metrics.FPR), y=np.log2(list_for_metrics.TPR), marker='o')
plt.xlabel('log(FPR)')
plt.ylabel('log(TPR)')

It is clear that there are some outliers, with a 0, that the model cannot predict at all. Then, there is a spread of microbes that have quite high true positive rate. See the histograms below:

In [None]:
list_for_metrics.TPR.hist(bins=20)
plt.xlabel('TPR')

In [None]:
list_for_metrics.FPR.hist(bins=20)
plt.xlabel('FPR')

Seems like the false positive rate is low, assume that the model predicts incorrectly those microbes that it cannot get any TP for.

###What are the precision, and recall for when the threshold is > 60? 
Precision is a measure of quality, and high precision means the model is returning more correct than incorrect. High recall, however is a measure of quantity, high recall means the model returns a large number of correct microbes, and doesn't miss microbes.

In [None]:
#plt.scatter(x=list_for_metrics.Precision, y=list_for_metrics.Recall, marker='o')
#plt.xlabel('Precision')
#plt.ylabel('Recall')
#plt.title('3 Epochs Threshold 0 - Species')

alt.Chart(list_for_metrics).mark_circle(size=60).encode(
    x='Precision',
    y='Recall',
    #color='Origin',
    #tooltip=['Name', 'Origin', 'Horsepower', 'Miles_per_Gallon']
).configure_axis(
    grid=False, labelFontSize=16,
    titleFontSize=16).properties(
    title='2 Epochs Genus (60 Threshold)').configure_title(
    fontSize=20,
)

Sometimes, recall and precision are inversely correlated, however that does not seem to be the case here; microbes with high precision also have high recall. It seems the model nearly equivalently on precision and recall, there is almost a linear trend between the two! 

In [None]:
print("Under 75 Precision: ")
print((list_for_metrics.Precision < 0.75).value_counts())
list_for_metrics.Precision.hist(bins=20)
plt.xlabel('Precision')

###It seems the dataset is slightly imbalanced.
There are many microbes predicted very accurately, but some with very low to zero true positive rate. With the threshold > 60 certainty, how many microbes have low TPR?

In [None]:
print("Under 75 TPR: ")
print((list_for_metrics.TPR < 0.75).value_counts())
print("Under 50 TPR: ")
print((list_for_metrics.TPR < 0.5).value_counts())
print("Under 25 TPR: ")
print((list_for_metrics.TPR < 0.25).value_counts())
print("Equals 0 ")
print((list_for_metrics.TPR == 0).value_counts())

There are actually only 24 microbes the model cannot predict less than 50% out of 855! Similarly, how many are read incorrectly? Which have a high false positive rate?

In [None]:
print("Over 1e-6 FPR: ")
print((list_for_metrics.FPR > 1e-6).value_counts())
print("Over 1e-5 FPR: ")
print((list_for_metrics.FPR > 1e-5).value_counts())
print("Over 1e-4 FPR: ")
print((list_for_metrics.FPR > 1e-4).value_counts())
print("Over 1e-3 FPR: ")
print((list_for_metrics.FPR > 1e-3).value_counts())

Overall, with enough reads the model should get to the correct microbe answer. 
* I suspect that the high precision and recall is due to the test set. Each microbe file has 10,000 reads for the model to check. If the model is trained well it is likely to get enough positive 'hits' to get good recall. 
* It is also posible that the training and test sets are quite similar. The near perfect precision could be due to test sets quite similar to the training. Especially if the test set is the same as the training set, it is likely the model has memorized the microbe data, so it leads to very high precision.  
* With a skewed distribution, perhaps metrics like TPR and FPR aren't appropriate. A better measure might be for **balanced accuracy, F measure, and Matthews correlation coefficient.**

The F-measure is a weighted harmonic mean of precision and recall. F = 2(Precision*Recall)/(Precision+Recall)

In [None]:
f_measure = 2*list_for_metrics.Precision*list_for_metrics.Recall/(list_for_metrics.Precision+list_for_metrics.Recall)
print("Under 95 F measure: ")
print((f_measure < 0.95).value_counts())
f_measure[f_measure<0.95].hist(bins=20)
plt.xlabel('F measure')

F-measure show that combining Precision and Recall  actually indicates the model has high levels for both.
However, F-measure can still be biased, argued as flawed often, because it ignores true negatives, so it is possible to bias the predictions of this metric.


#### An 'unbiased F-measure' is the Matthew correlation coefficient, (MCC)
MCC is used in ML as a meausre of quality of binary classifications, similar to Pearson correlation coefficient. It is 'balanced,' usually a good measure if classes are very different sizes. Return value os +1 means perfect prediction if its binary classification. MCC is only high if classicier doing well on both negative and positive elements. 
* MCC = sqrt(TPR x TNR x PPV x NPV) - sqrt(FNR x FPR x FOR x FDR)
* TNR = true negative rate = TN/N = 1-FPR
* PPV = positive predictive value = TP/(TP+FP)
* FDR = false discovery rate = 1-PPV
* NPV = negative predictive value = TN/(TN+FN)
* FOR = false omisson rate = FN/(TN+FN)
* FNR = 1-TPR


Also, accuracy, balanced accuracy, is usually seen as a better less biased metric than precision/recall, though still not at good as MCC. 
* Balanced accuracy = (TPR+TNR)/2

In [None]:
TNR = 1-list_for_metrics.FPR
PPV = list_for_metrics.TP/(list_for_metrics.TP+list_for_metrics.FP)
NPV = list_for_metrics.TN/(list_for_metrics.TN+list_for_metrics.FN)
FOR = list_for_metrics.FN/(list_for_metrics.TN+list_for_metrics.FN)
FDR = 1-PPV
FNR = 1-list_for_metrics.TPR

list_for_metrics['MCC'] = np.sqrt(list_for_metrics.TPR*TNR*PPV*NPV)-np.sqrt(FNR*list_for_metrics.FPR*FOR*FDR)
list_for_metrics['Accuracy'] = (list_for_metrics.TPR+TNR)/2
#balanced_accuracy = (list_for_metrics.TPR+TNR)/2
list_for_metrics = list_for_metrics.fillna(0)

#plt.scatter(x=list_for_metrics.MCC, y=balanced_accuracy, marker='o')
#plt.xlabel('MCC')
#plt.ylabel('Accuracy')

alt.Chart(list_for_metrics).mark_circle(size=60).encode(
    x='MCC',
    y='Accuracy',
    #color='Origin',
    #tooltip=['Name', 'Origin', 'Horsepower', 'Miles_per_Gallon']
).configure_axis(
    grid=False, labelFontSize=16,
    titleFontSize=16).properties(
    title='2 Epochs Genus (60 Threshold)').configure_title(
    fontSize=20,
)

MCC appears to really correlate with accuracy, revealing the dataset is not as skewed as we initially believed. However, there are microbes with lower MCC than accuracy, revealing MCC is a better metric to use.

In [None]:
print("Under 95 MCC: ")
print((list_for_metrics.MCC < 0.95).value_counts())
list_for_metrics.MCC[list_for_metrics.MCC<0.95].hist(bins=20)
plt.xlabel('Matthews Correlation Coefficient')

Clearly, there are some outliers. What are these microbes that the model does not accurately predict? Perhaps there is a reason it is predicted so poorly. Lets see which microbes have low TPR, TNR, and MCC.

In [None]:
from google.colab import files

uploaded = files.upload()

In [None]:
label_list = pd.read_csv('labels_MarRef.csv')
label_list['microbe'] = label_list.MarRef_ID
label_list.head()

In [None]:
list_for_metrics = list_for_metrics.merge(label_list[['microbe', 'Species']], how='left', on='microbe')
list_for_metrics.head()

In [None]:
unidentified = list_for_metrics[list_for_metrics.MCC < 0.5]
unidentified

Unnamed: 0,microbe,TP,FP,TN,FN,TPR,FPR,Precision,Recall,MCC,Accuracy
13,Agarilytica,1,0,17446951,66,0.014925,0.0,1.0,0.014925,0.122169,0.507463
26,Amycolatopsis,40,0,17446838,140,0.222222,0.0,1.0,0.222222,0.471403,0.611111
57,Cognaticolwellia,27,2,17446836,153,0.15,1.14634e-07,0.931034,0.15,0.373703,0.575
90,Granulosicoccus,8,0,17446981,29,0.216216,0.0,1.0,0.216216,0.46499,0.608108
162,Marinobacter_A,38,0,17446865,115,0.248366,0.0,1.0,0.248366,0.498362,0.624183
267,Nocardiopsis,64,2,17446179,773,0.076464,1.146383e-07,0.969697,0.076464,0.272292,0.538232
358,Rhodococcus_B,40,1,17446819,158,0.20202,5.731704e-08,0.97561,0.20202,0.443949,0.60101
387,Altererythrobacter_D,28,7,17446849,134,0.17284,4.012184e-07,0.8,0.17284,0.371847,0.58642
398,Spirillospora,10,0,17445668,1340,0.007407,0.0,1.0,0.007407,0.086063,0.503704
514,Zunongwangia,8,0,17446968,42,0.16,0.0,1.0,0.16,0.4,0.58


In [None]:
# finding all the species first names
series = pd.DataFrame(list_for_metrics.microbe.drop_duplicates().str.split(expand=True))[0]
series.value_counts().transpose()

Vibrio               26
Shewanella           25
Pseudoalteromonas    19
Streptomyces         13
Marinobacter         11
                     ..
KOR42                 1
Jeotgalibacillus      1
Janibacter            1
Izemoplasma_B         1
Zunongwangia          1
Name: 0, Length: 515, dtype: int64

In [None]:
new_series = pd.DataFrame(unidentified.microbe.drop_duplicates().str.split(expand=True))[0]
new_series = pd.DataFrame(new_series.value_counts())
new_series

Unnamed: 0,0
Agarilytica,1
Amycolatopsis,1
Cognaticolwellia,1
Granulosicoccus,1
Marinobacter_A,1
Nocardiopsis,1
Rhodococcus_B,1
Altererythrobacter_D,1
Spirillospora,1
Zunongwangia,1


####What are the classes for these?
* Streptomyces       5 - Actinomycetota
* Micromonospora     4 - Actinomycetota
* Mycobacterium      1 - Actinomycetota
* Marinobacter       1 - Pseudomonadota
* Shewanella         1 - Pseudomonadota
* Staphylothermus    1 - Thermoproteota
* Piscirickettsia    1 - Pseudomonadota
* Alteromonas        1 - Pseudomonadota
* Spirillospora      1 - Actinobacteria
* Bacillus_A         1 - Bacillota

###Quite a lot of these are Actinomycetota Phylum, and Streptomyces...
The failure rate is low! 