Skip to content

Commit

Permalink
coverage for wsgi
Browse files Browse the repository at this point in the history
  • Loading branch information
earonesty committed Nov 22, 2019
1 parent 606277e commit 1cc979f
Showing 1 changed file with 151 additions and 65 deletions.
216 changes: 151 additions & 65 deletions smx/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,6 @@ def is_script(self, path):
return "%expand%" in x
return ext in self.__expand

def parse(self, ctx, path, write):
with open(path) as f:
w = Writer()
w.write = write
ctx.expand_io(f, w)
yield ""

def static_resp(self, start_response, path):
content_type = "text/plain"
content_length = None
Expand Down Expand Up @@ -126,55 +119,57 @@ def __call__(self, env, start_response):
if os.path.isdir(full_path):
full_path = self.find_index(full_path)

if not self.is_script(full_path):
log.debug("STATIC %s", url)
yield from self.static_resp(start_response, full_path)
return

log.debug("SCRIPT %s", url)

content = b"{}"
length = env.get("CONTENT_LENGTH", 0)
content_type = env.get('CONTENT_TYPE', "")
info = {}
if length:
content = env['wsgi.input'].read(int(length))
if content_type.startswith('application/x-www-form-urlencoded'):
info = parse_query_string(content)
elif content_type.startswith('text/json'):
if content:
try:
info = json.loads(content)
except Exception:
raise HttpError(400, "Invalid JSON " + str(content, "utf-8"))
try:
if not self.is_script(full_path):
log.debug("STATIC %s", url)
yield from self.static_resp(start_response, full_path)
return

log.debug("SCRIPT %s", url)

content = b"{}"
length = env.get("CONTENT_LENGTH", 0)
content_type = env.get('CONTENT_TYPE', "")
info = {}
jq = {}
if length:
content = env['wsgi.input'].read(int(length))
if content_type.startswith('application/x-www-form-urlencoded'):
info = parse_query_string(content)
elif content_type.startswith('application/json'):
if content:
try:
jq = json.loads(content)
except Exception:
raise HttpError(400, "Invalid JSON " + str(content, "utf-8"))
else:
jq = {}
# elif ... handle more stuff

query = env.get('QUERY_STRING')

if query:
params = parse_query_string(query)
else:
info = {}
# elif ... handle more stuff

query = env.get('QUERY_STRING')
params = {}

if query:
params = parse_query_string(query)
else:
params = {}

info.update(params)
info.update(params)

# todo:
# we process the first MAX_MEM_SIZE bytes for status codes & errors
# if the file is larger than that
# we STREAM the results
# todo:
# we process the first MAX_MEM_SIZE bytes for status codes & errors
# if the file is larger than that
# we STREAM the results

ctx = Smx(self.ctx, environ=env)
ctx = Smx(self.ctx, environ=env)

headers = {}
headers = {}

ctx.set("form", lambda k: info.get(k))
ctx.set("header", headers)
ctx.set("error", lambda k, m=None, b=None: throw(HttpError(k, m, b)))
ctx.set("redirect", lambda k: throw(RedirectError(k)))
ctx.set("form", lambda k: info.get(k))
ctx.set("jq", jq)
ctx.set("header", headers)
ctx.set("error", lambda k, m=None, b=None: throw(HttpError(k, m, b)))
ctx.set("redirect", lambda k: throw(RedirectError(k)))

try:
fo = io.StringIO()
with open(full_path) as fi:
ctx.expand_io(fi, fo)
Expand All @@ -199,14 +194,13 @@ def __call__(self, env, start_response):
if response is None:
response = "<h1>%s %s</h1><p>An error was encountered processing your request</p>" % (e.code, e.msg or "Error")

headers = [('Content-Type', 'application/json'), ("Content-Length", str(len(response)))]
headers = [('Content-Type', 'text/html'), ("Content-Length", str(len(response)))]
if isinstance(e, RedirectError):
headers.append(('Location', e.path))
response = ""
else:
log.error("GET %s : ERROR : %s", url, e)


start_response(str(e.code) + ' ' + e.msg, headers)
yield bytes(response, "utf-8")
except ConnectionAbortedError as e:
Expand All @@ -218,6 +212,9 @@ def __call__(self, env, start_response):


if __name__ == "__main__":
main()

