In [1]:
import neusum
import torch.nn as nn
import torch
import logging

In [None]:
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
log = logging.getLogger(__name__)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
ONLINE_PROCESS_DATA = True
STRIPPING_MODE = 'none' #'none', 'normal', 'magnitude', 'topk'
MAX_SENT_LENGTH = 80
MAX_DOC_LEN = 500
NORM_LAMBDA = 20
THRESHOLD = 0.0001
SAVE_PATH = ''
BATCH_SIZE = 4

DROP_TOO_SHORT = 50
DROP_TOO_LONG = 500
DEC_INIT = 'simple' # simple, att

### Preprocessing

In [None]:
def prepare_data_online(
    train_src: str, 
    src_vocab: str, 
    train_tgt, 
    tgt_vocab, 
    train_oracle, 
    train_src_rouge, 
    src_section: str, 
    drop_too_short: int = 10, 
    drop_too_long: int = 500, 
    bert_annotation: str = ''
):
    dicts = {}
    dicts['src'] = initVocabulary('source', [train_src], src_vocab, 1000000)
    dicts['tgt'] = None

    logger.info('Preparing training ...')
    train = {}
    train['src'], train['src_raw'], train['tgt'], \
        train['oracle'], train['src_rouge'], \
        train['src_section'], train['src_section_raw'], train['bert_annotation'] = makeData(train_src, train_tgt,
                                                       train_oracle, train_src_rouge, src_section,
                                                       dicts['src'], dicts['tgt'],
                                                       drop_too_short, drop_too_long, bert_annotation)

    dataset = { 'dicts': dicts, 'train': train, }
    return dataset

In [None]:
def evalModel(model: nn.Module, 
              summarizer, 
              evalData, 
              output_len: int = 1, 
              prefix: str = 'dev', 
              postfix: str = '', 
              stripping_mode: str = STRIPPING_MODE, 
              specifyEpoch: int = -1):
    """
    Output length is used for debug purpose, meant to output more sentence at once.
    Make sure the beam_size is greater or equal to n_best_size.
    (The code is warped in the if output_len > 1)


    TODO: involve log-linear model in decoding?!
    """
    global evalModelCount

    if specifyEpoch <= 0:
        specifyEpoch = evalModelCount

    predict, gold, predict_sents, attnScore, topkPred, gold_sents = getLabel(
        summarizer, 
        evalData, 
        output_len, 
        threshold=0.0001, 
        stripping_mode=stripping_mode, 
        isEval=True
    )

    scores_total = compute_selection_acc(gold, predict)
    scores_hit1 = compute_selection_acc(gold, predict, hit1mode=True)
    scores_metrics = compute_metric(gold, predict)
    scores_bleu_raw = compute_bleu_raw(
        gold_sents, predict_sents, gold, predict)

    if postfix:
        postfix = '.' + postfix

    with open(os.path.join(SAVE_PATH, '{1}{2}.out.{0}'.format(specifyEpoch, prefix, postfix)), 'w', encoding='utf-8') as of:
        for p, sent in zip(predict, predict_sents):
            of.write('{0}\t{1}'.format(sent, p) + '\n')

    if output_len > 1:

        with open(os.path.join(SAVE_PATH, '{1}{2}.{3}_n_out.{0}'.format(specifyEpoch, prefix, postfix, output_len)), 'w', encoding='utf-8') as of:
            for p, sent, score, topk_idx in zip(predict, predict_sents, attnScore, topkPred):
                of.write('{0}\t{1}\t{2}\t{3}\n'.format(
                             sent, 
                             p, 
                             tuple(s.tolist() for s in score), 
                             tuple(ti.tolist() for ti in topk_idx))
                        )
                # note that, the topk_idx is not necessary be the same as idx,
                # since beam search finds the highest score of a single route,
                # so it is possible that it didn't select the "max attention" score index on a step

    return [scores_total, scores_hit1, scores_metrics, scores_bleu_raw]

