# Initialize

In [None]:
from pathlib import Path

import git

repo = git.Repo(Path(".").absolute(), search_parent_directories=True)
ROOT = Path(repo.working_tree_dir)
SRC = ROOT / "src"

In [None]:
cd $SRC

# Preamble

In [None]:
import matplotlib as mpl
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from tqdm.auto import trange

from analyze_bigram_encoders import (
    gen_neg_bigram_ixs,
    gen_pos_bigram_ixs,
    plot_bigram_norm,
    plot_result,
    plot_uniformity,
)
from misc import WV, BigramEncoder, load_wiki, process_word_vecs

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Word Vectors

In [None]:
FAST_TEXT = ROOT / "data/raw/crawl-300d-2M.vec"
word2index, word_vecs = process_word_vecs(FAST_TEXT)

# Note that the word embeddings are normalized.
wv = WV(F.normalize(word_vecs), word2index)
# wv = WV(word_vecs, word2index)

# Load Sentences

In [None]:
sentences = load_wiki(max_len=25)

ix_sents, sent_lengths = wv.to_ix_sents(
    sentences, filter_stopwords=False, return_sent_lengths=True, adjust=True
)
perm = torch.randperm(len(ix_sents))
ix_sents = ix_sents[perm]
sent_lengths = sent_lengths[perm]

# Figure 1

(a) $f_{\odot}(\mathbf{w}, \mathbf{w'})$

In [None]:
plot_result("mult", wv, ix_sents, 1000, average_comparison=True)

(b) $f_1(\mathbf{w}, \mathbf{w'})$

In [None]:
plot_result("tanh", wv, ix_sents, 1000, average_comparison=True, add_legend=False)

(c) $f_{10}(\mathbf{w}, \mathbf{w'})$

In [None]:
plot_result("tanh10", wv, ix_sents, 1000, average_comparison=True, add_legend=False)

(d) $f_{\infty}(\mathbf{w}, \mathbf{w'})$

In [None]:
plot_result("sign", wv, ix_sents, 1000, average_comparison=True, add_legend=False)

(e) $f_{T}(\mathbf{w}, \mathbf{w'})$

In [None]:
plot_result(
    "T",
    wv,
    ix_sents,
    100,
    average_comparison=True,
    model_path="../models/bigram_nn_wiki_train_1000000.pth",
    add_legend=False,
)

(f) The distribution of $\lVert f(\mathbf{w}, \mathbf{w'}) \rVert$ with $(w, w') \in B(S)$

In [None]:
plot_bigram_norm(
    wv=wv,
    ix_sents=ix_sents,
    batch_size=1000,
    outdir=ROOT / "paper/img",
    model_path="../models/bigram_nn_wiki_train_1000000.pth",
    seed=0,
    add_legend=False,
)

(f) $(w, w')$ are random word pairs

In [None]:
plot_uniformity(
    word_pair="random",
    wv=wv,
    ix_sents=ix_sents,
    batch_size=1000,
    outdir=ROOT / "paper/img",
    model_path="../models/bigram_nn_wiki_train_1000000.pth",
    seed=0,
    add_legend=True,
)

(g) $(w, w')$ are bigrams

In [None]:
plot_uniformity(
    word_pair="bigram",
    wv=wv,
    ix_sents=ix_sents,
    batch_size=1000,
    outdir=ROOT / "paper/img",
    model_path="../models/bigram_nn_wiki_train_1000000.pth",
    seed=0,
    add_legend=False,
)