In [3]:
%load_ext autoreload
%autoreload 2

In [1]:
import torch
import json
exec('from __future__ import unicode_literals')

import copy
import os

import re
import sys
import random
import numpy as np

import spacy
spacy_nlp = spacy.load('en_core_web_sm')

module_path = os.path.abspath(os.path.join('../'))
if module_path not in sys.path:
    sys.path.append(module_path)
module_path = os.path.abspath(os.path.join('../onmt'))
if module_path not in sys.path:
    sys.path.append(module_path)

from onmt.translate.translator import build_translator
from onmt.constants import ModelTask
import onmt.keyphrase.eval as eval
from onmt.keyphrase.pke.utils import compute_document_frequency
from onmt.keyphrase.utils import validate_phrases, if_present_duplicate_phrases
from onmt.utils.parse import ArgumentParser
import onmt.keyphrase.pke as pke

from kp_gen_eval_transfer import _get_parser



In [3]:
import importlib
importlib.reload(eval)

<module 'onmt.keyphrase.eval' from '/zfs1/hdaqing/rum20/kp/OpenNMT-kpg-transfer/onmt/keyphrase/eval.py'>

### Wiki-transferred KPs

Also see **onmt.keyphrase.kp_inference.py**.

This part of notebook is meant to manually find a good prefix `prompt` to generate good pseudo labels. The actually code to generate pseudo labels is
```
### for kp20k/openkp/kptimes/stackex
python kp_gen_magkp_transfer.py -config config/transfer_kp/infer/keyphrase-one2seq-controlled.yml -tasks pred -exp_root_dir /zfs1/hdaqing/rum20/kp/transfer_exps/kp_bart_DA/bart_kppretrain_wiki_1e5_controlled -data_dir /zfs1/hdaqing/rum20/kp/data/kp/oag_v1_cs_nokp/ -output_dir /zfs1/hdaqing/rum20/kp/data/kp/oag_v1_cs_nokp_wikiTL/ -gpu 0 -batch_size 32 -beam_size 1 -max_length 60

### for Mag
python kp_gen_magkp_transfer_labelling.py -config config/transfer_kp/infer/keyphrase-one2seq-controlled.yml -tasks pred -exp_root_dir /zfs1/hdaqing/rum20/kp/transfer_exps/kp_bart_DA/bart_kppretrain_wiki_1e5_controlled -data_dir /zfs1/hdaqing/rum20/kp/data/kp/oag_v1_cs_nokp/ -output_dir /zfs1/hdaqing/rum20/kp/data/kp/oag_v1_cs_nokp_wikiTL/ -gpu 0 -batch_size 32 -beam_size 1 -max_length 60
```

#### Load functions