In [None]:
def load_dev_data(summarizer, 
                  src_file: str, 
                  oracl_file: str, 
                  src_section_file: str, 
                  drop_too_short: int = 10, 
                  drop_too_long: int = 500, 
                  test_bert_annotation: str = '', 
                  postfix: str = '', 
                  qtype: str = ''):
    """
    Load dev/test set data. (similar with makeData in onlinePreprocess.py)
    """

    def addPair(f1, f2, f3, f4=None):
        if not f4:
            for x, x2, y1 in zip(f1, f2, f3):
                yield (x, x2, y1, None)
            yield (None, None, None, None)
        else:
            for x, x2, y1, y2 in zip(f1, f2, f3, f4):
                yield (x, x2, y1, y2)
            yield (None, None, None, None)

    if postfix:
        # assert if postfix is given, then it is running on train data
        keywords = loglinear.Config.Keyword[qtype]
        use_good = True
    else:
        # normal case
        keywords = []
        use_good = False

    # here tgt is sentence index
    seq_length = MAX_SENT_LENGTH
    dataset, raw = [], []
    src_raw, tgt_raw = [], []
    src_section_raw, src_section_batch = [], []
    src_batch, tgt_batch = [], []
    oracle_batch = []
    srcF = open(src_file, encoding='utf-8')
    srcSectionF = open(src_section_file, encoding='utf-8')
    # tgtF = open(tgt_file, encoding='utf-8')
    oracleF = open(oracl_file, encoding='utf-8')

    if test_bert_annotation:
        bertF = open(test_bert_annotation, encoding='utf-8')
        bert_annotation_batch = []
    else:
        bertF = None

    for sline, secline, oline, bline in addPair(srcF, srcSectionF, oracleF, bertF):
        if (sline is not None) and (oline is not None):
            if sline == "" or oline == "":
                continue
            sline = sline.strip()
            secline = secline.strip()
            oline = oline.strip()
            if test_bert_annotation:
                bline = bline.strip()
            srcSents = sline.split('##SENT##')
            srcSectionSents = secline.split('##SENT##')

            if len(srcSents) < drop_too_short or len(srcSents) > drop_too_long:
                logger.info('Drop data too short or too long')
                continue

            # this will transfer string of tuple to tuple
            oracle_combination = make_tuple(oline.split('\t')[0])
            oracle_combination = [x for x in oracle_combination]  # no sentinel
            if test_bert_annotation:
                bert_annotation_combination = make_tuple(bline.split('\t')[0])
                bert_annotation_combination = [
                    x for x in bert_annotation_combination]  # no sentinel
            srcWords = [x.split(' ')[:seq_length] for x in srcSents]
            srcSectionWords = [x.split(' ')[:seq_length]
                               for x in srcSectionSents]
            # tgtWords = ' '.join(tgtSents)
            src_raw.append(srcSents)
            src_batch.append(srcWords)
            src_section_raw.append(srcSectionSents)
            src_section_batch.append(srcSectionWords)
            # tgt_raw.append(tgtWords)
            oracle_batch.append(torch.LongTensor(oracle_combination))
            if test_bert_annotation:
                bert_annotation_batch.append(
                    torch.LongTensor(bert_annotation_combination))

            if len(src_batch) < BATCH_SIZE:
                continue
        else:
            # at the end of file, check last batch
            if len(src_batch) == 0:
                break
        if test_bert_annotation:
            data = summarizer.buildData(
                src_batch, src_raw, None, oracle_batch, None, src_section_batch, src_section_raw, bert_annotation=bert_annotation_batch, good_patterns=keywords, use_good=use_good)
        else:
            data = summarizer.buildData(
                src_batch, src_raw, None, oracle_batch, None, src_section_batch, src_section_raw, good_patterns=keywords, use_good=use_good)
        dataset.append(data)
        src_batch, tgt_batch = [], []
        src_raw, tgt_raw = [], []
        src_section_raw, src_section_batch = [], []
        oracle_batch = []
        if test_bert_annotation:
            bert_annotation_batch = []

    srcF.close()
    # tgtF.close()
    oracleF.close()
    if test_bert_annotation:
        bertF.close()

    return dataset

