# Count-Min Sketch
---

The Count-Min (CM) sketch is a probabilistic data structure that provides
a lossy form of compression for large count/frequency datasets.
It is typically used for streaming data. At the heart of the CM sketch
is hashing. The CM sketch uses a set of hash functions with corresponding,
constant size, hash tables. These hash functions are independent from one
another. Since the hash functions are independent, each distributes
data differently within its hash table. This independent hashing redundancy allows
CM sketches to achieve a high degree of lossy compression while still 
producing quality estimates of the original data.

### Internals
---
The core data storage structure within a CM sketch is a $w$ * $d$ table, $\text{count}$. $w$ is given by $w = \left\lceil\frac{e}{\epsilon}\right\rceil$ and d is given by $d = \ln\left(\frac{1}{\delta}\right)$. $\epsilon$ is the additive error factor that a result will be within with probability $1-\delta$.

<img src="./img/cm_internal_table.png" width="400" />

Each row in the table is used as the hash table for one of the $1..d$ hash functions. When we add an event to the sketch, its count is added to each row.

<img src="./img/cm_adding_event.png" width="400" />

### Operations
---
#### Point Query $Q(i)$
A point query is the estimation of $a_i$ from the original data.

<img src="./img/cm_point_q.png" width="400" />
$$Q(i) = \min_j\text{count}[j, h_j(i)]$$

#### Range Query $Q(l, r)$
A range query from $l..r$ is the estimation of the sum over that range.
$$Q(l,r) = \sum_{i=l}^r a_i$$
To accuratly calculate a range query, $log(n)$ sketches must be kept; one for each set of dyadic ranges spanning $1..n$.

#### Inner Product $Q(\boldsymbol{a}, \boldsymbol{b})$
The inner product between two arrays can be estimated using a sketch for each array and taking the minimum row-wise inner product.
$$Q(\boldsymbol{a}, \boldsymbol{b}) = \min_j\sum_{k=1}^w\text{count}_a[j, k]*\text{count}_b[j, k]$$

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import seaborn as sns # visualizations
import os
import glob
import time
import sys
import mmh3
# "pip install mmh3" should do it

Here we will look at the space/time trade offs of a min-sketch versus a more naiive implementation. In the class definitions, fill in the TODO's.

In [None]:
# naively count frequencies

class dictionary():
    
    def __init__(self):
        self.dictionary = {}
        self.nbytes = sys.getsizeof(self)
    
    def getsize(self):
        print("Dictionary is Size: {} Bytes\n".format(self.nbytes))
        
    def add(self,token):
        if token in self.dictionary:
            self.dictionary[token] += 1
        else:
            self.dictionary[token] = 1
        self.nbytes = sys.getsizeof(self.dictionary)
        
    def timed_update(self,tokenlist):
        startsize = self.nbytes
        start = time.time()
        for token in tokenlist:
            self.add(token)
        end = time.time() - start
        dsize = self.nbytes - startsize
        print("Time Elapsed: {} Seconds \n".format(end))
        print("Change In Memory: {} Bytes\n".format(dsize))
    
    def estimate(self,token):
        try:
            return self.dictionary[token]
        except:
            print("Error: Token Not Found \n")

In [None]:

class CountMinSketch():
    
    def __init__(self,seqlist=None,indexes=2**6,hashfuncs=2**6):
        self.N = indexes
        self.M = hashfuncs
        self.seeds = np.arange(hashfuncs).tolist()
        self.table = np.zeros((self.M,self.N))
        self.hashes = [self._genhash(seed) for seed in self.seeds]
        self.nbytes = sys.getsizeof(self.table) + sys.getsizeof(self.hashes)
        if seqlist is not None:
            for value in seqlist:
                self.add(value)
        
    def _genhash(self,seed):
        def hash_fn(val):
            index = mmh3.hash(val,seed=seed)
            return index%self.N
        return hash_fn

    def getsize(self):
        print("Sketch is Size: {} Bytes\n".format(self.nbytes))
        
    def add(self, value):      
        for ix in range(0, self.M):
            column = self.hashes[ix](value)
            self.table[ix,column] +=1

            
    def timed_update(self,valuelist):
        start = time.time()
        for value in valuelist:
            self.add(value)
        end = time.time() - start
        dsize = sys.getsizeof(self.table) + sys.getsizeof(self.hashes)
        print("Time Elapsed: {} Seconds \n".format(end))
        print("Memory Useage: {} Bytes\n".format(dsize))
                              
    
    def estimate(self, value):
        # Implement a point query from the sketch (see figure above if lost)
        results = []
        for ix in range(0, M):
            column = self.hashes[ix](value)
            results.append(self.table[ix,column])
        return np.min(results)

We have taken genomes from the fruit fly (Drosophilia Melanogaster) and from a human chromosome and from them created a list of 16 character sequences so as to simulate "words" (Not biologically accurate but it serves our purposes here). Upload these lists of words and store them in your naiive implementation and your CMS implementation and evaluate their performance.

In [None]:
fruitfly = np.load("genomedata/fruitfly.npy")
human = np.load("genomedata/human.npy")

In [None]:
print(fruitfly[:5])

