# Relevance Transformer Experiments

In [195]:
from src.autoregressive_transformer import AutoregressiveTransformer
from src.dataset_loaders import SRC_TGT_pairs
from src.vocab_classes import Shared_Vocab, BERT_Vocab
from src.useful_utils import string_split_v3, string_split_v2, string_split_v1
from src.trainers import Model_Trainer
from src.retrieval import PyLuceneRetriever, OracleBLEURetriever
import torch
import torch.nn as nn
from torchtext.data import Field, BucketIterator
import numpy as np
import dotmap
from collections import Counter
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [223]:
# for CoNaLa
src_train_fp = "datasets/CoNaLa/conala-train.src"
tgt_train_fp = "datasets/CoNaLa/conala-train.tgt"
src_valid_fp = "datasets/CoNaLa/conala-valid.src"
tgt_valid_fp = "datasets/CoNaLa/conala-valid.tgt"
src_test_fp = "datasets/CoNaLa/conala-test.src"
tgt_test_fp = "datasets/CoNaLa/conala-test.tgt"

model_save_file = "conala-tiny-transformer-valid-testing/model_file_step_195000.torch"
output_dir = "conala-retrieval-testing"
max_seq_len = 75

In [284]:
# for Hearthstone
src_train_fp = "datasets/HS/train_hs.in"
tgt_train_fp = "datasets/HS/train_hs.out"
src_valid_fp = "datasets/HS/valid_hs.in"
tgt_valid_fp = "datasets/HS/valid_hs.out"
src_test_fp = "datasets/HS/test_hs.in"
tgt_test_fp = "datasets/HS/test_hs.out"

model_save_file = "hearthstone-tiny-transformer-valid-testing/model_file_step_195000.torch"
output_dir = "hearthstone-retrieval-testing"
max_seq_len = 400

In [324]:
# for Django
src_train_fp = "datasets/django_folds/django.fold1-10.full_train.src"
tgt_train_fp = "datasets/django_folds/django.fold1-10.full_train.tgt"
src_valid_fp = "datasets/django_folds/django.fold1-10.valid.src"
tgt_valid_fp = "datasets/django_folds/django.fold1-10.valid.tgt"
src_test_fp = "datasets/django_folds/django.fold1-10.test.src"
tgt_test_fp = "datasets/django_folds/django.fold1-10.test.tgt"

model_save_file = "django-tiny-transformer-custom-tok-50-seq-len-850-vocab-shuffled-data/model_file_step_195000.torch"
output_dir = "django-retrieval-testing"
max_seq_len = 50

In [325]:
# hyperparams
vocab_size = 850
embed_dim = 512
att_heads = 4
layers = 2
batch_size = 32
dim_feedforward = 1024

In [326]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_samples = SRC_TGT_pairs(src_train_fp, tgt_train_fp, max_seq_len=max_seq_len).samples
valid_samples = SRC_TGT_pairs(src_valid_fp, tgt_valid_fp, max_seq_len=max_seq_len).samples
test_samples = SRC_TGT_pairs(src_test_fp, tgt_test_fp, max_seq_len=max_seq_len).samples

vocab = Shared_Vocab(train_samples, vocab_size, string_split_v3, use_OOVs=True)


model = AutoregressiveTransformer(vocab_size=vocab_size, embed_dim=embed_dim, att_heads=att_heads, \
                                  layers=layers, dim_feedforward=dim_feedforward, max_seq_length=max_seq_len).to(device)
test_dataset = model.data2dataset(test_samples, vocab)
test_iterator = BucketIterator(
    test_dataset,
    batch_size = batch_size,
    sort=True,
    sort_key = model.sample_order_fn,
    device = device)

valid_dataset = model.data2dataset(valid_samples, vocab)
valid_iterator = BucketIterator(
    valid_dataset,
    batch_size = batch_size,
    sort=True,
    sort_key = model.sample_order_fn,
    device = device)

model.load_model(model_save_file)
trainer = Model_Trainer(model, vocab, test_iterator=test_iterator)

'output_dir' not defined, training and model outputs won't be saved.


In [327]:
outputs = trainer.evaluate(test_iterator)
regular_results = [out["BLEU"] for out in outputs]
avg_BLEU = np.average(regular_results)
print(f"Small Copy Transformer BLEU: {avg_BLEU*100:.2f}")

HBox(children=(IntProgress(value=0, max=59), HTML(value='')))


Small Copy Transformer BLEU: 78.74


## Retrieval decoding

