Skip to content

Commit

Permalink
Use plain socket objects instead of wrapper classes
Browse files Browse the repository at this point in the history
Refactor socket creation to remove the socket wrapper classes so that
these objects have less surprising behavior when used in worker hooks,
worker classes, and custom applications.

Close #3013.
  • Loading branch information
tilgovi committed Dec 28, 2023
1 parent b5d78e8 commit a24ff07
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 190 deletions.
4 changes: 2 additions & 2 deletions gunicorn/arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def start(self):

self.LISTENERS = sock.create_sockets(self.cfg, self.log, fds)

listeners_str = ",".join([str(lnr) for lnr in self.LISTENERS])
listeners_str = ",".join([sock.get_uri(lnr) for lnr in self.LISTENERS])
self.log.debug("Arbiter booted")
self.log.info("Listening at: %s (%s)", listeners_str, self.pid)
self.log.info("Using worker: %s", self.cfg.worker_class_str)
Expand Down Expand Up @@ -461,7 +461,7 @@ def reload(self):
lnr.close()
# init new listeners
self.LISTENERS = sock.create_sockets(self.cfg, self.log)
listeners_str = ",".join([str(lnr) for lnr in self.LISTENERS])
listeners_str = ",".join([sock.get_uri(lnr) for lnr in self.LISTENERS])
self.log.info("Listening at: %s", listeners_str)

# do some actions on reload
Expand Down
4 changes: 2 additions & 2 deletions gunicorn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,7 +2076,7 @@ class KeyFile(Setting):
section = "SSL"
cli = ["--keyfile"]
meta = "FILE"
validator = validate_string
validator = validate_file_exists
default = None
desc = """\
SSL key file
Expand All @@ -2088,7 +2088,7 @@ class CertFile(Setting):
section = "SSL"
cli = ["--certfile"]
meta = "FILE"
validator = validate_string
validator = validate_file_exists
default = None
desc = """\
SSL certificate file
Expand Down
264 changes: 97 additions & 167 deletions gunicorn/sock.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,130 +14,56 @@
from gunicorn import util


class BaseSocket(object):

def __init__(self, address, conf, log, fd=None):
self.log = log
self.conf = conf

self.cfg_addr = address
if fd is None:
sock = socket.socket(self.FAMILY, socket.SOCK_STREAM)
bound = False
def _get_socket_family(addr):
if isinstance(addr, tuple):
if util.is_ipv6(addr[0]):
return socket.AF_INET6
else:
sock = socket.fromfd(fd, self.FAMILY, socket.SOCK_STREAM)
os.close(fd)
bound = True

self.sock = self.set_options(sock, bound=bound)

def __str__(self):
return "<socket %d>" % self.sock.fileno()

def __getattr__(self, name):
return getattr(self.sock, name)

def set_options(self, sock, bound=False):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if (self.conf.reuse_port
and hasattr(socket, 'SO_REUSEPORT')): # pragma: no cover
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except socket.error as err:
if err.errno not in (errno.ENOPROTOOPT, errno.EINVAL):
raise
if not bound:
self.bind(sock)
sock.setblocking(0)
return socket.AF_INET

# make sure that the socket can be inherited
if hasattr(sock, "set_inheritable"):
sock.set_inheritable(True)
if isinstance(addr, (str, bytes)):
return socket.AF_UNIX

sock.listen(self.conf.backlog)
return sock
raise TypeError("Unable to determine socket family for: %r" % addr)

def bind(self, sock):
sock.bind(self.cfg_addr)

def close(self):
if self.sock is None:
return
def create_socket(conf, log, addr):
family = _get_socket_family(addr)

if family is socket.AF_UNIX:
# remove any existing socket at the given path
try:
self.sock.close()
except socket.error as e:
self.log.info("Error while closing socket %s", str(e))

self.sock = None


class TCPSocket(BaseSocket):

FAMILY = socket.AF_INET

def __str__(self):
if self.conf.is_ssl:
scheme = "https"
st = os.stat(addr)
except OSError as e:
if e.args[0] != errno.ENOENT:
raise
else:
scheme = "http"

addr = self.sock.getsockname()
return "%s://%s:%d" % (scheme, addr[0], addr[1])

def set_options(self, sock, bound=False):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
return super().set_options(sock, bound=bound)


class TCP6Socket(TCPSocket):

FAMILY = socket.AF_INET6

def __str__(self):
(host, port, _, _) = self.sock.getsockname()
return "http://[%s]:%d" % (host, port)


class UnixSocket(BaseSocket):

FAMILY = socket.AF_UNIX

def __init__(self, addr, conf, log, fd=None):
if fd is None:
try:
st = os.stat(addr)
except OSError as e:
if e.args[0] != errno.ENOENT:
raise
if stat.S_ISSOCK(st.st_mode):
os.remove(addr)
else:
if stat.S_ISSOCK(st.st_mode):
os.remove(addr)
else:
raise ValueError("%r is not a socket" % addr)
super().__init__(addr, conf, log, fd=fd)

def __str__(self):
return "unix:%s" % self.cfg_addr

