### Embedding and Cosine Similarity

In this section we are collecting the sequences to embed them using a Pretrained model from Rostlab's [ProtTrans](https://github.com/agemagician/ProtTrans). After this embedding the cosine similarity of each embedded protein sequence is achieved and graphed as a heatmap.

In [1]:
# BioPython library for collecting the sequences from cif files
from Bio.PDB import PDBList
from Bio.PDB.MMCIFParser import MMCIFParser

In [2]:
# Data manipulation libraries
import os
import re
import io

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import Row
from pyspark.sql.types import ArrayType, DoubleType, MapType
from pyspark.ml.linalg import DenseVector, VectorUDT

import pandas as pd
import numpy as np

In [3]:
# Creating the spark session
spark = SparkSession.builder \
    .master("spark://master:7077")\
    .appName("Proteindata spark application")\
    .config("spark.executor.memory", "4096m")\
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/11/27 11:33:09 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:
# Creating the spark context
sc = spark.sparkContext

### Getting the file paths

In [5]:
# Getting the file paths
base_path = '/data_files/cif_files'
base_path_edit = '/data_files/cif_files/{}'
file_names = os.listdir(base_path)
file_list = [base_path_edit.format(i) for i in file_names]
files_rdd = sc.parallelize(file_list)

In [6]:
# Counting all of the paths to see if there are any errors
files_rdd.count()

                                                                                

1712

### Parsing the file

#### Parsing the files using the Biopython library to get the sequence. The output of this function is the id of the protein, sequence and the length. 

In [7]:
def parse_file(file):

    cif_parser = MMCIFParser(QUIET=True) # CIF file parser
    length = 0 # Setting the length initially to 0 for error correction
    name = file.split('/')[3].split('.')[0] # Getting the id of the protein
    structure = cif_parser.get_structure("protein", file) # getting structure ? try "protein"

    # Dictionary for residue names
    d3to1 = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
    'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N',
    'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W',
    'ALA': 'A', 'VAL':'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'}

    # The output of the cif parser needs to be looped in order to get the sequence itself    
    for model in structure:
        for chain in model:
            sequence = [d3to1.get(residue.get_resname(), 'X') for residue in chain.get_residues()]
            length = len(sequence)
    
    return name,sequence,length

### Creating a dataframe 
#### A dataframe containing the id, length and the token of the each sequence

In [8]:
# Creating an RDD for the tokens
def tokens_df_creator(file_path):

    data = []
    name, sequence, length = parse_file(file_path)
    row_value = {
        'id':name,
        'length':length,
        'tokens':sequence,
    }
    #if 32 <= row_value['length'] <= 256:
    data.append(Row(**row_value))
    
    return data

# Turning the RDD into a DF for easier usage
tokens_rdd = files_rdd.flatMap(tokens_df_creator) # FlatMap applied to the RDD


In [9]:
tokens_df = tokens_rdd.toDF()

                                                                                

#### Checking output

In [None]:
tokens_df.take(1)

In [None]:
tokens_df.where(tokens_df.length==0).show()
# running twice showString at NativeMethodAccessorImpl.java:0
#8 minutes
#12 minutes ?

### Frequency analysis for every sequence

#### Here I am going to create a dataframe consisting of the frequency analysis of each sequence. The dataframe will consist of columns as, sequence id, most frequent amino acid (mf_aa), most frequent amino acid percentage (mf_aa_freq).

In [None]:
tokens_df.printSchema()

In [None]:
temp = (
    tokens_df
    .withColumn("Dist", F.array_distinct("Tokens"))  # Get distinct tokens for the current sequence
    .withColumn(
        "Counts",
        F.expr(
            """
            transform(
                Dist,
                x -> aggregate(
                    Tokens,
                    0,
                    (acc, y) -> IF(y = x, acc + 1, acc)
                )
            )
            """
        )  # Count the frequencies of each token
    )
    .withColumn(
        "Map",
        F.arrays_zip("Dist", "Counts")  # Combine tokens and their counts into an array of structs
    )
    .drop("Dist", "Counts")  # Drop intermediate columns
)