In [4]:
def eval_and_print(src_text, tgt_kps, pred_kps, pred_scores, unk_token='<unk>'):
    src_seq = [t.text.lower() for t in spacy_nlp(src_text, disable=["textcat"])]
    tgt_seqs = [[t.text.lower() for t in spacy_nlp(p, disable=["textcat"])] for p in tgt_kps]
    pred_seqs = [[t.text.lower() for t in spacy_nlp(p, disable=["textcat"])] for p in pred_kps]

    topk_range = ['k', 10]
    absent_topk_range = ['M']
    metric_names = ['f_score'] # 'precision', 'recall', 'f_score'

    # 1st filtering, ignore phrases having <unk> and puncs
    valid_pred_flags = validate_phrases(pred_seqs, unk_token)
    # 2nd filtering: filter out phrases that don't appear in text, and keep unique ones after stemming
    present_pred_flags, _, duplicate_flags = if_present_duplicate_phrases(src_seq, pred_seqs)
    # treat duplicates as invalid
    valid_pred_flags = valid_pred_flags * ~duplicate_flags if len(valid_pred_flags) > 0 else []
    valid_and_present_flags = valid_pred_flags * present_pred_flags if len(valid_pred_flags) > 0 else []
    valid_and_absent_flags = valid_pred_flags * ~present_pred_flags if len(valid_pred_flags) > 0 else []

    # compute match scores (exact, partial and mixed), for exact it's a list otherwise matrix
    match_scores_exact = eval.compute_match_scores(tgt_seqs=tgt_seqs, pred_seqs=pred_seqs,
                                              do_lower=True, do_stem=True, type='exact')
    # split tgts by present/absent
    present_tgt_flags, _, _ = if_present_duplicate_phrases(src_seq, tgt_seqs)
    present_tgts = [tgt for tgt, present in zip(tgt_seqs, present_tgt_flags) if present]
    absent_tgts = [tgt for tgt, present in zip(tgt_seqs, present_tgt_flags) if ~present]

    # filter out results of invalid preds
    valid_preds = [seq for seq, valid in zip(pred_seqs, valid_pred_flags) if valid]
    valid_present_pred_flags = present_pred_flags[valid_pred_flags]

    valid_match_scores_exact = match_scores_exact[valid_pred_flags]

    # split preds by present/absent and exact/partial/mixed
    valid_present_preds = [pred for pred, present in zip(valid_preds, valid_present_pred_flags) if present]
    valid_absent_preds = [pred for pred, present in zip(valid_preds, valid_present_pred_flags) if ~present]
    present_exact_match_scores = valid_match_scores_exact[valid_present_pred_flags]
    absent_exact_match_scores = valid_match_scores_exact[~valid_present_pred_flags]

    all_exact_results = eval.run_classic_metrics(valid_match_scores_exact, valid_preds, tgt_seqs, metric_names, topk_range)
    present_exact_results = eval.run_classic_metrics(present_exact_match_scores, valid_present_preds, present_tgts, metric_names, topk_range)
    absent_exact_results = eval.run_classic_metrics(absent_exact_match_scores, valid_absent_preds, absent_tgts, metric_names, absent_topk_range)

    eval_results_names = ['all_exact', 'present_exact', 'absent_exact']
    eval_results_list = [all_exact_results, present_exact_results, absent_exact_results]

    print_out = print_predeval_result(src_text,
                                      tgt_seqs, present_tgt_flags,
                                      pred_seqs, pred_scores, present_pred_flags, valid_pred_flags,
                                      valid_and_present_flags, valid_and_absent_flags, match_scores_exact,
                                      eval_results_names, eval_results_list)

    print('[#present_tgts=%d] ' % len(present_tgts), str(present_tgts))
    print('[#absent_tgts=%d]' % len(absent_tgts), str(absent_tgts))
    
    print('[#valid_present_preds=%d]' % len(valid_present_preds), str(valid_present_preds))
    print('[#valid_absent_preds=%d]' % len(valid_absent_preds), str(valid_absent_preds))
    
    print('match_scores_exact', str(match_scores_exact))
    
    print('valid_match_scores_exact', str(valid_match_scores_exact))
    print('all_exact_results', str(all_exact_results))
    
    print('present_exact_match_scores', str(present_exact_match_scores))
    print('present_exact_results', str(present_exact_results))
    
    print('absent_exact_match_scores', str(absent_exact_match_scores))
    print('absent_exact_results', str(absent_exact_results))
    print(print_out)


def print_predeval_result(src_text,
                          tgt_seqs, present_tgt_flags,
                          pred_seqs, pred_scores, present_pred_flags, valid_pred_flags,
                          valid_and_present_flags, valid_and_absent_flags, match_scores_exact,
                          results_names, results_list):
    print_out = '=' * 50
    print_out += '\n[Source]: %s \n' % (src_text)

    print_out += '[GROUND-TRUTH] #(all)=%d, #(present)=%d, #(absent)=%d\n' % \
                 (len(present_tgt_flags), sum(present_tgt_flags), len(present_tgt_flags)-sum(present_tgt_flags))
    print_out += '\n'.join(['\t\t[%s]' % ' '.join(phrase) if is_present else '\t\t%s' % ' '.join(phrase) for phrase, is_present in zip(tgt_seqs, present_tgt_flags)])

    print_out += '\n[PREDICTION] #(all)=%d, #(valid)=%d, #(present)=%d, ' \
                 '#(valid&present)=%d, #(valid&absent)=%d\n' % (len(pred_seqs), sum(valid_pred_flags), sum(present_pred_flags), sum(valid_and_present_flags), sum(valid_and_absent_flags))
    print_out += ''
    preds_out = ''
    for p_id, (word, match, is_valid, is_present) in enumerate(zip(pred_seqs, match_scores_exact, valid_pred_flags, present_pred_flags)):
        score = pred_scores[p_id] if pred_scores else "Score N/A"

        preds_out += '%s\n' % (' '.join(word))
        if is_present:
            print_phrase = '[%s]' % ' '.join(word)
        else:
            print_phrase = ' '.join(word)

        if match == 1.0:
            correct_str = '[correct!]'
        else:
            correct_str = ''

        pred_str = '\t\t[%d] %s\t%s \t%s\n' % (p_id + 1, '[%.4f]' % (-score) if pred_scores else "Score N/A",
                                                print_phrase, correct_str)
        if not is_valid:
            pred_str = '\t%s' % pred_str

        print_out += pred_str

    print_out += "\n ======================================================= \n"

    print_out += '[GROUND-TRUTH] #(all)=%d, #(present)=%d, #(absent)=%d\n' % \
                 (len(present_tgt_flags), sum(present_tgt_flags), len(present_tgt_flags)-sum(present_tgt_flags))
    print_out += '\n[PREDICTION] #(all)=%d, #(valid)=%d, #(present)=%d, ' \
                 '#(valid&present)=%d, #(valid&absent)=%d\n' % (len(pred_seqs), sum(valid_pred_flags), sum(present_pred_flags), sum(valid_and_present_flags), sum(valid_and_absent_flags))

    for name, results in zip(results_names, results_list):
        # print @5@10@O@M for present_exact, print @50@M for absent_exact
        if name in ['all_exact', 'present_exact', 'absent_exact']:
            if name.startswith('all') or name.startswith('present'):
                topk_list = ['10', 'k']
            else:
                topk_list = ['M']

            for topk in topk_list:
                print_out += "\n --- batch {} F1 @{}: \t".format(name, topk) \
                             + "{:.4f}".format(results['f_score@{}'.format(topk)])
        else:
            # ignore partial for now
            continue

    print_out += "\n ======================================================="

    return print_out


