In [1]:
%load_ext autoreload
%autoreload 2

In [43]:
import argparse
import gzip
import os
import sys
import time
import io
# from multiprocessing import Pool
import multiprocessing

import numpy as np
import pandas as pd
from Bio.PDB import PDBParser
from tqdm.auto import tqdm
import utils.gcs_utils
import utils.spark_utils as sprk
from utils.proteins import *

In [152]:
tqdm.pandas()

In [3]:
import numpy as np
from bio_embeddings.embed import SeqVecEmbedder, ProtTransBertBFDEmbedder  

In [4]:
local_dir = "/Users/skyler.roh/Downloads/UP000005640_9606_HUMAN"
files = sorted([f"{local_dir}/{f}" for f in os.listdir(local_dir) if f.endswith(".cif.gz")])
fn = get_protein_sequence_from_cif

In [71]:
series = []
for file in tqdm(files):
    with gzip.open(file, 'rt') as f:
        content = f.read()
        series.append(get_protein_sequence_from_cif(content))
pd.concat(series)

  0%|          | 0/23391 [00:00<?, ?it/s]

Unnamed: 0,db_code,db_name,entity_id,id,pdbx_align_begin,pdbx_db_accession,pdbx_db_isoform,pdbx_seq_one_letter_code
0,A0A024R1R8_HUMAN,UNP,1,1,1,A0A024R1R8,,MSSHEGGKKKALKQPKKQAKEMDEEEKAFKQKQKEEQKKLEVLKAK...
0,NUD4B_HUMAN,UNP,1,1,1,A0A024RBG1,,MMKFKPNQTRTYDREGFKKRAACLCFRSEQEDEVLLVSSSRYPDQW...
0,A0A024RCN7_HUMAN,UNP,1,1,1,A0A024RCN7,,MERSFVWLSCLDSDSCNLTFRLGEVESHACSPSLLWNLLTQYLPPG...
0,A0A075B6H5_HUMAN,UNP,1,1,1,A0A075B6H5,,METVVTTLPREGGVGPSRKMLLLLLLLGPGSGLSAVVSQHPSRVIC...
0,KV37_HUMAN,UNP,1,1,1,A0A075B6H7,,MEAPAQLLFLLLLWLPDTTREIVMTQSPPTLSLSPGERVTLSCRAS...
...,...,...,...,...,...,...,...,...
0,A3LT2_HUMAN,UNP,1,1,1,U3KPV4,,MALKEGLRAWKRIFWRQILLTLGLLGLFLYGLPKFRHLEALIPMGV...
0,V9GZ13_HUMAN,UNP,1,1,1,V9GZ13,,MKNTSWIRKNWLLVAGISFIGVHLGTYFLQRSAKQSVKFQSQSKQK...
0,SACA6_HUMAN,UNP,1,1,1,W5XKT8,,MALLALASAVPSALLALAVFRVPAWACLLCFTTYSERLRICQMFVG...
0,PYDC5_HUMAN,UNP,1,1,1,W6CW81,,MESKYKEILLLTSLDNITDEELDRFKCFLPDEFNIATGKLHTLNST...


In [154]:
sequences_df = pd.concat(series)

In [178]:
sequences_df.to_pickle("~/mids/pss/vectorize/sequences.pkl")

In [156]:
sequences_df.head()

(23391, 8)

## Embedding the sequences with SeqVec and ProtTrans
Distribute with spark

In [163]:
from pyspark.sql.types import *
import pyspark.sql.functions as F
import dask.dataframe as dd
from dask.diagnostics import ProgressBar
from dask.multiprocessing import get as dget
import time

In [74]:
sparkSession = sprk.get_local_spark_session()

21/09/27 00:32:36 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.


In [75]:
sequences_df.dtypes

db_code                     object
db_name                     object
entity_id                   object
id                          object
pdbx_align_begin            object
pdbx_db_accession           object
pdbx_db_isoform             object
pdbx_seq_one_letter_code    object
dtype: object

In [76]:
sequences_spark_df = sparkSession.createDataFrame(
    sequences_df[["db_code", "pdbx_seq_one_letter_code"]]
)

In [78]:
sequences_spark_df.schema

StructType(List(StructField(db_code,StringType,true),StructField(pdbx_seq_one_letter_code,StringType,true)))

In [116]:
sequences_dask = dd.from_pandas(sequences_df, 30)

