In [None]:
import os
import re
import io
import pandas as pd
import numpy as np

#phi and psi
from Bio.PDB.MMCIFParser import MMCIFParser
from Bio import PDB
from typing import Dict, Tuple

# import the python libraries to create/connect to a Spark Session
from pyspark.sql import SparkSession, Row
from pyspark.sql.types import MapType, ArrayType, StringType, FloatType#, IntegerType, FloatType, StringType
from pyspark.sql.functions import udf, pandas_udf, col
from pyspark.ml.linalg import Vectors
from pyspark.mllib.linalg.distributed import RowMatrix
from pyspark.ml.linalg import VectorUDT

In [None]:
spark = SparkSession.builder \
    .master("spark://master:7077")\
    .appName("Proteindata spark application")\
    .config("spark.executor.memory", "4096m")\
    .getOrCreate()

In [None]:
print(spark)
# create a spark context
sc = spark.sparkContext
# print its status
print(sc)

In [None]:
base_path = '/data_files/cif_files'
base_path_edit = '/data_files/cif_files/{}'
file_names = os.listdir(base_path)
#file_list = [base_path_edit.format(i) for i in file_names]]
small_file_list = [base_path_edit.format(i) for i in file_names][:129]
#files_rdd = sc.parallelize(file_list)
small_files_rdd = sc.parallelize(small_file_list)

In [None]:
small_files_rdd.take(3)

In [None]:
def parse_file(file):
    #keys = names
    sequence = str()
    name = str()
    length = int()
    with open(file, 'r') as fin:
        lines = fin.readlines()
           
    for i,line in enumerate(lines):
        #Cleaning the end
        line.replace('\n','')
        #Getting the id
        id_line = re.findall(r'^_entry.id.*',line)
        if len(id_line) != 0:
            name = line.replace(' ','').split('id')[1].replace('\n','')

        #Getting the sequence
        seq_line = re.findall(r'^_entity_poly\.pdbx_seq_one_letter_code_can\s{3}.*',line)
        if len(seq_line) != 0:
            seq_line_0 = re.findall(r'^_entity_poly\.pdbx_seq_one_letter_code_can\s{3}\S.*',line)
            if len(seq_line_0) != 0:
                sequence = seq_line[0].split('can')[1].replace('\n','').replace(' ','')
            else:
                sequence = lines[i+1].split(' ')[0].replace(';','').replace('\n','').replace(' ','')
                if ';' not in lines[i+2]:
                    sequence = sequence + lines[i+2].split(' ')[0].replace('\n','').replace(' ','')
                elif ';' not in lines[i+3]:
                    sequence = sequence + lines[i+3].split(' ')[0].replace('\n','').replace(' ','')
                    
        length = len(sequence)
        
    return name,sequence,length

In [None]:
def foo(l, dtype=float):  
    return list(map(dtype, l))

def angle_transformer(file_path):
    file_model = file_path.split(".")[0]
    cif_parser = MMCIFParser()
    structure = cif_parser.get_structure(file_model, file_path)
    structure.atom_to_internal_coordinates() # turns xyz coordinates into angles and bond lengths
    chain:PDB.Chain.Chain = list(structure.get_chains())[0]#iterator of chains, turns it into list, [0] first chain

    ic_chain: PDB.internal_coords.IC_Chain = chain.internal_coord #this access the internal chain coords of the chain object

    d: Dict[Tuple[PDB.internal_coords.AtomKey,
              PDB.internal_coords.AtomKey,
              PDB.internal_coords.AtomKey,
              PDB.internal_coords.AtomKey],
        PDB.internal_coords.Dihedron] = ic_chain.dihedra

    cnt = 1
    phi_angles_list = []
    psi_angles_list = []

    for key in d:
        if key[0].akl[3] == 'N' and key[1].akl[3] == 'CA' and key[2].akl[3] == 'C' and key[3].akl[3] == 'N':
            phi_angles_list.append(d[key].angle)
        elif key[0].akl[3] == 'CA' and key[1].akl[3] == 'C' and key[2].akl[3] == 'N' and key[3].akl[3] == 'CA':
            psi_angles_list.append(d[key].angle)

    
    
    psi_angles_list.append(0)
    psi_angles_list = foo(psi_angles_list)
    #psi = np.asarray(psi_angles_list,dtype=np.float32)*(np.pi/180)

    phi_angles_list.append(0)
    phi_angles_list = foo(phi_angles_list)
    #phi = np.asarray(phi_angles_list,dtype=np.float32)*(np.pi/180)
    return phi_angles_list,psi_angles_list

In [None]:
def split_aa(sequence):
    sequence = ' '.join(sequence)
    aa_s = sequence.split(' ')
    return aa_s

def tokens_df_creator(file_path):

    data = []
    name, sequence, length = parse_file(file_path)
    row_value = {
        'id':name,
        'length':length,
        'tokens':split_aa(sequence),
    }
    if row_value['length'] <= 128:
        data.append(Row(**row_value))
    
    return data

tokens_rdd = sc.parallelize(small_file_list)\
        .flatMap(tokens_df_creator)



def angles_df_creator(file_path):

    data = []
    name, sequence, length = parse_file(file_path)
    phi, psi = angle_transformer(file_path)
    row_value = {
        'id':name,
        'phi':Vectors.dense(phi),
        'psi':Vectors.dense(psi)
    }
    if length <= 128:
        data.append(Row(**row_value))
    
    return data



angles_rdd = sc.parallelize(small_file_list)\
        .flatMap(angles_df_creator)


In [None]:
tokens_df = tokens_rdd.toDF()
angles_df = angles_rdd.toDF()

In [None]:
def load_vectors(fname):
    fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
    data = {}
    for line in fin:
        letter_token = line.rstrip().split()
        data[letter_token[0]] = [float(letter) for letter in letter_token[1:]]
    return data

vec_dict = load_vectors("/data_files/prot_bert.vec")

In [None]:
vec_broadcast = sc.broadcast(vec_dict)

In [None]:
@udf(ArrayType(ArrayType(FloatType())))
def embed_sequence(tokens):
    if len(tokens) < 128:
        local_vec_dict = vec_broadcast.value
        tokens_list = tokens
        padding = ['X' for i in range(128-len(tokens_list))]
        tokens_list.extend(padding)
    return [local_vec_dict[token] for token in tokens_list if token in vec_dict]

In [None]:
tokens_df =tokens_df.withColumn("embeddings",embed_sequence(tokens_df.tokens))
tokens_df.take(3)

In [None]:
def first_qr(embedding):
    Q, R = np.linalg.qr(embedding)
    return R
    
#need to collect R's and stack them up we can just "append" them maybe at the end?
#maybe reduce with a function?

def second_qr(R):
    Q_til, R_til = np.linalg.qr(R)
    return R_til
#now I think it can just be collected

def funcky(tokens):
    
    first_qr(embedding)

embedding = embed_sequence(small_df.tokens)
embedding_rdd = embedding.repartition() #create nxn in each partition

r_rdd = embedding_rdd.map(first_qr).reduce(stack_them)
r_global = r_rdd.reshufle().map(second_qr).collect()

U, eps, V_T = r_global.svd()