In [1]:
from eli5_utils import *

eli5_dbuilder = ELI5NLP(data_dir='eli5')
eli5_dbuilder.download_and_prepare()

eli5_train = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TRAIN)
eli5_valid = eli5_dbuilder.as_dataset(split=nlp.splits.Split.VALIDATION)
eli5_test = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TEST)

#### Training the retriever

In [2]:
class ArgumentsQAR():
    def __init__(self):
        self.batch_size = 2048
        self.max_length = 128
        self.checkpoint_batch_size = 128
        self.print_freq = 10
        self.pretrained_model_name = "google/bert_uncased_L-8_H-512_A-8"
        self.model_save_name = "retriever_models/eli5_retriever_model_2048"
        self.learning_rate = 2e-5
        self.num_epochs = 20

qar_args = ArgumentsQAR()

qar_train_dset = ELI5DatasetQARetriver(eli5_train, min_answer_length=64, training=True)
qar_valid_dset = ELI5DatasetQARetriver(eli5_valid, min_answer_length=64, training=False)

qar_tokenizer, qar_model = make_qa_retriever_model(
    model_name=qar_args.pretrained_model_name,
    from_file=None,
    device="cuda:0"
)

qar_optimizer = AdamW(qar_model.parameters(), lr=qar_args.learning_rate, eps=1e-8)
qar_scheduler = get_linear_schedule_with_warmup(
        qar_optimizer,
        num_warmup_steps=100,
        num_training_steps=qar_args.num_epochs * math.ceil(len(qar_train_dset) / qar_args.batch_size)
)

for e in range(qar_args.num_epochs):
    train_qa_retriever_epoch(
        qar_model, qar_train_dset, qar_tokenizer,
        qar_optimizer, qar_scheduler, qar_args, e
    )
    m_save_dict = {
        'model': qar_model.state_dict(),
        'optimizer': qar_optimizer.state_dict(),
        'scheduler': qar_scheduler.state_dict(),
    }
    print("Saving model {}".format(qar_args.model_save_name))
    torch.save(m_save_dict, '{}_{}.pth'.format(qar_args.model_save_name, e))
    eval_loss = evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args)
    print("Evaluation loss epoch {:4d}: {:.3f}".format(e, eval_loss))


 0     0 of   539 	 L: 6.342 	 -- 9.093
 0    10 of   539 	 L: 6.347 	 -- 95.173
 0    20 of   539 	 L: 6.308 	 -- 182.368
 0    30 of   539 	 L: 6.243 	 -- 271.287
 0    40 of   539 	 L: 6.161 	 -- 362.347
 0    50 of   539 	 L: 6.010 	 -- 454.841
 0    60 of   539 	 L: 5.733 	 -- 545.507
 0    70 of   539 	 L: 5.391 	 -- 638.503
 0    80 of   539 	 L: 4.973 	 -- 729.353
 0    90 of   539 	 L: 4.670 	 -- 818.165
 0   100 of   539 	 L: 4.343 	 -- 908.974
 0   110 of   539 	 L: 4.084 	 -- 997.924
 0   120 of   539 	 L: 3.877 	 -- 1086.849
 0   130 of   539 	 L: 3.733 	 -- 1176.005
 0   140 of   539 	 L: 3.645 	 -- 1264.747
 0   150 of   539 	 L: 3.508 	 -- 1353.476
 0   160 of   539 	 L: 3.459 	 -- 1443.157
 0   170 of   539 	 L: 3.384 	 -- 1530.991
 0   180 of   539 	 L: 3.237 	 -- 1619.006
 0   190 of   539 	 L: 3.172 	 -- 1706.880
 0   200 of   539 	 L: 3.150 	 -- 1794.816
 0   210 of   539 	 L: 3.081 	 -- 1883.141
 0   220 of   539 	 L: 3.028 	 -- 1972.121
 0   230 of   539 	 L: 2.9

 3   240 of   539 	 L: 1.592 	 -- 2081.457
 3   250 of   539 	 L: 1.572 	 -- 2167.973
 3   260 of   539 	 L: 1.563 	 -- 2254.105
 3   270 of   539 	 L: 1.565 	 -- 2340.173
 3   280 of   539 	 L: 1.548 	 -- 2426.425
 3   290 of   539 	 L: 1.610 	 -- 2512.734
 3   300 of   539 	 L: 1.562 	 -- 2598.918
 3   310 of   539 	 L: 1.535 	 -- 2685.304
 3   320 of   539 	 L: 1.556 	 -- 2772.164
 3   330 of   539 	 L: 1.557 	 -- 2858.908
 3   340 of   539 	 L: 1.598 	 -- 2945.203
 3   350 of   539 	 L: 1.581 	 -- 3031.560
 3   360 of   539 	 L: 1.568 	 -- 3117.719
 3   370 of   539 	 L: 1.534 	 -- 3204.090
 3   380 of   539 	 L: 1.583 	 -- 3290.546
 3   390 of   539 	 L: 1.542 	 -- 3376.847
 3   400 of   539 	 L: 1.580 	 -- 3463.172
 3   410 of   539 	 L: 1.558 	 -- 3549.349
 3   420 of   539 	 L: 1.541 	 -- 3635.657
 3   430 of   539 	 L: 1.567 	 -- 3722.059
 3   440 of   539 	 L: 1.571 	 -- 3808.347
 3   450 of   539 	 L: 1.557 	 -- 3894.831
 3   460 of   539 	 L: 1.574 	 -- 3981.229
 3   470 of

 6   480 of   539 	 L: 1.359 	 -- 4175.100
 6   490 of   539 	 L: 1.363 	 -- 4262.016
 6   500 of   539 	 L: 1.304 	 -- 4348.756
 6   510 of   539 	 L: 1.348 	 -- 4435.486
 6   520 of   539 	 L: 1.379 	 -- 4522.126
 6   530 of   539 	 L: 1.333 	 -- 4608.723