In [155]:
def retrieval_output_nudging_creator(oracle=False, relevance_interpol=0.0005, k_docs=10, k_words=10, peak_scaling_factor=40.0, \
                                     num_stop_words=10, verbose=False):
    src_train_samples = [src for src, tgt in train_samples]
    tgt_train_samples = [tgt for src, tgt in train_samples]
    if oracle:
        retriever = OracleBLEURetriever(ids_to_keep=k_docs)
        retriever.add_multiple_docs(train_samples)
    else:
        retriever = PyLuceneRetriever()
        retriever.add_multiple_docs(src_train_samples)
    
    counts = Counter(string_split_v3(" ".join(tgt_train_samples))).most_common(num_stop_words)
    stop_words = [x[0] for x in counts]
    
    def nudge_fn(last_token_log_probs, single_decoder_input, batch_encoder_ids, batch_decoder_truth_ids, OOVs):
        OOVs = OOVs.cpu().tolist()
        src_sent = vocab.decode_input(batch_encoder_ids, OOVs, copy_marker="")
        tgt_sent = vocab.decode_output(batch_decoder_truth_ids, OOVs, copy_marker="")
        current_pred = vocab.decode_output(single_decoder_input, OOVs, copy_marker="")
        top_5_ids = torch.argsort(last_token_log_probs.cpu(), descending=True)[:5]
        top_5_words = [vocab.decode_output([idx], OOVs, copy_marker="") for idx in top_5_ids]
        if verbose:
            print("## DECODE STEP ##")
            print(f"SRC input:      {src_sent}")
            print(f"TGT truth:      {tgt_sent}")
            print(f"decoded so far: {current_pred}")
            print(f"top words     : {' | '.join(top_5_words)}")
            print()
        if oracle:
            doc_ranking = retriever.search(src_sent, tgt_sent, max_retrieved_docs=k_docs)
        else:
            doc_ranking = retriever.search(src_sent, max_retrieved_docs=k_docs)
            
        retrieved_samples = [(tgt_train_samples[doc_id], score) for doc_id, score in doc_ranking]
        scoring_dict = {}
        for sample, score in retrieved_samples:
            if verbose:
                print(f"DOC: {sample}")
            sample_toks = string_split_v3(sample)
            for tok in sample_toks:
                if tok in scoring_dict:
                    scoring_dict[tok] += (peak_scaling_factor * score)/len(sample_toks)
                else:
                    scoring_dict[tok] = (peak_scaling_factor * score)/len(sample_toks)
        top_retrieved_words = [tok for tok in sorted(scoring_dict.items(), key=lambda item: -item[1]) if tok[0] not in stop_words][:k_words]
        if verbose:
            print(f"RETRIEVAL top words: {[tok for tok, score in top_retrieved_words]}")
            print()
            print()
        top_retrieved_ids = [(vocab.encode_output(tok, OOVs)[0], score) for tok, score in top_retrieved_words]
        top_retrieved_ids = [(i, s) for i, s in top_retrieved_ids if i != vocab.UNK]
        
        num_relevant_terms = len(list(set(batch_decoder_truth_ids.cpu().tolist()).intersection([x[0] for x in top_retrieved_ids])))
        total_relevant_terms.append(num_relevant_terms)
        
        relevance_vector = torch.zeros_like(last_token_log_probs).fill_(-5000.0)
        for idx, score in top_retrieved_ids:
            if idx not in single_decoder_input:
                relevance_vector[idx] = score
        relevance_vector.softmax(-1)
        
        if top_5_ids[0] == vocab.EOS:
            new_probs = last_token_log_probs
        else:
            new_probs = (1-relevance_interpol) * last_token_log_probs + relevance_interpol * relevance_vector
        
        new_top_pred = torch.argmax(new_probs)
        if verbose:
            if top_5_ids[0] != new_top_pred:
                print("Relevance impact:")
                print(f"SRC input:      {src_sent}")
                print(f"TGT truth:      {tgt_sent}")
                print(f"decoded so far: {current_pred}")
                print(f"RETRIEVAL top words: {[tok for tok, score in top_retrieved_words][:]}")
                print(f"Prerdicted {vocab.decode_output([new_top_pred], OOVs)} over {top_5_words[0]}")
                print(num_relevant_terms)
                print()
            
        return new_probs
    
    return nudge_fn


In [250]:
retrieval_decoder_fn = retrieval_output_nudging_creator(oracle=False, 
                                                        relevance_interpol=0.0003, 
                                                        k_docs=1, 
                                                        k_words=10,  
                                                        peak_scaling_factor=10.0, 
                                                        num_stop_words=10)

model = AutoregressiveTransformer(vocab_size=vocab_size, embed_dim=embed_dim, att_heads=att_heads, \
                                  layers=layers, dim_feedforward=dim_feedforward, max_seq_length=max_seq_len, \
                                  output_nudge_fn=retrieval_decoder_fn).to(device)
