In [69]:
# !pip install biopython
# !pip install findspark
# !pip install py4j

## Import Lib

In [47]:
from pyspark.sql import *
from pyspark import *
from pyspark.accumulators import AccumulatorParam
from Bio import SeqIO
import networkx as nx

In [48]:
import logging
import numpy as np
import random
import json
import re
import itertools as it
import gc
import sys
import findspark
findspark.init()
from collections import namedtuple
date_strftime_format = '%Y-%m-%y %H:%M:%S'
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(asctime)s %(message)s", datefmt=date_strftime_format)

In [49]:
from pyspark.sql.window import Window
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.dataframe import Column
import random

## Define Parameter

In [50]:
L_MER = 20
BUILD_GRAPH = 'SPARK_MAP' # 'ORIGINAL'
LABEL_GRAPH = 'LABEL_PREGEL' # 'LABEL_PREGEL'
PREGEL = 'NAIVE' # 'CENTRAL_DEGREE' 'NAIVE' 'LPA' 
COMPONENT = 30

### Log file

In [51]:
class Log4j:
    def __init__(self, spark):
        root_class = "guru.learningjournal.spark.examples"
        conf = spark.sparkContext.getConf()
        app_name = conf.get("spark.app.name")
        log4j = spark._jvm.org.apache.log4j
        self.logger = log4j.LogManager.getLogger(root_class + '.' + app_name)
    
    def warn(self, message):
        self.logger.warn(message)
        logging.warn(message)
        
    def info(self, message):
        self.logger.info(message)
        logging.info(message)
    
    def error(self, message):
        self.logger.error(message)
        logging.error(message)
    
    def debug(self, message):
        self.logger.debug(message)
        logging.debug(message)

## Initial SparkSession

In [52]:
spark = SparkSession.builder.appName('Hello Spark').master('local[*]').getOrCreate()
sc = spark.sparkContext
sc.addPyFile('/Users/DELL/.ivy2/jars/graphframes_graphframes-0.8.2-spark3.1-s_2.12.jar')
logger = Log4j(spark)
from graphframes import *


In [53]:
from graphframes import GraphFrame
from graphframes.lib import Pregel

## Ultility Function

In [54]:
def load_meta_reads(filename, type='fasta'):
    try:
        seqs = list(SeqIO.parse(filename, type))

        reads = []
        labels = []

        # Detect for paired-end or single-end reads
        # If the id of two first reads are different (e.g.: .1 and .2), they are paired-end reads
        is_paired_end = False
        if len(seqs) > 2 and seqs[0].id[-1:] != seqs[1].id[-1:]:
            is_paired_end = True

        label_list = dict()
        label_index = 0
        for i in range(0, len(seqs), 2 if is_paired_end else 1):
            read, label = format_read(seqs[i])
            if is_paired_end:
                read2, label2 = format_read(seqs[i + 1])
                read += read2
            reads += [str(read)]
        
            # Create labels
            if label not in label_list:
                label_list[label] = label_index
                label_index += 1
            labels.append(label_list[label])
        
        del seqs
        return reads, labels
    except:
        print('Error when loading file {} '.format(filename))
        return []

In [55]:
def format_read(read):
    # Return sequence and label
    z = re.split('[|={,]+', read.description)
    return read.seq, z[3]

In [56]:
reads, labels = load_meta_reads('data/S1.fna', type='fasta')

In [57]:
class DictParam(AccumulatorParam):
    def zero(self,  value = ""):
        return dict()

    def addInPlace(self, value1, value2):
        for i in value2.keys():
            if i in value1:
                value1[i].append(value2[i])
            else:
                value1[i] = [value2[i]]
        return value1

In [58]:
class DictEdgeParam(AccumulatorParam):
    def zero(self,  value = ""):
        return dict()

    def addInPlace(self, value1, value2):
        for i in value2.keys():
            if i in value1:
                value1[i] += value2[i]
            else:
                value1[i] = value2[i]
        return value1

