Permalink
Browse files

Initial checkin.

  • Loading branch information...
0 parents commit b558997fd9db8406b2a24a1831d06e283dbf35a6 @cortesi cortesi committed Jun 18, 2012
Showing with 998 additions and 0 deletions.
  1. +2 −0 .coveragerc
  2. +9 −0 .gitignore
  3. +2 −0 README
  4. 0 netlib/__init__.py
  5. +160 −0 netlib/odict.py
  6. +218 −0 netlib/protocol.py
  7. +182 −0 netlib/tcp.py
  8. +113 −0 test/test_odict.py
  9. +163 −0 test/test_protocol.py
  10. +93 −0 test/test_tcp.py
  11. +56 −0 test/tutils.py
@@ -0,0 +1,2 @@
+[report]
+include = *netlib*
@@ -0,0 +1,9 @@
+MANIFEST
+/build
+/dist
+/tmp
+/doc
+*.py[cdo]
+*.swp
+*.swo
+.coverage
2 README
@@ -0,0 +1,2 @@
+Netlib is a collection of common utility functions, used by the pathod and
+mitmproxy projects.
No changes.
@@ -0,0 +1,160 @@
+import re, copy
+
+def safe_subn(pattern, repl, target, *args, **kwargs):
+ """
+ There are Unicode conversion problems with re.subn. We try to smooth
+ that over by casting the pattern and replacement to strings. We really
+ need a better solution that is aware of the actual content ecoding.
+ """
+ return re.subn(str(pattern), str(repl), target, *args, **kwargs)
+
+
+class ODict:
+ """
+ A dictionary-like object for managing ordered (key, value) data.
+ """
+ def __init__(self, lst=None):
+ self.lst = lst or []
+
+ def _kconv(self, s):
+ return s
+
+ def __eq__(self, other):
+ return self.lst == other.lst
+
+ def __getitem__(self, k):
+ """
+ Returns a list of values matching key.
+ """
+ ret = []
+ k = self._kconv(k)
+ for i in self.lst:
+ if self._kconv(i[0]) == k:
+ ret.append(i[1])
+ return ret
+
+ def _filter_lst(self, k, lst):
+ k = self._kconv(k)
+ new = []
+ for i in lst:
+ if self._kconv(i[0]) != k:
+ new.append(i)
+ return new
+
+ def __len__(self):
+ """
+ Total number of (key, value) pairs.
+ """
+ return len(self.lst)
+
+ def __setitem__(self, k, valuelist):
+ """
+ Sets the values for key k. If there are existing values for this
+ key, they are cleared.
+ """
+ if isinstance(valuelist, basestring):
+ raise ValueError("ODict valuelist should be lists.")
+ new = self._filter_lst(k, self.lst)
+ for i in valuelist:
+ new.append([k, i])
+ self.lst = new
+
+ def __delitem__(self, k):
+ """
+ Delete all items matching k.
+ """
+ self.lst = self._filter_lst(k, self.lst)
+
+ def __contains__(self, k):
+ for i in self.lst:
+ if self._kconv(i[0]) == self._kconv(k):
+ return True
+ return False
+
+ def add(self, key, value):
+ self.lst.append([key, str(value)])
+
+ def get(self, k, d=None):
+ if k in self:
+ return self[k]
+ else:
+ return d
+
+ def items(self):
+ return self.lst[:]
+
+ def _get_state(self):
+ return [tuple(i) for i in self.lst]
+
+ @classmethod
+ def _from_state(klass, state):
+ return klass([list(i) for i in state])
+
+ def copy(self):
+ """
+ Returns a copy of this object.
+ """
+ lst = copy.deepcopy(self.lst)
+ return self.__class__(lst)
+
+ def __repr__(self):
+ elements = []
+ for itm in self.lst:
+ elements.append(itm[0] + ": " + itm[1])
+ elements.append("")
+ return "\r\n".join(elements)
+
+ def in_any(self, key, value, caseless=False):
+ """
+ Do any of the values matching key contain value?
+
+ If caseless is true, value comparison is case-insensitive.
+ """
+ if caseless:
+ value = value.lower()
+ for i in self[key]:
+ if caseless:
+ i = i.lower()
+ if value in i:
+ return True
+ return False
+
+ def match_re(self, expr):
+ """
+ Match the regular expression against each (key, value) pair. For
+ each pair a string of the following format is matched against:
+
+ "key: value"
+ """
+ for k, v in self.lst:
+ s = "%s: %s"%(k, v)
+ if re.search(expr, s):
+ return True
+ return False
+
+ def replace(self, pattern, repl, *args, **kwargs):
+ """
+ Replaces a regular expression pattern with repl in both keys and
+ values. Encoded content will be decoded before replacement, and
+ re-encoded afterwards.
+
+ Returns the number of replacements made.
+ """
+ nlst, count = [], 0
+ for i in self.lst:
+ k, c = safe_subn(pattern, repl, i[0], *args, **kwargs)
+ count += c
+ v, c = safe_subn(pattern, repl, i[1], *args, **kwargs)
+ count += c
+ nlst.append([k, v])
+ self.lst = nlst
+ return count
+
+
+class ODictCaseless(ODict):
+ """
+ A variant of ODict with "caseless" keys. This version _preserves_ key
+ case, but does not consider case when setting or getting items.
+ """
+ def _kconv(self, s):
+ return s.lower()
@@ -0,0 +1,218 @@
+import string, urlparse
+
+class ProtocolError(Exception):
+ def __init__(self, code, msg):
+ self.code, self.msg = code, msg
+
+ def __str__(self):
+ return "ProtocolError(%s, %s)"%(self.code, self.msg)
+
+
+def parse_url(url):
+ """
+ Returns a (scheme, host, port, path) tuple, or None on error.
+ """
+ scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
+ if not scheme:
+ return None
+ if ':' in netloc:
+ host, port = string.rsplit(netloc, ':', maxsplit=1)
+ try:
+ port = int(port)
+ except ValueError:
+ return None
+ else:
+ host = netloc
+ if scheme == "https":
+ port = 443
+ else:
+ port = 80
+ path = urlparse.urlunparse(('', '', path, params, query, fragment))
+ if not path.startswith("/"):
+ path = "/" + path
+ return scheme, host, port, path
+
+
+def read_headers(fp):
+ """
+ Read a set of headers from a file pointer. Stop once a blank line
+ is reached. Return a ODictCaseless object.
+ """
+ ret = []
+ name = ''
+ while 1:
+ line = fp.readline()
+ if not line or line == '\r\n' or line == '\n':
+ break
+ if line[0] in ' \t':
+ # continued header
+ ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip()
+ else:
+ i = line.find(':')
+ # We're being liberal in what we accept, here.
+ if i > 0:
+ name = line[:i]
+ value = line[i+1:].strip()
+ ret.append([name, value])
+ return ret
+
+
+def read_chunked(fp, limit):
+ content = ""
+ total = 0
+ while 1:
+ line = fp.readline(128)
+ if line == "":
+ raise IOError("Connection closed")
+ if line == '\r\n' or line == '\n':
+ continue
+ try:
+ length = int(line,16)
+ except ValueError:
+ # FIXME: Not strictly correct - this could be from the server, in which
+ # case we should send a 502.
+ raise ProtocolError(400, "Invalid chunked encoding length: %s"%line)
+ if not length:
+ break
+ total += length
+ if limit is not None and total > limit:
+ msg = "HTTP Body too large."\
+ " Limit is %s, chunked content length was at least %s"%(limit, total)
+ raise ProtocolError(509, msg)
+ content += fp.read(length)
+ line = fp.readline(5)
+ if line != '\r\n':
+ raise IOError("Malformed chunked body")
+ while 1:
+ line = fp.readline()
+ if line == "":
+ raise IOError("Connection closed")
+ if line == '\r\n' or line == '\n':
+ break
+ return content
+
+
+def has_chunked_encoding(headers):
+ for i in headers["transfer-encoding"]:
+ for j in i.split(","):
+ if j.lower() == "chunked":
+ return True
+ return False
+
+
+def read_http_body(rfile, headers, all, limit):
+ if has_chunked_encoding(headers):
+ content = read_chunked(rfile, limit)
+ elif "content-length" in headers:
+ try:
+ l = int(headers["content-length"][0])
+ except ValueError:
+ # FIXME: Not strictly correct - this could be from the server, in which
+ # case we should send a 502.
+ raise ProtocolError(400, "Invalid content-length header: %s"%headers["content-length"])
+ if limit is not None and l > limit:
+ raise ProtocolError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l))
+ content = rfile.read(l)
+ elif all:
+ content = rfile.read(limit if limit else None)
+ else:
+ content = ""
+ return content
+
+
+def parse_http_protocol(s):
+ if not s.startswith("HTTP/"):
+ return None
+ major, minor = s.split('/')[1].split('.')
+ major = int(major)
+ minor = int(minor)
+ return major, minor
+
+
+def parse_init_connect(line):
+ try:
+ method, url, protocol = string.split(line)
+ except ValueError:
+ return None
+ if method != 'CONNECT':
+ return None
+ try:
+ host, port = url.split(":")
+ except ValueError:
+ return None
+ port = int(port)
+ httpversion = parse_http_protocol(protocol)
+ if not httpversion:
+ return None
+ return host, port, httpversion
+
+
+def parse_init_proxy(line):
+ try:
+ method, url, protocol = string.split(line)
+ except ValueError:
+ return None
+ parts = parse_url(url)
+ if not parts:
+ return None
+ scheme, host, port, path = parts
+ httpversion = parse_http_protocol(protocol)
+ if not httpversion:
+ return None
+ return method, scheme, host, port, path, httpversion
+
+
+def parse_init_http(line):
+ """
+ Returns (method, url, httpversion)
+ """
+ try:
+ method, url, protocol = string.split(line)
+ except ValueError:
+ return None
+ if not (url.startswith("/") or url == "*"):
+ return None
+ httpversion = parse_http_protocol(protocol)
+ if not httpversion:
+ return None
+ return method, url, httpversion
+
+
+def request_connection_close(httpversion, headers):
+ """
+ Checks the request to see if the client connection should be closed.
+ """
+ if "connection" in headers:
+ for value in ",".join(headers['connection']).split(","):
+ value = value.strip()
+ if value == "close":
+ return True
+ elif value == "keep-alive":
+ return False
+ # HTTP 1.1 connections are assumed to be persistent
+ if httpversion == (1, 1):
+ return False
+ return True
+
+
+def response_connection_close(httpversion, headers):
+ """
+ Checks the response to see if the client connection should be closed.
+ """
+ if request_connection_close(httpversion, headers):
+ return True
+ elif not has_chunked_encoding(headers) and "content-length" in headers:
+ return True
+ return False
+
+
+def read_http_body_request(rfile, wfile, headers, httpversion, limit):
+ if "expect" in headers:
+ # FIXME: Should be forwarded upstream
+ expect = ",".join(headers['expect'])
+ if expect == "100-continue" and httpversion >= (1, 1):
+ wfile.write('HTTP/1.1 100 Continue\r\n')
+ wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION)
+ wfile.write('\r\n')
+ del headers['expect']
+ return read_http_body(rfile, headers, False, limit)
Oops, something went wrong.

0 comments on commit b558997

Please sign in to comment.