# 3-Pronged Classification Method for Taxonomic (3PCM)

In this project, we aim to classify astrovirus RNA sequences at genus level. Although this framework is designed for this type of virus, it can be used for any other viral classfification and clustering task.

The input is a collection of genomic sequences of astrovirus in `.fastsa` format.  The goal is to classify this collection and compare the output to the gound truth if its available.

Since the input sequences are RNA, the alphabet we are using is [A, C, G, U].

The pipeline of the project:

1.   Load and pre-process the data (Ulternatively we can use pickle files in which the data has been preprocessed and ready for the next steps.)
2.   Prong 1 (Supervised Classification)
3.   Prong 2 (Unsupervised Clustering)
4.   Prong 3 (Extract host information)
5.   Putting it all together


In [1]:
# dependencies
import pickle
import pandas as pd
from pathlib import Path

# install biopython, import it if available
try:
  import Bio
except:
  !pip install -q biopython
  import Bio

from Bio import SeqIO
print(f"biopython version: {Bio.__version__}")

biopython version: 1.79


## 1 Load and pre-process the data

In order to convert DNA sequences from `.fasta` format to `torch.utils.data.Dataset`, we need to go through the following initial steps:

* Read the data from `.fasta` file and convert it to `pandas.DataFrame`. The columns of the dataframe: accession ID, sequence, sequence length, *label*.
* Explore the dataset and distribution of different labels.
* Pre-process the data if necessary.

### 1.1 Read the data from `.fasta` file

For this purpose we will use `Biopython` library. `Biopython` is a collection of python modules that provide functions for deadling with bioinformatics data and functions for useful computational operations. It also contains a module for parsing sequence files (`SeqIO`).

`SeqIO.parse(file_path, format)` read in sequence files as `SeqRecord` object that contains following info:
* id - ID used to identify the sequence - a string
* seq - Seq object containing sequence

In [2]:
# Setup path to a data folder
data_path = Path("../data/")
fasta_path = data_path / "fasta_files"

In [3]:
# Get all the fatsa paths
fasta_path_list = list(fasta_path.glob("*.fasta"))

# Open the fasta file
seq_records = []
print(f"Reading the fasta files...")
for i in range(len(fasta_path_list)):
  seq_records.append(next(SeqIO.parse(fasta_path_list[i], "fasta")))

print(f"{len(seq_records)} sequence reconds has been read.")

Reading the fasta files...
992 sequence reconds has been read.


### 1.2 Extract the labels from metadata file

In [4]:
def extract_genus_label(seq_name: str,
                        metadata_df):
  """
  A function that extract the genus level label of a sequence record.
  """
  metadata_df = metadata_df.set_index(["Name"])
  return metadata_df.loc[seq_name, "Genus"]

In [5]:
def extract_host_as_genus_label(seq_name: str,
                                metadata_df):
  """
  A function that extract a potential genus level label of an unlabeled sequence
  record baed on its host.
  """
  metadata_df = metadata_df.set_index(["Name"])
  if metadata_df.loc[seq_name, "Class"] == "Mammalia":
    return "Mamastrovirus"
  elif metadata_df.loc[seq_name, "Class"] == "Aves":
    return "Avastrovirus"
  else:
    return metadata_df.loc[seq_name, "Class"]


### 1.3 Extract information about sequences and save data in a `pickle` file. The forrmat of the datasets is a lost of Tuple(label, sequence, accession_id).

In [6]:
# Create empty lists for dataset2 and dataset 3
dataset2 = []
dataset3 = []

metadata_genus_df = pd.read_csv(data_path / "metadata_genus.csv")
xl_file = pd.ExcelFile(data_path / "metadata_host.xlsx")
# extract different sheets in xl file
dfs = {sheet_name: xl_file.parse(sheet_name) for sheet_name in xl_file.sheet_names}
metadata_host_df = dfs["1039 documents"]


