In [1]:
from pyspark import AccumulatorParam
from pyspark import SparkContext
from pyspark.sql import SparkSession

In [9]:
import hashlib
import numpy as np

In [10]:
#global accumulator
class MatrixAccumulatorParam(AccumulatorParam):
    def zero(self, value):
        return np.zeros(value.shape)
    def addInPlace(self,val1,val2):
        val1+= val2 
        return val1

In [11]:
#hash function   
def hash_Md5(item,depth,width):
    res = []
    md5 = hashlib.md5(str(hash(item)).encode('utf-8'))
    for i in range(depth):
        md5.update(str(i).encode('utf-8'))
        hashcode = int(md5.hexdigest(), 16) % width
        res.append([i,hashcode])
    return res

In [12]:
class CountMinSKetch():
    def __init__(self, delta, epsilon):
        
        if delta <= 0 or delta >= 1:
            raise ValueError("delta must be between 0 and 1")
        if epsilon <= 0 or epsilon >= 1:
            raise ValueError("epsilon must be between 0 and 1")
        
        self.__name__ = "CountMinSketch"     
        self.width = int(np.ceil(np.exp(1) / epsilon))
        self.depth = int(np.ceil(np.log(1 / delta)))
        self.count_table = sc.accumulator(np.zeros((self.depth, self.width)),MatrixAccumulatorParam())
        
    def getWidth(self):
        return self.width
    
    def getDepth(self):
        return self.depth
    
    def getCount(self):
        return self.count_table
    
    def increment(self, file):
        
        depth = self.depth
        width = self.width
        count_table = self.count_table
        
        def addWord(item, depth, width):
            hash_res = hash_Md5(item, depth, width)
            for el in hash_res:
                add_table = np.zeros((depth, width))
                hash_num = el[0]
                bucket = el[1]
                add_table[hash_num][bucket] = 1
                nonlocal count_table
                count_table += add_table
    
        lines = sc.textFile(file).flatMap(lambda line: line.split(" "))\
                  .foreach(lambda word: addWord(word, depth, width))
      
    
    def merge(self, cms):
        if self.__name__ != cms.__name__:
            raise Exception("Unable to merge!")
        
        if self.depth != cms.depth:
            raise Exception("Two count-min sketches need the same number of hash functions!")
            
        if self.width != cms.width:
            raise Exception("Two count-min sketches need the same number of buckets for each hash function!")
            
        self.count_table += cms.count_table.value
        
            
    def estimator(self, file):
        
        width = self.width
        depth = self.depth
        count_table = self.count_table.value
        
        def get_word_estimation(word, depth, width, count_table):
            hash_value = hash_Md5(word, depth, width)
            word_count = []
            for coordinate in hash_value:
                count = count_table[coordinate[0]][coordinate[1]]
                word_count.append([word, count])
            return word_count
        
        lines = sc.textFile(file)
        estimate = lines.flatMap(lambda line: line.split(" "))\
                        .flatMap(lambda word:get_word_estimation(word,depth,width,count_table,))\
                        .reduceByKey(min)
                        #.foreach(lambda word: get_word_estimation(word, depth, width, count_table))
        
        return estimate.collect()
        

In [13]:
# User defined input delta and epsilon here:
# the file path should be changed accordingly
delta =0.07
epsilon = 0.03
incrementFile = 'file:///usr/local/spark/README.md'

cms = CountMinSKetch(delta,epsilon) 
print("The number of hash functions are:", cms.getDepth())
print("The number of buckets for each hash function are:", cms.getWidth())
print("\n")

cms.increment(incrementFile)
print("Here is the count table after increment:")
print(cms.getCount())

The number of hash functions are: 3
The number of buckets for each hash function are: 91


Here is the count table after increment:
[[  3.   1.   4.  10.   5.   9.   1.   4.   4.   6.   4.   6.   1.   5.
    7.   6.  15.   2.   6.   2.   2.   3.   2.   0.  13.   2.   2.   3.
   15.   6.   2.  13.   4.   1.  72.   2.  15.   1.   2.   3.   3.   5.
    3.   9.   1.  10.   6.   4.   6.   2.   0.  14.   4.  10.   6.   3.
    4.   4.   2.   4.   4.   4.   6.   2.   4.  15.  22.   5.   3.   2.
    1.   6.   5.   9.   4.   5.   2.   1.  25.   2.   1.   4.   1.   2.
    6.   9.   5.   4.   8.   5.   4.]
 [  2.   4.   2.   5.  11.   5.  17.   3.   2.   7.   4.  10.   4.   1.
    5.   4.   9.   1.   3.   3.   1.   1.   4.   6.   8.   3.   6.   2.
   11.   4.   4.   5.   9.   4.   2.   3.   7.   3.   3.   2.   2.   4.
    5.   2.   1.   1.   7.   9.  10.   0.   3.   4.   0.   4.   2.  25.
    2.   4.   3.   6.   0.   6.   0.   3.   1.   4.  16.   5.   7.   8.
    1.   4.   6.   2.   1.   6.   2.  

In [14]:
print("Now we merge the same count table by loading the same file.\n The count table should be doubled")

cms2 = CountMinSKetch(delta,epsilon)
cms2.increment(incrementFile)
cms.merge(cms2)
print(cms.getCount())

Now we merge the same count table by loading the same file.
 The count table should be doubled
[[   6.    2.    8.   20.   10.   18.    2.    8.    8.   12.    8.   12.
     2.   10.   14.   12.   30.    4.   12.    4.    4.    6.    4.    0.
    26.    4.    4.    6.   30.   12.    4.   26.    8.    2.  144.    4.
    30.    2.    4.    6.    6.   10.    6.   18.    2.   20.   12.    8.
    12.    4.    0.   28.    8.   20.   12.    6.    8.    8.    4.    8.
     8.    8.   12.    4.    8.   30.   44.   10.    6.    4.    2.   12.
    10.   18.    8.   10.    4.    2.   50.    4.    2.    8.    2.    4.
    12.   18.   10.    8.   16.   10.    8.]
 [   4.    8.    4.   10.   22.   10.   34.    6.    4.   14.    8.   20.
     8.    2.   10.    8.   18.    2.    6.    6.    2.    2.    8.   12.
    16.    6.   12.    4.   22.    8.    8.   10.   18.    8.    4.    6.
    14.    6.    6.    4.    4.    8.   10.    4.    2.    2.   14.   18.
    20.    0.    6.    8.    0.    8.    4.   

In [15]:
print("Now estimate the times of appearance for each word in given test file:\n")
estimateFile = 'file:///usr/local/spark/test.txt'    
res = cms.estimator(estimateFile)
print(res)   

Now estimate the times of appearance for each word in given test file:

[('implementation', 0.0), ('this', 6.0), ('is', 12.0), ('python.', 8.0), ('min', 2.0), ('for', 26.0), ('test', 2.0), ('spark', 8.0), ('file', 12.0), ('sketch', 0.0), ('count', 8.0), ('on', 14.0), ('using', 14.0), ('hello', 2.0)]
