Skip to content

Commit

Permalink
Fix pylint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Javier Collado committed Nov 25, 2016
1 parent febb952 commit b8b4c65
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 43 deletions.
7 changes: 7 additions & 0 deletions rabbithole/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# -*- coding: utf-8 -*-

"""RabbitHole: Store messages from rabbitmq into a SQL database.
The way the message are stores is that each exchange name is mapped to a SQL
query that is executed when neded.
"""

__author__ = """Javier Collado"""
__email__ = 'javier@gigaspaces.com'
__version__ = '0.1.0'
31 changes: 20 additions & 11 deletions rabbithole/batcher.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
# -*- coding: utf-8 -*-

"""Batcher: group messages in batches before writing them to the database.
The strategy to batch messages is:
- store them in memory as they are received
- send them to the database when either the size or the time limit is
exceeded.
"""

import logging
import threading

from collections import defaultdict

logger = logging.getLogger(__name__)
LOGGER = logging.getLogger(__name__)


class Batcher(object):
Expand Down Expand Up @@ -46,7 +55,7 @@ def message_received_cb(self, exchange_name, payload):
with self.locks[exchange_name]:
batch = self.batches[exchange_name]
batch.append(payload)
logger.debug(
LOGGER.debug(
'Message added to %r batch (size: %d)',
exchange_name,
len(batch),
Expand All @@ -55,7 +64,7 @@ def message_received_cb(self, exchange_name, payload):
if len(batch) == 1:
self.start_timer(exchange_name)
elif len(batch) >= self.SIZE_LIMIT:
logger.debug(
LOGGER.debug(
'Size limit (%d) exceeded for %r',
self.SIZE_LIMIT,
exchange_name,
Expand All @@ -75,19 +84,19 @@ def time_expired_cb(self, exchange_name):
"""
# Use a lock to make sure that callback execution doesn't interleave
with self.locks[exchange_name]:
logger.debug(
LOGGER.debug(
'Time limit (%.2f) exceeded for %r',
self.TIME_LIMIT,
exchange_name,
)
self.insert_batch(exchange_name)
if exchange_name not in self.timers:
logger.warning('Timer not found for: %r', exchange_name)
LOGGER.warning('Timer not found for: %r', exchange_name)
return
del self.timers[exchange_name]

thread = threading.current_thread()
logger.debug(
LOGGER.debug(
'Timer thread finished: (%d, %s)',
thread.ident,
thread.name,
Expand All @@ -106,7 +115,7 @@ def insert_batch(self, exchange_name):
"""
batch = self.batches[exchange_name]
if not batch:
logger.warning('Nothing to insert: %r', exchange_name)
LOGGER.warning('Nothing to insert: %r', exchange_name)
return
self.database.insert(exchange_name, self.batches[exchange_name])
del self.batches[exchange_name]
Expand All @@ -122,7 +131,7 @@ def start_timer(self, exchange_name):
"""
if exchange_name in self.timers:
logger.warning('Timer already active for: %r', exchange_name)
LOGGER.warning('Timer already active for: %r', exchange_name)
return
timer = threading.Timer(
self.TIME_LIMIT,
Expand All @@ -132,7 +141,7 @@ def start_timer(self, exchange_name):
timer.name = 'timer-{}'.format(exchange_name)
timer.daemon = True
timer.start()
logger.debug('Timer thread started: (%d, %s)', timer.ident, timer.name)
LOGGER.debug('Timer thread started: (%d, %s)', timer.ident, timer.name)
self.timers[exchange_name] = timer

def cancel_timer(self, exchange_name):
Expand All @@ -147,10 +156,10 @@ def cancel_timer(self, exchange_name):
"""
timer = self.timers.get(exchange_name)
if timer is None:
logger.warning('Timer not found for: %r', exchange_name)
LOGGER.warning('Timer not found for: %r', exchange_name)
return
timer.cancel()
logger.debug(
LOGGER.debug(
'Timer thread cancelled: (%d, %s)',
timer.ident,
timer.name,
Expand Down
14 changes: 7 additions & 7 deletions rabbithole/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from rabbithole.db import Database
from rabbithole.batcher import Batcher

logger = logging.getLogger(__name__)
LOGGER = logging.getLogger(__name__)


def main(argv=None):
Expand All @@ -35,13 +35,13 @@ def main(argv=None):
try:
consumer = Consumer(config['rabbitmq'], config['output'].keys())
except pika.exceptions.AMQPError as exception:
logger.error('Rabbitmq connectivity error: %s', exception)
LOGGER.error('Rabbitmq connectivity error: %s', exception)
return 1

try:
database = Database(config['database'], config['output'])
database = Database(config['database'], config['output']).connect()
except sqlalchemy.exc.SQLAlchemyError as exception:
logger.error(exception)
LOGGER.error(exception)
return 1

batcher = Batcher(database)
Expand All @@ -50,7 +50,7 @@ def main(argv=None):
try:
consumer.run()
except KeyboardInterrupt:
logger.info('Interrupted by user')
LOGGER.info('Interrupted by user')

return 0

Expand All @@ -74,9 +74,9 @@ def yaml_file(path):
if not os.path.isfile(path):
raise argparse.ArgumentTypeError('File not found')

with open(path) as fp:
with open(path) as file_:
try:
data = yaml.load(fp)
data = yaml.load(file_)
except yaml.YAMLError:
raise argparse.ArgumentTypeError('YAML parsing error')

Expand Down
25 changes: 18 additions & 7 deletions rabbithole/consumer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
# -*- coding: utf-8 -*-

"""Consumer: get messages from rabbitmq.
The strategy to get messages is:
- connect to the rabbitmq server
- bind a queue to the desired exchanges
Note that it's assumed that the exchanges will have `fanout` type and that the
routing key isn't relevant in this case.
"""

import json
import logging

import blinker
import pika

logger = logging.getLogger(__name__)
LOGGER = logging.getLogger(__name__)


class Consumer(object):
"""Message consumer.
"""Rabbitmq message consumer.
:param server: Rabbitmq server IP address
:type server: str
Expand All @@ -21,15 +32,15 @@ class Consumer(object):

def __init__(self, server, exchange_names):
"""Configure exchanges and queue."""
logger.info('Connecting to %r...', server)
LOGGER.info('Connecting to %r...', server)
parameters = pika.ConnectionParameters(server)
connection = pika.BlockingConnection(parameters)
channel = connection.channel()

# Use a single queue to process messages from all exchanges
result = channel.queue_declare(auto_delete=True)
queue_name = result.method.queue
logger.debug('Declared queue %r', queue_name)
LOGGER.debug('Declared queue %r', queue_name)

for exchange_name in exchange_names:
channel.exchange_declare(
Expand All @@ -40,7 +51,7 @@ def __init__(self, server, exchange_names):
exchange=exchange_name,
queue=queue_name,
)
logger.debug(
LOGGER.debug(
'Queue %r bound to exchange %r', queue_name, exchange_name)

channel.basic_consume(self.message_received_cb, queue=queue_name)
Expand All @@ -65,11 +76,11 @@ def message_received_cb(self, channel, method_frame, header_frame, body):
"""
exchange_name = method_frame.exchange
logger.debug('Message received from %r: %s', exchange_name, body)
LOGGER.debug('Message received from %r: %s', exchange_name, body)

# Only accept json messages
if header_frame.content_type != 'application/json':
logger.warning(
LOGGER.warning(
'Message discarded. Unexpected content type: %r',
header_frame.content_type,
)
Expand Down
20 changes: 13 additions & 7 deletions rabbithole/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-

"""Database: run queries with batches of rows per exchange."""

import logging

from sqlalchemy import (
Expand All @@ -8,7 +10,7 @@
)
from sqlalchemy.exc import SQLAlchemyError

logger = logging.getLogger(__name__)
LOGGER = logging.getLogger(__name__)


class Database(object):
Expand All @@ -23,17 +25,21 @@ class Database(object):

def __init__(self, url, queries):
"""Connect to database."""
engine = create_engine(url)

self.connection = engine.connect()
logger.debug('Connected to: %r', url)
self.engine = create_engine(url)
self.connection = None

self.queries = {
exchange_name: text(query)
for exchange_name, query
in queries.items()
}

def connect(self):
"""Connect to the database."""
self.connection = self.engine.connect()
LOGGER.debug('Connected to: %r', self.engine.url)
return self

def insert(self, exchange_name, rows):
"""Insert rows in database.
Expand All @@ -48,6 +54,6 @@ def insert(self, exchange_name, rows):
try:
self.connection.execute(query, rows)
except SQLAlchemyError as exception:
logger.error(exception)
LOGGER.error(exception)
else:
logger.debug('Inserted %d rows', len(rows))
LOGGER.debug('Inserted %d rows', len(rows))
23 changes: 12 additions & 11 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,30 @@ class TestParseArguments(TestCase):

"""Argument parsing test cases."""

def test_systemexit_raised_on_config_file_does_not_exist(self):
def test_config_file_does_not_exist(self):
"""SystemExit is raised if the configuration file does not exist."""
# Do not include error output in test output
with self.assertRaises(SystemExit), patch('rabbithole.cli.sys.stderr'):
parse_arguments(['file-does-not-exist'])

def test_systemexit_raised_on_config_file_invalid(self):
def test_config_file_invalid(self):
"""SystemExit is raised if the configuration file is invalid."""
with self.assertRaises(SystemExit), \
patch('rabbithole.cli.sys.stderr'), \
patch('rabbithole.cli.os') as os, \
patch('rabbithole.cli.open') as open:
os.path.isfile.return_value = True
open().__enter__.return_value = StringIO('>invalid yaml<')
patch('rabbithole.cli.os') as os_, \
patch('rabbithole.cli.open') as open_:
os_.path.isfile.return_value = True
open_().__enter__.return_value = StringIO('>invalid yaml<')
parse_arguments(['some file'])

def test_config_file_load_success(self):
"""Config file successfully loaded."""
expected_value = {'a': 'value'}
with patch('rabbithole.cli.os') as os, \
patch('rabbithole.cli.open') as open:
os.path.isfile.return_value = True
open().__enter__.return_value = StringIO(yaml.dump(expected_value))
with patch('rabbithole.cli.os') as os_, \
patch('rabbithole.cli.open') as open_:
os_.path.isfile.return_value = True
open_().__enter__.return_value = (
StringIO(yaml.dump(expected_value)))
args = parse_arguments(['some file'])

self.assertDictEqual(args.config, expected_value)
Expand All @@ -63,7 +64,7 @@ def test_root_level_set_to_debug(self):
root_logger = logging.getLogger()
self.assertEqual(root_logger.level, logging.DEBUG)

def test_streamh_handler_level_set_to_argument(self):
def test_stream_handler_level(self):
"""Stream handler level set to argument value."""
expected_value = logging.ERROR
configure_logging(expected_value)
Expand Down

0 comments on commit b8b4c65

Please sign in to comment.