for seq_record in seq_records:
  label = extract_genus_label(seq_record.id, metadata_genus_df)
  sequence = str(seq_record.seq)
  if label != "Unknown":
    record = (label, sequence, seq_record.id)
    dataset2.append(record)
  else:
    label = extract_host_as_genus_label(seq_record.id, metadata_host_df)
    if label in ["Avastrovirus", "Mamastrovirus"]:
      record = (label, sequence, seq_record.id)
      dataset3.append(record)

In [7]:
from utilities import get_stats
train_data = dataset2
test_data = dataset3

get_stats(train_data, "Dataset 2");
get_stats(test_data, "Dataset 3");

-----------Statistics about Dataset 2: ------------
# of samples:  684
# of classes:  2
min seq length: 5003
mean seq length: 6629
max seq length: 7799
data distribution: 
Avastrovirus =>      213
Mamastrovirus =>      471
-------------------------------------------------
-----------Statistics about Dataset 3: ------------
# of samples:  229
# of classes:  2
min seq length: 5084
mean seq length: 6432
max seq length: 8417
data distribution: 
Avastrovirus =>       42
Mamastrovirus =>      187
-------------------------------------------------


In [8]:
training_data_file = "dataset2.p"
testing_data_file = "dataset3.p"

### 1.4 Saving the datasets in `pickle` files

In [10]:
with open(data_path / training_data_file, "wb") as f:
    pickle.dump(dataset2, f)
with open(data_path / testing_data_file, "wb") as f:
    pickle.dump(dataset3, f)

### 1.5 Reading the data from `pickle` files

In [11]:
from utilities import get_stats
training_data_file = "dataset2.p"
testing_data_file = "dataset3.p"
train_data = pickle.load(open(data_path / training_data_file, "rb"))
test_data = pickle.load(open(data_path / testing_data_file, "rb"))
get_stats(train_data, "Dataset 2");
get_stats(test_data, "Dataset 3");

-----------Statistics about Dataset 2: ------------
# of samples:  684
# of classes:  2
min seq length: 5003
mean seq length: 6629
max seq length: 7799
data distribution: 
Mamastrovirus =>      471
Avastrovirus =>      213
-------------------------------------------------
-----------Statistics about Dataset 3: ------------
# of samples:  229
# of classes:  2
min seq length: 5084
mean seq length: 6432
max seq length: 8417
data distribution: 
Mamastrovirus =>      187
Avastrovirus =>       42
-------------------------------------------------


## 2. Prong 1

In [15]:
import statistics

from prong1 import supervised_classification

In [16]:
"""
Select one of "10-nearest-neighbors", "nearest-centroid-mean",
"nearest-centroid-median", "logistic-regression", "linear-svm",
"quadratic-svm", "cubic-svm", "sgd", "decision-tree", "random-forest",
"adaboost", "gaussian-naive-bayes", "lda", "qda", "multilayer-perceptron"
"""
SUPERVISED_ALGORITHM = "linear-svm"
K = 6 # change this hyperparameter if needed

In [17]:
def extract_class_names(training_data_file):
    if training_data_file in ["dataset1.p", "dataset2.p", "dataset2_NR.p", "dataset3.p", "dataset3_NR.p"]:
        class_names = sorted(["Avastrovirus", "Mamastrovirus"])
    elif training_data_file == "potyvirus.p":
        class_names = sorted(["Astrovirus", "Potyvirus"])
    elif training_data_file == "mamastrovirus.p":
        class_names = sorted(["HAstV", "Non-HAstV Mamastroviurs"])
    elif training_data_file == "avastrovirus.p":
        class_names = sorted(["GoAstV", "Non-GoAstV Avastrovirus"])
    return class_names
class_names = extract_class_names(training_data_file)

In [18]:
prong1_pred = supervised_classification(training_data=train_data,
                                        class_names=class_names,
                                        testing_data=test_data,
                                        k=K,
                                        algorithm=SUPERVISED_ALGORITHM)

