forked from BennyThink/realXiaoice
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ice_server.py
147 lines (119 loc) · 4.66 KB
/
ice_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#!/usr/bin/python
# coding: utf-8
# realXiaoice - server.py
# 2019/8/11 17:13
#
__author__ = "Benny <benny.think@gmail.com>"
import logging
import os
from platform import uname
import json
import traceback
from concurrent.futures import ThreadPoolExecutor
from tornado import web, ioloop, httpserver, gen, options
from tornado.concurrent import run_on_executor
from xiaoice import chat
ALLOWED_IPS, AUTH = [], False
class BaseHandler(web.RequestHandler):
def data_received(self, chunk):
pass
class IndexHandler(BaseHandler):
def get(self):
text = '''
GET: http://127.0.0.1:6789/chat?text=hello
POST:http://127.0.0.1:6789/chat, form-urlencoded or json with {"text":"hello"}
Response: HTTP 200: {"text":"hi there", "debug":""}
Other : {"text":"", "debug":"error"}
'''
self.write(text)
class ChatHandler(BaseHandler):
executor = ThreadPoolExecutor(max_workers=20)
def get_correct_argument(self, name):
try:
if self.request.headers.get('Content-Type') == 'application/json' \
and self.request.body:
value = json.loads(self.request.body).get(name)
else:
value = self.get_argument(name, None)
return value
except ValueError as e:
logging.error('Failed to extract arguments {}'.format(e))
def accessibility(self):
ip = self.request.headers.get("X-Real-IP", "") or self.request.remote_ip
auth_code = self.get_correct_argument('auth') or ''
msg = {}
correct_auth = [item.replace('\r', '').replace('\n', '')
for item in open('key.txt', encoding='u8').readlines()]
if AUTH and auth_code not in correct_auth:
msg = {"text": "", "debug": "Bad auth code."}
elif ALLOWED_IPS and ip not in ALLOWED_IPS:
msg = {"text": "", "debug": "Your IP is not allowed to access this API."}
if msg:
logging.warning('Access denied for {}'.format(ip))
self.set_status(403)
return msg
@run_on_executor
def run_request(self):
denied = self.accessibility()
if denied:
return denied
user_input = self.get_correct_argument('text')
if user_input:
try:
response = {"text": chat(user_input), "debug": ""}
except Exception as e:
logging.error(traceback.format_exc())
self.set_status(500)
response = {"text": "", "debug": str(e)}
else:
self.set_status(400)
response = {"text": "", "debug": "Wrong params."}
return response
@gen.coroutine
def get(self):
res = yield self.run_request()
self.write(res)
@gen.coroutine
def post(self):
res = yield self.run_request()
self.write(res)
class RunServer:
root_path = os.path.dirname(__file__)
page_path = os.path.join(root_path, 'pages')
handlers = [(r'/', IndexHandler),
(r'/chat', ChatHandler),
]
settings = {
"cookie_secret": "5Li05DtnQewDZq1mDVB3HAAhFqUu2vD2USnqezkeu+M=",
"xsrf_cookies": False,
"autoreload": True
}
application = web.Application(handlers, **settings)
@staticmethod
def run_server(port=9876, host='', **kwargs):
tornado_server = httpserver.HTTPServer(RunServer.application, **kwargs, xheaders=True)
tornado_server.bind(port, host)
if uname()[0] == 'Windows':
tornado_server.start()
else:
tornado_server.start(None)
try:
print('Server is running on http://{host}:{port}'.format(host=host, port=port))
ioloop.IOLoop.instance().current().start()
except KeyboardInterrupt:
ioloop.IOLoop.instance().stop()
print('"Ctrl+C" received, exiting.\n')
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
options.define("p", default=6789, help="running port", type=int)
options.define("h", default='', help="listen address", type=str)
options.define("a", default='', help="Allowed IPs to access this server,split by comma", type=str)
options.define("auth", default=False, help="Enable auth? default is set to false", type=bool)
options.parse_command_line()
p = options.options.p
h = options.options.h
allow = options.options.a
AUTH = options.options.auth
if allow:
ALLOWED_IPS = allow.split(',')
RunServer.run_server(port=p, host=h)