In [1]:
import json
import torch
import torch.nn as nn

from tqdm import tqdm
from fairseq.models.bart import BARTModel
from utils import read_lines

In [2]:
PATH = json.load(open('../path_config.json'))

#### Load Dataset

In [3]:
bart_xsum = BARTModel.from_pretrained(PATH['bart.large.xsum'],
                                      checkpoint_file='model.pt',
                                      data_name_or_path=PATH['bart.large.xsum'])

In [4]:
bart_xsum.cuda()
bart_xsum.eval()
bart_xsum.half()
print('- model loaded.')

- model loaded.


#### Read XSum

In [5]:
document_path = PATH['xsum_fariseq'] + '/train.source'
target_path = PATH['xsum_fariseq'] + '/train.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

203575


#### Inference

In [6]:
outputs = []

In [7]:
data = xsum_source[:100]

In [8]:
class Args:
    batch_size = 8
    beam_size = 6
    lenpen = 1.0
    max_len = 60
    min_len = 10

args = Args()

In [9]:
count = 1
slines = [data[0]]
for sline in tqdm(data[1:]):
    if count % args.batch_size == 0:
        with torch.no_grad():
            hypotheses_batch = bart_xsum.sample(slines,
                                                beam=args.beam_size, lenpen=args.lenpen,
                                                max_len_b=args.max_len, min_len=args.min_len,
                                                no_repeat_ngram_size=3, verbose=True)
        for hypothesis in hypotheses_batch:
            outputs.append(hypothesis)
        slines = []

    slines.append(sline)
    count += 1

if slines != []:
    hypotheses_batch = bart_xsum.sample(slines,
                                        beam=args.beam_size, lenpen=args.lenpen,
                                        max_len_b=args.max_len, min_len=args.min_len,
                                        no_repeat_ngram_size=3)
    for hypothesis in hypotheses_batch:
        outputs.append(hypothesis)

100%|██████████| 99/99 [00:30<00:00,  3.30it/s]


In [10]:
outputs[20]

'A Royal Navy submarine has returned to Plymouth after 11 months at sea.'

In [11]:
xsum_target[20]

'A Royal Navy submarine has returned home to Devonport after 11 months at sea.'