# Spec2Tax: Predicting Sample Taxonomy from MS/MS Spectra

##  Motivation


Public mass spectral repositories contain a vast (about 1.2B) collection of mass spectra. Many of these spectra have attached metadata. A classifier trained to predict taxonomy of samples by the MS2 data could be applied to untargeted pharmacokinetics experiment to quickly discern between animal vs plant compounds.  

## Imports
This notebook uses the metabolomics repo which is still in the works, but will hopefully be available as a package. https://github.com/enveda/sg-pipeline-dev

In [44]:
import sys

from pathlib import Path
import pandas as pd
from collections import defaultdict

import numpy as np
from collections import Counter
from tqdm import tqdm
import pickle5 as pickle
import awswrangler as wr
import os   
import boto3
from io import BytesIO

from sklearn import linear_model, model_selection
from sklearn.metrics import roc_auc_score, average_precision_score, auc, f1_score, recall_score
from sklearn.model_selection import StratifiedKFold
from sklearn.linear_model import MultiTaskElasticNetCV
import logging
from typing import List, Optional, Tuple, Union

logger = logging.getLogger(__name__)

#path to spec2vec embeddings of data collected from GNPS
local_spec2vec = 'local_spec2vec'

#setup a directory to hold the subsampled datasets
local_results_dir = "spec2tax"
if not os.path.exists(local_results_dir):
    os.mkdir(local_results_dir)


# Sampling data
Since there's too much data to hold in memory, we'll need to read from the locally saved pickles of spec2vec. We'll pull out the spec2vec embeddings and the taxonomic labels into two variables. 

In [None]:
tax_rank = 'class'
def get_taxonomic_ranked_dataset(str:rank='class', str:dir=local_spec2vec):
    labels = []
    spec2vec_data = []

    for file in tqdm(os.listdir(dir)):
        df = pd.read_parquet(f"{dir}/{file}")
        df = df[df[rank] != 'undefined']
        labels.extend(list(df[rank]))
        spec2vec_data.extend(list(df['spec2vec']))
    return labels, spec2vec_data

class_labels, class_data = get_taxonomic_ranked_dataset(rank='class',dir=local_spec2vec)
family_labels, family_data = get_taxonomic_ranked_dataset(rank='family',dir=local_spec2vec)

Now that we have data and labels at the class and family taxonomic levels, we can prepare datasets from them. 

There's clearly class imbalance in the data. For the task of classifying at the class-level in the taxonomy,
lets choose the five most abundant classes: Coscinodiscophyceae (diatoms), Insecta (insects), Mammalia (mammals), Actinomycetia (bacteria), and Anthozoa(marine invertebrates). 
We'll leave out the Clitellata, Eurotiomycetes, Gammaproteobactria,and Magnoliopsida. 

In [None]:
new_X, new_y = [],[]
for x,label in zip(class_data,class_labels):
    if label in ['Clitellata','Eurotiomycetes','Gammaproteobacteria','Magnoliopsida']:
        continue
    new_X.append(x)
    new_y.append(label)

In [None]:
#filter "None"s out of the data
class_data, class_labels = zip(*[(embd,label) for embd,label in zip(new_X,new_y) if embd is not None])


In [None]:
def make_balanced_sample(labels, data, int:cap=5000):
    import random

    all_class_labels = []
    all_class_data = []
    for tmp_label in set(labels):
        
        tmp_data = []
        for label,data in zip(labels,data):
            if(label == tmp_label and len(tmp_data) < cap):
                tmp_data.append(data)
            else:
                continue
        all_class_labels += [tmp_label]*cap
        all_class_data += tmp_data
    return all_class_labels,all_class_data
            
balanced_class_labels,balanced_class_data = make_balanced_sample(list(class_labels),list(class_data))

label_to_int_map = {result[0]:result[1] for result in zip(np.unique(new_y),range(0,len(np.unique(new_y))))}
y_with_ints = [label_to_int_map[label] for label in new_y]

y = np.array(balanced_class_labels)
X = np.array(balanced_class_data)

