-
Notifications
You must be signed in to change notification settings - Fork 59
/
kms_http_common.py
153 lines (113 loc) · 4.35 KB
/
kms_http_common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""Common code for mock kms http endpoint."""
import http.server
import json
import ssl
import urllib.parse
from abc import abstractmethod
URL_PATH_STATS = "/stats"
URL_DISABLE_FAULTS = "/disable_faults"
URL_ENABLE_FAULTS = "/enable_faults"
"""Fault which causes encrypt to return 500."""
FAULT_ENCRYPT = "fault_encrypt"
"""Fault which causes encrypt to return an error that contains a type and message"""
FAULT_ENCRYPT_CORRECT_FORMAT = "fault_encrypt_correct_format"
"""Fault which causes encrypt to return wrong fields in JSON."""
FAULT_ENCRYPT_WRONG_FIELDS = "fault_encrypt_wrong_fields"
"""Fault which causes encrypt to return bad BASE64."""
FAULT_ENCRYPT_BAD_BASE64 = "fault_encrypt_bad_base64"
"""Fault which causes decrypt to return 500."""
FAULT_DECRYPT = "fault_decrypt"
"""Fault which causes decrypt to return an error that contains a type and message"""
FAULT_DECRYPT_CORRECT_FORMAT = "fault_decrypt_correct_format"
"""Fault which causes decrypt to return wrong key."""
FAULT_DECRYPT_WRONG_KEY = "fault_decrypt_wrong_key"
"""Fault which causes an OAuth request to return an 500."""
FAULT_OAUTH = "fault_oauth"
"""Fault which causes an OAuth request to return an error response"""
FAULT_OAUTH_CORRECT_FORMAT = "fault_oauth_correct_format"
class Stats:
"""Stats class shared between client and server."""
def __init__(self):
self.encrypt_calls = 0
self.decrypt_calls = 0
self.fault_calls = 0
def __repr__(self):
return json.dumps({
'decrypts': self.decrypt_calls,
'encrypts': self.encrypt_calls,
'faults': self.fault_calls,
})
class KmsHandlerBase(http.server.BaseHTTPRequestHandler):
protocol_version = "HTTP/1.1"
def do_GET(self):
"""Serve a Test GET request."""
print("Received GET: " + self.path)
parts = urllib.parse.urlsplit(self.path)
path = parts[2]
if path == URL_PATH_STATS:
self._do_stats()
elif path == URL_DISABLE_FAULTS:
self._do_disable_faults()
elif path == URL_ENABLE_FAULTS:
self._do_enable_faults()
else:
self.send_response(http.HTTPStatus.NOT_FOUND)
self.end_headers()
self.wfile.write("Unknown URL".encode())
@abstractmethod
def do_POST(self):
"""Serve a POST request."""
pass
def _send_reply(self, data, status=http.HTTPStatus.OK):
print("Sending Response: " + data.decode())
self.send_response(status)
self.send_header("content-type", "application/octet-stream")
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)
@abstractmethod
def _do_encrypt(self, raw_input):
pass
@abstractmethod
def _do_encrypt_faults(self, raw_ciphertext):
pass
@abstractmethod
def _do_decrypt(self, raw_input):
pass
@abstractmethod
def _do_decrypt_faults(self, blob):
pass
def _send_header(self):
self.send_response(http.HTTPStatus.OK)
self.send_header("content-type", "application/octet-stream")
self.end_headers()
def _do_stats(self):
self._send_header()
self.wfile.write(str(stats).encode('utf-8'))
def _do_disable_faults(self):
global disable_faults
disable_faults = True
self._send_header()
def _do_enable_faults(self):
global disable_faults
disable_faults = False
self._send_header()
def run(port, cert_file, ca_file, handler_class, server_class=http.server.HTTPServer, cert_required=False):
"""Run web server."""
server_address = ('', port)
httpd = server_class(server_address, handler_class)
cert_reqs = ssl.CERT_NONE
if cert_required:
cert_reqs = ssl.CERT_REQUIRED
httpd.socket = ssl.wrap_socket(httpd.socket,
certfile=cert_file,
ca_certs=ca_file, server_side=True,
cert_reqs=cert_reqs)
print("Mock KMS Web Server Listening on port " + str(server_address[1]))
httpd.serve_forever()
# Pass this data out of band instead of storing it in AwsKmsHandler since the
# BaseHTTPRequestHandler does not call the methods as object methods but as class methods. This
# means there is not self.
stats = Stats()
disable_faults = False
fault_type = None