#### Load translator and model

In [5]:
# specify GPU device
print(torch.cuda.is_available())
torch.cuda.set_device(0)
print(torch.cuda.current_device())

# Supervised Deep Keyphrase Model, using OpenNMT 2.x pipeline
parser = _get_parser()
config_path = '/zfs1/hdaqing/rum20/kp/OpenNMT-kpg-transfer/config/transfer_kp/infer/keyphrase-one2seq-controlled.yml'
opt = parser.parse_args('-config %s' % (config_path))

ckpt_path = '/zfs1/hdaqing/rum20/kp/fairseq-kpg/exps/kp/bart_kppretrain_wiki_1e5/ckpts/checkpoint_step_100000.pt'
opt.__setattr__('models', [ckpt_path])
opt.__setattr__('fairseq_model', True)
opt.__setattr__('encoder_type', 'bart')
opt.__setattr__('decoder_type', 'bart')
opt.__setattr__('pretrained_tokenizer', True)
opt.__setattr__('copy_attn', False)

opt.__setattr__('valid_batch_size', 1)
opt.__setattr__('batch_size_multiple', 1)
opt.__setattr__('bucket_size', 128)
opt.__setattr__('pool_factor', 256)

opt.__setattr__('beam_size', 1)
opt.__setattr__('gpu', 0)

if isinstance(opt.data, str): setattr(opt, 'data', json.loads(opt.data.replace('\'', '"')))
setattr(opt, 'data_task', ModelTask.SEQ2SEQ)
ArgumentParser._get_all_transform(opt)

translator = build_translator(opt, report_score=False)

True
0




Loading pretrained vocabulary from /zfs1/hdaqing/rum20/kp/data/kp/hf_vocab/roberta-base-kp
Vocab size=50265, base vocab size=50265


In [6]:
# check if model is on GPU
print(translator._gpu)
print(translator._use_cuda)

0
True


In [16]:
dataset_name = 'jptimes'
dataset_path = '/zfs1/hdaqing/rum20/kp/data/kp/json/%s/test.json' % dataset_name

with open(dataset_path, 'r') as f:
    ex_dicts = [json.loads(l) for l in f.readlines()]
    for ex in ex_dicts:        
        if dataset_name.startswith('openkp'):
            ex['title'] = ''
            ex['abstract'] = ex['text']
            ex['keywords'] = ex['KeyPhrases']
            ex['dataset_type'] = 'webpage'
        elif dataset_name.startswith('stackex'):
            ex['abstract'] = ex['question']
            ex['keywords'] = ex['tags'].split(';')
            ex['dataset_type'] = 'qa'
        elif dataset_name.startswith('kp20k') or dataset_name.startswith('duc'):
            ex['keywords'] = ex['keywords'].split(';') if isinstance(ex['keywords'], str) else ex['keywords']
            ex['dataset_type'] = 'scipaper'
        elif dataset_name.startswith('kptimes') or dataset_name.startswith('jptimes'):
            ex['keywords'] = ex['keyword'].split(';') if isinstance(ex['keyword'], str) else ex['keyword']
            ex['dataset_type'] = 'news'
        else:
            print('????')


