In [1]:
from Bio import SeqIO
import json
from collections import Counter, defaultdict
import subprocess # to run blastn
from sklearn.metrics import classification_report # for eval



In [2]:
# Class dictionairies
with open("train_classes.json") as f:
    train_classes = json.load(f)

with open("test_classes.json") as f:
    test_classes = json.load(f)

In [3]:
# overall majority class (use if there's no hits at all)
overall_majority = Counter(list(train_classes.values())).most_common(1)[0][0]

print('Most common class in train set: ', overall_majority)

Most common class in train set:  NRP


In [4]:
def run_blastn(test_fasta, db, output_file, k = 10):
    # This runs blastn and outputs test ID, and k similar train IDs
    subprocess.run([
        "./ncbi-blast-2.16.0+/bin/blastn",
        "-query", test_fasta,
        "-db", db,
        "-out", output_file,
        # Needed for a tsv output, we only need these fields for our algo
        "-outfmt", "6 qseqid sseqid pident length evalue bitscore",
        "-max_target_seqs", str(k) # blastn will return at most k
    ], check=True)
    
def read_blast_output(blast_output):
    test_to_NN = defaultdict(list) # test instance -> k nearest neighbors
    with open(blast_output) as f:
        for line in f: 
            # train id is the blast hits
            test_id, train_id, *_ = line.strip().split("\t")
            test_to_NN[test_id].append(train_id)
    return test_to_NN

# This is our blast kNN implementation
def predict_class(test_to_NN, train_classes, k = 10):
    preds = {}
    for test_id, train_ids in test_to_NN.items():
        train_ids = train_ids[:k] # top k ids only
        # the classes for the top k ids
        k_classes = [train_classes.get(train_id) for train_id in train_ids if train_classes.get(train_id) is not None]
        
        if k_classes: # find majority
            majority = Counter(k_classes).most_common(1)[0][0]
        else: # edge case: no hits 
            majority = overall_majority
            
        preds[test_id] = majority
        
    return preds  

In [5]:
# Returns metrics to compare with DGEB
def eval_metrics(preds, test_classes):
    y_pred = []
    y_true = []
    
    for test_id, true_class in test_classes.items():
        y_pred.append(preds.get(test_id, overall_majority))
        y_true.append(true_class)
        
    # sklearn eval metrics
    report = classification_report(y_true, y_pred, output_dict = False)
    print(report)
    
    return classification_report(y_true, y_pred, output_dict = True) 

In [6]:
### Running the algorithm:
k = 9 # hyperparam!
blast_output = "blast_results.tsv"
test_fasta = "mibig_test.fasta"
db = "mibig_train_db" # this is our local blastdb

# Runs locally using BLAST+
run_blastn(test_fasta, db, blast_output, k)

test_to_NN = read_blast_output(blast_output)
preds = predict_class(test_to_NN, train_classes, k)
report = eval_metrics(preds, test_classes)

              precision    recall  f1-score   support

    Alkaloid       0.33      0.08      0.13        12
         NRP       0.60      0.88      0.71       163
  Polyketide       0.81      0.75      0.78       138
        RiPP       1.00      0.46      0.63        67
  Saccharide       1.00      0.79      0.88        28
     Terpene       0.73      0.33      0.46        33

    accuracy                           0.71       441
   macro avg       0.75      0.55      0.60       441
weighted avg       0.75      0.71      0.70       441

