In [1]:
# !pip install biopython
# !pip install findspark
# !pip install py4j

In [2]:
from pyspark.sql import *
from pyspark import *
from pyspark.accumulators import AccumulatorParam
from Bio import SeqIO
import networkx as nx

In [3]:
import logging
import re
import itertools as it
import gc
import sys
import findspark
findspark.init()
from collections import namedtuple
date_strftime_format = '%Y-%m-%y %H:%M:%S'
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(asctime)s %(message)s", datefmt=date_strftime_format)

In [4]:
class Log4j:
    def __init__(self, spark):
        root_class = "guru.learningjournal.spark.examples"
        conf = spark.sparkContext.getConf()
        app_name = conf.get("spark.app.name")
        log4j = spark._jvm.org.apache.log4j
        self.logger = log4j.LogManager.getLogger(root_class + '.' + app_name)
    
    def warn(self, message):
        self.logger.warn(message)
        logging.warn(message)
        
    def info(self, message):
        self.logger.info(message)
        logging.info(message)
    
    def error(self, message):
        self.logger.error(message)
        logging.error(message)
    
    def debug(self, message):
        self.logger.debug(message)
        logging.debug(message)

In [5]:
def format_read(read):
    # Return sequence and label
    z = re.split('[|={,]+', read.description)
    return read.seq, z[3]

In [30]:
def load_meta_reads(filename, type='fasta'):
    try:
        seqs = list(SeqIO.parse(filename, type))

        reads = []
        labels = []

        # Detect for paired-end or single-end reads
        # If the id of two first reads are different (e.g.: .1 and .2), they are paired-end reads
        is_paired_end = False
        if len(seqs) > 2 and seqs[0].id[-1:] != seqs[1].id[-1:]:
            is_paired_end = True

        label_list = dict()
        label_index = 0
        for i in range(0, len(seqs), 2 if is_paired_end else 1):
            read, label = format_read(seqs[i])
            if is_paired_end:
                read2, label2 = format_read(seqs[i + 1])
                read += read2
            reads += [str(read)]
        
            # Create labels
            if label not in label_list:
                print(label)
                label_list[label] = label_index
                label_index += 1
            labels.append(label_list[label])
        
        del seqs
        return reads, labels
    except:
        print('Error when loading file {} '.format(filename))
        return []

In [32]:
reads, labels = load_meta_reads('data/S1.fna', type='fasta')

325989358
344204770


In [9]:
L_MER = 20

In [10]:
class DictParam(AccumulatorParam):
    def zero(self,  value = ""):
        return dict()

    def addInPlace(self, value1, value2):
        for i in value2.keys():
            if i in value1:
                value1[i].append(value2[i])
            else:
                value1[i] = [value2[i]]
        return value1

In [11]:
class DictEdgeParam(AccumulatorParam):
    def zero(self,  value = ""):
        return dict()

    def addInPlace(self, value1, value2):
        for i in value2.keys():
            if i in value1:
                value1[i] += value2[i]
            else:
                value1[i] = value2[i]
        return value1

In [12]:
def build_dict_origin():
    logging.info('Start 1')
    lmers_dict = dict()
    for idx, r in enumerate(reads):
        for j in range(0,len(r)-L_MER+1):
            lmer = r[j:j+L_MER]
            if lmer in lmers_dict:
                lmers_dict[lmer] += [idx]
            else:
                lmers_dict[lmer] = [idx]
    E=dict()
    for lmer in lmers_dict:
        for e in it.combinations(lmers_dict[lmer],2):
            if e[0]!=e[1]:
                e_curr=(e[0],e[1])
                if e_curr in E:
                    E[e_curr] += 1 # Number of connected lines between read a and b
                else:
                    E[e_curr] = 1
    E_Filtered = {kv[0]: kv[1] for kv in E.items() if kv[1] >= 20}
    
    G = nx.Graph()
    print('Adding nodes...')
    color_map = {0: 'red', 1: 'green', 2: 'blue', 3: 'yellow', 4: 'darkcyan', 5: 'violet',
                6: 'black', 7: 'grey', 8: 'sienna', 9: 'wheat', 10: 'olive', 11: 'lightgreen',
                12: 'cyan', 13: 'slategray', 14: 'navy', 15: 'hotpink'}
    for i in range(0, len(labels)):
        G.add_node(i, label=labels[i], color=color_map[labels[i]])

    print('Adding edges...')
    for kv in E_Filtered.items():
        G.add_edge(kv[0][0], kv[0][1], weight=kv[1])
    print('Graph constructed!')
    logging.info('End 1')
    return G
