In [1]:
from eli5_utils import *

eli5_dbuilder = ELI5NLP(data_dir='eli5')
eli5_dbuilder.download_and_prepare()

eli5_train = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TRAIN)
eli5_valid = eli5_dbuilder.as_dataset(split=nlp.splits.Split.VALIDATION)
eli5_test = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TEST)

#### Training the retriever

In [2]:
class ArgumentsQAR():
    def __init__(self):
        self.batch_size = 2048
        self.max_length = 128
        self.checkpoint_batch_size = 128
        self.print_freq = 10
        self.pretrained_model_name = "google/bert_uncased_L-8_H-512_A-8"
        self.model_save_name = "retriever_models/eli5_retriever_model_2048"
        self.learning_rate = 2e-5
        self.num_epochs = 20

qar_args = ArgumentsQAR()

qar_train_dset = ELI5DatasetQARetriver(eli5_train, min_answer_length=64, training=True)
qar_valid_dset = ELI5DatasetQARetriver(eli5_valid, min_answer_length=64, training=False)

qar_tokenizer, qar_model = make_qa_retriever_model(
    model_name=qar_args.pretrained_model_name,
    from_file=None,
    device="cuda:0"
)

qar_optimizer = AdamW(qar_model.parameters(), lr=qar_args.learning_rate, eps=1e-8)
qar_scheduler = get_linear_schedule_with_warmup(
        qar_optimizer,
        num_warmup_steps=100,
        num_training_steps=qar_args.num_epochs * math.ceil(len(qar_train_dset) / qar_args.batch_size)
)

for e in range(qar_args.num_epochs):
    train_qa_retriever_epoch(
        qar_model, qar_train_dset, qar_tokenizer,
        qar_optimizer, qar_scheduler, qar_args, e
    )
    m_save_dict = {
        'model': qar_model.state_dict(),
        'optimizer': qar_optimizer.state_dict(),
        'scheduler': qar_scheduler.state_dict(),
    }
    print("Saving model {}".format(qar_args.model_save_name))
    torch.save(m_save_dict, '{}_{}.pth'.format(qar_args.model_save_name, e))
    eval_loss = evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args)
    print("Evaluation loss epoch {:4d}: {:.3f}".format(e, eval_loss))


 0     0 of   539 	 L: 6.342 	 -- 9.093
 0    10 of   539 	 L: 6.347 	 -- 95.173
 0    20 of   539 	 L: 6.308 	 -- 182.368
 0    30 of   539 	 L: 6.243 	 -- 271.287
 0    40 of   539 	 L: 6.161 	 -- 362.347
 0    50 of   539 	 L: 6.010 	 -- 454.841
 0    60 of   539 	 L: 5.733 	 -- 545.507
 0    70 of   539 	 L: 5.391 	 -- 638.503
 0    80 of   539 	 L: 4.973 	 -- 729.353
 0    90 of   539 	 L: 4.670 	 -- 818.165
 0   100 of   539 	 L: 4.343 	 -- 908.974
 0   110 of   539 	 L: 4.084 	 -- 997.924
 0   120 of   539 	 L: 3.877 	 -- 1086.849
 0   130 of   539 	 L: 3.733 	 -- 1176.005
 0   140 of   539 	 L: 3.645 	 -- 1264.747
 0   150 of   539 	 L: 3.508 	 -- 1353.476
 0   160 of   539 	 L: 3.459 	 -- 1443.157
 0   170 of   539 	 L: 3.384 	 -- 1530.991
 0   180 of   539 	 L: 3.237 	 -- 1619.006
 0   190 of   539 	 L: 3.172 	 -- 1706.880
 0   200 of   539 	 L: 3.150 	 -- 1794.816
 0   210 of   539 	 L: 3.081 	 -- 1883.141
 0   220 of   539 	 L: 3.028 	 -- 1972.121
 0   230 of   539 	 L: 2.9

 3   240 of   539 	 L: 1.592 	 -- 2081.457
 3   250 of   539 	 L: 1.572 	 -- 2167.973
 3   260 of   539 	 L: 1.563 	 -- 2254.105
 3   270 of   539 	 L: 1.565 	 -- 2340.173
 3   280 of   539 	 L: 1.548 	 -- 2426.425
 3   290 of   539 	 L: 1.610 	 -- 2512.734
 3   300 of   539 	 L: 1.562 	 -- 2598.918
 3   310 of   539 	 L: 1.535 	 -- 2685.304
 3   320 of   539 	 L: 1.556 	 -- 2772.164
 3   330 of   539 	 L: 1.557 	 -- 2858.908
 3   340 of   539 	 L: 1.598 	 -- 2945.203
 3   350 of   539 	 L: 1.581 	 -- 3031.560
 3   360 of   539 	 L: 1.568 	 -- 3117.719
 3   370 of   539 	 L: 1.534 	 -- 3204.090
 3   380 of   539 	 L: 1.583 	 -- 3290.546
 3   390 of   539 	 L: 1.542 	 -- 3376.847
 3   400 of   539 	 L: 1.580 	 -- 3463.172
 3   410 of   539 	 L: 1.558 	 -- 3549.349
 3   420 of   539 	 L: 1.541 	 -- 3635.657
 3   430 of   539 	 L: 1.567 	 -- 3722.059
 3   440 of   539 	 L: 1.571 	 -- 3808.347
 3   450 of   539 	 L: 1.557 	 -- 3894.831
 3   460 of   539 	 L: 1.574 	 -- 3981.229
 3   470 of

 6   480 of   539 	 L: 1.359 	 -- 4175.100
 6   490 of   539 	 L: 1.363 	 -- 4262.016
 6   500 of   539 	 L: 1.304 	 -- 4348.756
 6   510 of   539 	 L: 1.348 	 -- 4435.486
 6   520 of   539 	 L: 1.379 	 -- 4522.126
 6   530 of   539 	 L: 1.333 	 -- 4608.723
