## Minimal Example using SGNS from gensim

Here we create embedding snapshots that can be used in the web app by training a skip-gram with negative sampling (SGNS) word2vec model from the `gensim` library as described in the paper "Temporal Analysis of Language through Neural Language Models" by [Kim et. al (2014)](https://arxiv.org/pdf/1405.3515.pdf) (with the only difference that we first train the model on the whole corpus for serveral epochs before training on the individual time periods since the corpus is very very small).

In [None]:
import os
import random
import pickle
import logging
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
from gensim.models import KeyedVectors, Word2Vec

from evolvemb import PretrainedEmbeddings
from evolvemb.diachronic_utils import most_changed_tokens, analyze_emb_over_time

%load_ext autoreload
%autoreload 2

In [None]:
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)

def load_nyt(start_date="2019-01-01", end_date="2020-12-31"):
    # read in NYT dataset
    sentences = []
    dates = []
    with open("data/nytimes_dataset.txt") as f:
        for line in f:
            d, s = line.strip().split("\t")
            if d < start_date:
                continue
            elif d > end_date:
                break
            dates.append(d)
            # lowercase! and some longer words mistakenly can end with "." due to the tokenizer; remove this!
            sentences.append([w if len(w) <= 3 or not w.endswith(".") else w[:-1] for w in s.lower().split()])
    print(f"Dataset contains {len(sentences)} sentences between {start_date} and {end_date}")
    return sentences, dates


def get_sgns_emb_snapshots(snapshots, start_date="2019-01-01", min_freq=100, saveemb=False):
    savepath = f"data/snapshot_emb_sgns_{start_date}_{snapshots[-1]}_{min_freq}.pkl"
    # see if we can just load the embeddings
    if os.path.exists(savepath):
        try:
            snapshot_emb = pickle.load(open(savepath, "rb"))
            return snapshot_emb
        except Exception as e:
            print("could not load embeddings:", e)
    # learn embeddings instead
    snapshot_emb = {}
    # load full dataset
    sentences, _ = load_nyt(start_date, snapshots[-1])
    random.seed(10)
    random.shuffle(sentences)
    # train word2vec on whole dataset first (since its really really small)
    genw2v = Word2Vec(sentences, size=50, window=5, negative=13, sg=1, iter=100, min_count=min_freq, workers=4)
    # keep training on individual time periods
    for end_date in snapshots:
        # get current batch of sentences
        sentences, dates = load_nyt(start_date, end_date)
        # train embeddings on these
        genw2v.train(sentences, total_examples=len(sentences), epochs=50)
        # save embeddings as simple pretrained embeddings
        snapshot_emb[dates[-1]] = PretrainedEmbeddings(genw2v.wv).as_simple_pretrained()
        start_date = end_date  # only works as expected if end_date > dates[-1]
    # reduce file size by ensuring dtype of numpy arrays is float32
    for s in snapshot_emb:
        snapshot_emb[s].embeddings = np.array(snapshot_emb[s].embeddings, dtype=np.float32)
    # possibly save embeddings
    if saveemb:
        try:
            pickle.dump(snapshot_emb, open(savepath, "wb"), -1)
            print(f"successfully saved embeddings at {savepath}")
        except Exception as e:
            print("error saving embeddings:", e)
    return snapshot_emb

In [None]:
# desired snapshot dates: pre- and post-corona outbreak in detail
# snapshot dates with impossible dates so that setting start_date = end_date in 
# get_sgns_emb_snapshots yields expected results
snapshots = [f"2019-{i:02}-32" for i in range(6, 13)] + [f"2020-{i:02}-32" for i in range(1, 13)]
# compute embedding snapshots with SGNS
snapshot_emb = get_sgns_emb_snapshots(snapshots, start_date="2019-04-01", min_freq=50, saveemb=True)
# save embeddings to use with app.py
# pickle.dump(snapshot_emb, open("snapshot_emb.pkl", "wb"), -1)

In [None]:
# see which words have changed the most at some point in the time period
most_changed = most_changed_tokens(snapshot_emb)
print("most changed tokens")
print("\n".join([f"{x[0]:15} ({x[1]:.4f})" for x in most_changed[:25]]))

In [None]:
# create interactive plots for word "category"
fig_time, fig_pca = analyze_emb_over_time(snapshot_emb, "category", savefigs="sgns")
fig_time.show()
fig_pca.show()