In [190]:
def seqvectorize(x, seqvec): 
    if isinstance(x, str):
        return seqvec.reduce_per_protein(seqvec.embed(x))
    else:
        return [seqvec.reduce_per_protein(emb) for emb in seqvec.embed_many(x)]
# seqvec_udf = F.udf(seqvectorize, ArrayType(ArrayType(FloatType())))

In [191]:
print("initializing model")
seqvec = SeqVecEmbedder()
vectors = []
batch_size = 64

print(f"creating embedding vectors, batch size {batch_size}")
for i in tqdm(range(0, sequences_df.shape[0], batch_size)):
    start = time.time()
    embs = seqvectorize(list(sequences_df.pdbx_seq_one_letter_code.iloc[i:i+batch_size]), seqvec)
    vectors.extend(embs)
    print(f"{i} to {i+batch_size-1} finished in {time.time()-start//1}")
    

initializing model
creating embedding vectors, batch size 64


  0%|          | 0/366 [00:00<?, ?it/s]

0 to 63 finished in 723.2939350605011


KeyboardInterrupt: 

In [174]:
print(vectors[0])

[-0.00934542  0.03230684 -0.18556052 ... -0.14327957  0.19794334
  0.12167553]


In [167]:
SeqVecEmbedder().embed_many(list(sequences_df.pdbx_seq_one_letter_code.iloc[0:64]))

['MSSHEGGKKKALKQPKKQAKEMDEEEKAFKQKQKEEQKKLEVLKAKVVGKGPLATGGIKKSGKK',
 'MMKFKPNQTRTYDREGFKKRAACLCFRSEQEDEVLLVSSSRYPDQWIVPGGGMEPEEEPGGAAVREVYEEAGVKGKLGRL\nLGIFEQNQDRKHRTYVYVLTVTEILEDWEDSVNIGRKREWFKVEDAIKVLQCHKPVHAEYLEKLKLGCSPANGNSTVPSL\nPDNNALFVTAAQTSGLPSSVR',
 'MERSFVWLSCLDSDSCNLTFRLGEVESHACSPSLLWNLLTQYLPPGAGHILRTYNFPVLSCVSSCHLIGGKMPEN',
 'METVVTTLPREGGVGPSRKMLLLLLLLGPGSGLSAVVSQHPSRVICKSGTSVNIECRSLDFQATTMFWYRQLRKQSLMLM\nATSNEGSEVTYEQGVKKDKFPINHPNLTFSALTVTSAHPEDSSFYICSAR',
 'MEAPAQLLFLLLLWLPDTTREIVMTQSPPTLSLSPGERVTLSCRASQSVSSSYLTWYQQKPGQAPRLLIYGASTRATSIP\nARFSGSGSGTDFTLTISSLQPEDFAVYYCQQDYNLP',
 'MDMRVPAQLLGLLLLWLPGVRFDIQMTQSPSFLSASVGDRVSIICWASEGISSNLAWYLQKPGKSPKLFLYDAKDLHPGV\nSSRFSGRGSGTDFTLTIISLKPEDFAAYYCKQDFSYP',
 'MAWTPLLFLTLLLHCTGSLSQLVLTQSPSASASLGASVKLTCTLSSGHSSYAIAWHQQQPEKGPRYLMKLNSDGSHSKGD\nGIPDRFSGSSSGAERYLTISSLQSEDEADYYCQTWGTGI',
 'MSVPTMAWMMLLLGLLAYGSGVDSQTVVTQEPSFSVSPGGTVTLTCGLSSGSVSTSYYPSWYQQTPGQAPRTLIYSTNTR\nSSGVPDRFSGSILGNKAALTITGAQADDESDYYCVLYMGSGI',
 'MAWTPLLLLFPLLLHCTGSL

In [160]:
sequences_df["seqvec"] = sequences_df.pdbx_seq_one_letter_code.progress_apply(seqvectorize)

  0%|          | 0/23391 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [183]:
# with ProgressBar():
#     res = sequences_dask \
#         .map_partitions(lambda df: df.assign(seqvec=df.pdbx_seq_one_letter_code.apply(seqvectorize)), meta={"pdbx_seq_one_letter_code": str}) \
#         .compute(scheduler='threads')

In [None]:
res.to_parquet('seqvec.parquet')

In [None]:
# sequences_spark_df \
#     .repartition(1000, "db_code") \
#     .withColumn("seqvec", seqvec_udf(F.col("pdbx_seq_one_letter_code"))).show(10)