In [59]:
def build_dict_origin():
    logging.info('Start 1')
    lmers_dict = dict()
    for idx, r in enumerate(reads):
        for j in range(0,len(r)-L_MER+1):
            lmer = r[j:j+L_MER]
            if lmer in lmers_dict:
                lmers_dict[lmer] += [idx]
            else:
                lmers_dict[lmer] = [idx]
    E=dict()
    for lmer in lmers_dict:
        for e in it.combinations(lmers_dict[lmer],2):
            if e[0]!=e[1]:
                e_curr=(e[0],e[1])
                if e_curr in E:
                    E[e_curr] += 1 # Number of connected lines between read a and b
                else:
                    E[e_curr] = 1
    E_Filtered = {kv[0]: kv[1] for kv in E.items() if kv[1] >= 20}
    
    G = nx.Graph()
    print('Adding nodes...')
    color_map = {0: 'red', 1: 'green', 2: 'blue', 3: 'yellow', 4: 'darkcyan', 5: 'violet',
                6: 'black', 7: 'grey', 8: 'sienna', 9: 'wheat', 10: 'olive', 11: 'lightgreen',
                12: 'cyan', 13: 'slategray', 14: 'navy', 15: 'hotpink'}
    for i in range(0, len(labels)):
        G.add_node(i, label=labels[i], color=color_map[labels[i]])

    print('Adding edges...')
    for kv in E_Filtered.items():
        G.add_edge(kv[0][0], kv[0][1], weight=kv[1])
    print('Graph constructed!')
    logging.info('End 1')
    return G
#     print(E_Filtered[(0, 29033)])
#     print(len(E_Filtered.keys()))
#     print(lmers_dict[ATAAATACCTTCATTTAATA])


In [60]:
def build_dict_spark_map(readsRDD, spark):
    logging.info('Start 2')
    def create_lmers_pos(tuple):
        idx, r = tuple
        lmers_dict =list()
        for j in range(0,len(r)-L_MER+1):
            lmer = r[j:j+L_MER]
            lmers_dict.append((lmer, idx))
#         print(lmers_dict)
        return lmers_dict
    def create_edge(x):
        lmer, idx = x
#         print(lmer, idx)
        global edge_dict
        E=dict()
        for e in it.combinations(idx,2):
            if e[0]!=e[1]:
                e_curr=(e[0],e[1])
                if e_curr in E:
                    E[e_curr] += 1 # Number of connected lines between read a and b
                else:
                    E[e_curr] = 1
        edge_dict += E
    logging.info("Building hash table...")
    lmers_dict = readsRDD.map(create_lmers_pos).flatMap(lambda x: [i for i in x]).groupByKey().mapValues(list).filter(lambda x: len(x[1]) > 2)
    logging.info('Build edge ...')
    lmers_dict.coalesce(20).foreach(create_edge)
    # Step 1: map
        # ['asdasd': 2, 'asdasdgg': 3, 'asdasd': 4]; ['asdasd': 2, 'asdasdgg': 3, 'asdasd': 4]
    # Step 3: flat
        # ['asdasd': [2, 4, 90000], 'asdasdgg': 3]
    #     res = readsRDD.map(create_lmers_pos).flatMap(lambda x: [i for i in x]).reduceByKey(lambda x,y: x.append(y)).count()
    global edge_dict
    E = edge_dict.value
    E_Filtered = {kv[0]: kv[1] for kv in E.items() if kv[1] >= 20}
    color_map = {0: 'red', 1: 'green', 2: 'blue', 3: 'yellow', 4: 'darkcyan', 5: 'violet',
            6: 'black', 7: 'grey', 8: 'sienna', 9: 'wheat', 10: 'olive', 11: 'lightgreen',
            12: 'cyan', 13: 'slategray', 14: 'navy', 15: 'hotpink'}
    logging.info('Add nodes ...')
    vertices = spark.createDataFrame([(i, labels[i], color_map[labels[i]]) for i in range(0, len(labels))], ['id', 'colorId', 'color'])
    logging.info('Add edges ...')
    edges_data = []
    for kv in E_Filtered.items():
        edges_data += [(kv[0][0], kv[0][1], kv[1])]
        edges_data += [(kv[0][1], kv[0][0], kv[1])]
    edges = spark.createDataFrame(edges_data, ['src', 'dst', 'numOfLmers'])
    vertices = vertices.persist()
    edges = edges.persist()
    logging.info('Building graph ...')
    g = GraphFrame(vertices, edges)
    return g
    logging.info('End 2')
#     return res


In [61]:
# def build_overlap_graph( reads, spark ):

edge_dict = spark.sparkContext.accumulator({}, DictEdgeParam())

readsRDD = spark.sparkContext.parallelize(enumerate(reads)).repartition(40).cache()

G = build_dict_spark_map(readsRDD,spark) if BUILD_GRAPH == 'SPARK_MAP' else build_dict_origin()

print(G.vertices.count())
print(G.edges.count())


