In [1]:
from eli5_utils import *

eli5_dbuilder = ELI5NLP(data_dir='eli5')
eli5_dbuilder.download_and_prepare()

eli5_train = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TRAIN)
eli5_valid = eli5_dbuilder.as_dataset(split=nlp.splits.Split.VALIDATION)
eli5_test = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TEST)


#### Training the retriever

In [None]:
class ArgumentsQAR():
    def __init__(self):
        self.batch_size = 320
        self.max_length = 96
        self.print_freq = 100
        self.pretrained_model_name = "google/bert_uncased_L-8_H-512_A-8"
        self.model_save_name = "retriever_models/eli5_retriever_model_512"
        self.learning_rate = 2e-5
        self.num_epochs = 10

qar_args = ArgumentsQAR()

qar_train_dset = ELI5DatasetQARetriver(eli5_train, min_answer_length=64, training=True)
qar_valid_dset = ELI5DatasetQARetriver(eli5_valid, min_answer_length=64, training=False)

qar_tokenizer, qar_model = make_qa_retriever_model(
    model_name=qar_args.pretrained_model_name,
    from_file=None,
    device="cuda:0"
)

qar_optimizer = AdamW(qar_model.parameters(), lr=qar_args.learning_rate, eps=1e-8)
qar_scheduler = get_linear_schedule_with_warmup(
        qar_optimizer,
        num_warmup_steps=100,
        num_training_steps=qar_args.num_epochs * math.ceil(len(qar_train_dset) / qar_args.batch_size)
)

for e in range(qar_args.num_epochs):
    train_qa_retriever_epoch(
        qar_model, qar_train_dset, qar_tokenizer,
        qar_optimizer, qar_scheduler, qar_args, e
    )
    m_save_dict = {
        'model': qar_model.state_dict(),
        'optimizer': qar_optimizer.state_dict(),
        'scheduler': qar_scheduler.state_dict(),
    }
    print("Saving model {}".format(qar_args.model_save_name))
    torch.save(m_save_dict, '{}_{}.pth'.format(qar_args.model_save_name, e))
    eval_loss = evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args)
    print("Evaluation loss epoch {:4d}: {:.3f}".format(e, eval_loss))


In [None]:
# TODO: evaluate recall@N for validation / test set

In [None]:
kilt_snippets_dbuilder = KiltSnippets(data_dir='kilt_snippets_100w')
kilt_snippets_dbuilder.download_and_prepare()
wiki_passages = kilt_snippets_dbuilder.as_dataset(split=nlp.splits.Split.TRAIN)

make_qa_dense_index(qar_model, qar_tokenizer,
                    wiki_passages,
                    batch_size=512,
                    index_name='kilt_passages_reps_16.dat',
                    device='cuda:0')


#### Training the Seq2seq model

In [1]:
from eli5_utils import *

eli5_dbuilder = ELI5NLP(data_dir='eli5')
eli5_dbuilder.download_and_prepare()

eli5_train = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TRAIN)
eli5_valid = eli5_dbuilder.as_dataset(split=nlp.splits.Split.VALIDATION)
eli5_test = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TEST)

eli5_train_docs = json.load(open('eli5_train_precomputed_dense_docs.json'))
eli5_valid_docs = json.load(open('eli5_valid_precomputed_dense_docs.json'))
eli5_test_docs = json.load(open('eli5_test_precomputed_dense_docs.json'))

s2s_train_dset = ELI5DatasetS2S(eli5_train, document_cache=dict([(k, d) for k, d, src_ls in eli5_train_docs]))
s2s_valid_dset = ELI5DatasetS2S(eli5_valid, document_cache=dict([(k, d) for k, d, src_ls in eli5_valid_docs]), training=False)

In [2]:
qa_s2s_tokenizer, qa_s2s_model = make_qa_s2s_model(
    model_name="bart-large",
    from_file=None,
    device="cuda:0"
)

