# Initialize

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import git

repo = git.Repo(Path(".").absolute(), search_parent_directories=True)
ROOT = Path(repo.working_tree_dir)
SRC = ROOT / "src"

In [None]:
cd $SRC

# Preamble

In [None]:
import matplotlib as mpl
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from tqdm.auto import trange

from analyze_bigram_encoders import plot_result
from misc import WV, BigramEncoder, load_wiki, process_word_vecs

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Word Vectors

In [None]:
FAST_TEXT = ROOT / "data/raw/crawl-300d-2M.vec"
word2index, word_vecs = process_word_vecs(FAST_TEXT)

# Note that the word embeddings are normalized.
wv = WV(F.normalize(word_vecs), word2index)
# wv = WV(word_vecs, word2index)

# Load Sentences

In [None]:
sentences = load_wiki(max_len=25)

ix_sents, sent_lengths = wv.to_ix_sents(
    sentences, return_sent_lengths=True, adjust=True
)
perm = torch.randperm(len(ix_sents))
ix_sents = ix_sents[perm]
sent_lengths = sent_lengths[perm]

# Evaluate the Bigram Encoder 

## NN model With stop words

|lr|margin|train accurcay|test accuracy|
|-|-|-|-|
|0.056337|0.13387|0.92140|0.36|
|0.17235|0.10772|0.92560|0.34|
|0.042079|0.014977|0.91110|0.17|
|0.33424|0.99048|0.73610|0.20|

In [None]:
plot_result("T", wv, ix_sents, 100, add_legend=False)

## NN model without stop words

## Baselines

In [None]:
plot_result("mult", wv, ix_sents, 1000)

In [None]:
plot_result("tanh", wv, ix_sents, 1000, add_legend=False)

In [None]:
plot_result("tanh10", wv, ix_sents, 1000, add_legend=False)

In [None]:
plot_result("sign", wv, ix_sents, 1000, add_legend=False)

# Tests

## Test ``gen_pos_examples`` and ``gen_neg_examples``

Test whether all possible positive and negative examples are generated correctly and their counts are distributed evely.

In [None]:
from collections import Counter
from itertools import product

import numpy as np

In [None]:
def test_example_generators():
    n_rows = 1000000
    ix_sents = torch.tensor([[1, 2, 3, 0, 0], [4, 5, 6, 7, 8]]).repeat(n_rows // 2, 1)
    sent_lengths = ix_sents.sign().sum(dim=1)

    pos_bigram_ixs = gen_pos_bigram_ixs(ix_sents)
    neg_bigram_ixs = gen_neg_bigram_ixs(ix_sents)

    counter_pos_examples = Counter(tuple(pair) for pair in pos_bigram_ixs.numpy())
    counter_neg_examples = Counter(tuple(pair) for pair in neg_bigram_ixs.numpy())

    pos_examples_1 = {(1, 2), (2, 3)}
    neg_examples_1 = set(product([1, 2, 3], repeat=2)) - pos_examples_1
    pos_examples_2 = {(4, 5), (5, 6), (6, 7), (7, 8)}
    neg_examples_2 = set(product([4, 5, 6, 7, 8], repeat=2)) - pos_examples_2

    total_counts_pos_examples_1 = sum(
        [counter_pos_examples[pair] for pair in pos_examples_1]
    )
    total_counts_neg_examples_1 = sum(
        [counter_neg_examples[pair] for pair in neg_examples_1]
    )
    total_counts_pos_examples_2 = sum(
        [counter_pos_examples[pair] for pair in pos_examples_2]
    )
    total_counts_neg_examples_2 = sum(
        [counter_neg_examples[pair] for pair in neg_examples_2]
    )

    total_counts_pos_examples_1 == total_counts_neg_examples_1 == n_rows // 2
    total_counts_pos_examples_2 == total_counts_neg_examples_1 == n_rows // 2

    assert sum(counter_pos_examples.values()) == n_rows

    assert (
        np.std(
            [
                counter_pos_examples[pair] / total_counts_pos_examples_1
                for pair in pos_examples_1
            ]
        )
        < 0.01
    )

    assert (
        np.std(
            [
                counter_pos_examples[pair] / total_counts_neg_examples_1
                for pair in neg_examples_1
            ]
        )
        < 0.01
    )

    assert (
        np.std(
            [
                counter_pos_examples[pair] / total_counts_pos_examples_2
                for pair in pos_examples_2
            ]
        )
        < 0.01
    )

    assert (
        np.std(
            [
                counter_neg_examples[pair] / total_counts_neg_examples_2
                for pair in pos_examples_2
            ]
        )
        < 0.01
    )
    return True

In [None]:
test_example_generators()