README:  
PhaNNs expandedDB downloaded (https://phanns.com/download/expandedDB.tgz) on 01172023  
all_sequence_ids_to_vectors_dict.pkl is a dictionary that contains protbert_bfd embeddings for the PhaNNs sequences and is available for download from the KellyLab GCP in the viral_protein_family_plm_embeddings bucket

In [None]:
from sklearn.preprocessing import LabelBinarizer, LabelEncoder
from sklearn.metrics import classification_report
from sklearn.metrics import recall_score, precision_score, f1_score
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import random
import os
from tqdm import tqdm

# collecting and organizing train and test sequences

In [None]:
def get_faa(path: str):
    seqs = []
    seq = []

    with(open(path)) as file:
        for line in file:
            line = line.rstrip()
            if line.startswith('>'):
                if len(seq) > 0:
                    seqs.append( ''.join(seq))
                    seq = []
            else:
                seq.append(line)
    seqs.append(''.join(seq))
    return seqs

def get_faa_identifier(path: str):
    idents = []
    with(open(path)) as file:
        for line in file:
            line = line.rstrip()
            if line.startswith('>'):
                idents.append(line)
    return idents

In [None]:
cats = [
'HTJ',
'portal',
'major_tail',
'major_capsid',
'minor_tail',
'baseplate',
'collar',
'shaft',
'other',
'tail_fiber',
'minor_capsid'
]

In [None]:
data_dir = 'expandedDB'

In [None]:
training_dict = {}
testing_dict = {}

for d in cats:
    idents_train = []
    idents_test = []
    files = [x for x in os.listdir(data_dir) if d in x]
    for f in files:
        if f.startswith('.'):
            continue
        if f.split('_')[0] == '11':
            idents_test.append(get_faa_identifier('{0}/{1}' ''.format(data_dir,f)))
        else:
            idents_train.append(get_faa_identifier('{0}/{1}' ''.format(data_dir,f)))
    
    idents_train = [l for ll in idents_train for l in ll]
    idents_test = [l for ll in idents_test for l in ll]
    
    print(d)
    print('total number of sequences in training 10 sets: {0}' ''.format(len(idents_train)))
    print('total number of sequences in testing 1 sets: {0}' ''.format(len(idents_test)))
    print('\n')
    
    training_dict[d] = idents_train
    testing_dict[d] = idents_test

# load phann sequences plm embeddings

In [None]:
seqs_to_vectors = pickle.load(open('all_sequence_ids_to_vectors_dict.pkl', 'rb'))

In [None]:
pvp_training_vecs = []
other_training_vecs = []
for k in training_dict.keys():
    if k != 'other':
        pvp_training_vecs.append([seqs_to_vectors[s] for s in training_dict[k]])
    else:
        other_training_vecs.append([seqs_to_vectors[s] for s in training_dict[k]])

pvp_training_vecs = [v for l in pvp_training_vecs for v in l]
other_training_vecs = [v for l in other_training_vecs for v in l]

In [None]:
len(other_training_vecs)

In [None]:
len(pvp_training_vecs)

In [None]:
pvp_testing_vecs = []
other_testing_vecs = []
for k in testing_dict.keys():
    if k != 'other':
        pvp_testing_vecs.append([seqs_to_vectors[s] for s in testing_dict[k]])
    else:
        other_testing_vecs.append([seqs_to_vectors[s] for s in testing_dict[k]])

pvp_testing_vecs = [v for l in pvp_testing_vecs for v in l]
other_testing_vecs = [v for l in other_testing_vecs for v in l]

In [None]:
len(other_testing_vecs)

In [None]:
len(pvp_testing_vecs)

In [None]:
## fit a label binarizer to the classes, need to have this done before splits to the categories are the same in each split
cats = ['pvp', 'other']
lb = LabelEncoder()
cats = lb.fit_transform(cats)

In [None]:
pvp_training_label = ['pvp'] * len(pvp_training_vecs)
other_training_label = ['other'] * len(other_training_vecs)

pvp_testing_label = ['pvp'] * len(pvp_testing_vecs)
other_testing_label = ['other'] * len(other_testing_vecs)

In [None]:
all_training_vecs = np.concatenate((pvp_training_vecs, other_training_vecs), axis=0)
all_testing_vecs = np.concatenate((pvp_testing_vecs, other_testing_vecs), axis=0)

all_training_labels = np.concatenate((pvp_training_label, other_training_label), axis=0)
all_testing_labels = np.concatenate((pvp_testing_label, other_testing_label), axis=0)

# FNN

In [None]:
trainX = all_training_vecs
testX = all_testing_vecs

In [None]:
trainY = lb.fit_transform(all_training_labels)
trainY = to_categorical(trainY)

testY = lb.fit_transform(all_testing_labels)
testY = to_categorical(testY)

In [None]:
model = Sequential()
model.add(Dense(512, input_shape=(1024,), activation="relu"))
model.add(Dropout(0.2))
model.add(Dense(256, input_shape=(512,), activation="relu"))
model.add(Dropout(0.2))
model.add(Dense(128, input_shape=(256,), activation="relu"))
model.add(Dropout(0.2))
model.add(Dense(9, input_shape=(128,), activation="relu"))
model.add(Dense(2, activation="softmax"))

In [None]:
n_epoch = 5
opt = Adam(0.00001)
model.compile(loss="binary_crossentropy", optimizer=opt, metrics=["accuracy"])
H = model.fit(trainX, trainY, epochs=n_epoch, batch_size=60)

In [None]:
predictions = model.predict(testX)

In [None]:
ax = sns.heatmap(predictions)
plt.show()

In [None]:
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), target_names=[str(x) for x in lb.classes_]))

In [None]:
print('Precision: {0}' ''.format(precision_score(testY.argmax(axis=1), predictions.argmax(axis=1))))
print('Recall: {0}' ''.format(recall_score(testY.argmax(axis=1), predictions.argmax(axis=1))))
print('F1: {0}' ''.format(f1_score(testY.argmax(axis=1), predictions.argmax(axis=1))))