### Evaluation

In [2]:
if ONLINE_PROCESS_DATA:
    logger.info('Online Preprocessing data (to get vocabulary dictionary).')
    onlinePreprocess.seq_length = MAX_SENT_LENGTH
    onlinePreprocess.max_doc_len = MAX_DOC_LEN
    onlinePreprocess.shuffle = 1
    onlinePreprocess.norm_lambda = NORM_LAMBDA
    dataset = prepare_data_online(
        opt.train_src, 
        opt.src_vocab, 
        opt.train_tgt, 
        opt.tgt_vocab, 
        opt.train_oracle,
        opt.train_src_rouge, 
        opt.train_src_section, 
        opt.drop_too_short, 
        opt.drop_too_long
    )
else:
    logger.info('Use preprocessed data stored in checkpoint.')
    dataset = {} # this is used for the summarizer (only need the 'dict' part)

logger.info('Loading checkpoint...')
if opt.specific_epoch > 0:
    model_selected = os.path.join(SAVE_PATH, 'model_epoch_%s.pt' % opt.specific_epoch)
    logger.info('Loading from the specific epoch checkpoint "%s"' %
                model_selected)
else:
    # Find the latest model to load
    model_path = glob(os.path.join(SAVE_PATH, '*.pt'))
    if not model_path:
        raise ValueError("Can't find model %s" %
                         os.path.join(SAVE_PATH, '*.pt'))

    # make sure not load the log linear model
    model_selected = None
    for candidate in reversed(sorted(model_path, key=os.path.getmtime)):
        if 'log_linear' not in candidate:
            model_selected = candidate
            break
    assert model_selected is not None
    logger.info('Loading from the latest model "%s"' % model_selected)

    checkpoint = torch.load(model_selected, map_location=device)

logger.info('\tprevious training epochs: %d' % checkpoint['epoch'])

if not ONLINE_PROCESS_DATA:
    dataset['dicts'] = checkpoint['dicts']
dicts = checkpoint['dicts']

logger.info(' * vocabulary size. source = %d' %
            (dicts['src'].size()))
if ONLINE_PROCESS_DATA:
    logger.info(' * number of training sentences. %d' % len(dataset['train']['src']))
logger.info(' * maximum batch size. %d' % BATCH_SIZE)

logger.info('Building model...')

sent_encoder = neusum.Models.Encoder(opt, dicts['src'])
doc_encoder = neusum.Models.DocumentEncoder(opt)
pointer = neusum.Models.Pointer(opt)
if DEC_INIT == "simple":
    decIniter = neusum.Models.DecInit(opt)
elif DEC_INIT == "att":
    decIniter = neusum.Models.DecInitAtt(opt)
else:
    raise ValueError('Unknown decoder init method: {0}'.format(DEC_INIT))

model = neusum.Models.NMTModel(sent_encoder, doc_encoder, pointer, decIniter).to(device)

# load model
logger.info('Loading trained model...')

model.load_state_dict(checkpoint['model'])
summarizer = neusum.Summarizer(opt, model, dataset)


testData = load_dev_data(summarizer, opt.dev_input_src, opt.dev_ref,
                         opt.dev_input_src_section, test_bert_annotation=opt.test_bert_annotation)
model.eval()
scores = evalModel(model, summarizer, testData,
                   opt.output_len, 'test', opt.set_postfix, opt.stripping_mode, checkpoint['epoch'])
logger.info('Using checkpoint: %s' % model_selected)
logger.info('Key hyperparmeters:')
logger.info('\tMax Decode Steps: %d' % opt.max_decode_step)
logger.info('\tKeep Data: [%d, %d)' % (opt.drop_too_short, opt.drop_too_long))
logger.info('\tTrained epoch: %s' % checkpoint['epoch'])
logger.info('Evaluate score: (accuracy) total / hit@1, {precision, recall, f1 score} (sentence-level)')
logger.info(scores)

neusum.Summarizer.Summarizer