Saving model retriever_models/eli5_retriever_model_512
Evaluation loss epoch    6: 1.052
 7     0 of   539 	 L: 1.236 	 -- 8.700
 7    10 of   539 	 L: 1.370 	 -- 95.146
 7    20 of   539 	 L: 1.316 	 -- 181.676
 7    30 of   539 	 L: 1.288 	 -- 268.245
 7    40 of   539 	 L: 1.285 	 -- 354.808
 7    50 of   539 	 L: 1.307 	 -- 441.356
 7    60 of   539 	 L: 1.277 	 -- 527.899
 7    70 of   539 	 L: 1.306 	 -- 614.368
 7    80 of   539 	 L: 1.296 	 -- 700.755
 7    90 of   539 	 L: 1.303 	 -- 787.304
 7   100 of   539 	 L: 1.329 	 -- 873.831
 7   110 of   539 	 L: 1.322 	 -- 960.588
 7   120 of   539 	 L: 1.340 	 -- 1047.184
 7   130 of   539 	 L: 1.303 	 -- 1133.643
 7   140 of   539 	 L: 1.384 	 -- 1220.307
 7   150 of   539 	 L: 

In [None]:
# TODO: evaluate recall@N for validation / test set

#### Used train rtriever to index Wikipedia

In [None]:
from eli5_utils import *

qar_tokenizer, qar_model = make_qa_retriever_model(
    model_name="google/bert_uncased_L-8_H-512_A-8",
    from_file='{}_{}.pth'.format("retriever_models/eli5_retriever_model_512", 9),
    device="cuda:0"
)

kilt_snippets_dbuilder = KiltSnippets(data_dir='kilt_snippets_100w')
kilt_snippets_dbuilder.download_and_prepare()
wiki_passages = kilt_snippets_dbuilder.as_dataset(split=nlp.splits.Split.TRAIN)

make_qa_dense_index(qar_model, qar_tokenizer,
                    wiki_passages,
                    batch_size=512, max_length=96,
                    index_name='kilt_passages_reps_16.dat',
                    device='cuda:0')


#### Training the Seq2seq model

In [1]:
from eli5_utils import *

