In [1]:
from eli5_utils import *

eli5_dbuilder = ELI5NLP(data_dir='eli5')
eli5_dbuilder.download_and_prepare()

eli5_train = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TRAIN)
eli5_valid = eli5_dbuilder.as_dataset(split=nlp.splits.Split.VALIDATION)
eli5_test = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TEST)

#### Training the retriever

In [2]:
class ArgumentsQAR():
    def __init__(self):
        self.batch_size = 128
        self.max_length = 64
        self.print_freq = 10
        self.pretrained_model_name = "google/bert_uncased_L-8_H-512_A-8"
        self.model_save_name = "retriever_models/test_model"
        self.learning_rate = 2e-5
        self.num_epochs = 1

qar_args = ArgumentsQAR()

In [3]:
qar_train_dset = ELI5DatasetQARetriver(eli5_train, min_answer_length=64, training=True)
qar_valid_dset = ELI5DatasetQARetriver(eli5_valid, min_answer_length=64, training=False)

qar_tokenizer, qar_model = make_qa_retriever_model(
    model_name=qar_args.pretrained_model_name,
    from_file=None,
    device="cuda:0"
)

qar_optimizer = AdamW(qar_model.parameters(), lr=qar_args.learning_rate, eps=1e-8)
qar_scheduler = get_linear_schedule_with_warmup(
        qar_optimizer,
        num_warmup_steps=100,
        num_training_steps=qar_args.num_epochs * math.ceil(len(qar_train_dset) / qar_args.batch_size)
)

In [None]:
for e in range(qar_args.num_epochs):
    train_qa_retriever_epoch(
        qar_model, qar_train_dset, qar_tokenizer,
        qar_optimizer, qar_scheduler, qar_args, e
    )
    save_dict = {
        'model': qar_model.state_dict(),
        'optimizer': qar_optimizer.state_dict(),
        'scheduler': qar_scheduler.state_dict(),
    }
    print("Saving model {}".format(qar_args.model_save_name)
    torch.save(save_dict, '{}_{}.pth'.format(qar_args.model_save_name, e))
    eval_loss = evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args)
    print("Evaluation loss epoch {:4d}: {:/3f}".format(e, eval_loss))
    
    

 0     0 of  2156 	 L: 5.018 	 -- 1.022
 0    10 of  2156 	 L: 5.013 	 -- 8.781
 0    20 of  2156 	 L: 4.958 	 -- 16.555
 0    30 of  2156 	 L: 4.921 	 -- 24.480
 0    40 of  2156 	 L: 4.849 	 -- 32.616
 0    50 of  2156 	 L: 4.797 	 -- 40.657
 0    60 of  2156 	 L: 4.642 	 -- 48.742
 0    70 of  2156 	 L: 4.347 	 -- 56.742
 0    80 of  2156 	 L: 4.064 	 -- 64.766
 0    90 of  2156 	 L: 3.819 	 -- 72.923
 0   100 of  2156 	 L: 3.621 	 -- 80.995
 0   110 of  2156 	 L: 3.313 	 -- 88.996
 0   120 of  2156 	 L: 3.149 	 -- 97.060
 0   130 of  2156 	 L: 2.877 	 -- 105.140
 0   140 of  2156 	 L: 2.861 	 -- 113.138
 0   150 of  2156 	 L: 2.694 	 -- 121.515
 0   160 of  2156 	 L: 2.748 	 -- 130.136
 0   170 of  2156 	 L: 2.572 	 -- 138.723
 0   180 of  2156 	 L: 2.475 	 -- 147.311
 0   190 of  2156 	 L: 2.483 	 -- 156.035
 0   200 of  2156 	 L: 2.476 	 -- 164.318
 0   210 of  2156 	 L: 2.434 	 -- 172.529
 0   220 of  2156 	 L: 2.365 	 -- 181.085
 0   230 of  2156 	 L: 2.278 	 -- 190.422
 0   24

In [None]:
evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args)
# TODO: evaluate recall@N for validation / test set

#### Training the Seq2seq model

In [2]:
eli5_train_docs = json.load(open('eli5_train_precomputed_dense_docs.json'))
eli5_valid_docs = json.load(open('eli5_valid_precomputed_dense_docs.json'))
eli5_test_docs = json.load(open('eli5_test_precomputed_dense_docs.json'))

In [5]:
eli5_train[123456], eli5_train_docs[123456]

({'q_id': '8rn7rq',
  'title': 'Why do cars only ever have DC outlets even though most electrical devices use AC plugs?',
  'selftext': '',
  'answers': {'a_id': ['', '', '', '', '', ''],
   'text': ["To answer this, we probably need to look at why we use AC in buildings in the first place. When you transmit electricity from the power station to the consumers, you can choose to do it with a high current/low voltage, or low current/high voltage. We choose to do as much of the transmission as possible using the low current/high voltage as, because high currents have a nasty habit of heating the cables up excessively. At home, though, we use a low(ish) voltage, and so we need to be able to convert between the high transmission voltage and low domestic voltage - to do this, we use devices called transformers, and transformers only work with AC. If we didn't use AC, we'd need large numbers of local power stations close to consumers. In a car, though, we have a different system - firstly, th