Prong 1 starting...
Classification algorithm: linear-svm
Accuracy: 96.0699%
Prong 1 predictions:
{'JF713711': 'Mamastrovirus', 'KX907134': 'Mamastrovirus', 'KX266901': 'Mamastrovirus', 'MK404648': 'Mamastrovirus', 'FJ375759': 'Mamastrovirus', 'MG660832': 'Mamastrovirus', 'LC577872': 'Mamastrovirus', 'OM480542': 'Mamastrovirus', 'OM451148': 'Mamastrovirus', 'KJ571486': 'Mamastrovirus', 'JF327666': 'Mamastrovirus', 'OM451114': 'Mamastrovirus', 'MN087316': 'Mamastrovirus', 'MN920672': 'Avastrovirus', 'JN420358': 'Mamastrovirus', 'KT224358': 'Mamastrovirus', 'OM451116': 'Mamastrovirus', 'MN920670': 'Avastrovirus', 'MT138009': 'Avastrovirus', 'KM254166': 'Avastrovirus', 'NC_018702': 'Mamastrovirus', 'OM480521': 'Mamastrovirus', 'LC201620': 'Mamastrovirus', 'LC577870': 'Mamastrovirus', 'KF233994': 'Mamastrovirus', 'MT138010': 'Mamastrovirus', 'MN920669': 'Avastrovirus', 'MN503236': 'Mamastrovirus', 'OK107512': 'Mamastrovirus', 'MZ603072': 'Mamastrovirus', 'JA816489': 'Avastrovirus', 'MK08943

## 3. Prong 2

In [19]:
from prong2 import unsupervised_clustering

In [20]:
UNSUPERVISED_ALGORITHM = "k-means++"
CLUSTER_COUNT = 2
K = 6
RUN_COUNT = 10

In [21]:
def extract_class_names(training_data_file):
    if training_data_file in ["dataset1.p", "dataset2.p", "dataset2_NR.p", "dataset3.p", "dataset3_NR.p"]:
        class_names = sorted(["Avastrovirus", "Mamastrovirus"])
    elif training_data_file == "potyvirus.p":
        class_names = sorted(["Astrovirus", "Potyvirus"])
    elif training_data_file == "mamastrovirus.p":
        class_names = sorted(["HAstV", "Non-HAstV Mamastroviurs"])
    elif training_data_file == "avastrovirus.p":
        class_names = sorted(["GoAstV", "Non-GoAstV Avastrovirus"])
    return class_names
class_names = extract_class_names(training_data_file)

In [22]:
prong2_pred = unsupervised_clustering(training_data=train_data,
                                      class_names=class_names,
                                      clusters_count=CLUSTER_COUNT,
                                      testing_data=test_data,
                                      k=K,
                                      algorithm=UNSUPERVISED_ALGORITHM)

Prong 2 starting...
Clustering algorithm: k-means++
Accuracy: 83.5808%
NMI: 0.2209
ARI: 0.3778
Silhouette Score: 0.0625
Prong 2 predictions:
{'JF713711': 'Mamastrovirus', 'KX907134': 'Mamastrovirus', 'KX266901': 'Mamastrovirus', 'MK404648': 'Mamastrovirus', 'FJ375759': 'Mamastrovirus', 'MG660832': 'Mamastrovirus', 'LC577872': 'Mamastrovirus', 'OM480542': 'Mamastrovirus', 'OM451148': 'Mamastrovirus', 'KJ571486': 'Mamastrovirus', 'JF327666': 'Mamastrovirus', 'OM451114': 'Mamastrovirus', 'MN087316': 'Mamastrovirus', 'MN920672': 'Mamastrovirus', 'JN420358': 'Mamastrovirus', 'KT224358': 'Avastrovirus', 'OM451116': 'Mamastrovirus', 'MN920670': 'Avastrovirus', 'MT138009': 'Avastrovirus', 'KM254166': 'Mamastrovirus', 'NC_018702': 'Mamastrovirus', 'OM480521': 'Mamastrovirus', 'LC201620': 'Mamastrovirus', 'LC577870': 'Mamastrovirus', 'KF233994': 'Mamastrovirus', 'MT138010': 'Mamastrovirus', 'MN920669': 'Avastrovirus', 'MN503236': 'Mamastrovirus', 'OK107512': 'Mamastrovirus', 'MZ603072': 'Mamastr

## 4. Prong 3

In [26]:
import pandas as pd

def host_identification(testing_data):
  xl_file = pd.ExcelFile(data_path / "metadata_host.xlsx")
  # extract different sheets in xl file
  dfs = {sheet_name: xl_file.parse(sheet_name) for sheet_name in xl_file.sheet_names}
  metadata_host_df = dfs["1039 documents"]
  df = metadata_host_df.set_index(["Name"])

  dict_y_pred_class = {}
  dict_y_pred_species = {}
  for label, seq, name in testing_data:
    dict_y_pred_class[name] = df.loc[name, "Class"]
    dict_y_pred_species[name] = df.loc[name, "Host"]

  print("Host (class level):")
  print(dict_y_pred_class)
  print("Host (species level):")
  print(dict_y_pred_species)
  print("--------------------------------------------------")
  return dict_y_pred_class, dict_y_pred_species

prong3_pred, blah = host_identification(testing_data=test_data);

Host (class level):
{'JF713711': 'Mammalia', 'KX907134': 'Mammalia', 'KX266901': 'Mammalia', 'MK404648': 'Mammalia', 'FJ375759': 'Mammalia', 'MG660832': 'Mammalia', 'LC577872': 'Mammalia', 'OM480542': 'Mammalia', 'OM451148': 'Mammalia', 'KJ571486': 'Mammalia', 'JF327666': 'Mammalia', 'OM451114': 'Mammalia', 'MN087316': 'Mammalia', 'MN920672': 'Aves', 'JN420358': 'Mammalia', 'KT224358': 'Mammalia', 'OM451116': 'Mammalia', 'MN920670': 'Aves', 'MT138009': 'Aves', 'KM254166': 'Aves', 'NC_018702': 'Mammalia', 'OM480521': 'Mammalia', 'LC201620': 'Mammalia', 'LC577870': 'Mammalia', 'KF233994': 'Mammalia', 'MT138010': 'Aves', 'MN920669': 'Aves', 'MN503236': 'Mammalia', 'OK107512': 'Mammalia', 'MZ603072': 'Mammalia', 'JA816489': 'Mammalia', 'MK089435': 'Mammalia', 'KX266903': 'Mammalia', 'LC549662': 'Mammalia', 'MZ357117': 'Mammalia', 'JX857869': 'Mammalia', 'MZ819166': 'Mammalia', 'MT138014': 'Aves', 'MK327365': 'Mammalia', 'KX266907': 'Mammalia', 'MG693175': 'Mammalia', 'KX907132': 'Mammalia'

## 5. Putting it all together

In [27]:
for name in prong1_pred.keys():
  print(f"Predictions for {name}:\nProng 1: {prong1_pred[name]}\nProng 2: {prong2_pred[name]}\nProng 3: Host {prong3_pred[name]}")
  print("--------------------------------------------------")

Predictions for JF713711:
Prong 1: Mamastrovirus
Prong 2: Mamastrovirus
Prong 3: Host Mammalia
--------------------------------------------------
Predictions for KX907134:
Prong 1: Mamastrovirus
Prong 2: Mamastrovirus
Prong 3: Host Mammalia
--------------------------------------------------
Predictions for KX266901:
Prong 1: Mamastrovirus
Prong 2: Mamastrovirus
Prong 3: Host Mammalia
--------------------------------------------------
Predictions for MK404648:
Prong 1: Mamastrovirus
Prong 2: Mamastrovirus
Prong 3: Host Mammalia
--------------------------------------------------
Predictions for FJ375759:
Prong 1: Mamastrovirus
Prong 2: Mamastrovirus
Prong 3: Host Mammalia
--------------------------------------------------
Predictions for MG660832:
Prong 1: Mamastrovirus
Prong 2: Mamastrovirus
Prong 3: Host Mammalia
--------------------------------------------------
Predictions for LC577872:
Prong 1: Mamastrovirus
Prong 2: Mamastrovirus
Prong 3: Host Mammalia
-----------------------------