2022-12-22 15:38:11 Start 2
2022-12-22 15:38:11 Building hash table...
2022-12-22 15:38:11 Build edge ...
2022-12-22 15:40:31 Add nodes ...
2022-12-22 15:40:33 Add edges ...
2022-12-22 15:40:50 Building graph ...
96367
703416


In [62]:
edge_dict = spark.sparkContext.accumulator({}, DictEdgeParam())

In [63]:
# edge_dict.aid()
print(edge_dict.value)

{}


In [64]:
LB = G.labelPropagation(50)

In [68]:
GL = LB.groupBy('label').agg(collect_list('id').alias('group')).select('group').collect()

In [69]:
GROUP = [gl[0] for gl in GL]

In [71]:
print(len(GROUP))

8000


In [72]:
with open('sample.txt', 'w') as filehandle:
    for listitem in GROUP:
        filehandle.write('%s\n' % listitem)

In [17]:
logging.info('Connected Components Algorithm ...')
sc.setCheckpointDir("/tmp/graphframes-example-connected-components")
CC = G.connectedComponents()

2022-11-22 12:37:37 Connected Components Algorithm ...


In [18]:
CC.groupBy('component').count().orderBy('count', ascending=False).show()

+---------+-----+
|component|count|
+---------+-----+
|       47| 1600|
|       35| 1183|
|      111|  944|
|       48|  663|
|       40|  660|
|       58|  631|
|       16|  587|
|        5|  572|
|       54|  564|
|       14|  471|
|       24|  431|
|      461|  425|
|      141|  423|
|       70|  417|
|      152|  412|
|       25|  392|
|      125|  389|
|       23|  388|
|        0|  379|
|       30|  371|
+---------+-----+
only showing top 20 rows



In [19]:
vertices = CC.filter(col('component') == COMPONENT).drop('component')

if LABEL_GRAPH == 'LABEL_PREGEL':
    num_vertices = vertices.count()
    if PREGEL == 'CENTRAL_DEGREE':
        out_degrees = G.outDegrees
        vertices = vertices.join(out_degrees, ['id'], 'left').orderBy('outDegree', ascending=False).withColumn('row1', row_number().over(Window.orderBy(col("outDegree").desc())))
        initialVertice = spark.createDataFrame([[sample] for sample in range(1, int(num_vertices/10) + 1)], ['initialID'])
        vertices = vertices.join(initialVertice, vertices['row1'] == initialVertice['initialID'], 'left').drop('row1')
    if PREGEL == 'NAIVE':
        vertices = vertices.withColumn('row1', row_number().over(Window.orderBy("id")))
        initialVertice = spark.createDataFrame([[sample] for sample in random.sample(range(1, num_vertices), int(num_vertices/50) + 1)], ['initialVertice']).withColumn('initialId', row_number().over(Window.orderBy("initialVertice")))
        vertices = vertices.join(initialVertice, vertices['row1'] == initialVertice['initialVertice'], 'left').drop('row1','initialVertice')
    vertices.cache()
    vertices.orderBy('initialID', ascending=False).show()

vertices.toPandas().to_json('vertices.json', orient='records')
    
edges = G.edges.join(vertices,G.edges.src ==  vertices.id,"leftsemi")
edges.cache()
edges.toPandas().to_json('edges.json', orient='records')

subGraph = GraphFrame(vertices, edges)
# if LABEL_GRAPH == 'LABEL_SPARK_FUNCTION':
    

+-----+-------+-----+---------+
|   id|colorId|color|initialId|
+-----+-------+-----+---------+
|40795|      0|  red|        8|
|27321|      0|  red|        7|
|23531|      0|  red|        6|
|16852|      0|  red|        5|
|13484|      0|  red|        4|
|11086|      0|  red|        3|
| 7323|      0|  red|        2|
| 4385|      0|  red|        1|
|   30|      0|  red|     null|
|   31|      0|  red|     null|
|  679|      0|  red|     null|
| 1156|      0|  red|     null|
|  541|      0|  red|     null|
|  923|      0|  red|     null|
|  957|      0|  red|     null|
| 1277|      0|  red|     null|
| 1294|      0|  red|     null|
| 1769|      0|  red|     null|
| 1110|      0|  red|     null|
| 1588|      0|  red|     null|
+-----+-------+-----+---------+
only showing top 20 rows



In [20]:
edges.show()
print(edges.count())
vertices.show()
print(vertices.count())