In [None]:
# Step 2: Add sorted map, most frequent, and second most frequent tokens
freq_df = temp.withColumn(
    "SortedMap",  # Sort the Map by counts in descending order
    F.expr(
        """
        array_sort(
            Map,
            (first, second) -> CASE WHEN first['Counts'] > second['Counts'] THEN -1 ELSE 1 END
        )
        """
    )
).withColumn(
    "mf_aa",  # Most frequent token as a list
    F.expr("array(SortedMap[0]['Dist'])")
).withColumn(
    "mf_aa_count",  # Count of the most frequent token
    F.expr("SortedMap[0]['Counts']")
).withColumn(
    "mf_aa_freq",  # Frequency of the most frequent token
    F.col("mf_aa_count") / F.size(F.col("Tokens"))  # Divide count by total number of tokens
).drop('tokens','length','Map','SortedMap','mf_aa_count')

In [None]:
freq_df.printSchema()

In [None]:
freq_pd = freq_df.toPandas()
# Maybe to csv?

### Histogram of most frequent tokens

In [None]:
freq_pd['mf_aa'].hist()

### Histogram of the frequency of most frequent tokens

In [None]:
freq_pd['mf_aa_freq'].hist()

### Creating the .vec file using ProtBert

#### Here I shortly showed what I did in order to achieve the .vec file. Since the ProtBert model didn't have such file I looped through the residues and created the vocabulary. 

```python
import numpy as np
#transformer model for embedding space creation
from transformers import BertModel, BertTokenizer
import torch

tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False ) # change model and tokenizer to t5
model_embedd = BertModel.from_pretrained("Rostlab/prot_bert")

vocab = ['L','A','G','V','E','S','I','K','R','D','T','P','N','Q','F','Y','M','H','C','W','X','U','B','Z','O']
# Path to save the .vec file
vec_file_path = 'prot_bert.vec'

# Open the file in write mode
with open(vec_file_path, 'w') as f:
    # Write the header (vocab size and vector dimension)
    f.write(f"{len(vocab)} {outputs.last_hidden_state.size()[2]}\n")
    
    # Write each word and its corresponding vector
    for letter in vocab:
        encoded_input = tokenizer.encode(letter,return_tensors='pt').to(device)
        outputs = model_embedd(input_ids=encoded_input)
        vector = outputs.last_hidden_state[0,0].detach().numpy()
        vector_str = ' '.join(map(str, vector))  # Convert the vector to a string
        f.write(f"{letter} {vector_str}\n")
```

#### The problem with this implementation is that all the semantic information between the residues in a sequence is lost since the .vec file is created for each residue. This is highly affecting our output since this semantic knowledge contain ancestoral, functional, structural and many more information.

### Loading the embedding pre_trained model

#### We loaded the .vec file as a standard python dictionary which is at first only available in the driver node. But using the Pyspark's broadcast function this dictionary can be broadcasted to all of the nodes which enables the executors to have the dictionaries in their local environment.

In [10]:
# Creating the dictionary from the .vec file
def load_vectors(fname):
    fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
    data = {}
    for line in fin:
        letter_token = line.rstrip().split()
        data[letter_token[0]] = DenseVector([float(letter) for letter in letter_token[1:]])
    return data

vec_dict = load_vectors('/data_files/prot_bert.vec')

In [11]:
# Broadcasting the dictionary
vec_broadcast = sc.broadcast(vec_dict)

### Creating list of vectors

In [12]:
# Embedding the sequence
# The embedding is done using a UDF an alternative idea might be by using a dataframe and giving conditions ?
@F.udf(ArrayType(VectorUDT()))
def embed_sequence(tokens_list):
    return [vec_broadcast.value[token] for token in tokens_list if token in vec_dict]


In [13]:
# The embeddings are added a new column (UDF applied to the DataFrame)
tokens_df = tokens_df.withColumn("embeddings",embed_sequence(tokens_df.tokens))

#### Output check

In [None]:
tokens_df.take(1)

In [None]:
out_modified.take(1)

### Taking the mean of the list of the vectors (len(sequence) x 1024) to achieve one vector with length 1024