if(not os.path.exists(f"{local_results_dir}/{tax_rank}/")):
    os.mkdir(f"{local_results_dir}/{tax_rank}/") 
pickle.dump(X,open(f"spec2tax/{tax_rank}/X.pkl",'wb'))
pickle.dump(y_with_ints,open(f"{local_results_dir}/{tax_rank}/y.pkl",'wb'))

Preparing the mammalia, gammaproteobacteria, and magnolipsida datasets

In [None]:
tax_rank = 'Mammalia'  
new_X, new_y = [],[]
for x,label in zip(family_data,family_labels):
    if label in ['Hominidae','Muridae','Rhinocerotidae']:
        new_X.append(x)
        new_y.append(label)
        
#filter "None"s out of the data
family_data, family_labels = zip(*[(embd,label) for embd,label in zip(new_X,new_y) if embd is not None])

balanced_family_labels,balanced_family_data = make_balanced_sample(list(family_labels),list(family_data))

label_to_int_map = {result[0]:result[1] for result in zip(np.unique(new_y),range(0,len(np.unique(new_y))))}
y_with_ints = [label_to_int_map[label] for label in new_y]

y = np.array(balanced_class_labels)
X = np.array(balanced_class_data)

if(not os.path.exists(f"{local_results_dir}/{tax_rank}/")):
    os.mkdir(f"{local_results_dir}/{tax_rank}/") 
pickle.dump(X,open(f"{local_results_dir}/{tax_rank}/X.pkl",'wb'))
pickle.dump(y_with_ints,open(f"{local_results_dir}/{tax_rank}/y.pkl",'wb'))

In [None]:
tax_rank = 'Gammaproteobacteria'  
new_X, new_y = [],[]
for x,label in zip(family_data,family_labels):
    if label in ['Enterobacteriaceae','Morganellaceae']:
        new_X.append(x)
        new_y.append(label)
        
#filter "None"s out of the data
family_data, family_labels = zip(*[(embd,label) for embd,label in zip(new_X,new_y) if embd is not None])

balanced_family_labels,balanced_family_data = make_balanced_sample(list(family_labels),list(family_data),cap=1000)

label_to_int_map = {result[0]:result[1] for result in zip(np.unique(new_y),range(0,len(np.unique(new_y))))}
y_with_ints = [label_to_int_map[label] for label in new_y]

y = np.array(balanced_class_labels)
X = np.array(balanced_class_data)

if(not os.path.exists(f"{local_results_dir}/{tax_rank}/")):
    os.mkdir(f"{local_results_dir}/{tax_rank}/") 
pickle.dump(X,open(f"{local_results_dir}/{tax_rank}/X.pkl",'wb'))
pickle.dump(y_with_ints,open(f"{local_results_dir}/{tax_rank}/y.pkl",'wb'))

In [None]:
tax_rank = 'Magnoliopsida'  
new_X, new_y = [],[]
for x,label in zip(family_data,family_labels):
    if label in ['Malvaceae','Euphorbiaceae','Moraceae']:
        new_X.append(x)
        new_y.append(label)
        
#filter "None"s out of the data
family_data, family_labels = zip(*[(embd,label) for embd,label in zip(new_X,new_y) if embd is not None])

balanced_family_labels,balanced_family_data = make_balanced_sample(list(family_labels),list(family_data),cap=500)

label_to_int_map = {result[0]:result[1] for result in zip(np.unique(new_y),range(0,len(np.unique(new_y))))}
y_with_ints = [label_to_int_map[label] for label in new_y]

y = np.array(balanced_class_labels)
X = np.array(balanced_class_data)

if(not os.path.exists(f"{local_results_dir}/{tax_rank}/")):
    os.mkdir(f"{local_results_dir}/{tax_rank}/") 
pickle.dump(X,open(f"{local_results_dir}/{tax_rank}/X.pkl",'wb'))
pickle.dump(y_with_ints,open(f"{local_results_dir}/{tax_rank}/y.pkl",'wb'))