def main():
import argparse
parser = argparse.ArgumentParser(description='Start dev server')
parser.add_argument('--debug', "-d", help='set logging level debug', action="store_true")
Expand All @@ -234,21 +231,37 @@ def __call__(self, env, start_response):
print("Serving on port %s..." % args.port)
httpd.serve_forever()

import pytest

@pytest.fixture
def app():
def app_fixture(test_env=False, with_init=None):
# doing this because pytest fixtures seem hard to add optional params to

import tempfile, shutil, wsgiref, wsgiref.util
root = os.path.join(tempfile.gettempdir(), os.urandom(32).hex())
root = os.path.join(tempfile.gettempdir(), "smx-tests." + os.urandom(32).hex())
os.mkdir(root)
app = SmxWsgi(root)

def req(url, post=b''):
if with_init:
init = os.path.join(root, os.urandom(16).hex())
with open(init, "w") as f:
f.write(with_init)
os.environ["SMX_INIT"] = init

if test_env:
os.environ["SMX_ROOT"] = root
app = SmxWsgi()
else:
app = SmxWsgi(root)

def req(url, post=b'', type=""):
temp = io.BytesIO(post)
qs = ""
split = url.split('?')
if len(split) == 2:
url, qs = split
environ = {
'PATH_INFO': url,
'QUERY_STRING': qs,
'REQUEST_METHOD': 'POST' if post else 'GET',
'CONTENT_LENGTH': len(post),
'CONTENT_TYPE': type,
'wsgi.input': temp,
}

Expand Down Expand Up @@ -282,29 +295,102 @@ def create(path, data=b''):

app.req = req
app.create = create
app.__del__ = lambda: shutil.rmtree(root)

yield app

shutil.rmtree(root)
return app


def test_basic(app):
def test_basic():
app = app_fixture()
app.create("hi.smx", "%add(1,1)")
res = app.req("/hi.smx")
assert res.data == b'2'

def test_redirect(app):
def test_redirect():
app = app_fixture()
app.create("hi.smx", "%redirect(/yo)")
res = app.req("/hi.smx")
assert res.data == b''
assert res.code == 302
assert res.head.get("Location") == "/yo"

def test_error(app):
def test_error():
app = app_fixture()
app.create("hi.smx", "%notavar%")
res = app.req("/hi.smx")
assert b'Traceback' in res.data
assert res.code == 500

def test_qs():
app = app_fixture(test_env=True)
app.create("hi.smx", "%add(%form(x),1)")
res = app.req("/hi.smx?x=4")
assert b'5' == res.data
assert res.code == 200

def test_static():
app = app_fixture(test_env=True)
app.create("hi.txt", "%add(1,1)")
res = app.req("/hi.txt")
assert b'%add(1,1)' == res.data
assert res.code == 200


def test_init():
app = app_fixture(with_init='%set(foo, 44)')
app.create("hi.smx", "%add(2,%foo%)")
res = app.req("/hi.smx")
assert b'46' == res.data
assert res.code == 200

def test_index():
app = app_fixture(test_env=True)
app.create("index.smx", "%add(1,1)")
res = app.req("/")
assert b'2' == res.data
assert res.code == 200

def test_post_jq():
app = app_fixture(test_env=True)
app.create("index.smx", "%add(%jq(x),1)")
res = app.req("/", post=b'{"x":4}', type="application/json")
assert b'5' == res.data
assert res.code == 200

def test_post_badjq():
app = app_fixture(test_env=True)
app.create("index.smx", "%add(%jq(x),1)")
res = app.req("/", post=b'"x":4}', type="application/json")
assert res.code == 400

def test_404():
app = app_fixture(test_env=True)
res = app.req("/")
assert res.code == 404

def test_500():
app = app_fixture(test_env=True)
app.create("index.smx", "%addsdfsfd%")
res = app.req("/")
assert res.code == 500

def test_main():
import threading

app = app_fixture(test_env=True)
app.create("index.smx", "%add(44,44)")

import sys
sys.argv = ["smx", "-r", app.root, '-p' '8001']

t = threading.Thread(target=main, daemon=True)
import requests
t.start()
import time
t = time.monotonic() + 1
while time.monotonic() < t:
try:
assert requests.get("http://127.0.0.1:8001").text == "88"
except requests.ConnectionError:
continue


0 comments on commit 1cc979f

Please sign in to comment.