In this notebook, we demonstrate how to extract and lookup for contextually-most-similar words using BERT and nearest neighbor search. 

This was inspired by the StackOverflow question https://stackoverflow.com/questions/59865719/how-to-find-the-closest-word-to-a-vector-using-bert

Note: The environment we use is google colab, first run. Subsequent runs after reset have had issues with dependencies.

# learn to extract embeddings from bert

We use `bert-embedding` package; see https://pypi.org/project/bert-embedding/

We use GPU, so please choose the Colab kernel accordingly

In [None]:
!pip install mxnet-cu102
!pip install bert-embedding

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting numpy<2.0.0,>1.16.0
  Downloading numpy-1.23.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
[K     |████████████████████████████████| 17.1 MB 5.2 MB/s 
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.14.6
    Uninstalling numpy-1.14.6:
      Successfully uninstalled numpy-1.14.6
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
scipy 1.7.3 requires numpy<1.23.0,>=1.16.5, but you have numpy 1.23.5 which is incompatible.
mxnet 1.4.0 requires numpy<1.15.0,>=1.8.2, but you have numpy 1.23.5 which is incompatible.
bert-embedding 1.0.1 requires numpy==1.14.6, but you have numpy 1.23.5 which is incompatible.[0m
Successfully installed numpy-1.23.5


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting numpy==1.14.6
  Using cached numpy-1.14.6-cp38-cp38-linux_x86_64.whl
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.23.5
    Uninstalling numpy-1.23.5:
      Successfully uninstalled numpy-1.23.5
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
yellowbrick 1.5 requires numpy>=1.16.0, but you have numpy 1.14.6 which is incompatible.
xarray 0.20.2 requires numpy>=1.18, but you have numpy 1.14.6 which is incompatible.
xarray-einstats 0.3.0 requires numpy>=1.19, but you have numpy 1.14.6 which is incompatible.
tifffile 2022.10.10 requires numpy>=1.19.2, but you have numpy 1.14.6 which is incompatible.
thinc 8.1.5 requires numpy>=1.15.0, but you have numpy 1.14.6 which is incompatible.
tensorfl

In [None]:
import mxnet as mx
from bert_embedding import BertEmbedding

In [None]:
# Can't get this to work... may be deprecated
# ctx = mx.gpu(0)
bert = BertEmbedding()

In [None]:
from tqdm.auto import tqdm, trange

In [None]:
bert_abstract = """We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations from Transformers.
 Unlike recent language representation models, BERT is designed to pre-train deep bidirectional representations by jointly conditioning on both left and right context in all layers.
 As a result, the pre-trained BERT representations can be fine-tuned with just one additional output layer to create state-of-the-art models for a wide range of tasks, such as question answering and language inference, without substantial task-specific architecture modifications. 
BERT is conceptually simple and empirically powerful. 
It obtains new state-of-the-art results on eleven natural language processing tasks, including pushing the GLUE benchmark to 80.4% (7.6% absolute improvement), MultiNLI accuracy to 86.7 (5.6% absolute improvement) and the SQuAD v1.1 question answering Test F1 to 93.2 (1.5% absolute improvement), outperforming human performance by 2.0%."""

In [None]:
sentences = bert_abstract.split('\n')
result = bert(sentences)
toks, embs = result[0]
print(toks)
print(len(toks), len(embs))
print(embs[0][:10])

['we', 'introduce', 'a', 'new', 'language', 'representation', 'model', 'called', 'bert', ',', 'which', 'stands', 'for', 'bidirectional', 'encoder', 'representations', 'from', 'transformers']
18 18
[ 0.47964773  0.1824888  -0.28597528 -0.46567446  0.01248981 -0.07430505
 -0.18017295  0.37813222  0.9135139  -0.25295883]


# process a corpus

We download a 10k web-public .com corpus from https://wortschatz.uni-leipzig.de/en/download/


In [None]:
!wget https://files.pushshift.io/gab/GABPOSTS_2018-10.xz

--2022-12-08 10:44:44--  https://files.pushshift.io/gab/GABPOSTS_2018-10.xz
Resolving files.pushshift.io (files.pushshift.io)... 172.67.170.36, 104.21.28.11, 2606:4700:3031::6815:1c0b, ...
Connecting to files.pushshift.io (files.pushshift.io)|172.67.170.36|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 337144496 (322M) [application/octet-stream]
Saving to: ‘GABPOSTS_2018-10.xz.1’


2022-12-08 10:45:03 (16.8 MB/s) - ‘GABPOSTS_2018-10.xz.1’ saved [337144496/337144496]



In [None]:
import lzma
import json
import pandas as pd
import re
import numpy as np
gab_posts = pd.DataFrame()
temp = []
counter = 0
with lzma.open('GABPOSTS_2018-10.xz', mode='r') as file:
    for line in file:
      # Can add raw/other fields, just worried about memory requirements
      # I think I did this in a cleaner way before, but w/e it works :)
      if counter > 100000:
        break
      counter = counter + 1
      temp.append({"body": re.sub(r"(?:\@|\\|https?\://)\S+", "",json.loads(line)['body']), "date":json.loads(line)['created_at']})


In [None]:
gab_posts = gab_posts.append(temp)

In [None]:
len(temp)

5001

In [None]:
banned_words = ['pussie', 'phonesex', 'footjob', 'horniest', 'clitoris', 'headfuck', 'areola', 'pussies', 'goddamnes', 'suicide', 'voyeurweb', 'suicide girls', 'niggarded', 'deepthroat', 'fuckbuddy', 'nigra', 'freefuck', 'boob', 'hentai', 'rentafuck', 'wanking', 'jerk off', 'molester', 'horney', 'titfuckin', 'milf', 'wrapping men', 'whorefucker', 'masturbating', 'dick', 'honkers', 'chocolate rosebuds', 'neonazi', 'vibrater', 'uptheass', 'shitdick', 'pussycat', 'naked', 'group sex', 'suckmydick', 'pussyeater', 'masturbate', 'stupidfuck', 'nig', 'rape', 'meth', 'virgin', 'livesex', 'terrorist', 'upskirt', 'shortfuck', 'genital', 'jiggaboo', 'marijuana', 'cumshots', 'koon', 'holestuffer', 'tit', 'assbagger', 'ball sack', 'sexpot', 'suckmyass', 'lovejuice', 'phukking', 'wigger', 'black cock', 'whiskeydick', 'blonde on blonde action', 'retarded', 'kunt', 'motherfuckin', 'orgy', 'ejaculation', 'fuckme', 'phone sex', 'fuckher', 'niggerhole', 'intercourse', 'pussylips', 'niggardly', 'tongethruster', 'nig nog', 'kumbullbe', 'nigger', 'wanker', 'peepshpw', 'cocks', 'omorashi', 'female squirting', 'blow job', 'bung hole', 'homicide', 'penetration', 'puddboy', 'gang bang', 'lickme', 'spermhearder', 'titties', 'rigger', 'shitblimp', 'twat', 'fag', 'gangbanger', 'orgasim', 'porno', 'assfuck', 'pussy', 'sodomy', 'cumshot', 'cock', 'jihad', 'niggaz', 'picaninny', 'bondage', 'dry hump', 'poorwhitetrash', 'whitenigger', 'nip', 'masturbation', 'peni5', 'sexed', 'escort', 'g-spot', 'muffindiver', 'fingerbang', 'shite', 'gypo', 'scrotum', 'creampie', 'goddamnmuthafucker', 'foreskin', 'titty', 'dildo', 'sexkitten', 'anus', 'niggling', 'niggerhead', 'footlicker', 'pussylover', 'limpdick', 'fucktard', 'male squirting', 'gangbang', 'nigg', 'suckdick', 'vagina', 'reestie', 'bangbros', 'givehead', 'spank', 'trailertrash', 'giant cock', 'fucktards', 'sexo', 'pussypounder', 'gaymuthafuckinwhore', 'negroid', 'lsd', 'ball gag', 'jigga', "nigger's", 'orgasm', 'nlgger', 'asskiss', 'coprolagnia', 'boobs', 'pussylicker', 'whitetrash', 'mothafuckings', 'fingering', 'scum', 'paedophile', 'sperm', 'testicle', 'poopchute', 'wank', 'jerkoff', 'octopussy', 'pedophile', 'reverse cowgirl', 'negroes', 'suckmytit', 'big tits', 'sonofbitch', 'swastika', 'jizz', 'sexslave', 'bunghole', 'retard', 'hore', 'nipplering', 'kink', 'nipples', 'vaginal', 'tittie', 'hitler', 'jiggabo', 'pedobear', 'handjob', 'pubic', 'kkk', 'niggled', 'pthc', "negro's", 'doggystyle', 'samckdaddy', 'gangbanged', 'clit', 'hand job', 'beaners', 'ecchi', 'doggy style', 'nutten', 'bdsm', 'cunnilingus', 'killing', 'genitals', 'poop chute', 'fuckfest', 'spermherder', 'brunette action', 'motherfuck', 'cumming', 'erotic', 'splooge moose', 'foursome', 'niglet', 'nigre', 'incest', 'cunt', 'molest', 'threesome', 'kissass', 'narcotic', 'sexhouse', 'nudity', 'fudgepacker', 'snownigger', 'white power', 'jiggerboo', 'honky', 'rosy palm and her 5 sisters', 'nittit', 'horny', 'hotpussy', 'ball sucking', 'nignog', 'palesimian', 'jizjuice', 'zoophilia', 'nigga', 'asslicker', 'niggle', 'nlggor', 'pornography', 'sexing', 'slutt', 'titlicker', 'kunnilingus', 'fuckwhore', 'wet dream', 'spunk', 'pisser', 'puss', 'boner', 'skeet', 'sextoys', 'vibrator', 'manpaste', 'faggot', 'humping', 'nipple', 'double penetration', 'coons', 'assklown', 'pubes', 'fuckface', 'anal', 'nimphomania', 'blowjob', 'rimjob', 'fisting', 'niggardliness', 'sultry women', 'jizzim', 'kinkster', 'skankfuck', 'penis', 'how to kill', 'semen', 'mothafucker', 'analsex', 'niggur', 'panty', 'deep throat', 'foot fetish', 'freakyfucker', 'date rape', 'assblaster', 'bukkake', 'lesbo', 'spaghettinigger', 'beaner', 'clover clamps', 'twobitwhore', 'nigr', 'fuckfriend', 'sextoy', 'prostitute', 'pussyfucker', 'kanake', 'porchmonkey', 'testicles', 'erotism', 'pusy', 'assjockey', 'pimpjuic', 'booty call', 'kaffir', 'fuckable', 'goldenshower', 'homobangers', 'pegging', 'rapist', 'venus mound', 'raping', 'fudge packer', 'sexcam', 'timbernigger', 'viagra', 'make me come', 'beastiality', 'leather restraint', 'coon', 'futanari', 'fuckina', 'iblowu', 'masterbate', 'luckycammeltoe', "niggardliness's", 'fuckmehard', 'tits', 'suckme', 'intheass', 'niggarding', 'tonguetramp', 'niggor', 'schlong', 'niggah', 'raped', 'nazi', 'two girls one cup', 'huge fat', 'upthebutt', 'daterape', 'mastabater', 'cum', 'asslick', 'raghead', 'bestiality', 'golden shower', 'niggers', 'penises', 'mufflikcer', 'camel toe', 'shaved pussy', 'niggles', 'jijjiboo']


In [None]:

gab_posts['body'].replace(r'\n',' ', regex=True, inplace=True)
gab_posts['body'].replace(r'\r',' ', regex=True, inplace=True)
gab_posts['body'].replace('', np.nan, inplace=True)
gab_posts.dropna(subset=['body'], inplace=True)
gab_posts['body'].str.lower()


0                                      #trade #winning   
1       o deputado arthur lira (pp/al):  ⚠️ responde a...
2                     cocaine mitch comes out swinging.  
3       #demoncrats have no redeeming value. a black c...
4       putting up the rent in one town/city/country w...
                              ...                        
4996                            is this true - holy shit!
4997    #texasfirst #jesuskills #ziohomophobia #martin...
4998    just in: feinstein calls on white house, fbi t...
4999                                  keep thinking that 
5000    huh!! ((((gasp))) gt!!! those are my puppiesss...
Name: body, Length: 4289, dtype: object

remove row index from each sentence

In [None]:
all_sentences = gab_posts['body'].to_numpy()

# create a search index

In [None]:
from sklearn.neighbors import KDTree
import numpy as np


class ContextNeighborStorage:
    def __init__(self, sentences, model):
        self.sentences = sentences
        self.model = model

    def process_sentences(self):
        result = self.model(self.sentences)

        self.sentence_ids = []
        self.token_ids = []
        self.all_tokens = []
        all_embeddings = []
        for i, (toks, embs) in enumerate(tqdm(result)):
            for j, (tok, emb) in enumerate(zip(toks, embs)):
                self.sentence_ids.append(i)
                self.token_ids.append(j)
                self.all_tokens.append(tok)
                all_embeddings.append(emb)
        all_embeddings = np.stack(all_embeddings)
        # we normalize embeddings, so that euclidian distance is equivalent to cosine distance
        self.normed_embeddings = (all_embeddings.T / (all_embeddings**2).sum(axis=1) ** 0.5).T

    def build_search_index(self):
        # this takes some time
        self.indexer = KDTree(self.normed_embeddings)

    def query(self, query_sent, query_word, k=10, filter_same_word=False):
        toks, embs = self.model([query_sent])[0]

        found = False
        for tok, emb in zip(toks, embs):
            if tok == query_word:
                found = True
                break
        if not found:
            raise ValueError('The query word {} is not a single token in sentence {}'.format(query_word, toks))
        emb = emb / sum(emb**2)**0.5

        if filter_same_word:
            initial_k = max(k, 100)
        else:
            initial_k = k
        di, idx = self.indexer.query(emb.reshape(1, -1), k=initial_k)
        distances = []
        neighbors = []
        contexts = []
        for i, index in enumerate(idx.ravel()):
            token = self.all_tokens[index]
            if filter_same_word and (query_word in token or token in query_word):
                continue
            distances.append(di.ravel()[i])
            neighbors.append(token)
            contexts.append(self.sentences[self.sentence_ids[index]])
            if len(distances) == k:
                break
        return distances, neighbors, contexts

Now let's use this indexer

In [None]:
storage = ContextNeighborStorage(sentences=all_sentences, model=bert)
storage.process_sentences()

  0%|          | 0/4289 [00:00<?, ?it/s]

Creating the index would require some time

In [None]:
storage.build_search_index()

In [None]:
banned_sentences = dict((word,[]) for word in banned_words)

# Find all banned sentences in posts
for idx, entry in gab_posts.iterrows():
  body = entry['body']
  for word in banned_words:
    # Realize this needs to be exact match, it gets results without though
    if word in body:
      banned_sentences[word].append(body)

In [None]:
# Find all nearest neighbors for the banned sentences, up to 100 
final_count = {}
for banned_word, sentences in banned_sentences.items():
  ctr_passed_sentences = 0
  ctr_total_checked = 0
  final_count[banned_word] = {}
  for sentence in sentences:
    if ctr_passed_sentences > 100 or ctr_total_checked > 1000:
      break
    ctr_total_checked = ctr_total_checked + 1
    try:
      # The model only tokenizes a certain character count, meaning there will be some false hits when trying to find even the word itself.
      distances, neighbors, contexts = storage.query(query_sent=sentence, query_word=banned_word, k=5, filter_same_word=True)
      # print('BANNED WORD: {} \n ORIGINAL SENTENCE: {} '.format(banned_word, sentence))
      for w in neighbors:
          if w in final_count[banned_word]:
            final_count[banned_word][w] = final_count[banned_word][w] + 1
          else:
            final_count[banned_word][w] = 1
      # for d, w, c in zip(distances, neighbors, contexts):
      ctr_passed_sentences = ctr_passed_sentences + 1
          # print('{} {}  {}'.format(w, d, c.strip()))
    except Exception as e:
      continue

Those are #Muzhood Agent Obama's Faggots he forced in as Generals to Awol pussies,  Where 6 Marines were murdered looking for the fag. 
The query word pussie is not a single token in sentence ['those', 'are', '#', 'muzhood', 'agent', 'obama', "'", 's', 'faggots', 'he', 'forced', 'in', 'as', 'generals', 'to', 'awol', 'pussi']
These women were all liars, in my view just like the Duke Lacross, Twanna brawley lie, the liars against Trump that were paid $750,000,the Judge Roy Moore accusers that are about to go to trial. All of them. It is weaponized old pussies. Time to fight!!
The query word pussie is not a single token in sentence ['these', 'women', 'were', 'all', 'liars', ',', 'in', 'my', 'view', 'just', 'like', 'the', 'duke', 'lacross', ',', 'twanna', 'brawley', 'lie']
Those are #Muzhood Agent Obama's Faggots he forced in as Generals to Awol pussies,  Where 6 Marines were murdered looking for the fag. 
The query word pussies is not a single token in sentence ['those', 'are', '#', 'muzh

In [None]:
finalized_top_10 = {}

for key, occur_dict in final_count.items():
  if len(occur_dict.keys()) > 0:
    # Sort occurances
    sorted_dict = dict(sorted(occur_dict.items(), key=lambda item: item[1], reverse=True))
    # Take top 10
    finalized_top_10[key] = {k: sorted_dict[k] for k in list(sorted_dict)[:10]}
  else:
    finalized_top_10[key] = {}

In [None]:
import json
with open('top_10_100000.json', 'w', encoding='utf-8') as f:
    json.dump(finalized_top_10, f, ensure_ascii=False, indent=4)