In [None]:
from pathlib import Path

import git

repo = git.Repo(Path(".").absolute(), search_parent_directories=True)
ROOT = Path(repo.working_tree_dir)

In [None]:
cd $ROOT

# Preamble

In [None]:
import pickle

import numpy as np
import src.disc as disc
import src.neural_model as nm
import src.sentence_reconstruction as sr
import torch.nn.functional as F
from src.misc import WV, process_word_vecs
from src.neural_model import BigramNN
from tqdm.auto import tqdm

Choose betwwen `nm` and `disc` to generate results for $f_T$ and $f_\odot$, respectively.

In [None]:
METHOD = "disc" # nm, disc

In [None]:
# Folders
DATA = ROOT / "data"
FAST_TEXT = DATA / "raw" / "crawl-300d-2M.vec"
TRAIN = DATA / "processed" / "train.1000000.pkl"
VALID = DATA / "processed" / "valid.pkl"
TEST = DATA / "processed" / "test.pkl"
RESULTS = ROOT / "results"

START_TOKEN = "▷"
END_TOKEN = "◁"

# Choose data set and its size
VOCAB = "whole"  # VOCABS = {"whole", "test"}
MODE = "test"  # Modes = {"test", "valid"}
SIZE = 10000

In [None]:
word_index, word_vecs = process_word_vecs(FAST_TEXT)
wv = WV(word_vecs=word_vecs, word_index=word_index)
del word_index, word_vecs
train = pickle.load(open(TRAIN, "rb"))
valid = pickle.load(open(VALID, "rb"))
test = pickle.load(open(TEST, "rb"))

# Sentence Representation

In [None]:
if METHOD == "nm":
    get_candidate_sents = nm.get_candidate_sents
    gen_bvs_i2b = nm.gen_bvs_i2b
elif METHOD == "disc":
    get_candidate_sents = disc.get_candidate_sents
    gen_bvs_i2b = disc.gen_bvs_i2b    
else:
    raise NotImplementedError

Remove oov from sents

In [None]:
train = [[word for word in sent if word in wv.vocab] for sent in train]
valid = [[word for word in sent if word in wv.vocab] for sent in valid]
test = [[word for word in sent if word in wv.vocab] for sent in test]
vocab = set(
    word
    for sent in (train + valid + test if VOCAB == "whole" else test)
    for word in sent
)
vocab.update({START_TOKEN, END_TOKEN})
wv.adjust(vocab)
wv.vecs = F.normalize(wv.vecs)
index2word = {wv.dict[word]: word for word in wv.dict}
# Prepare the sentences. Remove the markers.
sents = [sent for sent in {"test": test, "valid": valid}[MODE][:SIZE]]
# unigrams
original_unigram_sents = sr.ngram_sents(sents, 1, markers=False)
# unigram sentence representations
unigram_sent_vecs = sr.ngram_sent_vecs(
    original_unigram_sents, disc.disc, np.array(wv.vecs), wv.dict, 1
)

## Reconstruct Unigrams

In [None]:
file_name = sr.make_file_name(1, sents)

if (RESULTS / file_name).exists():
    with open(RESULTS / file_name, "rb") as f:
        reconstructed_unigram_sents = pickle.load(f)
else:
    # The first 40,000 words close to the sentence cover in 99% of the cases all
    # the unigrams in the sentence. (Evaluated on the validation set).
    TOP_N = 40000

    reconstructed_unigram_sents = []
    for unigram_sent_vec in tqdm(unigram_sent_vecs):
        indices = np.argsort(np.array(wv.vecs) @ unigram_sent_vec)[::-1][:TOP_N]
        words = [index2word[index] for index in indices]
        reconstructed_unigram_sents.append(
            sr.reconstruct(
                unigram_sent_vec,
                np.array(wv.vecs)[indices],
                dict(enumerate(words)),
                solver="bp",
            )
        )
    # Serialize
    file_name = sr.make_file_name(1, sents)
    with open(RESULTS / file_name, "wb") as f:
        pickle.dump(reconstructed_unigram_sents, f)

# Evaluation
sr.accuracy(original_unigram_sents, reconstructed_unigram_sents)

## Reconstruct Bigrams

In [None]:
original_bigram_sents = sr.ngram_sents(sents, 2, markers=True)
file_name = sr.make_file_name(2, sents)

if (RESULTS / METHOD / file_name).exists():
    with open(RESULTS / METHOD / file_name, "rb") as f:
        reconstructed_bigram_sents = pickle.load(f)
