In [1]:
from eli5_utils import *

eli5_dbuilder = ELI5NLP(data_dir='eli5')
eli5_dbuilder.download_and_prepare()

eli5_train = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TRAIN)
eli5_valid = eli5_dbuilder.as_dataset(split=nlp.splits.Split.VALIDATION)
eli5_test = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TEST)

#### Training the retriever

In [None]:
class ArgumentsQAR():
    def __init__(self):
        self.batch_size = 512
        self.max_length = 128
        self.print_freq = 100
        self.pretrained_model_name = "google/bert_uncased_L-8_H-512_A-8"
        self.model_save_name = "retriever_models/eli5_retriever_model_512"
        self.learning_rate = 2e-5
        self.num_epochs = 10

qar_args = ArgumentsQAR()

qar_train_dset = ELI5DatasetQARetriver(eli5_train, min_answer_length=64, training=True)
qar_valid_dset = ELI5DatasetQARetriver(eli5_valid, min_answer_length=64, training=False)

qar_tokenizer, qar_model = make_qa_retriever_model(
    model_name=qar_args.pretrained_model_name,
    from_file=None,
    device="cuda:0"
)

qar_optimizer = AdamW(qar_model.parameters(), lr=qar_args.learning_rate, eps=1e-8)
qar_scheduler = get_linear_schedule_with_warmup(
        qar_optimizer,
        num_warmup_steps=100,
        num_training_steps=qar_args.num_epochs * math.ceil(len(qar_train_dset) / qar_args.batch_size)
)


In [None]:
for e in range(qar_args.num_epochs):
    train_qa_retriever_epoch(
        qar_model, qar_train_dset, qar_tokenizer,
        qar_optimizer, qar_scheduler, qar_args, e
    )
    m_save_dict = {
        'model': qar_model.state_dict(),
        'optimizer': qar_optimizer.state_dict(),
        'scheduler': qar_scheduler.state_dict(),
    }
    print("Saving model {}".format(qar_args.model_save_name))
    torch.save(m_save_dict, '{}_{}.pth'.format(qar_args.model_save_name, e))
    eval_loss = evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args)
    print("Evaluation loss epoch {:4d}: {:.3f}".format(e, eval_loss))



In [None]:
# TODO: evaluate recall@N for validation / test set

In [None]:
kilt_snippets_dbuilder = KiltSnippets(data_dir='kilt_snippets_100w')
kilt_snippets_dbuilder.download_and_prepare()
wiki_passages = kilt_snippets_dbuilder.as_dataset(split=nlp.splits.Split.TRAIN)

make_qa_dense_index(qar_model, qar_tokenizer,
                    wiki_passages,
                    batch_size=512,
                    index_name='kilt_passages_reps_16.dat',
                    device='cuda:0')




#### Training the Seq2seq model

In [2]:
eli5_train_docs = json.load(open('eli5_train_precomputed_dense_docs.json'))
eli5_valid_docs = json.load(open('eli5_valid_precomputed_dense_docs.json'))
eli5_test_docs = json.load(open('eli5_test_precomputed_dense_docs.json'))

In [3]:
eli5_train[123456], eli5_train_docs[123456]

({'q_id': '8rn7rq',
  'title': 'Why do cars only ever have DC outlets even though most electrical devices use AC plugs?',
  'selftext': '',
  'answers': {'a_id': ['', '', '', '', '', ''],
   'text': ["To answer this, we probably need to look at why we use AC in buildings in the first place. When you transmit electricity from the power station to the consumers, you can choose to do it with a high current/low voltage, or low current/high voltage. We choose to do as much of the transmission as possible using the low current/high voltage as, because high currents have a nasty habit of heating the cables up excessively. At home, though, we use a low(ish) voltage, and so we need to be able to convert between the high transmission voltage and low domestic voltage - to do this, we use devices called transformers, and transformers only work with AC. If we didn't use AC, we'd need large numbers of local power stations close to consumers. In a car, though, we have a different system - firstly, th

In [20]:
s2s_train_dset = ELI5DatasetS2S(eli5_train, document_cache=dict([(k, d) for k, d, src_ls in eli5_train_docs]))
s2s_valid_dset = ELI5DatasetS2S(eli5_valid, document_cache=dict([(k, d) for k, d, src_ls in eli5_valid_docs]), training=False)

In [13]:
s2s_valid_dset.data[25]['q_id']

'5e7vx3'

In [16]:
'5e7vx3' in s2s_valid_dset.document_cache, s2s_valid_dset.data.num_rows

(True, 9813)

('question: in films, when a gun, usually a pistol, is dropped or thrown onto the ground, it will sometimes fire. is this something real guns do too, and if so, why, since the trigger isn\'t being pulled? context: <p> have been in existence for decades; strictly speaking, loaded chamber indicators are not safeties, nor are they efficacious with an untrained user.\n an indicator that is behind the ejector port does not rise enough to disrupt a shooter\'s sight picture, but enough to be easily seen or felt to alert a user that there is a round in the chamber to avoid negligent discharge of the gun.\n typical safeties.:trigger disconnector.\n a trigger disconnector captures the hammer in the cocked position after a shot has been fired, even if the trigger is held to the rear as the gun cycles. this <p> head with one arm or resist the crushing effort of a car crusher, as seen in the tv series (episodes 5 and 21, respectively). he was designed to be able "to penetrate virtually any building