# 2024 COMP90042 Project
*Make sure you change the file name with your group id.*

# Readme
*If there is something to be noted for the marker, please mention here.*

*If you are planning to implement a program with Object Oriented Programming style, please put those the bottom of this ipynb file*

# 1.DataSet Processing
(You can add as many code blocks and text blocks as you need. However, YOU SHOULD NOT MODIFY the section title)

In [22]:
import pickle
with open("evidence_v1.pkl", "rb") as f:
    evidence = pickle.load(f)

In [18]:
import json
import re
import pandas as pd
import numpy as np

import nltk
nltk.download('punkt')

# CHANGE THIS TO FALSE ON SUBMISSON
# OR EXPERIMENTING OUTSIDE OF COLAB
LOCAL = False

dir = ""
if LOCAL:
  # In local testing, load data from drive
  from google.colab import drive
  drive.mount('/content/drive')

  dir = "drive/MyDrive/Colab Notebooks/Colab Files/"

# File paths
F_TRAIN = dir + "train-claims.json"
F_DEV = dir + "dev-claims.json"
F_UNLABELLED = dir + "test-claims-unlabelled.json"
F_EVIDENCE = dir + "evidence.json"
F_BASELINE = dir + "dev-claims-baseline.json"

# Loading relevant data
with open(F_TRAIN) as f_train, \
     open(F_DEV) as f_dev, \
     open(F_UNLABELLED) as f_unlabelled, \
     open(F_EVIDENCE)   as f_evidence, \
     open(F_BASELINE) as f_baseline:

  d_train = json.load(f_train)
  d_dev = json.load(f_dev)
  d_unlabelled = json.load(f_unlabelled)
  d_evidence = json.load(f_evidence)
  d_baseline = json.load(f_baseline)

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\mrpea\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [19]:
from nltk.stem.wordnet import WordNetLemmatizer
from nltk.tokenize import word_tokenize
from gensim.utils import deaccent
import spacy
nlp = spacy.load('en_core_web_sm')

nltk.download('wordnet')

from unidecode import unidecode


lemmatizer = WordNetLemmatizer()


def sentence_preprocessing(sentence):

    # Use gensim deaccent to match more characters to [a-z]
    sentence = nlp(sentence.lower())
    
    # Keep any token with 
    return [lemmatizer.lemmatize(word) for word in word_tokenize(sentence) if any(c.isalnum() for c in word)]


def evidence_preprocessing(evidences):
  processed = []
  for id, evidence in evidences.items():
    row = []
    row.append(id)
    row.append(evidence)
    row.append(sentence_preprocessing(evidence))

    # Appending an empty list to populate with embeddings later
    row.append([])

    processed.append(row)

  return pd.DataFrame(processed, columns = ["id", "raw evidence", "processed evidence", "embeddings"])


evidence = evidence_preprocessing(d_evidence)
evidence.head()