else:
    if METHOD == "nm":
        bigram_vec_repr = BigramNN("diff")
    elif METHOD == "disc":
        bigram_vec_repr = lambda x: disc.disc(x.numpy())
    else:
        raise NotImplementedError
    bigram_sent_vecs = sr.ngram_sent_vecs(
        original_bigram_sents, bigram_vec_repr, wv.vecs, wv.dict, 2
    )

    bigram_sent_vecs = [vec.reshape(-1) for vec in bigram_sent_vecs]
    # bigram_sent_vecs = [vec for vec in bigram_sent_vecs]

    bvs_i2b = gen_bvs_i2b(
        reconstructed_unigram_sents, bigram_vec_repr, wv.vecs, wv.dict, markers=True
    )

    reconstructed_bigram_sents = []
    # i = 0
    for bigram_sent_vec in tqdm(bigram_sent_vecs):
        bigram_vecs, index2bigram = next(bvs_i2b)
        #     i += 1
        reconstructed_bigram_sents.append(
            sr.reconstruct(bigram_sent_vec, bigram_vecs, index2bigram, solver="omp")
        )
    # Serialize
    file_name = sr.make_file_name(2, sents)
    with open(RESULTS / METHOD / file_name, "wb") as f:
        pickle.dump(reconstructed_bigram_sents, f)

    # Store candidate sentence generated from bigrams
    candidate_sents = [
        get_candidate_sents(reconstructed_bigram_sent)
        for reconstructed_bigram_sent in tqdm(reconstructed_bigram_sents)
    ]

    with open(RESULTS / METHOD / "candidate_sents_from_bigrams.pkl", "wb") as f:
        pickle.dump(candidate_sents, f)

# Evaluation
if METHOD == "nm":
    display(sr.accuracy(original_bigram_sents, reconstructed_bigram_sents,))
elif METHOD == "disc":
    # Sort the word order of each bigram in the sentences so as to compare unordered bigrams
    display(
        sr.accuracy(
            sr.sorted_ngrams(original_bigram_sents),
            sr.sorted_ngrams(reconstructed_bigram_sents),
        )
    )

## Reconstruct Trigrams

In [None]:
original_trigram_sents = sr.ngram_sents(sents, 3, markers=True)

file_name = sr.make_file_name(3, sents)
if (RESULTS / METHOD / file_name).exists():
    with open(RESULTS / METHOD / file_name, "rb") as f:
        reconstructed_trigram_sents = pickle.load(f)
else:
    trigram_sent_vecs = sr.ngram_sent_vecs(
        original_trigram_sents, disc.disc, np.array(wv.vecs), wv.dict, 3
    )

    def reconstruct_trigram_sents(
        reconstructed_bigram_sents, trigram_vec_repr, word_vecs, word2index
    ):
        reconstructed_trigram_sents = []
        for (bigram_sent, trigram_sent_vec) in tqdm(
            list(zip(reconstructed_bigram_sents, trigram_sent_vecs))
        ):
            candidate_sents = get_candidate_sents(bigram_sent)
            if candidate_sents:
                trigrams = disc.get_candidate_trigrams(candidate_sents)
                solver = "omp"
            else:
                trigrams = sr.bigram_sent2trigrams(bigram_sent)
                solver = "bp"
            trigram_vecs, index2trigram = sr.tvs_i2t(
                trigrams, trigram_vec_repr, word_vecs, word2index
            )
            reconstructed_trigram_sents.append(
                sr.reconstruct(
                    trigram_sent_vec, trigram_vecs, index2trigram, solver=solver
                )
            )
        return reconstructed_trigram_sents

    reconstructed_trigram_sents = reconstruct_trigram_sents(
        reconstructed_bigram_sents, disc.disc, np.array(wv.vecs), wv.dict
    )
    # Serialize
    file_name = sr.make_file_name(3, sents)
    with open(RESULTS / METHOD / file_name, "wb") as f:
        pickle.dump(reconstructed_trigram_sents, f)

    # Store the candidate sentences generated from trigrams
    candidate_sents_list = []
    for trigrams in tqdm(reconstructed_trigram_sents):
        candidate_sents_list.append(disc.get_candidate_sents_trigrams(trigrams))

    indices = np.where(np.array(list(map(len, candidate_sents_list))) == 0)[0].tolist()

    for i in indices:
        candidate_sents_list[i] = [()]

    with open(RESULTS / METHOD / "candidate_sents_from_trigrams.pkl", "wb") as f:
        pickle.dump(candidate_sents_list, f)
# Evaluate
sr.accuracy(
    sr.sorted_ngrams(original_trigram_sents),
    sr.sorted_ngrams(reconstructed_trigram_sents),
)

# Reconstruct Sentences

In [None]:
from joblib import Parallel, delayed
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu


def reconstruct_sents(original_sents, candidate_sents_list):
    reconstructed_sents = [
        min(
            candidate_sents,
            key=lambda candidate: sentence_bleu([original_sent], candidate),
        )
        for (original_sent, candidate_sents) in zip(
            original_sents, candidate_sents_list
        )
    ]
    return reconstructed_sents




original_sents = sents


with open(RESULTS / METHOD / "candidate_sents_from_bigrams.pkl", "rb") as f:
    candidate_sents_bigram = pickle.load(f)
with open(RESULTS / METHOD / "candidate_sents_from_trigrams.pkl", "rb") as f:
    candidate_sents_trigram = pickle.load(f)
reconstructed_sents = reconstruct_sents(original_sents, candidate_sents_trigram)

accuracy = np.array(
    [
        tuple(original_sents[i]) == reconstructed_sents[i]
        for i in range(len(original_sents))
    ]
).mean()
accuracy

# Corpus BLEU Score

In [None]:
corpus_bleu([[sent] for sent in original_sents], reconstructed_sents)