print('Loaded #(docs)=%d' % (len(ex_dicts)))
doc_id = random.randint(0, len(ex_dicts))
doc_id = 4399
ex_dict = ex_dicts[doc_id]

print(doc_id)


Loaded #(docs)=10000
4399


In [17]:
num_pres, num_header, num_cat, num_seealso, num_infill = 10, 5, 5, 2, 0

control_prefix = '<present>%d<header>%d<category>%d<seealso>%d<infill>%d<s>' \
    % (num_pres, num_header, num_cat, num_seealso, num_infill)

new_ex_dict = copy.copy(ex_dict)
new_ex_dict['src_control_prefix'] = control_prefix
# new_ex_dict['title'] = ex_dict['title']    

scores, preds = translator.translate(
    src=[new_ex_dict],
    batch_size=opt.batch_size,
    attn_debug=opt.attn_debug,
    opt=opt
)

Loading pretrained vocabulary from /zfs1/hdaqing/rum20/kp/data/kp/hf_vocab/roberta-base-kp
Vocab size=50265, base vocab size=50265
Loading pretrained vocabulary from /zfs1/hdaqing/rum20/kp/data/kp/hf_vocab/roberta-base-kp


Translating in batches: 0it [00:00, ?it/s]

Vocab size=50265, base vocab size=50265
<present>10<header>5<category>5<seealso>2<infill>0<s>Search warrants tied to former Trump lawyer Michael Cohen released  . WASHINGTON - Months before the FBI raided Michael Cohen’s office and hotel room, investigators were examining the flow of foreign money into his bank accounts and looking into whether the funds might be connected to a plan to lift sanctions on Russia, according to court filings unsealed Wednesday. The five search warrant applications, made in the early weeks and months of special counsel Robert Mueller’s Russia investigation in 2017, were made public in response to requests from The Associated Press and other media organizations. Ultimately, Cohen was not charged by Mueller or by prosecutors in New York with anything related to Russian collusion or illegal influence peddling, but the documents shed further light on how he capitalized financially on his closeness to the president immediately following the 2016 election. Cohen,

Translating in batches: 1it [00:01,  1.53s/it]

Total translation time (s): 1.530445
Average translation time (s): 1.530445
Tokens per second: 1.306810





In [19]:
src_text = new_ex_dict['title'] + ' . ' + new_ex_dict['abstract']

# print results
eval_and_print(src_text, tgt_kps=ex_dict['keywords'], pred_kps=preds[0], pred_scores=scores[0])


