/
test_twitter_trainer.py
98 lines (72 loc) · 2.88 KB
/
test_twitter_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from tests.base_case import ChatBotTestCase
from unittest.mock import Mock, MagicMock
from chatterbot.trainers import TwitterTrainer
import os
import json
def get_search_side_effect(*args, **kwargs):
from twitter import Status
current_directory = os.path.dirname(os.path.realpath(__file__))
data_file = os.path.join(
current_directory,
'test_data',
'get_search.json'
)
tweet_data = open(data_file)
data = json.loads(tweet_data.read())
tweet_data.close()
return [Status.NewFromJsonDict(x) for x in data.get('statuses')]
def get_status_side_effect(*args, **kwargs):
from twitter import Status
current_directory = os.path.dirname(os.path.realpath(__file__))
data_file = os.path.join(
current_directory,
'test_data',
'get_search.json'
)
tweet_data = open(data_file)
data = json.loads(tweet_data.read())
tweet_data.close()
return Status.NewFromJsonDict(data.get('statuses')[1])
class TwitterTrainerTestCase(ChatBotTestCase):
def setUp(self):
"""
Instantiate the trainer class for testing.
"""
super(TwitterTrainerTestCase, self).setUp()
self.trainer = TwitterTrainer(
self.chatbot,
twitter_consumer_key='twitter_consumer_key',
twitter_consumer_secret='twitter_consumer_secret',
twitter_access_token_key='twitter_access_token_key',
twitter_access_token_secret='twitter_access_token_secret',
show_training_progress=False
)
self.trainer.api = Mock()
self.trainer.api.GetSearch = MagicMock(side_effect=get_search_side_effect)
self.trainer.api.GetStatus = MagicMock(side_effect=get_status_side_effect)
def test_random_word(self):
word = self.trainer.random_word('random')
self.assertTrue(len(word) > 3)
def test_get_words_from_tweets(self):
tweets = get_search_side_effect()
words = self.trainer.get_words_from_tweets(tweets)
self.assertIn('about', words)
self.assertIn('favorite', words)
self.assertIn('things', words)
def test_get_statements(self):
statements = self.trainer.get_statements()
self.assertEqual(len(statements), 1)
def test_train(self):
self.trainer.train()
statements = self.trainer.chatbot.storage.filter()
self.assertGreater(len(statements), 1)
def test_train_sets_search_text(self):
self.trainer.train()
statements = self.trainer.chatbot.storage.filter()
self.assertGreater(len(statements), 1)
self.assertEqual(statements[0].search_text, 'ur u')
def test_train_sets_search_in_response_to(self):
self.trainer.train()
statements = self.trainer.chatbot.storage.filter()
self.assertGreater(len(statements), 1)
self.assertEqual(statements[0].search_in_response_to, 'ur u')