In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt 
import seaborn as sns
import sklearn
from Bio.Seq import Seq
from transformers import TFBertModel, BertTokenizer,BertConfig
import re
import pickle
import sys
import gc
import os

# np.random.seed(42)
# tf.random.set_seed(42)
# os.environ['PYTHONHASHSEED']=str(42)
tf.keras.utils.set_random_seed(42)


In [2]:
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False )

In [3]:
embedding_model = TFBertModel.from_pretrained("Rostlab/prot_bert_bfd", from_pt=True)

Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB



2023-01-22 12:24:07.282927: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-01-22 12:24:07.283103: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing TFBertModel from a PyTorc

In [4]:
def get_embeddings(sequence_list):
    sequence_list = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequence_list]
    ids = tokenizer.batch_encode_plus(sequence_list, add_special_tokens=True, padding=True, return_tensors="tf")
    input_ids = ids['input_ids']
    attention_mask = np.asarray(ids['attention_mask'])
    embedding = np.asarray(embedding_model(input_ids)[0])
    average_embeddings = []
    for seq_num in range(len(embedding)):

        seq_len = (attention_mask[seq_num] == 1).sum()
        residues = embedding[seq_num][1:seq_len-1]
        average = np.mean(residues, axis=0)
        average_embeddings.append(average)

    return np.asarray(average_embeddings)

In [5]:
sequences_Example = ["A E T C Z A O","S K T Z P"]

average_embeddings = get_embeddings(sequences_Example)
print(average_embeddings)

[[ 0.07877532 -0.07100235 -0.03651526 ...  0.02056815  0.052767
   0.07941834]
 [ 0.02623458 -0.10717591 -0.07274815 ... -0.04052575  0.00889488
  -0.01958423]]


In [6]:
del average_embeddings
gc.collect()

8

In [7]:
# Note that row 10467 of the COVID-19 sequences was deleted due to having an empty HCDR3. The total number of COVID-19 samples is now 11,867

In [8]:
tf.__version__

'2.10.0'

# Data Formatting

In [9]:
df = pd.read_csv("/Volumes/Seagate Portable Drive/Unpaired_COVID/total.csv")
df = df[["sequence_alignment_aa"]]
df = df.sample(200000, random_state=42)

In [11]:
dummy = []
head = []
with open("../Data/cAb-rep/cAb-Rep_heavy.nt.txt") as myfile:
    # count = 0
    for i in myfile:
        # if count <= 1:
        #     print(i)
        #     if i.find(">") == -1 & i.find("-") == -1:
        #         print(Seq.translate(i.strip()))
        #     count+=1
        dummy.append(i)
    np.random.shuffle(dummy)
    
    for i in dummy:
        if i.find(">") == -1 & i.find("-") == -1 & i.find("N") == -1: # These conditions must be met for a valid sequence, the longest was 141. However, there is no 141 sequence for COVID, the greatest is 138, so we go with that
            aa_sequence = Seq.translate(i.strip())
            if (len(aa_sequence) <= 138) & (len(aa_sequence) >= 100):
                head.append(aa_sequence)
                if len(head) >= 200000:
                    break
print(head[:5], len(head))
healthy_sequences = head



['EVQLVQSGPEVKKPGSSVKVSCKASGGTFSNFAFSWVRQAPGQGLEWMGSVILHLGTSTYAQKFQGRVTITADESTSAAFMDLNALTSDDTAVYYCARVVAVPGRVPYWFDPWGQGTLVTVSS', 'TLSLTCAVYGGSFSGYYWSWIRQPPGKGLEWIGEINHSGSTNYNPSLKSRVTISVDTSKNQFSLKLSSVTAADTAVYYCARVPPTSTVTTLGDDYWGQGTLVTVSS', 'QVQLVQSGPEVKKPGASVRVSCKPSGYPFSNYGISWMRQAPGQGLEWMGWVNIDKGNTKYAQKFQDRVTMTTDTSSSTVYLELRSLRSDDTALYYCARERGGYRYGDYWGQGTLVIVSS', 'TLSLTCAVYGGSFSGYYWSWIRQPPGKGLEWIGEIKHSGSTNYIPSLKSRVTISVDTSKNQFSLKLSSVTAADTAVYYCASRAGAAAASWGQGTLVTVSS', 'SETLSLTCAVHGGSFSDYYWTWIRQPPGKGLEWIGEINHRGGTNYNPSLKSRLNILVDTSKSQFSLKLSSVTAADTAVYFCARERFILIRGLTKYYYYMDVWGKGTTVTVS'] 200000


In [12]:
del head
del myfile
del dummy
gc.collect()

0

In [None]:
covid_sequences = df.to_numpy()
covid_sequences = np.squeeze(covid_sequences)
# np.random.shuffle(covid_sequences)
# print(len(max(healthy_sequences, key=len)))
print(len(max(covid_sequences, key=len)))

139


In [None]:
del df
gc.collect()

0

In [None]:
# order from https://www.ncbi.nlm.nih.gov/Class/MLACourse/Modules/MolBioReview/iupac_aa_abbreviations.html

# Preprocessing

In [None]:
if os.path.exists('/Volumes/Seagate Portable Drive/healthy_embeddings') == False:
    os.mkdir('/Volumes/Seagate Portable Drive/healthy_embeddings')

In [None]:
if os.path.exists('/Volumes/Seagate Portable Drive/covid_embeddings') == False:
    os.mkdir('/Volumes/Seagate Portable Drive/covid_embeddings')

In [None]:
# for i in range(20):
#     covid_sequences_new = [(" ".join(s)) for s in covid_sequences][round(len(covid_sequences) * 0.05 * i):round(len(covid_sequences) * 0.05 * (i+1))]
#     covid_average_embeddings, covid_residue_embeddings = get_embeddings(covid_sequences_new)

#     with open("/Volumes/Seagate Portable Drive/covid_embeddings/" + str(i) + ".pkl", "wb") as f:
#         pickle.dump([covid_average_embeddings, covid_residue_embeddings], f)

#     del covid_sequences_new
#     del covid_average_embeddings
#     del covid_residue_embeddings
#     print("Finished embeddings for " +i+ " of 100")
#     gc.collect()

In [None]:
for i in range(100):
    healthy_sequences_new = [(" ".join(s)) for s in healthy_sequences][round(len(healthy_sequences) * 0.01 * i):round(len(healthy_sequences) * 0.01 * (i+1))]
    healthy_average_embeddings = get_embeddings(healthy_sequences_new)

    with open("/Volumes/Seagate Portable Drive/healthy_embeddings/" + str(i) + ".pkl", "wb") as f:
        pickle.dump([healthy_average_embeddings], f)

    del healthy_sequences_new
    del healthy_average_embeddings
    print("Finished embeddings for " +str(i)+ " of 100")
    gc.collect()

In [None]:
for i in range(100):
    covid_sequences_new = [(" ".join(s)) for s in covid_sequences][round(len(covid_sequences) * 0.01 * i):round(len(covid_sequences) * 0.01 * (i+1))]
    covid_average_embeddings = get_embeddings(covid_sequences_new)

    with open("/Volumes/Seagate Portable Drive/covid_embeddings/" + str(i) + ".pkl", "wb") as f:
        pickle.dump([covid_average_embeddings], f)

    del covid_sequences_new
    del covid_average_embeddings
    print("Finished embeddings for " +str(i)+ " of 100")
    gc.collect()