Skip to content
Permalink
Browse files

Adding seq2seq stub + various fixes

  • Loading branch information...
boudinfl committed Nov 11, 2018
1 parent 37b391a commit db50e65bdf78b9c1a9372e4b32c1dd30e3bdf04b
@@ -8,6 +8,7 @@

from pke.data_structures import Document


class Reader(object):
def read(self, path):
raise NotImplementedError
@@ -59,7 +60,6 @@ def __init__(self, language=None):
if language is None:
self.language = 'en'


def read(self, text, **kwargs):
"""Read the input file and use spacy to pre-process.
@@ -6,3 +6,4 @@
from pke.supervised.api import SupervisedLoadFile
from pke.supervised.feature_based.kea import Kea
from pke.supervised.feature_based.wingnus import WINGNUS
from pke.supervised.neural_based.seq2seq import Seq2Seq
@@ -2,15 +2,16 @@
# Author: Florian Boudin
# Date: 09-10-2018

"""Kea keyphrase extraction model.
"""Kea supervised keyphrase extraction model.
Supervised approach to keyphrase extraction described in:
Kea is a supervised model for keyphrase extraction that uses two features,
namely TF x IDF and first occurrence, to classify keyphrase candidates as
keyphrase or not. The model is described in:
* Ian Witten, Gordon Paynter, Eibe Frank, Carl Gutwin and Craig Nevill-Mannin.
KEA: Practical Automatic Keyphrase Extraction.
*Proceedings of the 4th ACM Conference on Digital Libraries*, pages 254–255,
1999.
"""

from __future__ import absolute_import
@@ -37,57 +38,60 @@ class Kea(SupervisedLoadFile):
import pke
from nltk.corpus import stopwords
# define a list of stopwords
stoplist = stopwords.words('english')
# 1. create a Kea extractor.
extractor = pke.supervised.Kea()
# 2. load the content of the document.
extractor.load_document(input='path/to/input.xml')
extractor.load_document(input='path/to/input',
language='en',
normalization=None)
# 3. select 1-3 grams that do not start or end with a stopword as
# candidates.
stoplist = stopwords.words('english')
# candidates. Candidates that contain punctuation marks as words
# are discarded.
extractor.candidate_selection(stoplist=stoplist)
# 4. classify candidates as keyphrase or not keyphrase.
df = pke.load_document_frequency_file(input_file='path/to/df.tsv.gz')
model_file = 'path/to/kea_model'
extractor.candidate_weighting(self, model_file=model_file, df=df)
extractor.candidate_weighting(self,
model_file=model_file,
df=df)
# 5. get the 10-highest scored candidates as keyphrases
keyphrases = extractor.get_n_best(n=10)
"""

def __init__(self):
"""Redefining initializer for Kea.
"""
"""Redefining initializer for Kea."""

super(Kea, self).__init__()

def candidate_selection(self, stoplist=None, **kwargs):
"""Select 1-3 grams as keyphrase candidates. Candidates that start or
end with a stopword are discarded.
"""Select 1-3 grams of `normalized` words as keyphrase candidates.
Candidates that start or end with a stopword are discarded. Candidates
that contain punctuation marks (from `string.punctuation`) as words are
filtered out.
Args:
stoplist (list): the stoplist for filtering candidates, defaults
to the nltk stoplist. Words that are punctuation marks from
string.punctuation are not allowed.
to the nltk stoplist.
"""

# select ngrams from 1 to 3 grams
self.ngram_selection(n=3)

# filter candidates containing punctuation marks
self.candidate_filtering(list(string.punctuation) +
['-lrb-', '-rrb-', '-lcb-', '-rcb-', '-lsb-',
'-rsb-'])
self.candidate_filtering(list(string.punctuation))

# initialize stoplist list if not provided
if stoplist is None:
stoplist = self.stoplist

# filter candidates that start or end with a stopword
# Python 2/3 compatible
for k in list(self.candidates):

