# Finding interesting clusters using snorkel

With unsupervised learning methods it can be challenging to automatically figure out which clusters are useful/interesting. By using collections of weak labels, we can automatically identify potentially interesting (or less interesting) clusters. The overall strategy is as follows:

1. Vectorize the data
2. Embed the data into a lower dimensional space using [UMAP](https://umap-learn.readthedocs.io/en/latest/)
3. Run density based clustering using [HDBSCAN](https://hdbscan.readthedocs.io/en/latest/index.html)
4. Write labeling functions using [snorkel](https://www.snorkel.org/) to get collections of weak labels
5. Identify clusters that contain many points (e.g. >90%) with a given label 

Optionally we can go a step further. Snorkel can train predict an overall label for a point (given the set of weak labels) using [weak supervision](https://www.snorkel.org/blog/weak-supervision) methods. We can use this to find potentially interesting clusters.

1. Use snorkel to generate an overall label for each data point using weak supervision
2. Train a supervised learning model to predict this overall label (hoping it generalizes better). It can also use a different feature set than the snorkel model.
3. Predict a label for every point in the dataset.
4. Identify clusters that have a high proportion of a given label. Or if you have a lot of confidence in your model, use it to prune your clusters.

This notebook is an extension of the great [snorkel spam tutorial](https://www.snorkel.org/use-cases/01-spam-tutorial). It uses a youtube comment dataset which have been labeled as SPAM or HAM. We won't actually use the labels to train a model, just verify that we're on the right track.

In [1]:
import json
import pandas as pd
import numpy as np
import umap
import umap.plot

import scipy.sparse

import hdbscan

from sklearn.feature_extraction.text import CountVectorizer

import matplotlib.pyplot as plt

from warnings import warn
import numba

umap.plot.output_notebook()

In [62]:
# Download the spaCy english model
# You should only need to run this once but after running it restart your jupyter kernel
# ! python -m spacy download en_core_web_sm

### Load in the spam dataset

In [2]:
from utils import load_spam_dataset

df_train, df_test = load_spam_dataset()

# We pull out the label vectors for ease of use later
Y_test = df_test.label.values

### Vectorize the text to use with UMAP

Here we are just using CountVectorizer from sklearn but there are a large number of different vectorizers in [this repository](https://github.com/TutteInstitute/vectorizers)

In [3]:
vectorizer = CountVectorizer(min_df=1, stop_words='english')
word_doc_matrix = vectorizer.fit_transform(df_train.text)

In [4]:
word_doc_matrix

<1586x3699 sparse matrix of type '<class 'numpy.int64'>'
	with 12448 stored elements in Compressed Sparse Row format>

### Use UMAP to embed the vectorized text into a lower dimensional space

In [5]:
%%time
embedding_model = umap.UMAP(n_components=2, metric='hellinger')
embedding = embedding_model.fit_transform(word_doc_matrix)

OMP: Info #271: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
Disconnection_distance = 1 has removed 2033938 edges.
It has only fully disconnected 26 vertices.
Use umap.utils.disconnected_vertices() to identify them.
  warn(


CPU times: user 21.7 s, sys: 1.26 s, total: 23 s
Wall time: 10.9 s


We have some disconnected components so we can remove them. If we leave them in HDBSCAN will throw errors

In [6]:
is_disconnected = umap.utils.disconnected_vertices(embedding_model)
is_connected = [not x for x in is_disconnected]
connected_embedding = embedding[is_connected]

### Cluster the embedding using HDBSCAN

In [7]:
clusterer = hdbscan.HDBSCAN(metric='euclidean', min_cluster_size=5)
cluster_labels = list(clusterer.fit_predict(connected_embedding))

Add in a cluster label of -2 for disconnected points so that we can easily ignore them later

In [8]:
for idx, val in enumerate(is_disconnected):
    if val:
        cluster_labels.insert(idx, -2)

### Show some data when hovering over a point

In [9]:
comment_text = df_train['text'].reset_index(drop=True)
hover_df = pd.DataFrame({'text': comment_text})
hover_df['cluster'] = cluster_labels
hover_df.reset_index(drop=True, inplace=True)

### Plot the embedding

In [10]:
f = umap.plot.interactive(embedding_model, labels=hover_df['cluster'], hover_data=hover_df, subset_points=is_connected)
umap.plot.show(f)

### Using snorkel

All of these labeling functions are taken from the [snorkel spam tutorial](https://www.snorkel.org/use-cases/01-spam-tutorial) so I won't go into too much detail with them. There are a wide variety of labeling functions you can create, including fairly sophisticated ones.

In [11]:
import re
from collections import Counter

from snorkel.labeling import labeling_function, LabelingFunction
from snorkel.preprocess import preprocessor
from textblob import TextBlob
from snorkel.preprocess.nlp import SpacyPreprocessor
from snorkel.labeling.lf.nlp import nlp_labeling_function
from snorkel.labeling import PandasLFApplier

In [12]:
# For clarity, we define constants to represent the class labels for spam, ham, and abstaining.
ABSTAIN = -1
HAM = 0
SPAM = 1

Generic labeling functions

In [13]:
@labeling_function()
def check(x):
    return SPAM if "check" in x.text.lower() else ABSTAIN


@labeling_function()
def check_out(x):
    return SPAM if "check out" in x.text.lower() else ABSTAIN

Using preprocessors and external models

In [14]:
@preprocessor(memoize=True)
def textblob_sentiment(x):
    scores = TextBlob(x.text)
    x.polarity = scores.sentiment.polarity
    x.subjectivity = scores.sentiment.subjectivity
    return x

@labeling_function(pre=[textblob_sentiment])
def textblob_polarity(x):
    return HAM if x.polarity > 0.9 else ABSTAIN

@labeling_function(pre=[textblob_sentiment])
def textblob_subjectivity(x):
    return HAM if x.subjectivity >= 0.5 else ABSTAIN

Keyword based labeling

In [15]:
def keyword_lookup(x, keywords, label):
    if any(word in x.text.lower() for word in keywords):
        return label
    return ABSTAIN


def make_keyword_lf(keywords, label=SPAM):
    return LabelingFunction(
        name=f"keyword_{keywords[0]}",
        f=keyword_lookup,
        resources=dict(keywords=keywords, label=label),
    )


"""Spam comments talk about 'my channel', 'my video', etc."""
keyword_my = make_keyword_lf(keywords=["my"])

"""Spam comments ask users to subscribe to their channels."""
keyword_subscribe = make_keyword_lf(keywords=["subscribe"])

"""Spam comments post links to other channels."""
keyword_link = make_keyword_lf(keywords=["http"])

"""Spam comments make requests rather than commenting."""
keyword_please = make_keyword_lf(keywords=["please", "plz"])

"""Ham comments actually talk about the video's content."""
keyword_song = make_keyword_lf(keywords=["song"], label=HAM)

Using regular expressions for labeling

In [16]:
@labeling_function()
def regex_check_out(x):
    return SPAM if re.search(r"check.*out", x.text, flags=re.I) else ABSTAIN

Heuristic labeling functions

In [17]:
@labeling_function()
def short_comment(x):
    """Ham comments are often short, such as 'cool video!'"""
    return HAM if len(x.text.split()) < 5 else ABSTAIN

Complex preprocessors

In [19]:
# The SpacyPreprocessor parses the text in text_field and
# stores the new enriched representation in doc_field
spacy = SpacyPreprocessor(text_field="text", doc_field="doc", memoize=True)

@labeling_function(pre=[spacy])
def has_person(x):
    """Ham comments mention specific people and are short."""
    if len(x.doc) < 20 and any([ent.label_ == "PERSON" for ent in x.doc.ents]):
        return HAM
    else:
        return ABSTAIN

In [20]:
@nlp_labeling_function()
def has_person_nlp(x):
    """Ham comments mention specific people and are short."""
    if len(x.doc) < 20 and any([ent.label_ == "PERSON" for ent in x.doc.ents]):
        return HAM
    else:
        return ABSTAIN

### Apply the labeling functions

In [21]:
lfs = [
    keyword_my,
    keyword_subscribe,
    keyword_link,
    keyword_please,
    keyword_song,
    regex_check_out,
    short_comment,
    has_person_nlp,
    textblob_polarity,
    textblob_subjectivity,
]

In [22]:
applier = PandasLFApplier(lfs=lfs)
L_train = applier.apply(df=df_train)
L_test = applier.apply(df=df_test)

label_names = applier._lf_names

100%|█████████████████████████████████████████████████████████████████████████████████████████| 1586/1586 [00:09<00:00, 162.87it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:01<00:00, 148.90it/s]


We're going to visualize the embedding later. It is useful to see the counts for each of the labels on each point. E.g. this comment was labeled ABSTAIN 7 times, SPAM 2 times, and HAM 1 time.

In [23]:
label_mapping = {-1: 'ABSTAIN', 0: 'HAM', 1: 'SPAM'}

labels = pd.DataFrame(L_train, columns=applier._lf_names)
labels = labels.applymap(lambda x: label_mapping[x])

row_counts = labels.apply(lambda row: Counter(row), axis=1)
label_df_hover = pd.DataFrame.from_dict(list(row_counts)).fillna(0.0)
label_df_hover

Unnamed: 0,ABSTAIN,SPAM,HAM
0,8,1.0,1.0
1,7,2.0,1.0
2,10,0.0,0.0
3,10,0.0,0.0
4,10,0.0,0.0
...,...,...,...
1581,8,2.0,0.0
1582,9,0.0,1.0
1583,6,3.0,1.0
1584,7,2.0,1.0


In [63]:
label_df_hover['text'] = df_train['text'].reset_index(drop=True)
label_df_hover['cluster'] = cluster_labels

In [64]:
f = umap.plot.interactive(embedding_model, labels=label_df_hover['cluster'], hover_data=label_df_hover)
umap.plot.show(f)

### Find clusters with a high proportion of a given label

Here we can look for clusters that have lots of points with a given label. We can look on a per labeling function basis as well, not just overall label.

In [27]:
def calculate_label_frequencies(df, label_names):
    hits_df = pd.DataFrame()
    grouped = df.groupby('cluster')

    for cluster in grouped.groups:
        g_df = grouped.get_group(cluster)
        # Calcuate the frequency of each label within the cluster
        frequencies = g_df[label_names].apply(pd.value_counts, normalize=True).fillna(0.0)

        frequencies_df = frequencies.stack().rename_axis(['label', 'analytic']).to_frame(name='frequency').reset_index()
        frequencies_df.insert(0, 'cluster', cluster)

        hits_df = pd.concat([hits_df, frequencies_df])
    return hits_df

In [28]:
cluster_df = labels
cluster_df.insert(0, 'cluster', cluster_labels)
freqs = calculate_label_frequencies(cluster_df, label_names)

In [29]:
freqs

Unnamed: 0,cluster,label,analytic,frequency
0,-2,ABSTAIN,keyword_my,0.961538
1,-2,ABSTAIN,keyword_subscribe,1.000000
2,-2,ABSTAIN,keyword_http,1.000000
3,-2,ABSTAIN,keyword_please,1.000000
4,-2,ABSTAIN,keyword_song,1.000000
...,...,...,...,...
15,67,HAM,regex_check_out,0.000000
16,67,HAM,short_comment,0.181818
17,67,HAM,has_person_nlp,0.000000
18,67,HAM,textblob_polarity,0.000000


Let's define thresholds for what it means to be "a large proportion of the points". We can define it per analytic/labeling function with a sensible default value

In [30]:
def apply_thresholds(row, thresholds):
    if row['frequency'] >= thresholds[row['analytic']]:
        return True
    else:
        return False

In [31]:
default_threshold = 0.9
thresholds = {k: default_threshold for k in label_names}
labels_of_interest = ['SPAM']

In [32]:
above_threshold = freqs[freqs.apply(apply_thresholds, axis=1, thresholds=thresholds)]
above_threshold

Unnamed: 0,cluster,label,analytic,frequency
0,-2,ABSTAIN,keyword_my,0.961538
1,-2,ABSTAIN,keyword_subscribe,1.000000
2,-2,ABSTAIN,keyword_http,1.000000
3,-2,ABSTAIN,keyword_please,1.000000
4,-2,ABSTAIN,keyword_song,1.000000
...,...,...,...,...
3,67,ABSTAIN,keyword_please,1.000000
4,67,ABSTAIN,keyword_song,1.000000
5,67,ABSTAIN,regex_check_out,1.000000
7,67,ABSTAIN,has_person_nlp,1.000000


Lets look for clusters that have a lot of comments labeled SPAM

In [33]:
spam_hits = above_threshold[above_threshold['label']=='SPAM']
spam_hits

Unnamed: 0,cluster,label,analytic,frequency
25,1,SPAM,regex_check_out,1.0
25,4,SPAM,regex_check_out,0.954545
25,6,SPAM,regex_check_out,0.918033
22,7,SPAM,keyword_http,1.0
22,9,SPAM,keyword_http,1.0
22,10,SPAM,keyword_http,1.0
22,11,SPAM,keyword_http,0.96
22,12,SPAM,keyword_http,1.0
22,13,SPAM,keyword_http,1.0
22,15,SPAM,keyword_http,0.9


In [34]:
clusters_with_hits = pd.unique(spam_hits['cluster'])

In [35]:
clusters_with_hits

array([ 1,  4,  6,  7,  9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 33, 36, 37,
       38, 50])

### Lets try to combine labels and train a classifier

In [36]:
from snorkel.labeling.model import MajorityLabelVoter
from snorkel.labeling.model import LabelModel
from snorkel.labeling import filter_unlabeled_dataframe
from snorkel.utils import probs_to_preds

from sklearn.linear_model import LogisticRegression

In [37]:
majority_model = MajorityLabelVoter()
preds_train = majority_model.predict(L=L_train)

In [38]:
label_model = LabelModel(cardinality=2, verbose=True)
label_model.fit(L_train=L_train, n_epochs=500, log_freq=100, seed=123)

In [39]:
probs_train = label_model.predict_proba(L=L_train)

In [40]:
df_train_filtered, probs_train_filtered = filter_unlabeled_dataframe(
    X=df_train, y=probs_train, L=L_train
)

In [41]:
vectorizer = CountVectorizer(ngram_range=(1, 5))
X_train = vectorizer.fit_transform(df_train_filtered.text.tolist())
X_test = vectorizer.transform(df_test.text.tolist())

In [42]:
preds_train_filtered = probs_to_preds(probs=probs_train_filtered)

In [43]:
sklearn_model = LogisticRegression(C=1e3, solver="liblinear")
sklearn_model.fit(X=X_train, y=preds_train_filtered)

LogisticRegression(C=1000.0, solver='liblinear')

In [44]:
print(f"Test Accuracy: {sklearn_model.score(X=X_test, y=Y_test) * 100:.1f}%")

Test Accuracy: 93.6%


### Plot the embedding while showing the predicted class

In [45]:
# We need to use the same vocabulary that the model was trained on
cv = CountVectorizer(ngram_range=(1, 5), vocabulary=vectorizer.vocabulary_)
X_train_all = cv.fit_transform(df_train.text.tolist())

In [46]:
pred_probs = sklearn_model.predict_proba(X_train_all)

In [47]:
def get_labels_with_confidence(predictions, class_labels, threshold=0.8):
    predicted_labels = []
    for probs in predictions:

        max_idx = np.argmax(probs)
        probability = probs[max_idx]
        if probability >= threshold:
            predicted_labels.append(class_labels[max_idx])
        else:
            predicted_labels.append('ABSTAIN')
    
    return predicted_labels

In [48]:
class_labels = ['HAM', 'SPAM']

In [49]:
comment_text = df_train['text'].reset_index(drop=True)
hover_data = pd.DataFrame({'text': comment_text})
hover_data['cluster'] = cluster_labels
hover_data['predicted_label'] = get_labels_with_confidence(pred_probs, class_labels, threshold=0.8)
hover_data.reset_index(drop=True, inplace=True)

In [50]:
hover_data

Unnamed: 0,text,cluster,predicted_label
0,pls http://www10.vakinha.com.br/VaquinhaE.aspx...,-1,HAM
1,"if your like drones, plz subscribe to Kamal Ta...",19,SPAM
2,go here to check the views :3﻿,66,SPAM
3,"Came here to check the views, goodbye.﻿",66,ABSTAIN
4,"i am 2,126,492,636 viewer :D﻿",67,HAM
...,...,...,...
1581,Check out my mummy chanel!,-1,SPAM
1582,The rap: cool Rihanna: STTUUPID﻿,34,SPAM
1583,I hope everyone is in good spirits I&#39;m a h...,33,SPAM
1584,Lil m !!!!! Check hi out!!!!! Does live the wa...,-1,SPAM


In [51]:
f = umap.plot.interactive(embedding_model, labels=hover_data['cluster'], hover_data=hover_data)
umap.plot.show(f)

In [52]:
cluster_df_preds = hover_data[['cluster', 'predicted_label']]
freqs_preds = calculate_label_frequencies(cluster_df_preds, ['predicted_label'])

In [53]:
freqs_preds

Unnamed: 0,cluster,label,analytic,frequency
0,-2,HAM,predicted_label,0.923077
1,-2,SPAM,predicted_label,0.076923
0,-1,SPAM,predicted_label,0.578182
1,-1,HAM,predicted_label,0.381818
2,-1,ABSTAIN,predicted_label,0.040000
...,...,...,...,...
0,66,SPAM,predicted_label,0.500000
1,66,HAM,predicted_label,0.300000
2,66,ABSTAIN,predicted_label,0.200000
0,67,SPAM,predicted_label,0.545455


In [54]:
default_threshold = 0.9
prediction_thresholds = {k: default_threshold for k in ['predicted_label']}
labels_of_interest = ['SPAM']

In [55]:
above_threshold_preds = freqs_preds[freqs_preds.apply(apply_thresholds, axis=1, thresholds=prediction_thresholds)]
above_threshold_preds

Unnamed: 0,cluster,label,analytic,frequency
0,-2,HAM,predicted_label,0.923077
0,0,HAM,predicted_label,1.0
0,1,SPAM,predicted_label,1.0
0,4,SPAM,predicted_label,0.954545
0,6,SPAM,predicted_label,0.991803
0,15,HAM,predicted_label,0.9
0,18,SPAM,predicted_label,0.978723
0,19,SPAM,predicted_label,1.0
0,21,SPAM,predicted_label,1.0
0,25,HAM,predicted_label,0.913043


In [56]:
spam_hits_predicted = above_threshold_preds[above_threshold_preds['label']=='SPAM']
spam_hits_predicted

Unnamed: 0,cluster,label,analytic,frequency
0,1,SPAM,predicted_label,1.0
0,4,SPAM,predicted_label,0.954545
0,6,SPAM,predicted_label,0.991803
0,18,SPAM,predicted_label,0.978723
0,19,SPAM,predicted_label,1.0
0,21,SPAM,predicted_label,1.0
0,33,SPAM,predicted_label,1.0
0,36,SPAM,predicted_label,0.9
0,37,SPAM,predicted_label,0.9
0,38,SPAM,predicted_label,1.0


In [57]:
clusters_with_predicted_hits = pd.unique(spam_hits_predicted['cluster'])

In [58]:
clusters_with_predicted_hits

array([ 1,  4,  6, 18, 19, 21, 33, 36, 37, 38, 39, 49, 50])

In [59]:
clusters_with_hits

array([ 1,  4,  6,  7,  9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 33, 36, 37,
       38, 50])

In [60]:
hover_data[hover_data['cluster'] == 13]

Unnamed: 0,text,cluster,predicted_label
30,Have you tried a new social network TSU? This ...,13,HAM
62,"People, here is a new network like FB...you re...",13,HAM
132,need money?Enjoy https://www.tsu.co/emerson_za...,13,HAM
195,https://www.tsu.co/KodysMan plz ^^﻿,13,SPAM
231,https://www.tsu.co/Aseris get money here !﻿,13,HAM
390,"People, here is a new network like FB...you re...",13,HAM
427,https://www.paidverts.com/ref/tomuciux99 esyes...,13,HAM
531,https://www.tsu.co/ToMeks Go register ;) free ...,13,HAM
625,Sign up for free on TSU and start making money...,13,HAM


In [61]:
freqs_preds[freqs_preds['cluster'] == 12]

Unnamed: 0,cluster,label,analytic,frequency
0,12,HAM,predicted_label,0.75
1,12,SPAM,predicted_label,0.25