model.eval()
model.load_model(model_save_file)

In [146]:
src = "plot dataframe ` df ` without a legend "
src_ids, OOVs = vocab.encode_input(src)

tgt = "df . plot ( legend = False ) "
tgt_ids = vocab.encode_output(tgt, OOVs)

batch = dotmap.DotMap()
batch.src = torch.tensor([vocab.SOS] + src_ids + [vocab.EOS]).unsqueeze(1).to(device)
batch.tgt = torch.tensor([vocab.SOS] + tgt_ids + [vocab.EOS]).unsqueeze(1).to(device)
batch.OOVs = torch.tensor(OOVs).unsqueeze(1).to(device)

# OVERRIDE
# batch = next(iter(test_iterator))

model.eval_step(batch, vocab)

[{'BLEU': 0.0698819818549069,
  'SRC': 'plot(COPY) dataframe(COPY) `(COPY) df(COPY) `(COPY) without a legend(COPY)',
  'TGT': 'df(COPY) . plot(COPY) ( legend(COPY) = False )',
  'PRED': '<unk> ( dataframe(COPY) )'}]

In [147]:
trainer = Model_Trainer(model, vocab, output_dir=output_dir)
outputs = trainer.evaluate(test_iterator, save_file="eval_test_samples.txt")
avg_BLEU = np.average([out["BLEU"] for out in outputs])
print(f"Relevance Transformer BLEU: {avg_BLEU*100:.2f}")
print(f"Average relevant terms: {np.average(total_relevant_terms)}")

Writing logs to: django-retrieval-testing/logs.txt


HBox(children=(IntProgress(value=0, max=59), HTML(value='')))


Relevance Transformer BLEU: 82.22
Average relevant terms: 0.7146699109944291


## Checking relevance

In [157]:
for k in [1,5,10,20,40,60,100,150,200]:

    total_relevant_terms = []
    retrieval_decoder_fn = retrieval_output_nudging_creator(oracle=False, 
                                                        relevance_interpol=0.0003, 
                                                        k_docs=k, 
                                                        k_words=10,  
                                                        peak_scaling_factor=10.0, 
                                                        num_stop_words=10)

    model = AutoregressiveTransformer(vocab_size=vocab_size, embed_dim=embed_dim, att_heads=att_heads, \
                                      layers=layers, dim_feedforward=dim_feedforward, max_seq_length=max_seq_len, \
                                      output_nudge_fn=retrieval_decoder_fn).to(device)
    model.eval()
    model.load_model(model_save_file)
    trainer = Model_Trainer(model, vocab, output_dir=output_dir)
    outputs = trainer.evaluate(test_iterator, save_file="eval_test_samples.txt")
    avg_BLEU = np.average([out["BLEU"] for out in outputs])
    print(f"Relevance Transformer BLEU: {avg_BLEU*100:.2f}")
    print(f"Average relevant terms: {np.average(total_relevant_terms)}")

JVM Running
Writing logs to: hearthstone-retrieval-testing/logs.txt


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


Relevance Transformer BLEU: 73.97
Average relevant terms: 6.780221811460259
JVM Running
Writing logs to: hearthstone-retrieval-testing/logs.txt


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


Relevance Transformer BLEU: 73.91
Average relevant terms: 9.15061133753242
JVM Running
Writing logs to: hearthstone-retrieval-testing/logs.txt


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


Relevance Transformer BLEU: 74.05
Average relevant terms: 9.170876671619613
JVM Running
Writing logs to: hearthstone-retrieval-testing/logs.txt


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


Relevance Transformer BLEU: 74.04
Average relevant terms: 9.190900649953575
JVM Running
Writing logs to: hearthstone-retrieval-testing/logs.txt


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


Relevance Transformer BLEU: 73.89
Average relevant terms: 8.898982423681776
JVM Running
Writing logs to: hearthstone-retrieval-testing/logs.txt


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


Relevance Transformer BLEU: 74.00
Average relevant terms: 8.950082888193037
JVM Running
Writing logs to: hearthstone-retrieval-testing/logs.txt


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


Relevance Transformer BLEU: 74.00
Average relevant terms: 8.837170749677657
JVM Running
Writing logs to: hearthstone-retrieval-testing/logs.txt


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


Relevance Transformer BLEU: 74.00
Average relevant terms: 8.86240559955793
JVM Running
Writing logs to: hearthstone-retrieval-testing/logs.txt


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))


Relevance Transformer BLEU: 74.00
Average relevant terms: 8.840117885430098


## Random parameter sarch

In [46]:

def test_model(relevance_interpol, k_docs, k_words, peak_scaling_factor, num_stop_words):
    tgt_train_samples = [tgt for src, tgt in train_samples]

    counts = Counter(string_split_v3(" ".join(tgt_train_samples))).most_common(30)
    stop_words = [x[0] for x in counts]
    
    retrieval_decoder_fn = retrieval_output_nudging_creator(oracle=False, 
                                                            relevance_interpol=relevance_interpol, 
                                                            k_docs=k_docs, 
                                                            k_words=k_words, 
                                                            peak_scaling_factor=peak_scaling_factor)

    model = AutoregressiveTransformer(vocab_size=vocab_size, embed_dim=embed_dim, att_heads=att_heads, \
                                      layers=layers, dim_feedforward=dim_feedforward, max_seq_length=max_seq_len, \
                                      output_nudge_fn=retrieval_decoder_fn).to(device)
    model.eval()
    model.load_model(model_save_file)
    
    trainer = Model_Trainer(model, vocab, output_dir=output_dir)
    outputs = trainer.evaluate(valid_iterator)
    avg_BLEU = np.average([out["BLEU"] for out in outputs])
    return avg_BLEU

In [47]:
search_results = []
def random_search():
    relevance_interpol = np.random.uniform(0.001, 0.00005)
    k_docs = np.random.randint(1,50)
    k_words = np.random.randint(1,50)
    peak_scaling_factor = np.random.uniform(0.5, 100.0)
    num_stop_words = np.random.randint(3,50)
    
    avg_BLEU = test_model(relevance_interpol, k_docs, k_words, peak_scaling_factor, num_stop_words)
    return ((relevance_interpol, k_docs, k_words, peak_scaling_factor, num_stop_words), avg_BLEU)

In [50]:
search_results = []

In [None]:
import tqdm.notebook as tqdm 
iters = 750
for i in tqdm.tqdm(range(iters)):
    result = random_search()
    search_results.append(result)

HBox(children=(IntProgress(value=0, max=750), HTML(value='')))

JVM Running
Writing logs to: django-retrieval-testing/logs.txt


HBox(children=(IntProgress(value=0, max=27), HTML(value='')))

In [54]:
search_results

[((0.00027132764995628045, 28, 17, 92.62807027519402, 3), 0.8242841085725205),
 ((0.00031631203678050704, 16, 15, 86.6994295528835, 26), 0.8238846502939439),
 ((0.0003436605959209651, 38, 38, 27.178666113265393, 49), 0.8199692004342154),
 ((0.0006881683563104867, 41, 25, 66.41908383318608, 24), 0.8024939538387946),
 ((0.0009053637508599583, 5, 37, 16.639937417495496, 45), 0.8042001087480264),
 ((5.567529626851529e-05, 49, 31, 15.225977444649727, 40), 0.8215407590067793),
 ((0.0006100539203266688, 28, 49, 25.098293836787544, 16), 0.8076904059900123),
 ((0.00041730675022604903, 22, 24, 92.75634677146452, 40), 0.8214348646617602),
 ((0.0009377920884179971, 15, 42, 26.877549587443877, 30), 0.7946313615804381),
 ((0.000993562057673889, 16, 24, 3.4886702890685406, 26), 0.7921574670166104),
 ((0.00016703763619177025, 17, 4, 52.553586304792546, 3), 0.8229666615318995),
 ((0.0005905587961810301, 24, 1, 33.45229385021521, 37), 0.8188397421491495),
 ((0.0004111915306344078, 31, 1, 43.411966182228

# T-test stuff

In [305]:
model = AutoregressiveTransformer(vocab_size=vocab_size, embed_dim=embed_dim, att_heads=att_heads, \
                                  layers=layers, dim_feedforward=dim_feedforward, max_seq_length=50, \
                                  output_nudge_fn=retrieval_decoder_fn).to(device)
model.eval()
model.load_model(model_save_file)
trainer = Model_Trainer(model, vocab, output_dir=output_dir)

Writing logs to: django-retrieval-testing/logs.txt


In [306]:
outputs = trainer.evaluate(test_iterator)

HBox(children=(IntProgress(value=0, max=59), HTML(value='')))




In [330]:
relev_results = [out["BLEU"] for out in outputs]
np.average(relev_results)

0.7873712048124596

In [210]:
base_results = [out["BLEU"] for out in outputs]
np.average(base_results)

0.17308943249100314

In [331]:
from scipy import stats
stats.ttest_ind(regular_results,relev_results)

Ttest_indResult(statistic=0.0, pvalue=1.0)

In [304]:
retrieval_decoder_fn = retrieval_output_nudging_creator(oracle=False, 
                                                        relevance_interpol=0.0003, 
                                                        k_docs=10, 
                                                        k_words=10,  
                                                        peak_scaling_factor=96.0, 
                                                        num_stop_words=17)

JVM Running
