In [1]:
import datasets
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.corpus import stopwords, wordnet, words
from nltk.stem import WordNetLemmatizer
import nltk

lemmatizer = WordNetLemmatizer()
nltk.download('stopwords')
nltk.download('wordnet')
en_stopwords = set(stopwords.words('english'))

data = datasets.load_dataset("wikipedia", "20220301.en")
# data = datasets.load_dataset("bookcorpus/bookcorpus")

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/Yourui/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /Users/Yourui/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
import string

def tokenize(sentence):
    tokenized = word_tokenize(sentence.translate(str.maketrans('', '', string.punctuation)))
    return [lemmatizer.lemmatize(token) for token in tokenized if token not in en_stopwords and wordnet.synsets(token)]

In [3]:
tokenize(sent_tokenize(next(iter(data['train']))['text'].lower())[0])

['anarchism',
 'political',
 'philosophy',
 'movement',
 'sceptical',
 'authority',
 'reject',
 'involuntary',
 'coercive',
 'form',
 'hierarchy']

In [4]:
import contextlib
import numpy as np

def generate_frequencies(word, n_occurrences=10000, deltas=None):
    if deltas is None:
        deltas = [-4, -3, -2, -1, 1, 2, 3, 4]
    word = lemmatizer.lemmatize(word)

    frequencies = {}
    occurrences = 0

    for i, row in enumerate(data['train']):
        sentences = sent_tokenize(row['text'].lower())
        
        for sentence in sentences:
            if word in sentence:
                tokenized = tokenize(sentence)
                indices = [i for i, x in enumerate(tokenized) if x == word]
                for index in indices:
                    for delta in deltas:
                        with contextlib.suppress(IndexError):
                            try:
                                frequencies[tokenized[index + delta]] += 1
                            except KeyError:
                                frequencies[tokenized[index + delta]] = 1

                    occurrences += 1
                    if occurrences >= n_occurrences:
                        return frequencies

        if i % 1000 == 0:
            print(f'"{word}", {i}th row processed, {occurrences}/{n_occurrences} occurrences')

In [5]:
# keep only the most frequently occurring words next to the target word (ex: 1000+ occurrences) + aggregate those words + collect data
    # Table 1: Most frequent words
    # Table 2: Most frequent words with count vectors
# test discarding bits with high variance values

In [6]:
import json

def store_encoding(word, fname, args):
    frequencies = generate_frequencies(word, **args)
    
    with open(fname, 'r') as f:
        encodings = json.load(f)
    encodings[word] = frequencies
    with open(fname, 'w') as f:
        json.dump(encodings, f, indent=4)

In [7]:
l = ["man","woman","king","queen"]

for value in l:
    store_encoding(value, 'data/wikipedia_20000_frequencies.json', {'n_occurrences':20000, 'deltas': [-4, -3, -2, -1, 1, 2, 3, 4]})

"man", 0th row processed, 0/20000 occurrences
"man", 1000th row processed, 877/20000 occurrences
"man", 2000th row processed, 1544/20000 occurrences
"man", 3000th row processed, 2304/20000 occurrences
"man", 4000th row processed, 3101/20000 occurrences
"man", 5000th row processed, 3978/20000 occurrences
"man", 6000th row processed, 4911/20000 occurrences
"man", 7000th row processed, 5798/20000 occurrences
"man", 8000th row processed, 7266/20000 occurrences
"man", 9000th row processed, 8036/20000 occurrences
"man", 10000th row processed, 8912/20000 occurrences
"man", 11000th row processed, 9575/20000 occurrences
"man", 12000th row processed, 10353/20000 occurrences
"man", 13000th row processed, 11170/20000 occurrences
"man", 14000th row processed, 11840/20000 occurrences
"man", 15000th row processed, 12723/20000 occurrences
"man", 16000th row processed, 13811/20000 occurrences
"man", 17000th row processed, 14777/20000 occurrences
"man", 18000th row processed, 14926/20000 occurrences
"ma