Skip to content

Commit

Permalink
Merge pull request #2479 from benoitc/capture-peer-name
Browse files Browse the repository at this point in the history
Capture peer name from accept
  • Loading branch information
tilgovi authored Dec 31, 2020
2 parents 03c642e + 3573fd3 commit 86eac4c
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 32 deletions.
35 changes: 11 additions & 24 deletions gunicorn/http/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
import io
import re
import socket
from errno import ENOTCONN

from gunicorn.http.unreader import SocketUnreader
from gunicorn.http.body import ChunkedReader, LengthReader, EOFReader, Body
from gunicorn.http.errors import (
InvalidHeader, InvalidHeaderName, NoMoreData,
Expand All @@ -29,9 +27,10 @@


class Message(object):
def __init__(self, cfg, unreader):
def __init__(self, cfg, unreader, peer_addr):
self.cfg = cfg
self.unreader = unreader
self.peer_addr = peer_addr
self.version = None
self.headers = []
self.trailers = []
Expand Down Expand Up @@ -69,16 +68,10 @@ def parse_headers(self, data):
# handle scheme headers
scheme_header = False
secure_scheme_headers = {}
if '*' in cfg.forwarded_allow_ips:
if ('*' in cfg.forwarded_allow_ips or
not isinstance(self.peer_addr, tuple)
or self.peer_addr[0] in cfg.forwarded_allow_ips):
secure_scheme_headers = cfg.secure_scheme_headers
elif isinstance(self.unreader, SocketUnreader):
remote_addr = self.unreader.sock.getpeername()
if self.unreader.sock.family in (socket.AF_INET, socket.AF_INET6):
remote_host = remote_addr[0]
if remote_host in cfg.forwarded_allow_ips:
secure_scheme_headers = cfg.secure_scheme_headers
elif self.unreader.sock.family == socket.AF_UNIX:
secure_scheme_headers = cfg.secure_scheme_headers

# Parse headers into key/value pairs paying attention
# to continuation lines.
Expand Down Expand Up @@ -169,7 +162,7 @@ def should_close(self):


class Request(Message):
def __init__(self, cfg, unreader, req_number=1):
def __init__(self, cfg, unreader, peer_addr, req_number=1):
self.method = None
self.uri = None
self.path = None
Expand All @@ -184,7 +177,7 @@ def __init__(self, cfg, unreader, req_number=1):

self.req_number = req_number
self.proxy_protocol_info = None
super().__init__(cfg, unreader)
super().__init__(cfg, unreader, peer_addr)

def get_data(self, unreader, buf, stop=False):
data = unreader.read()
Expand Down Expand Up @@ -280,16 +273,10 @@ def proxy_protocol(self, line):

def proxy_protocol_access_check(self):
# check in allow list
if isinstance(self.unreader, SocketUnreader):
try:
remote_host = self.unreader.sock.getpeername()[0]
except socket.error as e:
if e.args[0] == ENOTCONN:
raise ForbiddenProxyRequest("UNKNOW")
raise
if ("*" not in self.cfg.proxy_allow_ips and
remote_host not in self.cfg.proxy_allow_ips):
raise ForbiddenProxyRequest(remote_host)
if ("*" not in self.cfg.proxy_allow_ips and
isinstance(self.peer_addr, tuple) and
self.peer_addr[0] not in self.cfg.proxy_allow_ips):
raise ForbiddenProxyRequest(self.peer_addr[0])

def parse_proxy_protocol(self, line):
bits = line.split()
Expand Down
5 changes: 3 additions & 2 deletions gunicorn/http/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ class Parser(object):

mesg_class = None

def __init__(self, cfg, source):
def __init__(self, cfg, source, source_addr):
self.cfg = cfg
if hasattr(source, "recv"):
self.unreader = SocketUnreader(source)
else:
self.unreader = IterUnreader(source)
self.mesg = None
self.source_addr = source_addr

# request counter (for keepalive connetions)
self.req_count = 0
Expand All @@ -38,7 +39,7 @@ def __next__(self):

# Parse the next request
self.req_count += 1
self.mesg = self.mesg_class(self.cfg, self.unreader, self.req_count)
self.mesg = self.mesg_class(self.cfg, self.unreader, self.source_addr, self.req_count)
if not self.mesg:
raise StopIteration()
return self.mesg
Expand Down
2 changes: 1 addition & 1 deletion gunicorn/workers/base_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def is_already_handled(self, respiter):
def handle(self, listener, client, addr):
req = None
try:
parser = http.RequestParser(self.cfg, client)
parser = http.RequestParser(self.cfg, client, addr)
try:
listener_name = listener.getsockname()
if not self.cfg.keepalive:
Expand Down
2 changes: 1 addition & 1 deletion gunicorn/workers/gthread.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def init(self):
**self.cfg.ssl_options)

# initialize the parser
self.parser = http.RequestParser(self.cfg, self.sock)
self.parser = http.RequestParser(self.cfg, self.sock, self.client)

def set_timeout(self):
# set the timeout
Expand Down
2 changes: 1 addition & 1 deletion gunicorn/workers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def handle(self, listener, client, addr):
client = ssl.wrap_socket(client, server_side=True,
**self.cfg.ssl_options)

parser = http.RequestParser(self.cfg, client)
parser = http.RequestParser(self.cfg, client, addr)
req = next(parser)
self.handle_request(listener, req, client, addr)
except http.errors.NoMoreData as e:
Expand Down
2 changes: 1 addition & 1 deletion tests/t.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, name):
def __call__(self, func):
def run():
src = data_source(self.fname)
func(src, RequestParser(src, None))
func(src, RequestParser(src, None, None))
run.func_name = func.func_name
return run

Expand Down
4 changes: 2 additions & 2 deletions tests/treq.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_req(sn, sz, mt):

def check(self, cfg, sender, sizer, matcher):
cases = self.expect[:]
p = RequestParser(cfg, sender())
p = RequestParser(cfg, sender(), None)
for req in p:
self.same(req, sizer, matcher, cases.pop(0))
assert not cases
Expand Down Expand Up @@ -282,5 +282,5 @@ def send(self):
read += chunk

def check(self, cfg):
p = RequestParser(cfg, self.send())
p = RequestParser(cfg, self.send(), None)
next(p)

0 comments on commit 86eac4c

Please sign in to comment.