Skip to content

Commit

Permalink
Improve how sentences are split using Punkt tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Dec 15, 2018
1 parent 5174c52 commit a48b3ed
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 28 deletions.
40 changes: 24 additions & 16 deletions chatterbot/stemming.py
@@ -1,5 +1,6 @@
import string
from nltk import pos_tag
from nltk.data import load as load_data
from nltk.corpus import wordnet, stopwords


Expand Down Expand Up @@ -120,8 +121,6 @@ class PosHypernymStemmer(object):
"""

def __init__(self, language='english'):
self.punctuation_table = str.maketrans(dict.fromkeys(string.punctuation))

self.language = language

self.stopwords = None
Expand Down Expand Up @@ -182,32 +181,41 @@ def get_bigram_pair_string(self, text):
DT:beautiful JJ:wetland
"""
words = text.split()
WORD_INDEX = 0
POS_INDEX = 1

pos_tags = []

sentence_detector = load_data('tokenizers/punkt/english.pickle')

# Separate punctuation from last word in string
if words:
word_with_punctuation_removed = words[-1].strip(string.punctuation)
for sentence in sentence_detector.tokenize(text.strip()):

if word_with_punctuation_removed:
words[-1] = word_with_punctuation_removed
# Remove punctuation
if sentence and sentence[-1] in string.punctuation:
sentence_with_punctuation_removed = sentence[:-1]

pos_tags = pos_tag(words)
if sentence_with_punctuation_removed:
sentence = sentence_with_punctuation_removed

words = sentence.split()

pos_tags.extend(pos_tag(words))

hypernyms = self.get_hypernyms(pos_tags)

high_quality_bigrams = []
all_bigrams = []

word_count = len(words)
word_count = len(pos_tags)

if word_count <= 1:
all_bigrams = words
if all_bigrams:
all_bigrams[0] = all_bigrams[0].lower()
if word_count == 1:
all_bigrams.append(
pos_tags[0][WORD_INDEX].lower()
)

for index in range(1, word_count):
word = words[index].lower()
previous_word_pos = pos_tags[index][1]
word = pos_tags[index][WORD_INDEX].lower()
previous_word_pos = pos_tags[index - 1][POS_INDEX]
if word not in self.get_stopwords() and len(word) > 1:
bigram = previous_word_pos + ':' + hypernyms[index].lower()
high_quality_bigrams.append(bigram)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_stemming.py
Expand Up @@ -131,7 +131,7 @@ def test_stemming(self):
'Hello, how are you doing on this awesome day?'
)

self.assertEqual(stemmed_text, 'JJ:awesome NN:time_unit')
self.assertEqual(stemmed_text, 'DT:awesome JJ:time_unit')

def test_string_becomes_lowercase(self):
stemmed_text = self.stemmer.get_bigram_pair_string('THIS IS HOW IT BEGINS!')
Expand All @@ -141,12 +141,12 @@ def test_string_becomes_lowercase(self):
def test_stemming_medium_sized_words(self):
stemmed_text = self.stemmer.get_bigram_pair_string('Hello, my name is Gunther.')

self.assertEqual(stemmed_text, 'NN:language_unit NNP:gunther')
self.assertEqual(stemmed_text, 'PRP$:language_unit VBZ:gunther')

def test_stemming_long_words(self):
stemmed_text = self.stemmer.get_bigram_pair_string('I play several orchestra instruments for pleasuer.')

self.assertEqual(stemmed_text, 'VBP:compete JJ:several JJ:orchestra NNS:device NN:pleasuer')
self.assertEqual(stemmed_text, 'PRP:compete VBP:several JJ:orchestra JJ:device IN:pleasuer')

def test_get_bigram_pair_string_punctuation_only(self):
bigram_string = self.stemmer.get_bigram_pair_string(
Expand Down Expand Up @@ -209,18 +209,18 @@ def test_get_bigram_pair_string_multiple_words(self):
'Hello Dr. Salazar. How are you today?'
)

self.assertEqual(bigram_string, 'NNP:scholar NNP:salazar. NN:present')
self.assertEqual(bigram_string, 'NNP:scholar NNP:salazar PRP:present')

def test_get_bigram_pair_string_single_character_words(self):
bigram_string = self.stemmer.get_bigram_pair_string(
'a e i o u'
)

self.assertEqual(bigram_string, 'NN:e NN:i VBP:o NN:u')
self.assertEqual(bigram_string, 'DT:e NN:i NN:o VBP:u')

def test_get_bigram_pair_string_two_character_words(self):
bigram_string = self.stemmer.get_bigram_pair_string(
'Lo my mu it is of us'
)

self.assertEqual(bigram_string, 'NN:letter PRP:us')
self.assertEqual(bigram_string, 'PRP$:letter IN:us')
4 changes: 2 additions & 2 deletions tests/training/test_list_training.py
Expand Up @@ -65,7 +65,7 @@ def test_training_sets_search_text(self):
))

self.assertIsLength(statements, 1)
self.assertEqual(statements[0].search_text, 'IN:kind NN:headdress')
self.assertEqual(statements[0].search_text, 'RB:kind PRP$:headdress')

def test_training_sets_search_in_response_to(self):

Expand All @@ -81,7 +81,7 @@ def test_training_sets_search_in_response_to(self):
))

self.assertIsLength(statements, 1)
self.assertEqual(statements[0].search_in_response_to, 'IN:kind NN:headdress')
self.assertEqual(statements[0].search_in_response_to, 'PRP:kind PRP$:headdress')

def test_database_has_correct_format(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/training/test_twitter_trainer.py
Expand Up @@ -88,11 +88,11 @@ def test_train_sets_search_text(self):
statements = list(self.trainer.chatbot.storage.filter())

self.assertGreater(len(statements), 1)
self.assertEqual(statements[0].search_text, 'VBP:sure NNP:jewel')
self.assertEqual(statements[0].search_text, 'PRP:sure IN:jewel')

def test_train_sets_search_in_response_to(self):
self.trainer.train()
statements = list(self.trainer.chatbot.storage.filter())

self.assertGreater(len(statements), 1)
self.assertEqual(statements[0].search_in_response_to, 'VBP:sure NNP:jewel')
self.assertEqual(statements[0].search_in_response_to, 'PRP:sure IN:jewel')
4 changes: 2 additions & 2 deletions tests/training/test_ubuntu_corpus_training.py
Expand Up @@ -176,7 +176,7 @@ def test_train_sets_search_text(self):
results = list(self.chatbot.storage.filter(text='Is anyone there?'))

self.assertEqual(len(results), 2)
self.assertEqual(results[0].search_text, 'NN:anyone')
self.assertEqual(results[0].search_text, 'VBZ:anyone')

def test_train_sets_search_in_response_to(self):
"""
Expand All @@ -190,7 +190,7 @@ def test_train_sets_search_in_response_to(self):
results = list(self.chatbot.storage.filter(in_response_to='Is anyone there?'))

self.assertEqual(len(results), 2)
self.assertEqual(results[0].search_in_response_to, 'NN:anyone')
self.assertEqual(results[0].search_in_response_to, 'VBZ:anyone')

def test_is_extracted(self):
"""
Expand Down

0 comments on commit a48b3ed

Please sign in to comment.