Skip to content

Commit

Permalink
Skip persisting incoming message if recipient is not allowed
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Qiu committed Sep 2, 2014
1 parent a38a897 commit 00df351
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
24 changes: 17 additions & 7 deletions mailchute/smtpd/mailchute.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import smtpd
from email.parser import Parser
from mailchute import db
from mailchute import settings
from mailchute.model import RawMessage, IncomingEmail
from logbook import Logger

Expand All @@ -9,6 +10,12 @@


class MessageProcessor(object):
def _should_persist(self, recipient):
allowed_receiver_domain = settings.RECEIVER_DOMAIN
recipient_domain = recipient.split('@')[1].lower()
return (allowed_receiver_domain is None
or recipient_domain == settings.RECEIVER_DOMAIN)

def __call__(self, peer, mailfrom, recipients, data):
try:
mailfrom = mailfrom.lower()
Expand All @@ -22,12 +29,16 @@ def __call__(self, peer, mailfrom, recipients, data):
raw_message = RawMessage(message=data)

for recipient in recipients:
incoming_email = IncomingEmail(
sender=mailfrom, recipient=recipient,
raw_message=raw_message,
subject=email['subject'],
)
db.session.add(incoming_email)
if self._should_persist(recipient):
incoming_email = IncomingEmail(
sender=mailfrom, recipient=recipient,
raw_message=raw_message,
subject=email['subject'],
)
db.session.add(incoming_email)
else:
logger.info('{} is not an allowed recipient. Skip.'.format(
recipient))

db.session.commit()
logger.info("Message saved")
Expand All @@ -36,6 +47,5 @@ def __call__(self, peer, mailfrom, recipients, data):
db.session.rollback()



class MailchuteSMTPServer(smtpd.SMTPServer):
process_message = MessageProcessor()
17 changes: 16 additions & 1 deletion tests/test_smtpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from mailchute import db
from mailchute.smtpd.mailchute import MessageProcessor
from mailchute.model import IncomingEmail
from unittest.mock import patch


class TestMessageProcessor(BaseTestCase):
Expand Down Expand Up @@ -44,7 +45,9 @@ def test_process_message_with_subject(self):
assert 1 == len(emails)
assert 'another test' == emails[0].subject

def test_process_message_recipient_wrong_domain(self):
@patch('mailchute.smtpd.mailchute.settings')
def test_process_message_recipient_wrong_domain(self, settings):
settings.RECEIVER_DOMAIN = 'receiver.com'
self.message_processor(
'PEER',
'johndoe@example.com',
Expand All @@ -53,3 +56,15 @@ def test_process_message_recipient_wrong_domain(self):
)
emails = db.session.query(IncomingEmail).all()
assert 0 == len(emails)

@patch('mailchute.smtpd.mailchute.settings')
def test_process_message_no_check_recipient_domain(self, settings):
settings.RECEIVER_DOMAIN = None
self.message_processor(
'PEER',
'johndoe@example.com',
['janesmith@test.com'],
'DATA',
)
emails = db.session.query(IncomingEmail).all()
assert 1 == len(emails)

0 comments on commit 00df351

Please sign in to comment.