diff --git a/salmon/queue.py b/salmon/queue.py index 6d497b3..6c9dd89 100644 --- a/salmon/queue.py +++ b/salmon/queue.py @@ -4,6 +4,7 @@ to do some serious surgery go use that. This works as a good API for the 90% case of "put mail in, get mail out" queues. """ +import contextlib import errno import hashlib import logging @@ -42,14 +43,6 @@ def _create_tmp(self): raise mailbox.ExternalClashError('Name clash prevented file creation: %s' % path) -class QueueError(Exception): - - def __init__(self, msg, data): - Exception.__init__(self, msg) - self._message = msg - self.data = data - - class Queue: """ Provides a simplified API for dealing with 'queues' in Salmon. @@ -104,6 +97,16 @@ def push(self, message): message = str(message) return self.mbox.add(message) + def _move_oversize(self, key, name): + if self.oversize_dir: + logging.info("Message key %s over size limit %d, moving to %s.", + key, self.pop_limit, self.oversize_dir) + os.rename(name, os.path.join(self.oversize_dir, key)) + else: + logging.info("Message key %s over size limit %d, DELETING (set oversize_dir).", + key, self.pop_limit) + os.unlink(name) + def pop(self): """ Pops a message off the queue, order is not really maintained @@ -115,21 +118,10 @@ def pop(self): over, over_name = self.oversize(key) if over: - if self.oversize_dir: - logging.info("Message key %s over size limit %d, moving to %s.", - key, self.pop_limit, self.oversize_dir) - os.rename(over_name, os.path.join(self.oversize_dir, key)) - else: - logging.info("Message key %s over size limit %d, DELETING (set oversize_dir).", - key, self.pop_limit) - os.unlink(over_name) + self._move_oversize(key, over_name) else: - try: - msg = self.get(key) - except QueueError as exc: - raise exc - finally: - self.remove(key) + msg = self.get(key) + self.remove(key) return key, msg return None, None @@ -149,11 +141,11 @@ def get(self, key): try: return mail.MailRequest(self.dir, None, None, msg_data) except Exception as exc: - logging.exception("Failed to decode message: %s; msg_data: %r", exc, msg_data) + logging.exception("Failed to decode message: %s; msg_data: %r", exc, msg_data) return None def remove(self, key): - """Removes the queue, but not returned.""" + """Removes key the queue.""" self.mbox.remove(key) def __len__(self): @@ -166,15 +158,8 @@ def __len__(self): def clear(self): """ Clears out the contents of the entire queue. - - Warning: This could be horribly inefficient since it pops messages - until the queue is empty. It could also cause an infinite loop if - another process is writing to messages to the Queue faster than we can - pop. """ - # man this is probably a really bad idea - while len(self) > 0: - self.pop() + self.mbox.clear() def keys(self): """ @@ -188,3 +173,80 @@ def oversize(self, key): return os.path.getsize(file_name) > self.pop_limit, file_name else: return False, None + + +class Metadata: + def __init__(self, path): + # mkdir dir+metadata + self.path = os.path.join(path, "metadata") + os.mkdir(self.path) + + def get(self, key): + return json.load(open(os.path.join(self.path, key), "r")) + + def set(self, key, data): + json.dump(open(os.path.join(self.path, key), "w"), data) + + def remove(self, key): + os.unlink(open(os.path.join(self.path, key))) + + @contextlib.contextmanager + def lock(self, key): + i = 0 + meta_file = open(os.path.join(self.path, key), "rw") + while True: + # try for a lock using exponential backoff + try: + fcntl.flock(meta_file, fcntl.LOCK_EX | fcntl.LOCK_NB) + except BlockingIOError: + if i > 5: + # 2**5 is 30 seconds which is far too long + raise + time.sleep(2**i) + i += 1 + else: + break + + try: + yield meta_file + finally: + fcntl.flock(meta_file, fcntl.LOCK_UN) + meta_file.close() + + +class QueueWithMetadata(Queue): + """Just like Queue, except it stores envolope data""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.metadata = Metadata(self.dir) + + def push(self, message, Peer, From, To): + if not isinstance(To, list): + To = [To] + key = super().push(message) + with self.metadata.lock(key): + self.metadata.set(key, {"Peer": Peer, "From": From, "To": To}) + return key + + def get(self, key): + with self.metadata.lock(key): + msg = super().get(key) + metadata = self.metadata.get(key) + # move data from metadata to msg obj + for k, v in metadata.items(): + setattr(msg, k, v) + metadata["To"].remove(msg.To) + self.metadata.set(key, metadata) + return msg + + def remove(self, key): + with self.metadata.lock(key) as meta_file: + metadata = self.metadata.get(key) + # if there's still a To to be processed, leave the message on disk + if not metadata.get("To"): + super().remove(key) + self.metadata.remove(key) + + def clear(self): + self.metadata.clear() + super().clear() diff --git a/salmon/server.py b/salmon/server.py index 4faf695..6f0dbcf 100644 --- a/salmon/server.py +++ b/salmon/server.py @@ -23,6 +23,7 @@ ROUTER_VERSION_STRING = "Salmon Mail router, version %s" % __version__ SMTP_MULTIPLE_RCPTS_ERROR = "451 Will not accept multiple recipients in one transaction" +IN_QUEUE = "run/in_queue" lmtpd.__version__ = ROUTER_VERSION_STRING smtpd.__version__ = ROUTER_VERSION_STRING @@ -292,40 +293,47 @@ def process_message(self, Peer, From, To, Data, **kwargs): return _deliver(self, Peer, From, To, Data, **kwargs) -class SMTPOnlyOneRcpt(SMTP): - async def smtp_RCPT(self, arg): - if self.envelope.rcpt_tos: - await self.push(SMTP_MULTIPLE_RCPTS_ERROR) - else: - await super().smtp_RCPT(arg) - - class SMTPHandler: - def __init__(self, executor=None): + def __init__(self, executor=None, *, in_queue): self.executor = executor + self.in_queue = in_queue async def handle_DATA(self, server, session, envelope): - status = await server.loop.run_in_executor(self.executor, partial( - _deliver, - self, - session.peer, - envelope.mail_from, - envelope.rcpt_tos[0], - envelope.content, - )) - return status or "250 Ok" + try: + status = await server.loop.run_in_executor(self.executor, partial( + self.in_queue.queue.push, + envelope.content, + session.peer, + envolope.mail_from, + envolope.rcpt_tos, + )) + status = "250 Ok" + except Exception: + logging.exception("Raised exception while trying to push to Queue: %r, Peer: %r, From: %r, To: %r") + status = "550 Server error" + return status class AsyncSMTPReceiver(Controller): """Receives emails and hands it to the Router for further processing.""" - def __init__(self, handler=None, **kwargs): + def __init__(self, handler=None, in_queue=None, **kwargs): + if in_queue is None: + in_queue = QueueReceiver(queue.QueueWithMetadata(IN_QUEUE)) + self.in_queue = in_queue if handler is None: - handler = SMTPHandler() + handler = SMTPHandler(in_queue=self.in_queue) super().__init__(handler, **kwargs) def factory(self): - # TODO implement a queue - return SMTPOnlyOneRcpt(self.handler, enable_SMTPUTF8=self.enable_SMTPUTF8, ident=ROUTER_VERSION_STRING) + return SMTP(self.handler, enable_SMTPUTF8=self.enable_SMTPUTF8, ident=ROUTER_VERSION_STRING) + + def start(self): + super().start() + self.in_queue.start() + + def stop(self): + super().stop() + self.in_queue.stop() class LMTPHandler: @@ -340,7 +348,8 @@ async def handle_DATA(self, server, session, envelope): self, session.peer, envelope.mail_from, - rcpt, envelope.content, + rcpt, + envelope.content, )) statuses.append(status or "250 Ok") return "\r\n".join(statuses) @@ -389,20 +398,25 @@ class QueueReceiver: same way otherwise. """ - def __init__(self, queue_dir, sleep=10, size_limit=0, oversize_dir=None, workers=10): + def __init__(self, in_queue, sleep=10, size_limit=0, oversize_dir=None, workers=10): """ The router should be fully configured and ready to work, the queue_dir can be a fully qualified path or relative. The option workers dictates how many threads are started to process messages. Consider adding ``@nolocking`` to your handlers if you are able to. """ - self.queue = queue.Queue(queue_dir, pop_limit=size_limit, - oversize_dir=oversize_dir) + if isinstance(in_queue, str): + self.queue = queue.Queue(in_queue, pop_limit=size_limit, + oversize_dir=oversize_dir) + else: + self.queue = in_queue self.sleep = sleep # Pool is from multiprocess.dummy which uses threads rather than processes self.workers = Pool(workers) + self._running = True + def start(self, one_shot=False): """ Start simply loops indefinitely sleeping and pulling messages @@ -412,25 +426,35 @@ def start(self, one_shot=False): """ logging.info("Queue receiver started on queue dir %s", self.queue.dir) - logging.debug("Sleeping for %d seconds...", self.sleep) - - # if there are no messages left in the maildir and this a one-shot, the - # while loop terminates - while not (len(self.queue) == 0 and one_shot): - # if there's nothing in the queue, take a break - if len(self.queue) == 0: - time.sleep(self.sleep) - continue - try: - key, msg = self.queue.pop() - except KeyError: - logging.debug("Could not find message in Queue") - continue - - logging.debug("Pulled message with key: %r off", key) - self.workers.apply_async(self.process_message, args=(msg,)) + def _run(): + while self._running: + # if there's nothing in the queue, take a break + if len(self.queue) == 0: + if one_shot: + self._running = False + else: + logging.debug("Sleeping for %d seconds...", self.sleep) + time.sleep(self.sleep) + continue + + try: + key, msg = self.queue.pop() + except KeyError: + logging.debug("Could not find message in Queue") + continue + + logging.debug("Pulled message with key: %r off", key) + self.workers.apply_async(self.process_message, args=(msg,)) + self.main_thread = threading.Thread(target=_run) + self.main_thread.start() + + if one_shot: + self.main_thread.join() + def stop(self): + self._running = False + self.main_thread.join() self.workers.close() self.workers.join() @@ -441,12 +465,13 @@ def process_message(self, msg): """ try: - logging.debug("Message received from Peer: %r, From: %r, to To %r.", msg.Peer, msg.From, msg.To) + logging.debug("Message received from Queue: %r, Peer: %r, From: %r, to To %r.", + self.queue, msg.Peer, msg.From, msg.To) routing.Router.deliver(msg) except SMTPError as err: logging.exception("Raising SMTPError when running in a QueueReceiver is unsupported.") undeliverable_message(msg.Data, err.message) except Exception: - logging.exception("Exception while processing message from Peer: " - "%r, From: %r, to To %r.", msg.Peer, msg.From, msg.To) + logging.exception("Exception while processing message from Queue: %r, Peer: " + "%r, From: %r, to To %r.", self.queue, msg.Peer, msg.From, msg.To) undeliverable_message(msg.Data, "Router failed to catch exception.") diff --git a/tests/command_tests.py b/tests/command_tests.py index 8bda525..459d8cd 100644 --- a/tests/command_tests.py +++ b/tests/command_tests.py @@ -79,7 +79,7 @@ def test_queue_command(self, MockQueue): self.assertEqual(mq.__len__.call_count, 1) @patch('salmon.utils.daemonize') - @patch('salmon.server.SMTPReceiver') + @patch('salmon.server.AsyncSMTPReceiver') def test_log_command(self, MockSMTPReceiver, daemon_mock): runner = CliRunner() ms = MockSMTPReceiver() diff --git a/tests/server_tests.py b/tests/server_tests.py index 49cd19c..31f81ae 100644 --- a/tests/server_tests.py +++ b/tests/server_tests.py @@ -222,6 +222,7 @@ def sleepy(*args, **kwargs): receiver = server.QueueReceiver('run/queue', sleep=10, workers=1) with self.assertRaises(SleepCalled): receiver.start() + receiver.main_thread.join() self.assertEqual(receiver.workers.apply_async.call_count, 0) self.assertEqual(sleep_mock.call_count, 2) @@ -427,7 +428,7 @@ def test_multiple_rcpts(self): code, _ = client.rcpt("you@localhost") self.assertEqual(code, 250) code, _ = client.rcpt("them@localhost") - self.assertEqual(code, 451) + self.assertEqual(code, 250) class AsyncLmtpServerTestCase(SalmonTestCase): diff --git a/tests/utils_tests.py b/tests/utils_tests.py index a8eb808..b15d99d 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -21,7 +21,6 @@ def test_make_fake_settings(self): assert settings assert settings.receiver assert settings.relay is None - settings.receiver.close() def test_import_settings(self): assert utils.settings is None