eli5_dbuilder = ELI5NLP(data_dir='eli5')
eli5_dbuilder.download_and_prepare()

eli5_train = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TRAIN)
eli5_valid = eli5_dbuilder.as_dataset(split=nlp.splits.Split.VALIDATION)

eli5_train_docs = json.load(open('precomputed/eli5_train_precomputed_dense_docs.json'))
eli5_valid_docs = json.load(open('precomputed/eli5_valid_precomputed_dense_docs.json'))

s2s_train_dset = ELI5DatasetS2S(eli5_train, document_cache=dict([(k, d) for k, d, src_ls in eli5_train_docs]))
s2s_valid_dset = ELI5DatasetS2S(eli5_valid, document_cache=dict([(k, d) for k, d, src_ls in eli5_valid_docs]), training=False)

qa_s2s_tokenizer, qa_s2s_model = make_qa_s2s_model(
    model_name="t5-small",
    from_file=None,
    device="cuda:0"
)

class ArgumentsS2S():
    def __init__(self):
        self.batch_size = 4
        self.backward_freq = 4
        self.max_length = 1024
        self.print_freq = 4000
        self.model_save_name = "seq2seq_models/eli5_t5_model_1024"
        self.learning_rate = 1e-4
        self.num_epochs = 10

s2s_args = ArgumentsS2S()
s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8)
s2s_scheduler = get_linear_schedule_with_warmup(
        s2s_optimizer,
        num_warmup_steps=400,
        num_training_steps=s2s_args.num_epochs * math.ceil(len(s2s_train_dset) / s2s_args.batch_size)
)

for e in range(s2s_args.num_epochs):
    train_qa_s2s_epoch(
        qa_s2s_model,
        s2s_train_dset, qa_s2s_tokenizer,
        s2s_optimizer, s2s_scheduler,
        s2s_args, e
    )
    m_save_dict = {
        'model': qa_s2s_model.state_dict(),
        'optimizer': s2s_optimizer.state_dict(),
        'scheduler': s2s_scheduler.state_dict(),
    }
    print("Saving model {}".format(s2s_args.model_save_name))
    torch.save(m_save_dict, '{}_{}.pth'.format(s2s_args.model_save_name, e))



In [2]:
for e in range(s2s_args.num_epochs):
    train_qa_s2s_epoch(
        qa_s2s_model,
        s2s_train_dset, qa_s2s_tokenizer,
        s2s_optimizer, s2s_scheduler,
        s2s_args, e
    )
    m_save_dict = {
        'model': qa_s2s_model.state_dict(),
        'optimizer': s2s_optimizer.state_dict(),
        'scheduler': s2s_scheduler.state_dict(),
    }
    print("Saving model {}".format(s2s_args.model_save_name))
    torch.save(m_save_dict, '{}_{}.pth'.format(s2s_args.model_save_name, e))



#### Testing seq2seq model

In [1]:
from eli5_utils import *

eli5_dbuilder = ELI5NLP(data_dir='eli5')
eli5_dbuilder.download_and_prepare()

eli5_train = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TRAIN)
eli5_valid = eli5_dbuilder.as_dataset(split=nlp.splits.Split.VALIDATION)

eli5_train_docs = json.load(open('precomputed/eli5_train_precomputed_dense_docs.json'))
eli5_valid_docs = json.load(open('precomputed/eli5_valid_precomputed_dense_docs.json'))

s2s_train_dset = ELI5DatasetS2S(eli5_train, document_cache=dict([(k, d) for k, d, src_ls in eli5_train_docs]))
s2s_valid_dset = ELI5DatasetS2S(eli5_valid, document_cache=dict([(k, d) for k, d, src_ls in eli5_valid_docs]), training=False)

qa_s2s_tokenizer, qa_s2s_model = make_qa_s2s_model(
    model_name="t5-small",
    from_file=None,
    device="cuda:0"
)

class ArgumentsS2S():
    def __init__(self):
        self.batch_size = 1
        self.backward_freq = 16
        self.max_length = 512
        self.print_freq = 1000
        self.model_save_name = "seq2seq_models/eli5_t5_model_512"
        self.learning_rate = 2e-5
        self.num_epochs = 10