In [None]:
genomedict = dictionary()
genomedict.timed_update(human)

In [None]:
genomesketch = CountMinSketch()
genomesketch.timed_update(human)

In [None]:
fruitflydict = dictionary()
fruitflydict.timed_update(fruitfly)

In [None]:
fruitflysketch = CountMinSketch()
fruitflysketch.timed_update(fruitfly)

Reflections:

1) Compare and contrast the time/space performance between the naiive implementations of both the fruit fly and human chromosome sequence

2) Compare and contrast the time/space performance between the CMS implementations of both the fruit fly and human chromosome sequence

3) When might the CMS be a more prudent tool than something more basic?

4) Tweak the CMS size (in the init method) to something bigger/smaller (i.e. 2^2,2^9). What does this do to runtime/memory useage?

## Kmer Count Similarity
---

As it turns out, we can use the inner product between two sketches to estimate similiarity.

The inner product between two vectors, $\boldsymbol{a}$ and $\boldsymbol{b}$ is given by

$$\boldsymbol{a} \cdot \boldsymbol{b} = \|\boldsymbol{a}\|\|\boldsymbol{b}\|\cos{\theta},$$

where $\theta$ is the angle between the two vectors. As $\theta$ increases to 90,
the dot product decreases. It is maximized when $\theta$ is at or near 0. We can use
this to determine how similar two vectors are.

Kmer counts can be viewed as vectors with a large number of dimensions and so dot products can be used
to determine similarity. Below we will perform classification of smaller kmer counts by
comparing them against larger genome sequence counts using the inner product.
Instead of keeping massive count arrays in memory, we will
sketch them. Remember that the inner product between two sketches is given by

$$Q(\boldsymbol{a}, \boldsymbol{b}) = \min_j\sum_{k=1}^w\text{count}_a[j, k]*\text{count}_b[j, k].$$

Here we implement this in BinaryClassification() to see if we can determine whether a sub-sketch (a sketch formed from a portion of a given sequence) came from one "complete" sketch or the other. Fill in the blanks and try it out below:

In [None]:
class BinaryClassification():
    # Sketches A and B are the comparison sketches we are trying to compare similarity too.
    
    def __init__(self,A,B,update=False):
        self.sketchA = A
        self.sketchB = B
        
    def _dotProduct(self,tableA,tableB):
        return (tableA * tableB).sum(axis=1).min()
        
    def classify(self,subsketch):
        x = self._dotProduct(self.sketchA.table,subsketch.table)
        y = self._dotProduct(self.sketchB.table,subsketch.table)
        
        if x>y:
            print("Subsketch is of class A")
        else:
            print("Subsketch is of class B")

In [None]:
# Classify these sketches with the BinaryClassification object
fly_or_human = BinaryClassification(genomesketch,fruitflysketch)

subhuman = CountMinSketch(seqlist=human[::4])
subfly =  CountMinSketch(seqlist=fruitfly[::4])

fly_or_human.classify(subhuman)
fly_or_human.classify(subfly)

If you have conviced yourself that a CMS in conjunction with the inner product can serve to create
a supervised learning algorithm, then proceed to classify the mystery sequences  with the template
class object below. We have provided "sequences" containing words/kmers of length 16 from the Atlantic
Cod, Fruit Fly, Garter Snake, Human, Nematode, a simulated genome, and Yeast. The answers are in
the text file provided.

<img src="./img/dros_fruit_fly.jpg" width="400" />
<img src="./img/namethatpok.jpg" width="400" />


In [None]:
class classifier():
    def __init__(self,sketchset,labels):
        self.sketchset = sketchset
        self.labels = labels
        
    def _dotProduct(self,tableA,tableB):
        return (tableA * tableB).sum(axis=1).min()
        
    def compare(self,subsketch):
        dotproduct = []
        for sketch in self.sketchset:
            dotproduct.append(self._dotProduct(sketch.table,subsketch))
        return self.labels[np.argmax(dotproduct)]

In [None]:
# Note that you will need to upload this with numpy and instatiate CMS instances first
files = [
    "genomedata/atlanticcod.npy",
    "genomedata/fruitfly.npy",
    "genomedata/gartersnake.npy",
    "genomedata/human.npy",
    "genomedata/nematode.npy",
    "genomedata/yeast.npy",
]
labels = [
    "Atlantic Cod",
    "Fruit Fly",
    "Garter Snake",
    "Human",
    "Nematode",
    "Yeast",
]
mysteryfiles = [
    "genomedata/mystery1.npy",
    "genomedata/mystery2.npy",
    "genomedata/mystery3.npy",
    "genomedata/mystery4.npy",
    "genomedata/mystery5.npy",
    "genomedata/mystery6.npy",
    "genomedata/mystery8.npy",
]

In [None]:
sketchset = [CountMinSketch(seqlist=np.load(file),indexes=2**6,hashfuncs=2**6) for file in files]
mysteryset = [CountMinSketch(seqlist=np.load(file),indexes=2**6,hashfuncs=2**6) for file in mysteryfiles]

In [None]:
mysteryfile = classifier(sketchset,labels)
for mysterysketch in mysteryset:
    print(mysteryfile.compare(mysterysketch.table))