[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\mrpea\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


Unnamed: 0,id,raw evidence,processed evidence,embeddings
0,evidence-0,"John Bennet Lawes, English entrepreneur and ag...","[john, bennet, lawes, english, entrepreneur, a...",[]
1,evidence-1,Lindberg began his professional career at the ...,"[lindberg, began, his, professional, career, a...",[]
2,evidence-2,``Boston (Ladies of Cambridge)'' by Vampire We...,"[boston, lady, of, cambridge, by, vampire, wee...",[]
3,evidence-3,"Gerald Francis Goyer (born October 20, 1936) w...","[gerald, francis, goyer, born, october, 20, 19...",[]
4,evidence-4,He detected abnormalities of oxytocinergic fun...,"[he, detected, abnormality, of, oxytocinergic,...",[]


In [29]:
def claim_preprocessing(claims):
  processed = []
  for id, inner in claims.items():
    row = []
    row.append(id)
    row.append(inner.get("claim_text"))
    row.append(sentence_preprocessing(inner.get("claim_text")))

    # No label or evidence for unlabelled set
    row.append(inner.get("claim_label", None))
    row.append(inner.get("evidences", None))

    processed.append(row)

  return pd.DataFrame(processed, columns = ["id", "claim_text", "processed text", "claim_label", "evidences"])



train = claim_preprocessing(d_train)
dev = claim_preprocessing(d_dev)
unlabelled = claim_preprocessing(d_unlabelled)


dev.head()

Unnamed: 0,id,claim_text,processed text,claim_label,evidences
0,claim-752,[South Australia] has the most expensive elect...,"[south, australia, ha, the, most, expensive, e...",SUPPORTS,"[evidence-67732, evidence-572512]"
1,claim-375,when 3 per cent of total annual global emissio...,"[when, 3, per, cent, of, total, annual, global...",NOT_ENOUGH_INFO,"[evidence-996421, evidence-1080858, evidence-2..."
2,claim-1266,This means that the world is now 1C warmer tha...,"[this, mean, that, the, world, is, now, 1c, wa...",SUPPORTS,"[evidence-889933, evidence-694262]"
3,claim-871,"“As it happens, Zika may also be a good model ...","[a, it, happens, zika, may, also, be, a, good,...",NOT_ENOUGH_INFO,"[evidence-422399, evidence-702226, evidence-28..."
4,claim-2164,Greenland has only lost a tiny fraction of its...,"[greenland, ha, only, lost, a, tiny, fraction,...",REFUTES,"[evidence-52981, evidence-264761, evidence-947..."


In [27]:
from gensim.models import Word2Vec
EMBEDDING_DIM = 100

embedding_model = Word2Vec(sentences=evidence["processed evidence"],
                           vector_size=EMBEDDING_DIM,
                           window=4,
                           min_count=1,
                           workers=10,
                           negative=5
                           )

In [28]:
def sentence_embedding(sentence):

  # Failsafe
  if len(sentence) == 0:
    return np.zeros(EMBEDDING_DIM)

  embedding = np.zeros(EMBEDDING_DIM)
  for word in sentence:
    word_embedding = np.zeros(EMBEDDING_DIM)

    # get word vector for given word
    # if not found, ignore (treat as having the zero vector)
    try:
      word_embedding = embedding_model.wv[str(word)]
    except KeyError:
      pass

    embedding += word_embedding

  return embedding / len(sentence)


# Populate the empty column with embeddings
evidence["embeddings"] = evidence["processed evidence"].apply(sentence_embedding)
evidence.head()

Unnamed: 0,id,raw evidence,processed evidence,embeddings
0,evidence-0,"John Bennet Lawes, English entrepreneur and ag...","[john, bennet, lawes, english, entrepreneur, a...","[1.8147023021010682, 0.10596465552225709, -0.3..."
1,evidence-1,Lindberg began his professional career at the ...,"[lindberg, began, his, professional, career, a...","[2.707790378895071, -0.2509865226844947, -1.38..."
2,evidence-2,``Boston (Ladies of Cambridge)'' by Vampire We...,"[boston, lady, of, cambridge, by, vampire, wee...","[2.6244308948516846, -0.3796046035630362, -0.6..."
3,evidence-3,"Gerald Francis Goyer (born October 20, 1936) w...","[gerald, francis, goyer, born, october, 20, 19...","[2.1674460284835235, -0.1364949027245695, -0.9..."
4,evidence-4,He detected abnormalities of oxytocinergic fun...,"[he, detected, abnormality, of, oxytocinergic,...","[0.7273209213333971, 0.22991515433087067, -0.7..."


In [31]:
import pickle
version = 2

print("Saving...")
with open(f"embeddings_v{version}.pkl", "wb") as f:
    pickle.dump(embedding_model, f)

with open(f"evidence_v{version}.pkl", "wb") as f:
    pickle.dump(evidence, f)
    
print("Embeddings and processed evidence saved.")

with open(f"train_v{version}.pkl","wb") as f_train,\
     open(f"dev_v{version}.pkl", "wb") as f_dev,\
     open(f"unlabelled_v{version}.pkl","wb") as f_unlabelled:
    pickle.dump(train, f_train)
    pickle.dump(dev, f_dev)
    pickle.dump(unlabelled, f_unlabelled)

print("Train, dev, test sets saved.")


Saving...
Embeddings and processed evidence saved.
Train, dev, test sets saved.


# 2. Model Implementation
(You can add as many code blocks and text blocks as you need. However, YOU SHOULD NOT MODIFY the section title)

In [44]:
# Baseline retrieval: immediately use the raw embeddings to retrieve closest sentences
# Train a cutoff distance threshold.

from scipy.spatial.distance import cosine

# Similarity based on cosine similarity ([0-1], higher the more similar)
def similarity(text, evidence_ids):

    # Seems stupid and retrieving everything from w2v is probably cleaner
    # TODO: make this better
    evidence_embeddings = [evidence.loc[evidence['id'] == id, 'embeddings'].values[0] for id in evidence_ids]
    key_embedding = sentence_embedding(text)
    
    similarities = []
    for evidence_embedding in evidence_embeddings:
        similarities.append(1-cosine(key_embedding, evidence_embedding))

    return similarities


# Using 1 - fscore as the loss
def retrieval_loss(prediction, target):
    numerator = 0
    denominator = 0
    
    for p in prediction:
        if p in target:
            denominator += 2
            numerator += 2
        else:
            denominator += 1
    
    for t in target:
        if t not in prediction:
            denominator += 1
    
    return 1 - numerator/denominator



print(similarity(dev.loc[0, "processed text"], dev.loc[0, "evidences"]))

print(similarity(dev.loc[0,"processed text"], evidence.loc[[1,23,15,35,4444,10000,48,223,4499], "id"]))

[0.9588124658068482, 0.9708111020526521]
[0.8053056696120621, 0.8792417071129837, 0.8131490764220766, 0.8218331009062386, 0.8818242052551241, 0.6063855121980968, 0.8255621188432163, 0.7519479433293833, 0.6837166892458209]


In [47]:
tune = dev.copy(deep = True)

# https://stackoverflow.com/a/39259437
# Reference: apply on df on more than one column
tune['sim'] = tune[['processed text','evidences']].apply(lambda x: similarity(*x), axis=1)
tune.head()

Unnamed: 0,id,claim_text,processed text,claim_label,evidences,sim
0,claim-752,[South Australia] has the most expensive elect...,"[south, australia, ha, the, most, expensive, e...",SUPPORTS,"[evidence-67732, evidence-572512]","[0.9588124658068482, 0.9708111020526521]"
1,claim-375,when 3 per cent of total annual global emissio...,"[when, 3, per, cent, of, total, annual, global...",NOT_ENOUGH_INFO,"[evidence-996421, evidence-1080858, evidence-2...","[0.9483460908942689, 0.9404415279228605, 0.955..."
2,claim-1266,This means that the world is now 1C warmer tha...,"[this, mean, that, the, world, is, now, 1c, wa...",SUPPORTS,"[evidence-889933, evidence-694262]","[0.8901516367091654, 0.9028373134170873]"
3,claim-871,"“As it happens, Zika may also be a good model ...","[a, it, happens, zika, may, also, be, a, good,...",NOT_ENOUGH_INFO,"[evidence-422399, evidence-702226, evidence-28...","[0.8934156082157716, 0.8796556127921996, 0.925..."
4,claim-2164,Greenland has only lost a tiny fraction of its...,"[greenland, ha, only, lost, a, tiny, fraction,...",REFUTES,"[evidence-52981, evidence-264761, evidence-947...","[0.9336424264743003, 0.9096821736935496, 0.724..."


In [61]:
total = 0
count = 0
low_acc_rows = []
for i, cell in enumerate(tune['sim']):
    total += sum(cell)
    count += len(cell)

    if sum(cell)/len(cell) < 0.8:
        low_acc_rows.append(i)

print(low_acc_rows)

print(f"Average gold label w2v similarity: {total/count}")

tune.loc[low_acc_rows,:]

[5, 32, 34, 37, 38, 93, 97, 109, 120, 131, 141]
Average gold label w2v similarity: 0.8938452608162429


Unnamed: 0,id,claim_text,processed text,claim_label,evidences,sim
5,claim-1607,CO2 limits won't cool the planet.,"[co2, limit, wo, n't, cool, the, planet]",NOT_ENOUGH_INFO,"[evidence-913997, evidence-955328, evidence-40...","[0.812856634036198, 0.7118861282568756, 0.7928..."
32,claim-2168,IPCC graph showing accelerating trends is misl...,"[ipcc, graph, showing, accelerating, trend, is...",SUPPORTS,[evidence-41418],[0.5853555463893919]
34,claim-2426,"""Twentieth century global warming did not star...","[twentieth, century, global, warming, did, not...",REFUTES,[evidence-697238],[0.7870948671978204]
37,claim-2593,[T]he study indicates “Greenland’s ice may be ...,"[t, he, study, indicates, greenland, s, ice, m...",REFUTES,"[evidence-857269, evidence-797161]","[0.9022972061700051, 0.5573798282483118]"
38,claim-1567,IPCC overestimate temperature rise.,"[ipcc, overestimate, temperature, rise]",REFUTES,[evidence-105184],[0.5259836084705118]
93,claim-1933,"Newt Gingrich ""teamed with Nancy Pelosi and Al...","[newt, gingrich, teamed, with, nancy, pelosi, ...",NOT_ENOUGH_INFO,"[evidence-380361, evidence-977987, evidence-83...","[0.8139702982113997, 0.7765324931500546, 0.753..."
97,claim-1734,Jim Hansen had several possible scenarios; his...,"[jim, hansen, had, several, possible, scenario...",NOT_ENOUGH_INFO,"[evidence-1005240, evidence-485896, evidence-1...","[0.8079354959973795, 0.7865658690722043, 0.713..."
109,claim-392,a study that totally debunks the whole concept...,"[a, study, that, totally, debunks, the, whole,...",REFUTES,"[evidence-724010, evidence-846906, evidence-10...","[0.9040239554800842, 0.6142145576008137, 0.804..."
120,claim-1668,Thick arctic sea ice is in rapid retreat.,"[thick, arctic, sea, ice, is, in, rapid, retreat]",SUPPORTS,"[evidence-280204, evidence-1200544, evidence-6...","[0.8151305429609677, 0.7420533257334931, 0.767..."
131,claim-1928,"NASA Finds Antarctica is Gaining Ice,","[nasa, find, antarctica, is, gaining, ice]",DISPUTED,"[evidence-1099128, evidence-529248, evidence-5...","[0.7406952061582817, 0.8170975340261469, 0.719..."


In [101]:

# Look into gold labels with poor similarity values
for row in low_acc_rows:
    print(f"Row {row}, Verdict: {tune.loc[row, 'claim_label']}:")
    print("CLAIM: "+ " ".join(tune.loc[row, 'processed text']))

    for i, evidence_id in enumerate(tune.loc[row, 'evidences']):
        print("EVIDE: " + f"{tune.loc[row, 'sim'][i]:.4f} " + " ".join(evidence.loc[evidence['id'] == evidence_id, 'processed evidence'].values[0]))
    print('-'*70)

Row 5, Verdict: NOT_ENOUGH_INFO:
CLAIM: co2 limit wo n't cool the planet
EVIDE: 0.8129 le energy reach the upper atmosphere which is therefore cooler because of this absorption
EVIDE: 0.7119 occupational co 2 exposure limit have been set in the united state at 0.5 5000 ppm for an eight-hour period
EVIDE: 0.7929 a the temperature rise closer to the value the white daisy like the white daisy outreproduce the black daisy leading to a larger percentage of white surface and more sunlight is reflected reducing the heat input and eventually cooling the planet
EVIDE: 0.8087 global warming is the long-term rise in the average temperature of the earth 's climate system
EVIDE: 0.8617 if cloud cover increase more sunlight will be reflected back into space cooling the planet
----------------------------------------------------------------------
Row 32, Verdict: SUPPORTS:
CLAIM: ipcc graph showing accelerating trend is misleading
EVIDE: 0.5854 the ipcc need to look at this trend in the error and ask

In [None]:
import torch
import torch.nn as nn


# 3.Testing and Evaluation
(You can add as many code blocks and text blocks as you need. However, YOU SHOULD NOT MODIFY the section title)

In [10]:
import spacy
nlp = spacy.load("en_core_web_sm")

In [12]:
stopwords = spacy.lang.en.stop_words.STOP_WORDS
print(stopwords)


{'former', 'move', 'latterly', 'otherwise', 'on', 'something', 'until', 'whereas', 'whereby', 'ourselves', 'should', 'we', 'except', 'it', 'even', 'same', 'doing', 'towards', 'than', 'due', 'become', 'whether', 'already', 'if', 'here', 'where', "'ll", 'seems', 'became', 'twenty', '’ve', 'front', 'give', 'behind', '’s', 'also', 'nowhere', 'never', 'our', 'really', 'whenever', 'from', 'seeming', 'elsewhere', 'hers', 'may', 'after', 'many', 'unless', 'nothing', 'often', 'n’t', 'five', 'somewhere', 'though', 'before', 'go', 'her', 'throughout', '’d', 'therein', 'see', 'used', 'itself', 'any', 'beside', 'one', 'alone', 'done', 'perhaps', 'at', 'thereby', 'upon', 'anyhow', 'might', 'there', 'everyone', 'side', 'must', 'all', 'his', 'me', 'was', 'yourself', '‘ll', 'third', 'i', 'too', 'how', 'indeed', 'nor', 'can', 'being', 'when', 'per', 'afterwards', 'say', 'since', 'beforehand', 'to', 'fifteen', 'four', 'beyond', 'back', 'enough', 'thru', 'did', 'most', 'meanwhile', 'several', 'their', 'th

In [13]:
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
print(stopwords.words('english'))

['i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', "you're", "you've", "you'll", "you'd", 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', "she's", 'her', 'hers', 'herself', 'it', "it's", 'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', "that'll", 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', '

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\mrpea\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [20]:
vocab = [k for k in embedding_model.wv.key_to_index.keys()]
nonstandard = [w for w in vocab if any(c not in "qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM0123456789 .'-–+,\\/[]()?!" for c in w)]

print(nonstandard)


NameError: name 'embedding_model' is not defined

In [45]:
from gensim.utils import deaccent

nonstandard2 = (deaccent(w) for w in nonstandard)
nonstandard2 = [w for w in nonstandard2 if any(c not in "qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM0123456789 .'-–+,\\/[]()?!" for c in w)]
print(nonstandard2)

['²', 'Łodz', '°C', '°F', 'اباد', 'Białystok', 'Wrocław', 'Stanisław', 'محمد', 'علي', 'Słupsk', 'α', 'عليا', 'Александр', 'Władysław', '8:00', 'Sokołka', 'Biała', 'بن', 'Aydın', 'قلعه', 'Møre', 'Александрович', '³', 'سفلي', 'Płock', '10:00', 'ε', 'င', 'عبد', 'β', 'Michał', 'چاه', '9:00', 'μm', 'Николаевич', 'Sør-Trøndelag', 'اباد', 'θ', 'Tromsø', 'Владимирович', 'حسن', 'Сергеи', 'Adıyaman', 'Suwałki', 'Ostrołeka', 'Paweł', 'မ', 'Chełm', 'ರ', 'Łomza', ':20', 'π', '1.5°C', 'Иванович', 'န', 'Владимир', 'Bełchatow', 'ð', 'km²', '8:30', '6:00', 'Jørgen', 'က', 'Włocławek', 'تلمبه', 'مزرعه', 'µm', 'Nord-Trøndelag', 'Płonsk', 'μ', 'Encyclopædia', 'Bjørn', 'Николаи', 'حسين', 'Иван', '9:30', 'Đong', 'Сергеевич', 'خان', 'Sørensen', 'Евгении', 'Mała', '7:00', 'Łukow', 'Андреи', 'Søren', 'ζ', 'Mława', 'Groß', 'δ', 'Bolesław', 'ಗ', '11:00', 'احمد', 'pɾi', 'Bartın', 'Szamotuły', 'də', 'Małe', 'Østfold', '4:00', 'Tønsberg', 'Bærum', 'Łeczyca', 'ಮ', 'حاجي', 'κ', 'Sokołow', 'ðe', 'Jørgensen', 'λ', 'ده',

In [56]:
import unidecode as ud
nonstandard3 = [ud.unidecode(w) for w in nonstandard]
nonstandard3 = [w for w in nonstandard3 if any(c not in "qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM0123456789 .'-–+,\\/[]`()?!@#$%^&*=:_" for c in w)]
print(nonstandard3)

['mSTf~', '`l~', 'yHy~', 'colspan=9|', 'mws~', '~20', '`ys~', '~70', '~10', '|=', '~2000', '~1', 'l~', '|=Hoan', 'dyl~', '~2.6', '~25', '~50', '||Karas', 'lHsyn~', '~0.3', 'trsh~', '~1000', 'Gz~', '~Akizakura~', '~30', '~80-90', 'qz~', 'fwz~', 'Zuo "Mu ', '~.367', '~.385', '~.344', '~4', 'Glacier|Helheim', 'qD~', '~340', '~77', '~23', '~240', '~2', 'mrs~', '|Khowesin', '||xui-doa', 'yHy~', '|intuitive', '~5mV', 'Microeconomics|', '~0.1', '"money', 'bank"', '|archer', 'Show|url=http', '//www.hendrickmotorsports.com/news/articles/64840/elliott-and-earnhardt-featured-on-dude-perfect-premiere|title=Elliott', 'premiere|publisher=Hendrick', 'Motorsports|date=April', 'mjtb~', '~11', '~Legal', '|Xam', '|=Khomani', '~130,000', '~2,700', '~5,000', '<u> ', '<o> ', "koro'ne|", '~2100', '|rogerian', 'supply|water', 'style=|', 'lwsT~', '~458.4', 'dioxide|CO', "'qT~", '~500', "n~'", '0~ZERO~', 'l`z~', '||Garoeb', '|Elected', 'lzw~', '|I|', 'bkhr~', 'fS~H', 'wjd~', '//lsec.cc.ac.cn/~lyuan/code.html', 

## Object Oriented Programming codes here

*You can use multiple code snippets. Just add more if needed*

In [24]:
import time
from multiprocessing.pool import ThreadPool

N_THREADS = 1

# Surprisingly couldn't find an implementation on torch/tf/keras...
# Byte Pair Encoding tokenizer to feed into BERT-like model below
class BPE:
    def __init__(self, corpus, vocab_size, min_count=1):
        self.corpus = corpus
        self.vocab_size = vocab_size
        self.min_count = min_count
        
        self.vocab = []
        
        self.word_freq = {}
        self.word_freq_partitions = [{} for _ in range(N_THREADS)]
        self.splits = {}  # e.g. highest: [high, est</w>]
        self.merges = {}  # e.g. [high, est</w>]: highest


    def train(self):

        t_start = time.time()
        for text in self.corpus:
            i = 0
            for word in text:
                self.word_freq[word] = self.word_freq.get(word, 0) + 1
                self.word_freq_partitions[i % N_THREADS][word] = self.word_freq_partitions[i % N_THREADS].get(word, 0)+1

        # initialize the self.splits
        for word in self.word_freq:
            self.splits[word] = list(word) + ["</w>"]


        alphabet = set(("</w>",))
        for word in self.word_freq:
           for letter in word:
               if word not in alphabet:
                   alphabet.add(letter)

        self.vocab = list(alphabet)
        self.vocab.sort()

        print(f"Time {time.time()-t_start:.2f} - Alphabet initialized with size {len(alphabet)}.")
        print(f"Alphabet: {self.vocab}")

        # Sanity keeping every 100 iterations - this takes HOURS
        # TODO: any chance of multithread part of this somehow?
        iter = 0
        while len(self.vocab) < self.vocab_size:

            t_s = time.time()
            # ~4s singlethread
            pair_freq = self.get_pair_freq()

            print(f"pair_freq: {time.time()-t_s:.2f}s")
            

            if len(pair_freq) == 0 or max(pair_freq.values()) < self.min_count:
                print(f"Time {time.time()-t_start:.2f} - No more pairs. Exiting at vocab size {len(self.vocab)}")

            
            pair = max(pair_freq, key=pair_freq.get)

            t_s = time.time()
            self.update_splits(pair[0], pair[1])
            # ~3s singlethread
            print(f"update_splits: {time.time()-t_s:.2f}s")

            self.merges[pair] = pair[0] + pair[1]

            self.vocab.append(pair[0] + pair[1])

            iter += 1
            if iter % 100 == 0:
                print(f"Time {time.time()-t_start:.2f} - Now on iteration {iter}.")

    def update_splits(self, left, right):
            for word, word_split in self.splits.items():
                new_split = []
                ptr = 0
                while ptr < len(word_split):
                    if (
                        word_split[ptr] == left
                        and ptr + 1 < len(word_split)
                        and word_split[ptr + 1] == right
                    ):
                        new_split.append(left + right)
                        ptr += 2
                    else:
                        new_split.append(word_split[ptr])
                        ptr += 1
                self.splits[word] = new_split


    # Returns a pair frequency dictionary with entries:
    # pair (tuple) : freq (int)
    def get_pair_freq(self):
        pair_freq = {}

        pair_freq_partitions = []

        pool = ThreadPool(N_THREADS)
        
        def get_pair_freq_subtask(partition):

            #pair_freq_partition = {}

            for word_freq_pair in partition.items():
                word, freq = word_freq_pair
                split = self.splits[word]
                for i in range(len(split)-1):
                    #pair_freq_partition[(split[i], split[i + 1])] = pair_freq_partition.get((split[i], split[i + 1]), 0) + freq
                    pair_freq[(split[i], split[i + 1])] = pair_freq.get((split[i], split[i + 1]), 0) + freq
            return

        
        pool.map(get_pair_freq_subtask, self.word_freq_partitions)


        return pair_freq
    
    def tokenize(self, s):
            splits = [list(t) + ["</w>"] for t in s.split()]
    
            for left, right in self.merges:
                for index, split in enumerate(splits):
                    new_split = []
                    ptr = 0
                    while ptr < len(split):
                        if (
                            ptr + 1 < len(split)
                            and split[ptr] == left
                            and split[ptr + 1] == right
                        ):
                            new_split.append(left + right)
                            ptr += 2
                        else:
                            new_split.append(split[ptr])
                            ptr += 1
                    assert "".join(new_split) == "".join(split)
                    splits[index] = new_split
    
            return sum(splits, [])

bpe = BPE(evidence.loc[:,"processed evidence"], 10000, min_count=4)

In [25]:
bpe.train()

Time 35.95 - Alphabet initialized with size 50.
Alphabet: ["'", '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '</w>', '=', '\\', '^', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '|', '~']
pair_freq: 3.89s
update_splits: 2.65s
pair_freq: 3.97s
update_splits: 2.63s


KeyboardInterrupt: 

False

In [1]:
import time
from multiprocessing.pool import ThreadPool
import pickle

with open("evidence_v1.pkl","rb") as f:

    evidence = pickle.load(f)


# Barely speeds up the process if at all???
# screw GIL, definitely not I/O bound
N_THREADS = 1


# Surprisingly couldn't find an implementation on torch/tf/keras
# Byte Pair Encoding tokenizer to feed into BERT-like model below
class BPE:
    def __init__(self, corpus, vocab_size, min_count=1):
        self.corpus = corpus
        self.vocab_size = vocab_size
        self.min_count = min_count
        
        self.vocab = []
        
        self.word_freq = {}
        
        # For multiprocessing only
        self.word_freq_partitions = [{} for _ in range(N_THREADS)]
        
        # word to fragments
        self.word_partitions = {}
        
        # fragments to bigger fragments
        self.merges = {}

    def train(self):

        t_start = time.time()
        for text in self.corpus:
            i = 0
            for word in text:
                self.word_freq[word] = self.word_freq.get(word, 0) + 1
                self.word_freq_partitions[i % N_THREADS][word] = self.word_freq_partitions[i % N_THREADS].get(word, 0)+1
        
        alphabet = set(("</w>",))
        for word in self.word_freq:

            # initialize word partitions
            char_list = list(word)
            char_list.append("</w>")
            self.word_partitions[word] = char_list

            # construct alphabet
            for letter in word:
               if letter not in alphabet:
                   alphabet.add(letter)

        self.vocab = list(alphabet)
        self.vocab.sort()

        print(f"Time {time.time()-t_start:.2f} - Alphabet initialized with size {len(alphabet)}.")
        print(f"Alphabet: {self.vocab}")

        # Sanity keeping every 100 iterations - this takes HOURS
        # 05/11: Any chance of multithreading part of this somehow?
        # 05/13: Nope (no cuda on local machine :c)
        iter = 0
        while len(self.vocab) < self.vocab_size:

            t_s = time.time()
            # ~4s singlethread
            pair_freq = self.get_pair_freq()

            print(f"pair_freq: {time.time()-t_s:.2f}s")
            

            if len(pair_freq) == 0 or max(pair_freq.values()) < self.min_count:
                print(f"Time {time.time()-t_start:.2f} - No more pairs. Exiting at vocab size {len(self.vocab)}")

            
            best_pair = max(pair_freq, key=pair_freq.get)

            t_s = time.time()
            self.update_word_partitions(best_pair[0], best_pair[1])
            
            # ~3s singlethread
            print(f"update_word_partitions: {time.time()-t_s:.2f}s")

            # functionally instant
            self.merges[best_pair] = best_pair[0] + best_pair[1]
            self.vocab.append(best_pair[0] + best_pair[1])

            iter += 1
            if iter % 100 == 0:
                print(f"Time {time.time()-t_start:.2f} - Now on iteration {iter}.")


    # Returns a pair frequency dictionary with entries:
    # pair (tuple) : freq (int)
    def get_pair_freq(self):
        pair_freq = {}

        #pair_freq_partitions = []

        # Multithreading attempt
        # Doesn't work I think.. not enough compute time on colab either
        pool = ThreadPool(N_THREADS)
        def get_pair_freq_subtask(partition):

            #pair_freq_partition = {}

            for word_freq_pair in partition.items():
                word, freq = word_freq_pair
                word_partition = self.word_partitions[word]

                for i in range(len(word_partition)-1):
                    #pair_freq_partition[(word_partition[i], word_partition[i + 1])] = pair_freq_partition.get((word_partition[i], word_partition[i + 1]), 0) + freq
                    pair_freq[(word_partition[i], word_partition[i + 1])] = pair_freq.get((word_partition[i], word_partition[i + 1]), 0) + freq
            return

        # This returns None, results are already aggregated by the time all threads join.
        pool.map(get_pair_freq_subtask, self.word_freq_partitions)


        return pair_freq


    # Merge word partitions by the new pattern
    # Modify in place for performance (saves ~0.2s/iter)
    def update_word_partitions(self, left, right):
        for word in self.word_partitions.keys():

            partition = self.word_partitions[word]
            length = len(partition)

            i = 0
            while i < length:
                
                # Last token, do nothing
                if i+1 >= length:
                    pass
                
                # Otherwise check for matching pattern
                # Delete the next entry if matches
                elif partition[i] == left and partition[i+1] == right:
                    partition[i] = partition[i] + partition[i+1]
                    
                    del partition[i+1]
                    length -= 1 # Otherwise runs out of bounds

                    assert(len(self.word_partitions[word]) == length)
    
                i += 1
    
    def save(self, path):
        with open(path, "wb") as f:
            pickle.dump(bpe, f)

bpe = BPE(evidence.loc[:,"processed evidence"], 20000, min_count=4)

# y
bpe.train()


with open("BPETokenizer_v1", "wb") as f:
    pickle.dump(bpe, f)

Time 32.23 - Alphabet initialized with size 50.
Alphabet: ["'", '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '</w>', '=', '\\', '^', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '|', '~']
pair_freq: 2.32s
update_word_partitions: 1.11s
pair_freq: 2.07s
update_word_partitions: 1.12s
pair_freq: 1.82s
update_word_partitions: 0.84s
pair_freq: 2.18s
update_word_partitions: 1.13s
pair_freq: 2.10s
update_word_partitions: 1.10s
pair_freq: 2.21s
update_word_partitions: 1.14s
pair_freq: 2.51s
update_word_partitions: 1.12s
pair_freq: 2.25s
update_word_partitions: 1.02s
pair_freq: 2.34s
update_word_partitions: 0.90s
pair_freq: 2.00s
update_word_partitions: 0.91s
pair_freq: 2.01s
update_word_partitions: 1.24s
pair_freq: 2.71s
update_word_partitions: 1.09s
pair_freq: 2.19s
update_word_partitions: 1.07s
pair_freq: 2.97s
update_word_partitions: 1.13s
pair_freq: 2.04s
update_word_

KeyboardInterrupt: 