Skip to content
Permalink
Browse files

add feature: load a json reference file for training

  • Loading branch information...
boudinfl committed Nov 6, 2018
1 parent c885e58 commit dbdda39b47dcc2b27c188f600f2a922c5c678118
Showing with 27 additions and 18 deletions.
  1. +1 −1 examples/training_and_testing_a_kea_model/train.py
  2. +26 −17 pke/utils.py
@@ -29,4 +29,4 @@
language='en',
normalization="stemming",
df=df_counts,
model=pke.supervised.WINGNUS())
model=pke.supervised.Kea())
@@ -12,6 +12,7 @@
import glob
import pickle
import gzip
import json
import codecs
import logging

@@ -188,7 +189,7 @@ def train_supervised_model(input_dir,
model.__init__()

# get the document id from file name
doc_id = input_file.split('/')[-1].split('.')[0]
doc_id = '.'.join(input_file.split('/')[-1].split('.')[0:-1])

# load the document
model.load_document(input=input_file,
@@ -220,27 +221,35 @@ def load_references(input_file,
sep_ref_keyphrases=',',
reference_stemming=False,
stemmer='porter'):
""" Load a reference file and returns a dictionary. """
"""Load a reference file and returns a dictionary."""

logging.info('loading reference keyphrases from ' + input_file)
logging.info('loading reference keyphrases from {}'.format(input_file))

references = defaultdict(list)

with codecs.open(input_file, 'r', 'utf-8') as f:
for line in f:
cols = line.strip().split(sep_doc_id)
doc_id = cols[0].strip()
keyphrases = cols[1].strip().split(sep_ref_keyphrases)
for v in keyphrases:
if '+' in v:
for s in v.split('+'):
references[doc_id].append(s)
else:
references[doc_id].append(v)
if reference_stemming:
for i, k in enumerate(references[doc_id]):
stems = [Stemmer(stemmer).stem(u) for u in k.split()]
references[doc_id][i] = ' '.join(stems)

if input_file.endswith('.json'):
references = json.load(f)
for doc_id in references:
references[doc_id] = [keyphrase for variants in
references[doc_id] for keyphrase in
variants]
else:
for line in f:
cols = line.strip().split(sep_doc_id)
doc_id = cols[0].strip()
keyphrases = cols[1].strip().split(sep_ref_keyphrases)
for v in keyphrases:
if '+' in v:
for s in v.split('+'):
references[doc_id].append(s)
else:
references[doc_id].append(v)
if reference_stemming:
for i, k in enumerate(references[doc_id]):
stems = [Stemmer(stemmer).stem(u) for u in k.split()]
references[doc_id][i] = ' '.join(stems)

return references

0 comments on commit dbdda39

Please sign in to comment.
You can’t perform that action at this time.