# get the candidate
@@ -99,8 +103,9 @@ def candidate_selection(self, stoplist=None, **kwargs):
del self.candidates[k]

def feature_extraction(self, df=None, training=False):
"""Extract features (tf*idf, first occurrence and length) for each
candidate.
"""Extract features for each keyphrase candidate. Features are the
tf*idf of the candidate and its first occurrence relative to the
document.
Args:
df (dict): document frequencies, the number of documents should be
@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
# Python Keyphrase Extraction toolkit: supervised neural-based ranking models
@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
# Author: Florian Boudin
# Date: 11-11-2018

"""
Implementation of the Seq2Seq model for automatic keyphrase extraction.
"""

from __future__ import absolute_import
from __future__ import print_function

from pke.supervised.api import SupervisedLoadFile


class Seq2Seq(SupervisedLoadFile):

def __init__(self):
"""Redefining initializer for Seq2Seq."""

super(Seq2Seq, self).__init__()

def candidate_selection(self):
pass

def candidate_weighting(self):
pass
@@ -78,10 +78,7 @@ def candidate_selection(self, lasf=3, cutoff=400, stoplist=None, **kwargs):
stoplist = self.stoplist

# filter candidates containing stopwords or punctuation marks
self.candidate_filtering(stoplist=list(string.punctuation) +
['-lrb-', '-rrb-', '-lcb-', '-rcb-', '-lsb-',
'-rsb-'] +
stoplist)
self.candidate_filtering(stoplist=list(string.punctuation) + stoplist)

# further filter candidates using lasf and cutoff
# Python 2/3 compatible
@@ -33,7 +33,8 @@ class TfIdf(LoadFile):
normalization=None)
# 3. select {1-3}-grams not containing punctuation marks as candidates.
extractor.candidate_selection(n=3, stoplist=list(string.punctuation))
extractor.candidate_selection(n=3,
stoplist=list(string.punctuation))
# 4. weight the candidates using a `tf` x `idf`
df = pke.load_document_frequency_file(input_file='path/to/df.tsv.gz')
@@ -59,12 +60,10 @@ def candidate_selection(self, n=3, stoplist=None, **kwargs):

# initialize empty list if stoplist is not provided
if stoplist is None:
stoplist = []
stoplist = list(string.punctuation)

# filter candidates containing punctuation marks
self.candidate_filtering(stoplist=list(string.punctuation) +
['-lrb-', '-rrb-', '-lcb-', '-rcb-',
'-lsb-', '-rsb-'] + stoplist)
self.candidate_filtering(stoplist=stoplist)

def candidate_weighting(self, df=None):
"""Candidate weighting function using document frequencies.
@@ -85,6 +84,7 @@ def candidate_weighting(self, df=None):

# loop throught the candidates
for k, v in self.candidates.items():

# get candidate document frequency
candidate_df = 1 + df.get(k, 0)

@@ -98,10 +98,7 @@ def candidate_selection(self, n=3, stoplist=None, **kwargs):
self.ngram_selection(n=n)

# filter candidates containing punctuation marks
self.candidate_filtering(
stoplist=list(string.punctuation) +
['-lrb-', '-rrb-', '-lcb-', '-rcb-', '-lsb-', '-rsb-']
)
self.candidate_filtering(stoplist=list(string.punctuation))

# initialize empty list if stoplist is not provided
if stoplist is None:
@@ -8,7 +8,7 @@
license='gnu',
packages=['pke', 'pke.unsupervised', 'pke.supervised',
'pke.supervised.feature_based', 'pke.unsupervised.graph_based',
'pke.unsupervised.statistical'],
'pke.unsupervised.statistical', 'pke.supervised.neural_based'],
url="https://github.com/boudinfl/pke",
install_requires=[
'nltk',

0 comments on commit db50e65

Please sign in to comment.
You can’t perform that action at this time.