In [2]:
from multiprocessing import cpu_count
from gensim.models import Word2Vec, callbacks

def train_embeddings(df, sg=0, negative=5, window=8, min_count=5, size=300, epochs=10):
    sentences = df["body_clean"].to_list()
    
    class callback(callbacks.CallbackAny2Vec):
        def __init__(self):
            self.epoch = 0

        def on_epoch_end(self, model):
            loss = model.get_latest_training_loss()
            if self.epoch == 0:
                print('Loss after epoch {}: {}'.format(self.epoch, loss))
            else:
                print('Loss after epoch {}: {}'.format(self.epoch, loss - self.loss_previous_step))
            self.epoch += 1
            self.loss_previous_step = loss

    model = Word2Vec(sg=sg, negative=negative, window=window, min_count=min_count, vector_size=size, workers=cpu_count())
    model.build_vocab(corpus_iterable=sentences)
    model.train(corpus_iterable=sentences, compute_loss=True, callbacks=[callback()], epochs=epochs, total_examples=model.corpus_count)
    return model

In [4]:
from loader import load_dataset

def evaluate(ds, prefix="embeddings/"):
    model = Word2Vec.load(prefix + ds + ".bin")
    print(model.wv.most_similar("man", negative="woman", topn=5))

dataset_names = ["incels", "braincels", "trufemcels", "mensrights", "incels_full","feminism_full", "feminism_2015_2017", "feminism_2017_2019", "feminism_2019_2021", "feminism_2021_2023"]
for name in dataset_names:
    ds = load_dataset(name)
    model = train_embeddings(ds)
    model.save("embeddings/" + name + ".bin")
    evaluate(name)

Loss after epoch 0: 5268405.0
Loss after epoch 1: 4455162.0
Loss after epoch 2: 3899823.0
Loss after epoch 3: 3833034.0
Loss after epoch 4: 3318686.0
Loss after epoch 5: 3305714.0
Loss after epoch 6: 3299068.0
Loss after epoch 7: 3150756.0
Loss after epoch 8: 3141032.0
Loss after epoch 9: 1716280.0
[('bro', 0.4186728000640869), ('dude', 0.3685024082660675), ('yeahhh', 0.32525625824928284), ('lad', 0.31590956449508667), ('mane', 0.30899643898010254)]
Loss after epoch 0: 10844291.0
Loss after epoch 1: 8644629.0
Loss after epoch 2: 7824318.0
Loss after epoch 3: 7221154.0
Loss after epoch 4: 4996408.0
Loss after epoch 5: 4862000.0
Loss after epoch 6: 4676888.0
Loss after epoch 7: 4509496.0
Loss after epoch 8: 4181508.0
Loss after epoch 9: 3782944.0
[('homie', 0.39327824115753174), ('fella', 0.3634299635887146), ('dude', 0.36313050985336304), ('bruh', 0.35077112913131714), ('hoge', 0.34670573472976685)]
Loss after epoch 0: 861876.5625
Loss after epoch 1: 794704.0625
Loss after epoch 2: 7616