s2s_args = ArgumentsS2S()


In [32]:
qd, a = s2s_valid_dset[123]
qd.split(' context: ')[0]

'question: why is google fibre taking so long to roll out?'

In [31]:
qa_s2s_generate(qd, qa_s2s_model, qa_s2s_tokenizer,
    num_answers=1,
    num_beams=2,
    max_input_length=512,
    device="cuda:0")[0]

"Because Google has to build a lot of cables, and they have to lay a lot more cables than you think. They also have to make sure that the cables are strong enough to handle the load, and make sure they don't break down. They have to build the cables in a way that they can make sure the cables don't get damaged."

In [30]:
qa_s2s_generate(qd, qa_s2s_model, qa_s2s_tokenizer,
    num_answers=1,
    do_sample=True,
    temp=0.7,
    top_p=0.95,
    max_input_length=512,
    device="cuda:0")[0]

'Google Fiber has a ton of infrastructure to build and maintain in addition to the fiber itself. A lot of that infrastructure is already in place. Google is just waiting for you to pay for it. Google Fiber is a lot like a business. You have to buy a lot of equipment, and you have to pay a lot.'

In [41]:
qd, a = s2s_valid_dset[0]
qd.split(' context: ')[0]

"question: the hubble telescope was launched in 1990. since our technology has advanced tremendously since then, wouldn't it be advantageous to send a more advanced telescope up there?"

In [42]:
qa_s2s_generate(qd, qa_s2s_model, qa_s2s_tokenizer,
    num_answers=1,
    num_beams=2,
    max_input_length=512,
    device="cuda:0")[0]

"The Hubble was launched in 1990. The Hubble telescope is still in orbit around the Sun. It's not going to get any more advanced than the Hubble, and it's not even close to the size of the Hubble. It was launched for science, not for fun. It is still a very small telescope."

In [43]:
qa_s2s_generate(qd, qa_s2s_model, qa_s2s_tokenizer,
    num_answers=1,
    do_sample=True,
    temp=0.7,
    top_p=0.95,
    max_input_length=512,
    device="cuda:0")[0]

"The Hubble was launched because people needed more information about the universe and Hubble was the best instrument available. The Hubble is still in use today for science purposes, as it's used to study a lot of the light that comes from distant stars. If you want to send a telescope up to any star that's far away, you'd need to send it to a star that has a much bigger size and brighter light, and you'd have to send another telescope to observe the star. Since telescopes are only a few meters across and the Hubble is only a couple of meters across, you don't really need to do that."

In [15]:
qd, a = s2s_valid_dset[1]
qd.split(' context: ')[0]

'question: could a computer be built out of electromagnetic relays instead of semiconductors?'

In [17]:
qa_s2s_generate(qd, qa_s2s_model, qa_s2s_tokenizer,
    num_answers=1,
    num_beams=2,
    min_len=128,
    max_input_length=512,
    device="cuda:0")[0]

"No, because they would be very expensive and very heavy, and they would need to be able to withstand a lot of heat and radiation to be useful. It would be a lot easier to just use a magnetic field instead of semiconductors. But it would be much easier to make a computer out of electromagnetic relays than a semiconductor one, and it would probably be cheaper to do so. So it's not really worth the cost to do it, but it's a good way to get a computer that can handle a lot more heat and heat without having to worry about it getting too hot or getting too cold. But if you want to build a computer with a lot less power, you'd need to make it a lot smaller, and you'd have to make sure that it could handle the amount of heat that would be required to run the computer."

In [36]:
qa_s2s_generate(qd, qa_s2s_model, qa_s2s_tokenizer,
    num_answers=1,
    do_sample=True,
    temp=0.7,
    top_p=0.95,
    max_input_length=512,
    device="cuda:0")[0]

"No. Electromagnetics are far too reactive to be used in a computer. Not to mention the inherent resistance. It's also not practical. It would be a lot more energy than you can ever hope to gain. Also, it's a very bad idea to build a computer out of magnetic relays."

