# Minhash LSH for Set Similarity

The Minhash-LSH algorithm is very suitable for finding near-duplicates in an efficient manner. This notebook will outline the main concepts and considerations related to the algorithm.

It was first introduced in the paper [Min-wise independent permutations](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.121.8215&rep=rep1&type=pdf)

Most of the notebook is inspired by Chapter 3 of [Mining of Massive Datasets](http://www.mmds.org/) Chapter 3 - Finding Similar Items.

The problem: You have a large collection of documents. You need to find all pairs with a high degree of similarity. This gets really computationally expensive with 1 million docs. LSH allows you to drastically reduce the number of comparisons made while only sacrificing a negligent ammount of recall.


Some example applications:

1. Plagiarism Detection
2. Uber: Detecting overlapping trips [link](https://eng.uber.com/lsh/)
3. Smule: Duplicate songs
4. Facebook: Fake news
5. Neural nets:
      * [The Reformer](http://ai.googleblog.com/2020/01/reformer-efficient-transformer.html)
      * Why not as an alternative to negative sampling?


# Imports

In [1]:
import pandas as pd
import numpy as np
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

CSS = """
.output {
    flex-direction: row;
}
"""

HTML('<style>{}</style>'.format(CSS))

# 1. Representing sets via their Characteristic Matrix

For our running examples, let's use the universe of the first 15 english letters: A - P.

The characteristic matrix of a set is a binary representation where the elements of the universe are enumerated and each set has a `0` for missing elements and a `1` for present elements. E.g. the set {A,B,D} would be 

<table> 
    <tr> <td> A </td> <td style="background-color: gold"> 1 </tr>
    <tr> <td> B </td> <td style="background-color: gold"> 1 </tr>
    <tr> <td> C </td> <td> 0 </tr>
    <tr> <td> D </td> <td style="background-color: gold"> 1 </tr>
    <tr> <td> E </td> <td> 0 </tr>
    <tr> <td> F </td> <td> ... </tr>
</table>

In [90]:
UNIVERSE = list("ABCDEFGHIJKLMNOP")
UNIVERSE

['A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P']

Let's sample some sets:

In [94]:
def sample_set():
    """
    Samples a random set from the universe.
    """
    return pd.DataFrame(np.random.randint(0, 2, size=len(UNIVERSE)), index=UNIVERSE, columns=[""])

def pp_set(s):
    """
    Pretty prints sets - highlights the present elements
    """
    def highlight_ones(x):
        return f"background-color: {'gold' if x==1 else 'white'}"
    display(s.style.applymap(highlight_ones))

s = sample_set()
pp_set(sample_set())

Unnamed: 0,Unnamed: 1
A,1
B,1
C,1
D,1
E,1
F,0
G,1
H,1
I,0
J,0


# 2. Jaccard Similarity

The Jaccard Similarity of a pair of sets is defined as:

$$ jac\_sim(A,B) = \frac{|A\cap B|}{|A \cup B|}$$

Examples:

$$ jac\_sim(\{\color{green}{A},\color{green}{B},C\}, \{\color{green}{A},\color{green}{B},D\}) = \frac{1}{2}$$
$$ jac\_sim(\{\color{green}{A},C\}, \{\color{green}{A},B,D\}) = \frac{1}{4}$$
$$ jac\_sim(\{C\}, \{A,B,D\}) = \frac{0}{4}$$
$$ jac\_sim(\{\color{green}{A},\color{green}{B},\color{green}{C},D\}, \{\color{green}{A},\color{green}{B},\color{green}{C},E\}) = \frac{3}{5}$$

# 3. Minhash

The operation to get the minhash of a set:
1. Permute the enumeration of the Universe.
2. The minhash of a set is the first present element under the new enumeration.

Example:

Suppose *ABCDEFGHIJKLMNOP* is permuted to *HIDCABJEFKLMNOPG*.
Then, the set {ABD}
<div style="display:flex; margin-top:12px;">
<table> 
    <tr> <td> A </td> <td> 1 </tr>
    <tr> <td> B </td> <td> 1 </tr>
    <tr> <td> C </td> <td> 0 </tr>
    <tr> <td> D </td> <td> 1 </tr>
    <tr> <td> E </td> <td> 0 </tr>
    <tr> <td> F </td> <td> ... </tr>
</table>
will become 
<table> 
    <tr> <td> H </td> <td> 0 </tr>
    <tr> <td> I </td> <td> 0 </tr>
    <tr> <td style="background-color:yellow"> D </td> <td style="background-color:yellow"> 1 </tr>
    <tr> <td> C </td> <td> 0 </tr>
    <tr> <td> A </td> <td> 1 </tr>
    <tr> <td> B </td> <td> 1 </tr>
    <tr> <td> J </td> <td> ... </tr>
</table>
</div>

and so the minhash of ABD under this permutation is `D`.

In [98]:
def permute(s):
    """
    Permutes a set
    """
    return s.reindex(np.random.permutation(s.index))
    
def pp_minhash(s, permute_set=True):
    """
    Pretty prints the minhash of a given set of a dataframe of sets
    """
    p = permute(s) if permute_set else s
    def find_first(rows):
        res = ['background-color:white']  * len(rows)
        res[rows.tolist().index(1)] = 'background-color:gold'
        return res
    display(p.style.apply(find_first, axis=0))

pp_set(s)
p = permute(s)
display(HTML('<div style="top:50%;position:absolute;  margin: 0;"> => permute => </div>'))
pp_set(p)
display(HTML('<div style="top:50%;position:absolute;  margin: 0;"> => minhash =></div>'))
pp_minhash(p, permute_set=False)

Unnamed: 0,Unnamed: 1
A,1
B,0
C,0
D,1
E,1
F,1
G,1
H,0
I,0
J,0


Unnamed: 0,Unnamed: 1
O,0
M,0
E,1
J,0
I,0
B,0
H,0
C,0
N,0
K,1


Unnamed: 0,Unnamed: 1
O,0
M,0
E,1
J,0
I,0
B,0
H,0
C,0
N,0
K,1


# 4. Minhash equality probability is Jaccard similarity

We will prove that the probability that the minhashes of two sets X and Y under the same random permutation is equal to their jaccard similarity: 

$$ P(minhash(X) = minhash(Y)) = jac\_sim(X, Y)$$

In [102]:
pair = pd.concat([sample_set(), sample_set()], axis=1)
pair.columns=['X','Y']
pp_set(pair)

Unnamed: 0,X,Y
A,1,1
B,1,1
C,0,0
D,1,1
E,0,0
F,0,0
G,0,0
H,1,1
I,1,1
J,1,1


Let's define 3 kinds of rows:
1. <span style="background-color:lightgreen;">Green rows</span>: these are rows with entries (1,1)
2. <span style="background-color:darkorange;">Orange</span> rows: these are rows with entries (0,1) or (1,0)
3. Gray rows: these are rows with entries (0,0)

In [103]:
def pp_jaccard(df):
    def highlight_eq_rows(row):
        l = row.iloc[0]
        r = row.iloc[1]
        if l == 0 and r == 0:
            return ['']* 2
        elif l != r:
            return ['background-color: darkorange']* 2
        else:
            return ['background-color: lightgreen'] * 2
    return df.style.apply(highlight_eq_rows, axis=1)

In [105]:
pp_jaccard(pair)

Unnamed: 0,X,Y
A,1,1
B,1,1
C,0,0
D,1,1
E,0,0
F,0,0
G,0,0
H,1,1
I,1,1
J,1,1


Let's draw a random permutation and visualize what the minhash of each of the sets is:

In [121]:
pp_minhash(pair)

Unnamed: 0,X,Y
C,0,0
F,0,0
E,0,0
L,1,0
M,0,0
I,1,1
K,1,0
H,1,1
P,0,1
G,0,0


It turns out, that the probability of having an equal minhash is:

$$ P(minhash(X) = minhash(Y)) = \frac{\color{green}{green}}{\color{green}{green}+\color{darkorange}{orange}}$$


Thus, if we sample the minhash operation enough times, we can get a reliable estimate of the true Jaccard similarity.

Let's see if this holds empirically.

In [122]:
def get_mh(s):
    p = permute(s)
    return p.index[p.apply(lambda x: x.tolist().index(1), axis=0)]

In [123]:
def jac_sim(pair):
    return (pair['X'] & pair['Y']).sum() / (pair['X'] | pair['Y']).sum()

In [124]:
jac_sim(pair)

0.7272727272727273

In [125]:
from math import sqrt
def wilson(p, n, z = 1.96):
    denominator = 1 + z**2/n
    centre_adjusted_probability = p + z*z / (2*n)
    adjusted_standard_deviation = sqrt((p*(1 - p) + z*z / (4*n)) / n)
    
    return f"{centre_adjusted_probability:.2f}+-{2*adjusted_standard_deviation:.2f}"

In [126]:
eq = []
for i in range(2000):
    mh_equality = get_mh(pair)
    eq.append(mh_equality[0] == mh_equality[1])
display(wilson(np.sum(eq)/len(eq), len(eq)))


'0.73+-0.02'

To construct minhash signatures, we fix N random permutations.
Then, the signature for a set would be a list containing the minhashes of the set under the different permutations. 

# 5. Min - okay. Why hash? Minhash in Practice

We've come up with a way to map sets to similarity-preserving signatures. Unfortunately, one of the steps involves calculating permutations of the universe, which is an expensive/slow operation when the universe is large. In practice, instead of permuting the universe and finding the minumum present element, we use a **hash function and pretend like it's giving us an ordering**.

Let's consider the function f(x) = 5 * x + 7 mod 17. In that case, for the set {A, B, D} we would get:
<table>
    <thead>
      <tr> <td>Element</td> <td>Original index</td> <td> 5x+7 mod 17 </td></tr>
    </thead>
    <tbody  >
        <tr><td>A</td><td>0</td><td>7</td></tr>
        <tr><td>B</td><td>1</td><td>12</td></tr>
        <tr><td>D</td><td>3</td><td style="background-color:gold">5</td></tr>
    </tbody>
</table>

In this case, the minhash for the set {A,B,D} would be **D** because it receives the minumal value of the hash function.

To sum up, to generate minhash signatures of documents, the procedure is as follows:

1. Map documents to sets (Domain dependent)
2. Generate a number of random hash functions
3. For each set, for each hash function, compute the hash values of all present elements and find the minimums. These minima make up the minhash signature.

# 6. Locality Sensitive Hashing

So, we've come up with a way to compare sets based on their *Minhash signature*. However, we still need to do the full $\frac{n*(n-1)}{2}$ comparisons.    
This is where Locality Sensitive Hashing comes into play. Suppose we have minhash signatures of length 12. Then, we can divide each signature into `b` bands of `r` rows, such that `b*r = 12`

This is how it works:
1. Divide the signature into bands
2. For each band, index the signature based on the hash of a band 
3. For each band, for each bucket compare only the collisions 

<table>
    <tbody>
        <tr> <td> Doc 1 <td style="border-left:2px solid black;background-color: lightgreen"> 12 </td> <td style="background-color: lightgreen"> 15</td> <td style="background-color: lightgreen"> 7</td> <td style="border-right:2px solid black;background-color:lightgreen"> 21 </td> <td style="background-color:darkorange"> 78 </td> <td style="background-color:darkorange"> 66 </td> <td style="background-color:darkorange"> 51 </td> <td style="background-color:darkorange; border-right:2px solid black;"> 43 </td> <td style="background-color:violet"> 113 </td> <td style="background-color:violet"> 41 </td> <td style="background-color:violet"> 16 </td> <td style="background-color:violet;border-right:2px solid black;"> 2 </td>  </tr>
        <tr> <td> Doc 2 <td style="border-left:2px solid black;background-color: lightgreen"> 12 </td> <td style="background-color: lightgreen"> 15</td> <td style="background-color: lightgreen"> 7</td> <td style="border-right:2px solid black;background-color:lightgreen"> 21 </td> <td style="background-color:darkorange"> 78 </td> <td style="background-color:darkorange"> 66 </td> <td style="background-color:darkorange"> 51 </td> <td style="background-color:darkorange; border-right:2px solid black;"> 54 </td> <td style="background-color:violet"> 113 </td> <td style="background-color:violet"> 17 </td> <td style="background-color:violet"> 16 </td> <td style="background-color:violet;border-right:2px solid black;"> 2 </td>  </tr>
        <tr> <td> Doc 3 <td style="border-left:2px solid black;background-color: lightgreen"> 12 </td> <td style="background-color: lightgreen"> 15</td> <td style="background-color: lightgreen"> 7</td> <td style="border-right:2px solid black;background-color:lightgreen"> 18 </td> <td style="background-color:darkorange"> 21 </td> <td style="background-color:darkorange"> 14 </td> <td style="background-color:darkorange"> 48 </td> <td style="background-color:darkorange; border-right:2px solid black;"> 25 </td> <td style="background-color:violet"> 113 </td> <td style="background-color:violet"> 42 </td> <td style="background-color:violet"> 8 </td> <td style="background-color:violet;border-right:2px solid black;"> 16 </td>  </tr>
    </tbody>
    <caption align="bottom"> Splitting signatures with N=12 into b=3 bands of r=4 entries </caption>
</table>

<div style="display:flex;">
    
<table style="margin-top:12px;">
    <caption>Band 1 hashes</caption>
    <tbody>
        <tr> <td style="background-color:lightgreen">(12, 15, 7, 21) </td> <td> Doc 1, Doc 2 </td> </tr>
        <tr> <td style="background-color:lightgreen">(12, 15, 7, 18) </td> <td> Doc 3 </td> </tr>
    </tbody>
</table> 
<table >
    <caption>Band 2 hashes</caption>
    <tbody>
        <tr> <td style="background-color:darkorange">(78, 66, 51, 43) </td> <td> Doc 1</td> </tr>
        <tr> <td style="background-color:darkorange">(78, 66, 51, 54) </td> <td> Doc 2</td> </tr>
        <tr> <td style="background-color:darkorange">(21, 14, 48, 25) </td> <td> Doc 3</td> </tr>
    </tbody>
</table> 
<table > 
    <caption>Band 3 hashes</caption>
    <tbody>
        <tr> <td style="background-color:violet">(113, 41, 16, 2) </td> <td> Doc 1 </td> </tr>
        <tr> <td style="background-color:violet">(113, 17, 16, 2) </td> <td> Doc 2 </td> </tr>
        <tr> <td style="background-color:violet">(113, 42, 8, 16) </td> <td> Doc 3 </td> </tr>
    </tbody>
</table> </div>

# 7. Selecting the band width for LSH

Choosing the band width is quite important: it has repercurssions in precision, recall and even performance.
Let's see what is the relationhip between the jaccard similarity $s$ of two documents with $b$ bands and $r$ entries in each band.

1. The probability that the signatures agree in all entries of one particular band is $s^r$
2. The probability that the signatures disagree in at least one row of a particular band is $1 − s^r$
3. The probability that the signatures disagree in at least one row of each of the bands is $(1 − s^r)^b$
4. The probability that the signatures agree in all the rows of at least one band, and therefore become a candidate pair, is $$1 − (1 − s^r)^b$$


In [127]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set('talk')
sns.set_style('whitegrid')

In [128]:
@interact(r=(1, 25,1))
def probability_of_match(r):
    b = 100 // r
    xs = np.arange(0, 1, 0.001)
    ys = 1 - (1 - xs**r) ** b
    fig, ax = plt.subplots(figsize=(12,6))
    sns.despine(fig)
    ax.set_ylabel('Probability of Collision', size=18)
    ax.set_xlabel('Jaccard Similarity', size = 18)
    #ax.plot(xs, ys, 'bo-', lw=4, markevery=[800]);
    ax.annotate(f"{ys[800]:.2f}", (xs[800], ys[800]), textcoords="offset points", xytext=(-50,0),c='b', weight='bold')
    ax.set_title("You must find the right balance between band width and \n number of bands", loc='left', pad=15)
    ax.fill_between(xs[:801], ys[:801], facecolor='darkorange', lw=4);
    ax.fill_between(xs[800:].tolist(),ys[800:].tolist(), facecolor='b', lw=4);

interactive(children=(IntSlider(value=13, description='r', max=25, min=1), Output()), _dom_classes=('widget-in…

# Additional Resources

1. [Mining of Massive Datasets Chapter 3](http://www.mmds.org/)
2. [Reformer - The Efficient Transformer](http://ai.googleblog.com/2020/01/reformer-efficient-transformer.html)
3. [Datasketch](http://ekzhu.com/datasketch/lsh.html)
4. [Spark Minhash LSH](https://spark.apache.org/docs/2.1.0/ml-features.html#minhash-for-jaccard-distance)