#     print(E_Filtered[(0, 29033)])
#     print(len(E_Filtered.keys()))
#     print(lmers_dict[ATAAATACCTTCATTTAATA])


In [14]:
def build_dict_spark_map(readsRDD, spark):
    logging.info('Start 2')
    def create_lmers_pos(tuple):
        idx, r = tuple
        lmers_dict = list()
        for j in range(0,len(r)-L_MER+1):
            lmer = r[j:j+L_MER]
            lmers_dict.append((lmer, idx))
#         print(lmers_dict)
        return lmers_dict
    def create_edge(x):
        lmer, idx = x
#         print(lmer, idx)
        global edge_dict
        E=dict()
        for e in it.combinations(idx,2):
            if e[0]!=e[1]:
                e_curr=(e[0],e[1])
                if e_curr in E:
                    E[e_curr] += 1 # Number of connected lines between read a and b
                else:
                    E[e_curr] = 1
        edge_dict += E
    readsRDD.map(create_lmers_pos).flatMap(lambda x: [i for i in x]).groupByKey().mapValues(list).filter(lambda x: len(x[1]) > 2).coalesce(20).foreach(create_edge)
#     res = readsRDD.map(create_lmers_pos).flatMap(lambda x: [i for i in x]).reduceByKey(lambda x,y: x.append(y)).count()
    global edge_dict
    E = edge_dict.value
    E_Filtered = {kv[0]: kv[1] for kv in E.items() if kv[1] >= 20}
    print(len(E_Filtered.keys()))
    color_map = {0: 'red', 1: 'green', 2: 'blue', 3: 'yellow', 4: 'darkcyan', 5: 'violet',
            6: 'black', 7: 'grey', 8: 'sienna', 9: 'wheat', 10: 'olive', 11: 'lightgreen',
            12: 'cyan', 13: 'slategray', 14: 'navy', 15: 'hotpink'}
    columns = ["id","labelId","labelColor"]
    vertices = spark.createDataFrame([(i, labels[i], color_map[labels[i]]) for i in range(0, len(labels))]).toDF(columns)
    vertices.groupBy("labelId").show()
    logging.info('End 2')
#     return res


In [15]:
def build_dict_spark_foreach(readsRDD):
    logging.info('Start 3')
    def create_dict_foreach(tuple):
        idx, r = tuple
        global lmers_dict_3
        for j in range(0,len(r)-L_MER+1):
            lmer = r[j:j+L_MER]
            lmers_dict_3 += {lmer: idx}
    readsRDD.foreach(create_dict_foreach)
    global lmers_dict_3
    res = lmers_dict_3.value
#     print(res['ATAATTGGCAAGTGTTTTAG'])
    print(len(res.keys()))
    logging.info('End 3')
    return res

In [16]:
def build_dict_spark_mapPartition(readsRDD):
    logging.info('Start 4')
    def create_dict_mapPartition(partitionData):
        lmers_dict = dict()
        for idx, r in [*partitionData]:
            for j in range(0,len(r)-L_MER+1):
                lmer = r[j:j+L_MER]
                if lmer in lmers_dict:
                    lmers_dict[lmer] += [idx]
                else:
                    lmers_dict[lmer] = [idx]
        yield lmers_dict
    def merge_dict(x,y):
        for i in y.keys():
            if i in x:
                x[i] += y[i] 
            else:
                x[i] = y[i]
        return x

    lmers_dict = readsRDD.mapPartitions(create_dict_mapPartition).reduce(lambda x, y: merge_dict(x,y))
    logging.warning('Processing 1')
    E=dict()
    for lmer in lmers_dict:
        for e in it.combinations(lmers_dict[lmer],2):
            if e[0]!=e[1]:
                e_curr=(e[0],e[1])
            if e_curr in E:
                E[e_curr] += 1
            else:
                E[e_curr] = 1
    E_Filtered = {kv[0]: kv[1] for kv in E.items() if kv[1] >= 20}
    