def bind(self, sock):
old_umask = os.umask(self.conf.umask)
sock.bind(self.cfg_addr)
util.chown(self.cfg_addr, self.conf.uid, self.conf.gid)
os.umask(old_umask)
raise ValueError("%r is not a socket" % addr)

for i in range(5):
try:
sock = socket.socket(family)
sock.bind(addr)
sock.listen(conf.backlog)
if family is socket.AF_UNIX:
util.chown(addr, conf.uid, conf.gid)
return sock
except socket.error as e:
if e.args[0] == errno.EADDRINUSE:
log.error("Connection in use: %s", str(addr))
if e.args[0] == errno.EADDRNOTAVAIL:
log.error("Invalid address: %s", str(addr))
if i < 5:
msg = "connection to {addr} failed: {error}"
log.debug(msg.format(addr=str(addr), error=str(e)))
log.error("Retrying in 1 second.")
time.sleep(1)

def _sock_type(addr):
if isinstance(addr, tuple):
if util.is_ipv6(addr[0]):
sock_type = TCP6Socket
else:
sock_type = TCPSocket
elif isinstance(addr, (str, bytes)):
sock_type = UnixSocket
else:
raise TypeError("Unable to create socket from: %r" % addr)
return sock_type
log.error("Can't connect to %s", str(addr))
sys.exit(1)


def create_sockets(conf, log, fds=None):
Expand All @@ -150,67 +76,71 @@ def create_sockets(conf, log, fds=None):
"""
listeners = []

# get it only once
addr = conf.address
fdaddr = [bind for bind in addr if isinstance(bind, int)]
if fds:
fdaddr += list(fds)
laddr = [bind for bind in addr if not isinstance(bind, int)]

# check ssl config early to raise the error on startup
# only the certfile is needed since it can contains the keyfile
if conf.certfile and not os.path.exists(conf.certfile):
raise ValueError('certfile "%s" does not exist' % conf.certfile)

if conf.keyfile and not os.path.exists(conf.keyfile):
raise ValueError('keyfile "%s" does not exist' % conf.keyfile)

# sockets are already bound
if fdaddr:
for fd in fdaddr:
sock = socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM)
sock_name = sock.getsockname()
sock_type = _sock_type(sock_name)
listener = sock_type(sock_name, conf, log, fd=fd)
listeners.append(listener)

return listeners

# no sockets is bound, first initialization of gunicorn in this env.
for addr in laddr:
sock_type = _sock_type(addr)
sock = None
for i in range(5):
try:
sock = sock_type(addr, conf, log)
except socket.error as e:
if e.args[0] == errno.EADDRINUSE:
log.error("Connection in use: %s", str(addr))
if e.args[0] == errno.EADDRNOTAVAIL:
log.error("Invalid address: %s", str(addr))
if i < 5:
msg = "connection to {addr} failed: {error}"
log.debug(msg.format(addr=str(addr), error=str(e)))
log.error("Retrying in 1 second.")
time.sleep(1)
else:
break

if sock is None:
log.error("Can't connect to %s", str(addr))
sys.exit(1)

listeners.append(sock)
# sockets are already bound
listeners = []
for fd in list(fds) + [a for a in conf.address if isinstance(a, int)]:
sock = socket.socket(fileno=fd)
set_socket_options(conf, sock)
listeners.append(sock)
else:
# first initialization of gunicorn
old_umask = os.umask(conf.umask)
try:
for addr in [bind for bind in conf.address if not isinstance(bind, int)]:
sock = create_socket(conf, log, addr)
set_socket_options(conf, sock)
listeners.append(sock)
finally:
os.umask(old_umask)

return listeners


def close_sockets(listeners, unlink=True):
for sock in listeners:
sock_name = sock.getsockname()
sock.close()
if unlink and _sock_type(sock_name) is UnixSocket:
os.unlink(sock_name)
try:
if unlink and sock.family is socket.AF_UNIX:
sock_name = sock.getsockname()
os.unlink(sock_name)
finally:
sock.close()


def get_uri(listener, is_ssl=False):
addr = listener.getsockname()
family = _get_socket_family(addr)
scheme = "https" if is_ssl else "http"

if family is socket.AF_INET:
(host, port) = listener.getsockname()
return f"{scheme}://{host}:{port}"

if family is socket.AF_INET6:
(host, port, _, _) = listener.getsockname()
return f"{scheme}://[{host}]:{port}"

if family is socket.AF_UNIX:
path = listener.getsockname()
return f"unix://{path}"


def set_socket_options(conf, sock):
sock.setblocking(False)

# make sure that the socket can be inherited
if hasattr(sock, "set_inheritable"):
sock.set_inheritable(True)

if sock.family in (socket.AF_INET, socket.AF_INET6):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if (conf.reuse_port and hasattr(socket, 'SO_REUSEPORT')): # pragma: no cover
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except socket.error as err:
if err.errno not in (errno.ENOPROTOOPT, errno.EINVAL):
raise


def ssl_context(conf):
Expand Down
Loading

0 comments on commit a24ff07

Please sign in to comment.