[autoreload of onmt.translate.translator failed: Traceback (most recent call last):
  File "/ihome/hdaqing/rum20/anaconda3/envs/kp/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/ihome/hdaqing/rum20/anaconda3/envs/kp/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 394, in superreload
    module = reload(module)
  File "/ihome/hdaqing/rum20/anaconda3/envs/kp/lib/python3.7/imp.py", line 314, in reload
    return importlib.reload(module)
  File "/ihome/hdaqing/rum20/anaconda3/envs/kp/lib/python3.7/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 630, in _exec
  File "<frozen importlib._bootstrap_external>", line 728, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/zfs1/hdaqing/rum20/kp/OpenNMT-kpg-transfer/onmt/translate/translator.py", line 24, in <modul

[#present_tgts=3]  [['robert', 'mueller'], ['michael', 'cohen'], ['viktor', 'vekselberg']]
[#absent_tgts=3] [['u.s', '.'], ['donald', 'trump'], ['russia', 'probe']]
[#valid_present_preds=8] [['viktor', 'vekselberg'], ['robert', 'mueller'], ['russian', 'president'], ['russian', 'collusion'], ['vladimir', 'putin'], ['michael', 'cohen'], ['fbi', 'raids'], ['were', 'made', 'public', 'in']]
[#valid_absent_preds=8] [['donald', 'trump'], ['search', 'warrant', 'documents'], ['american', 'people', 'of', 'ukrainian', '-', 'jewish', 'descent'], ['american', 'people', 'of', 'russian', '-', 'jewish', 'descent'], ['list', 'of', 'topics', 'characterized', 'as', 'pseudoscience'], ['list', 'of', 'people', 'from', 'new', 'york', 'city'], ['cohen', 'family'], ['in', 'a', 'court', 'filing', 'on']]
match_scores_exact [1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
valid_match_scores_exact [1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
all_exact_results {'f_score@k': 0.5, 'f_score@10': 0.5}
present_

### Noun phrases

Use min_len=2, max_len=6, ignore single word.
See **onmt.keyphrase.extract_np.py** for more examples.

Some issues:
1. Most common, participles are mistakenly tagged as VERB.
  - image segmentation through evolved cellular automata ['NOUN', 'NOUN', 'ADP', 'VERB', 'ADJ', 'NOUN']
  - HQCRFF-based modulator ['PROPN', '-', 'VERB', 'NOUN']
  - collection of organized data ['NOUN', 'ADP', 'VERB', 'NOUN']
  - irregularly-sampled data ['ADV', '-', 'VERB', 'NOUN']
2. NP containing numbers, rare.
  - sidewall angle of 90 ['ADJ', 'NOUN', 'ADP', 'NUM']
3. Single-word phrases are ignored, and seemingly no simple way to resolve. Some abbreviations and acronyms are ignored.
4. Starting with ADP.
  - in-band zeros ['ADP', '-', 'NOUN', 'NOUN']

In [2]:
def noun_chunks_by_pos_regex(text, min_len, max_len):
    '''
    https://files.ifi.uzh.ch/cl/hess/classes/ecl1/termerCIE.html
        (Adjective | Noun)* (Noun Preposition)? (Adjective | Noun)* Noun
    https://www.aclweb.org/anthology/D09-1027.pdf
        (JJ)*(NN|NNS|NNP)+
    :param doc:
    :param min_len:
    :param max_len:
    :return:
    '''
    doc = spacy_nlp(text, disable=["textcat"])

    np_regex = r'((^ADJ|^NOUN|^PROPN)(ADP|-|ADJ|NOUN|PROPN)*?)?(NOUN|PROPN)+'
    cands = []
    # a two-layer loop to get all n-grams
    for i in range(0, len(doc) - 1):
        for k in range(min_len, max_len + 1):
            if i + k > len(doc): break
            span = doc[i: i + k]
            pos = ['-' if t.text=='-' else t.pos_ for t in span]
            pos_str = ''.join(pos)

            cands.append((span, pos_str, pos))

#     for np_id, (np, pos_str, pos) in enumerate(cands):
#         print('[%d]' % np_id, np, str(pos), '[match]' if re.fullmatch(np_regex, pos_str) else '')
        
    cands = [span.text for span, pos_str, pos in cands if re.fullmatch(np_regex, pos_str)]

    return cands

In [3]:
dataset_names = ['kp20k_train100k', 'kptimes_train100k', 'openkp_train100k', 'stackex_train100k']
dataset_names = ['kp20k', 'inspec', 'krapivin', 'nus', 'semeval', 'openkp', 'kptimes', 'jptimes', 'stackex', 'duc']
dataset_names = ['kp20k_valid2k', 'openkp_valid2k', 'kptimes_valid2k', 'stackex_valid2k']


for dataset_name in dataset_names:
    print('*' * 100)
    print(dataset_name)
    print('*' * 100)
    input_path = '/zfs1/hdaqing/rum20/kp/data/kp/json/%s/test.json' % dataset_name
    output_path = '/zfs1/hdaqing/rum20/kp/fairseq-kpg/exps/kp_nounphrase/checkpoint_step_9500-data_%s_test.pred' % dataset_name

    with open(input_path, 'r') as input_jsonl, open(output_path, 'w') as output_jsonl:
        for l_id, l in enumerate(input_jsonl):
            if l_id % 1000 == 0: print('%d' % l_id)
            ex = json.loads(l)

            if dataset_name.startswith('openkp'):
                src_text = ex['text']
            elif dataset_name.startswith('stackex'):
                src_text = ex['title'] + ' . ' + ex['question']
            else:
                src_text = ex['title'] + ' . ' + ex['abstract']

            nps = noun_chunks_by_pos_regex(src_text, min_len=2, max_len=6)

#             print(src_text)
#             for np_id, np in enumerate(nps):
#                 print('[%d]' % np_id, np)

            # remove duplicates and write to file
            nps = list(set(nps))
            output_ex = {'pred_sents': nps}
            output_jsonl.write(json.dumps(output_ex) + '\n')


****************************************************************************************************
kp20k_valid2k
****************************************************************************************************
0
1000
****************************************************************************************************
openkp_valid2k
****************************************************************************************************
0
1000
****************************************************************************************************
kptimes_valid2k
****************************************************************************************************
0
1000
****************************************************************************************************
stackex_valid2k
****************************************************************************************************
0
1000