In [38]:
qd, a = s2s_valid_dset[2]
qd.split(' context: ')[0]

"question: in trading places (1983, akroyd/murphy) how does the scheme at the end of the movie work? why would buying a lot of oj at a high price ruin the duke brothers? i have a vague understanding, but i'm hoping someone can explain it better to me. and maybe throw in some knowledge about the stock market in general while you're at it? thank you! edit: hey everyone, thanks for all the great answers! i think i actually really understand futures commodities now. who ever said reddit was a waste of time, eh? so yes, this question has been answered, and yes, i've seen/heard that great npr clip about this topic. thanks again to everyone!"

In [39]:
qa_s2s_generate(qd, qa_s2s_model, qa_s2s_tokenizer,
    num_answers=1,
    num_beams=2,
    max_input_length=512,
    device="cuda:0")[0]

"The Duke brothers have a lot of money, and they're trying to sell it to the Duke brothers. They're also trying to buy OJ at a high price. They don't want to buy it at a low price, because then they'd have to pay a lot more for it, and that would ruin them. So they buy a ton of OJ, and then sell it for a lot less."

In [40]:
qa_s2s_generate(qd, qa_s2s_model, qa_s2s_tokenizer,
    num_answers=1,
    do_sample=True,
    temp=0.7,
    top_p=0.95,
    max_input_length=512,
    device="cuda:0")[0]

"URL_0  I think this explains it quite well. And you might like this too. It's a great episode of The Simpsons. Edit: I've heard it described in other comments as a stock market scam. I find this to be hilarious. edit2: Edit3: I find it hilarious. I've seen this described as a crime syndicate or a scam artist."

In [2]:
st_time = time()
examples_with_generations = []
for i in range(2000):
    qd, a = s2s_valid_dset[i]
    beam = qa_s2s_generate(
        qd, qa_s2s_model, qa_s2s_tokenizer,
        num_answers=1,
        num_beams=2,
        min_len=128,
        max_input_length=512,
        device="cuda:0"
    )[0]
    examples_with_generations += [(qd.split(' context: ')[0], a, beam)]
    if i % 10 == 0:
        print(i, time() - st_time)

0 2.744018793106079
10 28.753477811813354
20 53.937344551086426
30 78.60568451881409
40 104.10094237327576
50 128.5769989490509
60 154.27992796897888
70 179.06538081169128
80 203.8809735774994
90 230.13178277015686
100 256.45514965057373
110 282.3637926578522
120 307.88806414604187
130 333.8857891559601
140 360.0709044933319
150 386.02467918395996
160 414.57378482818604
170 442.61042857170105
180 469.20493245124817
190 494.8593752384186
200 520.1254770755768
210 545.6333200931549
220 572.047559261322
230 597.3021047115326
240 622.8987302780151
250 648.7013621330261
260 674.3510477542877
270 700.575044631958
280 728.2109053134918
290 754.1497228145599
300 782.2213447093964
310 808.7190420627594
320 835.6998612880707
330 863.1649961471558
340 889.878288269043
350 916.4322776794434
360 942.4431409835815
370 969.9250016212463
380 996.5509266853333
390 1022.6301748752594
400 1048.5962445735931
410 1074.9585995674133
420 1101.2033276557922
430 1128.582790851593
440 1155.4053695201874
450 118

In [34]:
choice(examples_with_generations)