Saving model retriever_models/eli5_retriever_model_512
Evaluation loss epoch    6: 1.052
 7     0 of   539 	 L: 1.236 	 -- 8.700
 7    10 of   539 	 L: 1.370 	 -- 95.146
 7    20 of   539 	 L: 1.316 	 -- 181.676
 7    30 of   539 	 L: 1.288 	 -- 268.245
 7    40 of   539 	 L: 1.285 	 -- 354.808
 7    50 of   539 	 L: 1.307 	 -- 441.356
 7    60 of   539 	 L: 1.277 	 -- 527.899
 7    70 of   539 	 L: 1.306 	 -- 614.368
 7    80 of   539 	 L: 1.296 	 -- 700.755
 7    90 of   539 	 L: 1.303 	 -- 787.304
 7   100 of   539 	 L: 1.329 	 -- 873.831
 7   110 of   539 	 L: 1.322 	 -- 960.588
 7   120 of   539 	 L: 1.340 	 -- 1047.184
 7   130 of   539 	 L: 1.303 	 -- 1133.643
 7   140 of   539 	 L: 1.384 	 -- 1220.307
 7   150 of   539 	 L: 

In [None]:
# TODO: evaluate recall@N for validation / test set

#### Used train rtriever to index Wikipedia

In [None]:
from eli5_utils import *

qar_tokenizer, qar_model = make_qa_retriever_model(
    model_name="google/bert_uncased_L-8_H-512_A-8",
    from_file='{}_{}.pth'.format("retriever_models/eli5_retriever_model_512", 9),
    device="cuda:0"
)

kilt_snippets_dbuilder = KiltSnippets(data_dir='kilt_snippets_100w')
kilt_snippets_dbuilder.download_and_prepare()
wiki_passages = kilt_snippets_dbuilder.as_dataset(split=nlp.splits.Split.TRAIN)

make_qa_dense_index(qar_model, qar_tokenizer,
                    wiki_passages,
                    batch_size=512, max_length=96,
                    index_name='kilt_passages_reps_16.dat',
                    device='cuda:0')


#### Training the Seq2seq model

In [1]:
from eli5_utils import *

eli5_dbuilder = ELI5NLP(data_dir='eli5')
eli5_dbuilder.download_and_prepare()

eli5_train = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TRAIN)
eli5_valid = eli5_dbuilder.as_dataset(split=nlp.splits.Split.VALIDATION)

eli5_train_docs = json.load(open('eli5_train_precomputed_dense_docs.json'))
eli5_valid_docs = json.load(open('eli5_valid_precomputed_dense_docs.json'))

s2s_train_dset = ELI5DatasetS2S(eli5_train, document_cache=dict([(k, d) for k, d, src_ls in eli5_train_docs]))
s2s_valid_dset = ELI5DatasetS2S(eli5_valid, document_cache=dict([(k, d) for k, d, src_ls in eli5_valid_docs]), training=False)

qa_s2s_tokenizer, qa_s2s_model = make_qa_s2s_model(
    model_name="bart-large",
    from_file=None,
    device="cuda:0"
)

class ArgumentsS2S():
    def __init__(self):
        self.batch_size = 1
        self.backward_freq = 16
        self.max_length = 1024
        self.print_freq = 1000
        self.model_save_name = "seq2seq_models/eli5_bart_model_512"
        self.learning_rate = 2e-5
        self.num_epochs = 10

s2s_args = ArgumentsS2S()
s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8)
s2s_scheduler = get_linear_schedule_with_warmup(
        s2s_optimizer,
        num_warmup_steps=100,
        num_training_steps=s2s_args.num_epochs * math.ceil(len(s2s_train_dset) / s2s_args.batch_size)
)

In [2]:
for e in range(s2s_args.num_epochs):
    train_qa_s2s_epoch(
        qa_s2s_model,
        s2s_train_dset, qa_s2s_tokenizer,
        s2s_optimizer, s2s_scheduler,
        s2s_args, e
    )
    m_save_dict = {
        'model': qa_s2s_model.state_dict(),
        'optimizer': s2s_optimizer.state_dict(),
        'scheduler': s2s_scheduler.state_dict(),
    }
    print("Saving model {}".format(s2s_args.model_save_name))
    torch.save(m_save_dict, '{}_{}.pth'.format(s2s_args.model_save_name, e))