+-----+-----+----------+
|  src|  dst|numOfLmers|
+-----+-----+----------+
|11913|   30|        34|
|   30|11913|        34|
|11913|14452|        58|
|14452|11913|        58|
|11913| 5044|        25|
| 5044|11913|        25|
|   30|14452|        38|
|14452|   30|        38|
|   30| 5044|        54|
| 5044|   30|        54|
|14452| 5044|        29|
| 5044|14452|        29|
|11913|30305|        56|
|30305|11913|        56|
|11913|34464|        34|
|34464|11913|        34|
|   30|30305|        29|
|30305|   30|        29|
|   30|34464|        21|
|34464|   30|        21|
+-----+-----+----------+
only showing top 20 rows

2572
+----+-------+-----+---------+
|  id|colorId|color|initialId|
+----+-------+-----+---------+
|  30|      0|  red|     null|
|  31|      0|  red|     null|
| 541|      0|  red|     null|
| 679|      0|  red|     null|
| 923|      0|  red|     null|
| 957|      0|  red|     null|
|1110|      0|  red|     null|
|1156|      0|  red|     null|
|1277|      0|  red|     nul

## Pregel Function

In [21]:
def vertexProgram(msg, id):
    print('Msg', msg, 'ID', id, sep=', ')
    return msg
vertexProgramUdf = udf(vertexProgram)

def sendMsgToDst(srcID, dstID, src, dst):
    print('Send', 'src', src, 'dst', dst, srcID, dstID, sep = ', ')
    if srcID == None or dstID != None:
        return None
    else:
        return srcID
sendMsgToDstUdf = udf(sendMsgToDst)

def most_frequent(List):
    dict = {}
    count, itm = 0, ''
    for item in reversed(List):
        dict[item] = dict.get(item, 0) + 1
        if dict[item] >= count :
            count, itm = dict[item], item
    return itm

def aggMsgs(msg, id):
    print('Agg', id, msg,  sep = ', ')

    res = most_frequent(msg)
    dictionary = {
        "id": id,
        "label": res,
    }

    json_object = json.dumps(dictionary, indent=4)

    with open("sample.json", "a") as outfile:
        outfile.write(json_object)
        outfile.write(',\n')
    return res
aggMsgsUdf = udf(aggMsgs)

In [24]:
if LABEL_GRAPH == 'LABEL_SPARK_FUNCTION':
    LB = subGraph.labelPropagation(50)#
    LB.drop('colorId', 'color').toPandas().to_json('lb_function.json', orient='records')
if LABEL_GRAPH == 'LABEL_PREGEL':
    with open("sample.json", "w") as outfile:
        outfile.write('[')
    
    sc.setCheckpointDir("/tmp/graphframes-example-connected-components")
    LB = subGraph.pregel \
        .withVertexColumn("partitionID", col('initialID'), coalesce(Pregel.msg(), col('partitionID'))) \
        .sendMsgToDst(when(Pregel.dst('partitionID').isNull() & Pregel.src('partitionID').isNotNull(), Pregel.src('partitionID'))) \
        .aggMsgs(aggMsgsUdf(collect_list(Pregel.msg()),col('id')))  \
        .run()
    
    with open("sample.json", "a") as outfile:
        outfile.write(']')
#.setMaxIter(10) \
# .sendMsgToDst(sendMsgToDstUdf(Pregel.src('partitionID'),Pregel.dst('partitionID'),Pregel.src('id'),Pregel.dst('id'))) \

In [24]:
# sc.stop()

## Test labelPropagation

In [25]:
count_group = LB.groupBy('partitionID').count().orderBy('count', ascending=False)
count_group.show()
print(count_group.count())

+-----------+-----+
|partitionID|count|
+-----------+-----+
|          3|   81|
|          5|   67|
|          1|   52|
|          6|   49|
|          4|   43|
|          8|   40|
|          2|   21|
|          7|   18|
+-----------+-----+

8


In [26]:
LB.show()#.drop('color', 'colorId').toPandas().to_json('label.json', orient='records')

+-----+-------+-----+-----+
|   id|colorId|color|label|
+-----+-------+-----+-----+
| 3091|      0|  red|10912|
|  541|      0|  red| 2712|
| 1277|      0|  red|22217|
|25759|      0|  red|22217|
|31111|      0|  red|29179|
| 2030|      0|  red| 2030|
| 3452|      0|  red|29292|
|36890|      0|  red|41458|
|11638|      0|  red|29179|
|31212|      0|  red|35462|
|12860|      0|  red|15470|
|12889|      0|  red| 2712|
|14043|      0|  red|36563|
|33371|      0|  red|30097|
|11913|      0|  red| 6310|
|17209|      0|  red|36563|
|20854|      0|  red|29292|
| 2241|      0|  red| 6310|
|23358|      0|  red| 4635|
|10967|      0|  red|10975|
+-----+-------+-----+-----+
only showing top 20 rows