##### To reduce the dimensions of the embeddings we used the same aproach as Rostlab did in order to achieve an embedding for the protein sequence.
$$

$$

In [14]:
# Creating another udf to get the mean of each embedding row
@F.udf(VectorUDT())
def mean_calculator(embedding,length):
    mean_embedding = sum(embedding)/length
    return mean_embedding   

# Created mean embedding is added as a new column
tokens_df = tokens_df.withColumn("mean_embed",mean_calculator(tokens_df.embeddings,tokens_df.length))

In [15]:
# Selecting a subset of the tokens_df in order to persist in the memory for future usage
mean_embed_rdd = tokens_df.select("id","mean_embed").rdd
mean_embed_rdd = mean_embed_rdd.persist() # Persisting to get the data quickly since the cosine similarity is done on this dataframe
mean_embed_rdd.take(3)

[Stage 2:>                                                          (0 + 1) / 1]

Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.runJob.
: org.apache.spark.SparkException: Job 2 cancelled 
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.handleJobCancellation(DAGScheduler.scala:2731)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3013)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2393)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2414)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2433)
	at org.apache.spark.api.python.PythonRDD$.runJob(PythonRDD.scala:181)
	at org.apache.spark.api.python.PythonRDD.runJob(PythonRDD.scala)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:569)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:840)


In [None]:
# Creating a DF using the persisted RDD
mean_embed_df = mean_embed_rdd.toDF()
#mean_embed_df.printSchema(2)

24/11/27 11:51:02 WARN TaskSetManager: Lost task 0.0 in stage 2.0 (TID 3) (10.67.22.219 executor 2): TaskKilled (Stage cancelled: Job 2 cancelled )


#### Checking output

In [None]:
take5 = mean_embed_df.take(5)

In [None]:
print(len(take5[3][1]))

In [None]:
print(take5[1][0])

### Cosine similarity

#### Here we created another udf for the calculation of the cosine similarity for each protein sequence. With this we can see if there are any similarities between protein sequences.

In [None]:
def cos_sim_local(a, b):
    return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))

In [None]:
cartesian_rdd = mean_embed_rdd.cartesian(mean_embed_rdd)
upper_triangle_rdd = cartesian_rdd.filter(lambda x: x[0][0] < x[1][0])

In [None]:
result_rdd = upper_triangle_rdd.map(
    lambda pair: (pair[0][0], pair[1][0], cos_sim_local(pair[0][1], pair[1][1]))
)

# Convert to DataFrame
result_df = spark.createDataFrame(result_rdd, ["row_idx1", "row_idx2", "dot_product"])
result_df.count()

I don't know if it's necessary

In [None]:
path = '/data_files/output_files/' 
result_df.repartition('row_idx1').write.partitionBy('row_idx1').csv(path)

Go through each file and find the cos sim less than 98 collect the pairs

In [None]:
num_partitions = 100  # Adjust based on the size of your dataset and cluster resources
repartitioned_df = result_df.repartition('row_idx1')
filtered_df = repartitioned_df.filter(result_df["dot_product"] < 0.98)

In [None]:
filtered_df.count()

In [None]:
filtered_df2 = repartitioned_df.filter(result_df["dot_product"] < 0.95)
df2 = filtered_df2.toPandas()

In [None]:
cos_df

#### For heatmap

In [None]:
df2.take(2)

In [None]:
cos_df = result_df

In [None]:
cos_df = df2_matrix.sort_values('row_idx1')

#cod_df = cos_df.set_index('id')


In [None]:
cos_df.head()
    

#### Heatmap trial

In [None]:
import seaborn as sns
data2 = df2_matrix.values.flatten().tolist()
data2_array = np.array(data2).reshape(len(df2_matrix.columns),len(df2_matrix.columns))
sns.heatmap(data2_array)

Here we see that the embeddings look very similar. I will continue the invastigation by checking the distribution of amino acids in the most similar and least similar sequence couples.

In [None]:
mean_embed_rdd.unpersist()

In [None]:
sc.stop()

In [None]:
spark.stop()