In [3]:
class ArgumentsS2S():
    def __init__(self):
        self.batch_size = 1
        self.backward_freq = 8
        self.max_length = 64
        self.print_freq = 10
        self.model_save_name = "seq2seq_models/eli5_bart_model_256"
        self.learning_rate = 2e-5
        self.num_epochs = 10

s2s_args = ArgumentsS2S()
s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8)
s2s_scheduler = get_linear_schedule_with_warmup(
        s2s_optimizer,
        num_warmup_steps=100,
        num_training_steps=s2s_args.num_epochs * math.ceil(len(s2s_train_dset) / s2s_args.batch_size)
)

In [4]:
batch_ids = make_qa_s2s_batch([s2s_train_dset[0]], qa_s2s_tokenizer, max_len=256, device="cuda:0")
batch_ids

{'input_ids': tensor([[    0,   864,    35,    11,  1037, 45676,     5,   477,     9, 21025,
              5,    78,    80,  1974,    19,    10,  6187,   111,    62,     5,
           1692,   111,    45,  1675,  6187,  1974,   939,   120,   167,  5377,
             35, 28696,   642, 15698, 23492,  3212,   238,    98,   373,   142,
              5, 23492,   154,   869,   531, 27545,     5,  2136,    22, 17745,
           3006, 37679,   113,   227,  1530,    98,    25,    45,     7,  1157,
              5,   869,     7,  3212, 33785,  1769,     8,  4296,   492,     5,
           2525,   117,    86,     7,  3211,     4,  2128,     5,    80,  1492,
             32,  2771,     6,  2455,    65,  2559,   486,     9,    22,  3662,
           4494,  2901,   228,   278,     9,   204, 12071,     4,     5,    97,
           1973,     7,  3679,    10,  6187,    16,     7,   304,    41,  2555,
          14261,    50,  1312,     7,  1803,   143,  1323,  6187,     4,    10,
            516,    16,  31

In [5]:
loss = qa_s2s_model(**batch_ids)[0]
loss

tensor(5.0085, device='cuda:0', grad_fn=<NllLossBackward>)

In [6]:
loss.backward()

In [7]:
def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0):
    model.train()
    # make iterator
    train_sampler = SequentialSampler(dataset)
    model_collate_fn = functools.partial(
        make_qa_s2s_batch,
        tokenizer=tokenizer, max_len=args.max_length, device='cuda:0'
    )
    data_loader = DataLoader(
        dataset, batch_size=args.batch_size,
        sampler=train_sampler, collate_fn=model_collate_fn
    )
    epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
    # accumulate loss since last print
    loc_steps = 0
    loc_loss = 0.0
    st_time = time()
    for step, batch_inputs in enumerate(epoch_iterator):
        print(step)
        loss = model(**batch_inputs)[0]
        loss.backward()
        # optimizer
        if step % args.backward_freq == 0:
            optimizer.step()
            scheduler.step()
            model.zero_grad()
        # some printing within the epoch
        loc_loss += loss.item()
        loc_steps += 1
        if step % args.print_freq == 0:
            print(
                "{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
                    e, step,
                    len(dataset) // args.batch_size,
                    loc_loss / loc_steps,
                    time() - st_time,
                )
            )
            loc_loss = 0
            loc_steps = 0



In [4]:
train_qa_s2s_epoch(qa_s2s_model, s2s_train_dset, qa_s2s_tokenizer, s2s_optimizer, s2s_scheduler, s2s_args, e=0)

0
 0     0 of 585009 	 L: 12.297 	 -- 0.385
1


RuntimeError: CUDA out of memory. Tried to allocate 14.00 MiB (GPU 0; 7.79 GiB total capacity; 6.22 GiB already allocated; 13.19 MiB free; 6.28 GiB reserved in total by PyTorch)

In [10]:
batch_ids['input_ids']

RuntimeError: CUDA error: device-side assert triggered