Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Changelog
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
.. contents::
:local:

Next Release
============

- Add support for ``Connection.connect_timeout`` parameter


.. _version-2.0.0:

2.0.0
Expand Down
17 changes: 14 additions & 3 deletions Modules/_librabbitmq/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,7 @@ PyRabbitMQ_ConnectionType_init(PyRabbitMQ_Connection *self,
"frame_max",
"heartbeat",
"client_properties",
"connect_timeout",
NULL
};
char *hostname;
Expand All @@ -1063,11 +1064,13 @@ PyRabbitMQ_ConnectionType_init(PyRabbitMQ_Connection *self,
int frame_max = 131072;
int heartbeat = 0;
int port = 5672;
int connect_timeout = 0;
PyObject *client_properties = NULL;

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ssssiiiiO", kwlist,
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ssssiiiiOi", kwlist,
&hostname, &userid, &password, &virtual_host, &port,
&channel_max, &frame_max, &heartbeat, &client_properties)) {
&channel_max, &frame_max, &heartbeat, &client_properties,
&connect_timeout)) {
return -1;
}

Expand All @@ -1089,6 +1092,7 @@ PyRabbitMQ_ConnectionType_init(PyRabbitMQ_Connection *self,
self->channel_max = channel_max;
self->frame_max = frame_max;
self->heartbeat = heartbeat;
self->connect_timeout = connect_timeout;
self->weakreflist = NULL;
self->callbacks = PyDict_New();
if (self->callbacks == NULL) return -1;
Expand Down Expand Up @@ -1127,6 +1131,7 @@ PyRabbitMQ_Connection_connect(PyRabbitMQ_Connection *self)
amqp_rpc_reply_t reply;
amqp_pool_t pool;
amqp_table_t properties;
struct timeval timeout = {0, 0};

pyobject_array_t pyobj_array = {0};

Expand All @@ -1144,7 +1149,13 @@ PyRabbitMQ_Connection_connect(PyRabbitMQ_Connection *self)
goto error;
}
Py_BEGIN_ALLOW_THREADS;
status = amqp_socket_open(socket, self->hostname, self->port);
if (self->connect_timeout <= 0) {
status = amqp_socket_open(socket, self->hostname, self->port);
} else {
timeout.tv_sec = self->connect_timeout;
status = amqp_socket_open_noblock(socket, self->hostname, self->port, &timeout);
}

Py_END_ALLOW_THREADS;
if (PyRabbitMQ_HandleAMQStatus(status, "Error opening socket")) {
goto error;
Expand Down
3 changes: 3 additions & 0 deletions Modules/_librabbitmq/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ typedef struct {
int frame_max;
int channel_max;
int heartbeat;
int connect_timeout;

int sockfd;
int connected;
Expand Down Expand Up @@ -267,6 +268,8 @@ static PyMemberDef PyRabbitMQ_ConnectionType_members[] = {
offsetof(PyRabbitMQ_Connection, frame_max), READONLY, NULL},
{"callbacks", T_OBJECT_EX,
offsetof(PyRabbitMQ_Connection, callbacks), READONLY, NULL},
{"connect_timeout", T_INT,
offsetof(PyRabbitMQ_Connection, connect_timeout), READONLY, NULL},
{NULL, 0, 0, 0, NULL} /* Sentinel */
};

Expand Down
3 changes: 2 additions & 1 deletion librabbitmq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,15 @@ class Connection(_librabbitmq.Connection):
def __init__(self, host='localhost', userid='guest', password='guest',
virtual_host='/', port=5672, channel_max=0xffff,
frame_max=131072, heartbeat=0, lazy=False,
client_properties=None, **kwargs):
client_properties=None, connect_timeout=None, **kwargs):
if ':' in host:
host, port = host.split(':')
super(Connection, self).__init__(
hostname=host, port=int(port), userid=userid, password=password,
virtual_host=virtual_host, channel_max=channel_max,
frame_max=frame_max, heartbeat=heartbeat,
client_properties=client_properties,
connect_timeout=0 if connect_timeout is None else int(connect_timeout),
)
self.channels = {}
self._used_channel_ids = array('H')
Expand Down
34 changes: 34 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,45 @@
import socket
import unittest
from array import array
import time

from librabbitmq import Message, Connection, ConnectionError, ChannelError
TEST_QUEUE = 'pyrabbit.testq'


class test_Connection(unittest.TestCase):
def test_connection_defaults(self):
"""Test making a connection with the default settings."""
with Connection() as connection:
self.assertGreaterEqual(connection.fileno(), 0)

def test_connection_invalid_host(self):
"""Test connection to an invalid host fails."""
# Will fail quickly as OS will reject it.
with self.assertRaises(ConnectionError):
Connection(host="255.255.255.255")

def test_connection_invalid_port(self):
"""Test connection to an invalid port fails."""
# Will fail quickly as OS will reject it.
with self.assertRaises(ConnectionError):
Connection(port=0)

def test_connection_timeout(self):
"""Test connection timeout."""
# Can't rely on local OS being configured to ignore SYN packets
# (OS would normally reply with RST to closed port). To test the
# timeout, need to connect to something that is either slow, or
# never responds.
start_time = time.time()
with self.assertRaises(ConnectionError):
Connection(host="google.com", port=81, connect_timeout=3)
took_time = time.time() - start_time
# Allow some leaway to avoid spurious test failures.
self.assertGreaterEqual(took_time, 2)
self.assertLessEqual(took_time, 4)


class test_Channel(unittest.TestCase):

def setUp(self):
Expand Down