#     print(E_Filtered[(0, 29033)])
    print(len(E_Filtered.keys()))
    
#     logging.warning('Processing 2')
#     res = readsRDD.mapPartitions(create_dict_mapPartition).collect()
#     count = 0
#     for dict1 in res:
#         count += len(dict1.keys())
#     print(count)
#     logging.warning('End Processing 2')
#     def create_edge(dictionary):
#         E  = dict()
#         for lmer in dictionary:
#             for e in it.combinations(dictionary[lmer],2):
#                 if e[0]!=e[1]:
#                     e_curr=(e[0],e[1])
#                 if e_curr in E:
#                     E[e_curr] += 1 # Number of connected lines between read a and b
#                 else:
#                     E[e_curr] = 1
#         E_Filtered = {kv[0]: kv[1] for kv in E.items() if kv[1] >= 20}
#         return E_Filtered
# #     first = readsRDD.mapPartitions(create_dict_mapPartition).map(create_edge).collect()
# #     print(first)
#     res = readsRDD.mapPartitions(create_dict_mapPartition).map(create_edge).collect()
    
#     count = 0
#     for edge in res:
#         count += len(edge.keys())
#     print(count)
# #     print(res[(49008, 56213)])
    logging.info('End 4')
#     return res?

In [20]:
def build_overlap_graph( reads, spark ):
    # Create hash table with q-mers are keys
    print("Building hash table...")

    readsRDD = spark.sparkContext.parallelize(enumerate(reads)).repartition(40).cache()
    
#     build_dict_origin()
    build_dict_spark_map(readsRDD,spark)
#     lmers_dict_3 = build_dict_spark_foreach(readsRDD)
#     build_dict_spark_mapPartition(readsRDD)


In [33]:
if __name__ == '__main__':
    spark = SparkSession.builder.appName('Hello Spark').master('local[*]').getOrCreate()
    logger = Log4j(spark)
    color_map = {0: 'red', 1: 'green', 2: 'blue', 3: 'yellow', 4: 'darkcyan', 5: 'violet',
            6: 'black', 7: 'grey', 8: 'sienna', 9: 'wheat', 10: 'olive', 11: 'lightgreen',
            12: 'cyan', 13: 'slategray', 14: 'navy', 15: 'hotpink'}
    vertices = spark.createDataFrame([(i, labels[i], color_map[labels[i]]) for i in range(0, len(labels))]).toDF("id","labelId","labelColor")
    vertices.groupBy("labelId").count().show()
    for i in range(0, len(labels)-1):
        if labels[i] != labels[i+1]:
            print(labels[i])
#     sg = spark.createDataFrame
#     edge_dict = spark.sparkContext.accumulator({}, DictEdgeParam())
#     lmers_dict_3 = sc.accumulator({}, DictParam())
#     lmers_dict_4 = sc.accumulator({}, DictParam())
#     lmers_dict = sc.accumulator({}, DictParam())
#     gc.collect()
    logger.info('Start')
#     dict_test = build_overlap_graph(reads, sc)
#     build_overlap_graph(reads, spark)
#     gc.collect()
    logger.info('End')

+-------+-----+
|labelId|count|
+-------+-----+
|      0|44405|
|      1|51962|
+-------+-----+

0
0
2022-11-22 16:29:26 Start
2022-11-22 16:29:26 End


In [None]:
# print(type(lmers_dict_3))

In [None]:
# print(dict_test['ATAAATACCTTCATTTAATA'])

In [None]:
# sc.stop()