[Lucy and Bamman 2021](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00383/101877/Characterizing-English-Variation-across-Social) uses KMeans clustering over BERT representations to learn word senses in order to characterize their distinctive use within online communities.  In this notebook, we'll explore inferring distinct senses using clustering.

In [None]:
from transformers import BertModel, BertTokenizer
import numpy as np
from sklearn.cluster import KMeans
from collections import Counter

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

In [None]:
def get_bert_for_token(string, term):
    
    # tokenize
    inputs = tokenizer(string, return_tensors="pt")
    
    # convert input ids to words
    tokens=tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    
    # find the first location of the query term among those tokens (so we know which BERT rep to use)
    term_idx=tokens.index(term)
    
    outputs = model(**inputs)

    # return the BERT rep for that token index
    # The output is a pytorch tensor object, but let's convert it to a numpy object to work with numpy functions
    
    return outputs.last_hidden_state[0][term_idx].detach().numpy()

In [None]:
def read_data(filename):
    data=[]
    with open(filename, encoding="utf-8") as file:
        for line in file:
            data.append(line.rstrip())
    return data

First, let's examine uses of the word "cabinet" from several contemporary novels.

In [None]:
data=read_data("../data/cabinet.txt")
reps=[]
for sentence in data:
    reps.append(get_bert_for_token(sentence, "cabinet"))

In [None]:
kmeans = KMeans(n_clusters=2, random_state=0).fit(reps)

In [None]:
for idx in np.argsort(kmeans.labels_):
    print("%s\t%s" % (kmeans.labels_[idx], data[idx]))

Now let's examine a word that has slightly more polysemy: *right*.  Explore clustering with different number of clusters; how many clusters do you need to settle on what you would consider to be the right number of distinct senses?

In [None]:
data=read_data("../data/right200.txt")
reps=[]
for sentence in data:
    reps.append(get_bert_for_token(sentence, "right"))

In [None]:
kmeans = KMeans(n_clusters=2, random_state=0).fit(reps)

In [None]:
max_per_class=5
cluster_counts=Counter()
last_lab=None
for idx in np.argsort(kmeans.labels_):
    clusterID=kmeans.labels_[idx]
    if cluster_counts[clusterID] < max_per_class:
        cluster_counts[clusterID]+=1
        if clusterID != last_lab and last_lab is not None:
            print()
        last_lab=clusterID
        print("%s\t%s" % (clusterID, data[idx]))