In [None]:
LB.orderBy('initialID', ascending=False).show()

In [None]:
LB.groupby('partitionID').count().sort(desc("partitionID")).show()

In [None]:
# print(type(lmers_dict_3))

In [None]:
LB.groupBy('label').count().orderBy('count', ascending=False).show()
LB.filter(col('label') == '34101').show()
LB.select(countDistinct("label")).show()
CC.groupBy('component').count().orderBy('count', ascending=False).write.csv("componnent")

In [None]:
# print(dict_test['ATAAATACCTTCATTTAATA'])

In [None]:
# if __name__ == '__main__':
# #     sg = spark.createDataFrame
#     edge_dict = spark.sparkContext.accumulator({}, DictEdgeParam())
# #     lmers_dict_3 = sc.accumulator({}, DictParam())
# #     lmers_dict_4 = sc.accumulator({}, DictParam())
# #     lmers_dict = sc.accumulator({}, DictParam())
# #     gc.collect()
#     logger.info('Start')
# #     dict_test = build_overlap_graph(reads, sc)
#     build_overlap_graph(reads, spark)
#     logger.info('End')

In [None]:
def build_dict_spark_foreach(readsRDD):
    logging.info('Start 3')
    def create_dict_foreach(tuple):
        idx, r = tuple
        global lmers_dict_3
        for j in range(0,len(r)-L_MER+1):
            lmer = r[j:j+L_MER]
            lmers_dict_3 += {lmer: idx}
    readsRDD.foreach(create_dict_foreach)
    global lmers_dict_3
    res = lmers_dict_3.value
#     print(res['ATAATTGGCAAGTGTTTTAG'])
    print(len(res.keys()))
    logging.info('End 3')
    return res

In [None]:
def build_dict_spark_mapPartition(readsRDD):
    logging.info('Start 4')
    def create_dict_mapPartition(partitionData):
        lmers_dict = dict()
        for idx, r in [*partitionData]:
            for j in range(0,len(r)-L_MER+1):
                lmer = r[j:j+L_MER]
                if lmer in lmers_dict:
                    lmers_dict[lmer] += [idx]
                else:
                    lmers_dict[lmer] = [idx]
        yield lmers_dict
    def merge_dict(x,y):
        for i in y.keys():
            if i in x:
                x[i] += y[i] 
            else:
                x[i] = y[i]
        return x

    lmers_dict = readsRDD.mapPartitions(create_dict_mapPartition).reduce(lambda x, y: merge_dict(x,y))
    logging.warning('Processing 1')
    E=dict()
    for lmer in lmers_dict:
        for e in it.combinations(lmers_dict[lmer],2):
            if e[0]!=e[1]:
                e_curr=(e[0],e[1])
            if e_curr in E:
                E[e_curr] += 1
            else:
                E[e_curr] = 1
    E_Filtered = {kv[0]: kv[1] for kv in E.items() if kv[1] >= 20}
    
#     print(E_Filtered[(0, 29033)])
    print(len(E_Filtered.keys()))
    
#     logging.warning('Processing 2')
#     res = readsRDD.mapPartitions(create_dict_mapPartition).collect()
#     count = 0
#     for dict1 in res:
#         count += len(dict1.keys())
#     print(count)
#     logging.warning('End Processing 2')
#     def create_edge(dictionary):
#         E  = dict()
#         for lmer in dictionary:
#             for e in it.combinations(dictionary[lmer],2):
#                 if e[0]!=e[1]:
#                     e_curr=(e[0],e[1])
#                 if e_curr in E:
#                     E[e_curr] += 1 # Number of connected lines between read a and b
#                 else:
#                     E[e_curr] = 1
#         E_Filtered = {kv[0]: kv[1] for kv in E.items() if kv[1] >= 20}
#         return E_Filtered
# #     first = readsRDD.mapPartitions(create_dict_mapPartition).map(create_edge).collect()
# #     print(first)
#     res = readsRDD.mapPartitions(create_dict_mapPartition).map(create_edge).collect()
    
#     count = 0
#     for edge in res:
#         count += len(edge.keys())
#     print(count)
# #     print(res[(49008, 56213)])
    logging.info('End 4')
#     return res?