("question: why we can't transplant intestines it seems like we can transplant anything these days from hearts, livers, penises and even faces! i'm an ulcerative colitis patient and always wondered why it isn't possible.",
 "We can transplant intestines. However, ulcerative colitis is poor candidate because it is thought to be mostly caused by your own immune system attacking your intestines. They will probably attack someone else's intestines just as badly, so the disease will just reoccur. When you get a transplant, you usually have to be on powerful immunosuppressive drugs for the rest of your life to prevent rejection of the organ. But, immunosuppressives are also a treatment for ulcerative colitis itself. If the immunosuppressives were effective for you and the dangers of them seemed warranted, you would just take them in the first place and skip the transplant. If they don't work for you, then the transplant probably won't either.",
 "It's a lot easier to transplant a heart than 

In [6]:
choice(examples_with_generations)

("question: why do we have to build earth re-entry capsules for future space missions? - why can't we use the iss? why do we have to build earth re-entry capsules for future space missions? (all the future mars and asteroid ones) the capsules are heavy and difficult. why can't the mars trips for example just rendezvous with the international space station and then just come back to earth on a soyuz? wouldn't this be easier/cheaper/lighter? *edit:* thanks for all the great explanations.",
 'Because stopping in space isn\'t easy. For every bit of outward acceleration you build up as you head out, you have to turn around and stop it or slow it down enough to land or orbit. The same holds true on the way back. Depending on your flight plan, decelerating and matching orbits or docking chews up half or more of your fuel budget. It is much easier and consumes much less "reaction mass" if you take a man-containing craft and plummet it through the earth\'s atmosphere to allow it to lose a lot o

In [7]:
from nltk import PorterStemmer
from rouge import Rouge
from spacy.lang.en import English
from time import time

stemmer = PorterStemmer()
rouge = Rouge()
tokenizer = English().Defaults.create_tokenizer()
rouge_metric = nlp.load_metric('rouge')

def compute_rouge_eli5(compare_list):
    preds = [" ".join([stemmer.stem(str(w))
                       for w in tokenizer(pred)])
             for gold, pred in compare_list]
    golds = [" ".join([stemmer.stem(str(w))
                       for w in tokenizer(gold)])
             for gold, pred in compare_list]
    scores = rouge.get_scores(preds, golds, avg=True)
    return scores

def compute_rouge_nlp(compare_list):
    refs = [" ".join([stemmer.stem(str(w))
                      for w in tokenizer(a.replace('\n', ''))]) 
            for a, b in compare_list]
    preds = [" ".join([stemmer.stem(str(w))
                       for w in tokenizer(b.replace('\n', ''))]) 
             for a, b in compare_list]
    scores = rouge_metric.compute(preds, refs, rouge_types = ['rouge1', 'rouge2', 'rougeL'], use_stemmer=False)
    return scores

In [8]:
rouge_res = compute_rouge_eli5([(a, b) for q, a, b in examples_with_generations])
for t in ['1', '2', 'l']:
    print('R-{} \t {:.3f}'.format(t, rouge_res['rouge-{}'.format(t)]['f']))

R-1 	 0.267
R-2 	 0.056
R-l 	 0.267


In [9]:
rouge_res = rouge_res = compute_rouge_nlp([(a, b) for q, a, b in examples_with_generations])
for t in ['1', '2', 'L']:
    print('R-{} \t {:.3f}'.format(t, rouge_res['rouge{}'.format(t)].mid.fmeasure))

R-1 	 0.264
R-2 	 0.056
R-L 	 0.152


#### Does the retriever discriminate between generated and gold answers?

In [36]:
_ = qa_s2s_model.cpu()
torch.cuda.empty_cache()

In [37]:
qar_tokenizer, qar_model = make_qa_retriever_model(
    model_name="google/bert_uncased_L-8_H-512_A-8",
    from_file='{}_{}.pth'.format("retriever_models/eli5_retriever_model_512", 9),
    device="cuda:0"
)

In [45]:
with torch.no_grad():
    q_reps = embed_question_for_retrieval(
        [q for q, a, b in examples_with_generations],
        qar_tokenizer, qar_model, device='cuda:0'
    )
    a_reps = embed_passages_for_retrieval(
        {'passage_text': [a for q, a, b in examples_with_generations]},
        qar_tokenizer, qar_model, device='cuda:0'
    )
    b_reps = embed_passages_for_retrieval(
        {'passage_text': [b for q, a, b in examples_with_generations]},
        qar_tokenizer, qar_model, device='cuda:0'
    )

In [54]:
(((q_reps * a_reps).sum(axis=-1) - (q_reps * b_reps).sum(axis=-1)) > 0).sum() / len(examples_with_generations)

0.275