Skip to content

Commit

Permalink
Change to consumers taking a single "message" argument
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgodwin committed Sep 8, 2015
1 parent 9b92eec commit 48d6f63
Show file tree
Hide file tree
Showing 11 changed files with 185 additions and 157 deletions.
29 changes: 12 additions & 17 deletions channels/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from django.core.handlers.base import BaseHandler
from django.http import HttpRequest, HttpResponse

from channels import Channel, channel_backends, DEFAULT_CHANNEL_BACKEND
from channels import Channel


class UrlConsumer(object):
Expand All @@ -15,13 +15,13 @@ def __init__(self):
self.handler = BaseHandler()
self.handler.load_middleware()

def __call__(self, channel, **kwargs):
request = HttpRequest.channel_decode(kwargs)
def __call__(self, message):
request = HttpRequest.channel_decode(message.content)
try:
response = self.handler.get_response(request)
except HttpResponse.ResponseLater:
return
Channel(request.response_channel).send(**response.channel_encode())
message.reply_channel.send(response.channel_encode())


def view_producer(channel_name):
Expand All @@ -30,24 +30,19 @@ def view_producer(channel_name):
and abandons the response (with an exception the Worker will catch)
"""
def producing_view(request):
Channel(channel_name).send(**request.channel_encode())
Channel(channel_name).send(request.channel_encode())
raise HttpResponse.ResponseLater()
return producing_view


def view_consumer(channel_name, alias=DEFAULT_CHANNEL_BACKEND):
def view_consumer(func):
"""
Decorates a normal Django view to be a channel consumer.
Does not run any middleware
"""
def inner(func):
@functools.wraps(func)
def consumer(channel, **kwargs):
request = HttpRequest.channel_decode(kwargs)
response = func(request)
Channel(request.response_channel).send(**response.channel_encode())
# Get the channel layer and register
channel_backend = channel_backends[DEFAULT_CHANNEL_BACKEND]
channel_backend.registry.add_consumer(consumer, [channel_name])
return func
return inner
@functools.wraps(func)
def consumer(message):
request = HttpRequest.channel_decode(message.content)
response = func(request)
message.reply_channel.send(response.channel_encode())
return func
21 changes: 16 additions & 5 deletions channels/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ def __init__(self, name, alias=DEFAULT_CHANNEL_BACKEND, channel_backend=None):
else:
self.channel_backend = channel_backends[alias]

def send(self, **kwargs):
def send(self, content):
"""
Send a message over the channel, taken from the kwargs.
Send a message over the channel - messages are always dicts.
"""
self.channel_backend.send(self.name, kwargs)
if not isinstance(content, dict):
raise ValueError("You can only send dicts as content on channels.")
self.channel_backend.send(self.name, content)

@classmethod
def new_name(self, prefix):
Expand All @@ -51,6 +53,9 @@ def as_view(self):
from channels.adapters import view_producer
return view_producer(self.name)

def __str__(self):
return self.name


class Group(object):
"""
Expand All @@ -66,13 +71,19 @@ def __init__(self, name, alias=DEFAULT_CHANNEL_BACKEND, channel_backend=None):
self.channel_backend = channel_backends[alias]

def add(self, channel):
if isinstance(channel, Channel):
channel = channel.name
self.channel_backend.group_add(self.name, channel)

def discard(self, channel):
if isinstance(channel, Channel):
channel = channel.name
self.channel_backend.group_discard(self.name, channel)

def channels(self):
self.channel_backend.group_channels(self.name)

def send(self, **kwargs):
self.channel_backend.send_group(self.name, kwargs)
def send(self, content):
if not isinstance(content, dict):
raise ValueError("You can only send dicts as content on channels.")
self.channel_backend.send_group(self.name, content)
51 changes: 27 additions & 24 deletions channels/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,27 @@ def http_session(func):
be None, rather than an empty session you can write to.
"""
@functools.wraps(func)
def inner(*args, **kwargs):
if "COOKIES" not in kwargs and "GET" not in kwargs:
def inner(message, *args, **kwargs):
if "COOKIES" not in message.content and "GET" not in message.content:
raise ValueError("No COOKIES or GET sent to consumer; this decorator can only be used on messages containing at least one.")
# Make sure there's a session key
session_key = None
if "GET" in kwargs:
if "GET" in message.content:
try:
session_key = kwargs['GET'].get("session_key", [])[0]
session_key = message.content['GET'].get("session_key", [])[0]
except IndexError:
pass
if "COOKIES" in kwargs and session_key is None:
session_key = kwargs['COOKIES'].get(settings.SESSION_COOKIE_NAME)
if "COOKIES" in message.content and session_key is None:
session_key = message.content['COOKIES'].get(settings.SESSION_COOKIE_NAME)
# Make a session storage
if session_key:
session_engine = import_module(settings.SESSION_ENGINE)
session = session_engine.SessionStore(session_key=session_key)
else:
session = None
kwargs['session'] = session
message.session = session
# Run the consumer
result = func(*args, **kwargs)
result = func(message, *args, **kwargs)
# Persist session if needed (won't be saved if error happens)
if session is not None and session.modified:
session.save()
Expand All @@ -65,46 +65,49 @@ def http_django_auth(func):
"""
@http_session
@functools.wraps(func)
def inner(*args, **kwargs):
def inner(message, *args, **kwargs):
# If we didn't get a session, then we don't get a user
if kwargs['session'] is None:
kwargs['user'] = None
if not hasattr(message, "session"):
raise ValueError("Did not see a session to get auth from")
if message.session is None:
message.user = None
# Otherwise, be a bit naughty and make a fake Request with just
# a "session" attribute (later on, perhaps refactor contrib.auth to
# pass around session rather than request)
else:
fake_request = type("FakeRequest", (object, ), {"session": kwargs['session']})
kwargs['user'] = auth.get_user(fake_request)
fake_request = type("FakeRequest", (object, ), {"session": message.session})
message.user = auth.get_user(fake_request)
# Run the consumer
return func(*args, **kwargs)
return func(message, *args, **kwargs)
return inner


def send_channel_session(func):
def channel_session(func):
"""
Provides a session-like object called "channel_session" to consumers
as a message attribute that will auto-persist across consumers with
the same incoming "send_channel" value.
the same incoming "reply_channel" value.
"""
@functools.wraps(func)
def inner(*args, **kwargs):
# Make sure there's a send_channel in kwargs
if "send_channel" not in kwargs:
raise ValueError("No send_channel sent to consumer; this decorator can only be used on messages containing it.")
# Turn the send_channel into a valid session key length thing.
def inner(message, *args, **kwargs):
# Make sure there's a reply_channel in kwargs
if not message.reply_channel:
raise ValueError("No reply_channel sent to consumer; this decorator can only be used on messages containing it.")
# Turn the reply_channel into a valid session key length thing.
# We take the last 24 bytes verbatim, as these are the random section,
# and then hash the remaining ones onto the start, and add a prefix
# TODO: See if there's a better way of doing this
session_key = "skt" + hashlib.md5(kwargs['send_channel'][:-24]).hexdigest()[:8] + kwargs['send_channel'][-24:]
reply_name = message.reply_channel.name
session_key = "skt" + hashlib.md5(reply_name[:-24]).hexdigest()[:8] + reply_name[-24:]
# Make a session storage
session_engine = import_module(settings.SESSION_ENGINE)
session = session_engine.SessionStore(session_key=session_key)
# If the session does not already exist, save to force our session key to be valid
if not session.exists(session.session_key):
session.save()
kwargs['channel_session'] = session
message.channel_session = session
# Run the consumer
result = func(*args, **kwargs)
result = func(message, *args, **kwargs)
# Persist session if needed (won't be saved if error happens)
if session.modified:
session.save()
Expand Down
42 changes: 16 additions & 26 deletions channels/interfaces/websocket_twisted.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,26 @@ def onConnect(self, request):

def onOpen(self):
# Make sending channel
self.send_channel = Channel.new_name("!django.websocket.send")
self.reply_channel = Channel.new_name("!django.websocket.send")
self.request_info["reply_channel"] = self.reply_channel
self.last_keepalive = time.time()
self.factory.protocols[self.send_channel] = self
self.factory.protocols[self.reply_channel] = self
# Send news that this channel is open
Channel("django.websocket.connect").send(
send_channel = self.send_channel,
**self.request_info
)
Channel("django.websocket.connect").send(self.request_info)

def onMessage(self, payload, isBinary):
if isBinary:
Channel("django.websocket.receive").send(
send_channel = self.send_channel,
Channel("django.websocket.receive").send(dict(
self.request_info,
content = payload,
binary = True,
**self.request_info
)
))
else:
Channel("django.websocket.receive").send(
send_channel = self.send_channel,
Channel("django.websocket.receive").send(dict(
self.request_info,
content = payload.decode("utf8"),
binary = False,
**self.request_info
)
))

def serverSend(self, content, binary=False, **kwargs):
"""
Expand All @@ -64,21 +60,15 @@ def serverClose(self):
self.sendClose()

def onClose(self, wasClean, code, reason):
if hasattr(self, "send_channel"):
del self.factory.protocols[self.send_channel]
Channel("django.websocket.disconnect").send(
send_channel = self.send_channel,
**self.request_info
)
if hasattr(self, "reply_channel"):
del self.factory.protocols[self.reply_channel]
Channel("django.websocket.disconnect").send(self.request_info)

def sendKeepalive(self):
"""
Sends a keepalive packet on the keepalive channel.
"""
Channel("django.websocket.keepalive").send(
send_channel = self.send_channel,
**self.request_info
)
Channel("django.websocket.keepalive").send(self.request_info)
self.last_keepalive = time.time()


Expand All @@ -94,7 +84,7 @@ def __init__(self, *args, **kwargs):
super(InterfaceFactory, self).__init__(*args, **kwargs)
self.protocols = {}

def send_channels(self):
def reply_channels(self):
return self.protocols.keys()

def dispatch_send(self, channel, message):
Expand Down Expand Up @@ -128,7 +118,7 @@ def backend_reader(self):
Run in a separate thread; reads messages from the backend.
"""
while True:
channels = self.factory.send_channels()
channels = self.factory.reply_channels()
# Quit if reactor is stopping
if not reactor.running:
return
Expand Down
6 changes: 3 additions & 3 deletions channels/interfaces/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, channel_backend, *args, **kwargs):
super(WSGIInterface, self).__init__(*args, **kwargs)

def get_response(self, request):
request.response_channel = Channel.new_name("django.wsgi.response")
Channel("django.wsgi.request", channel_backend=self.channel_backend).send(**request.channel_encode())
channel, message = self.channel_backend.receive_many_blocking([request.response_channel])
request.reply_channel = Channel.new_name("django.wsgi.response")
Channel("django.wsgi.request", channel_backend=self.channel_backend).send(request.channel_encode())
channel, message = self.channel_backend.receive_many_blocking([request.reply_channel])
return HttpResponse.channel_decode(message)
18 changes: 18 additions & 0 deletions channels/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .channel import Channel


class Message(object):
"""
Represents a message sent over a Channel.
The message content is a dict called .content, while
reply_channel is an optional extra attribute representing a channel
to use to reply to this message's end user, if that makes sense.
"""

def __init__(self, content, channel, channel_backend, reply_channel=None):
self.content = content
self.channel = channel
self.channel_backend = channel_backend
if reply_channel:
self.reply_channel = Channel(reply_channel, channel_backend=self.channel_backend)
4 changes: 2 additions & 2 deletions channels/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def encode_request(request):
"path": request.path,
"path_info": request.path_info,
"method": request.method,
"response_channel": request.response_channel,
"reply_channel": request.reply_channel,
}
return value

Expand All @@ -34,7 +34,7 @@ def decode_request(value):
request.path = value['path']
request.method = value['method']
request.path_info = value['path_info']
request.response_channel = value['response_channel']
request.reply_channel = value['reply_channel']
return request


Expand Down
11 changes: 9 additions & 2 deletions channels/worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import traceback
from .message import Message


class Worker(object):
Expand All @@ -17,12 +18,18 @@ def run(self):
"""
channels = self.channel_backend.registry.all_channel_names()
while True:
channel, message = self.channel_backend.receive_many_blocking(channels)
channel, content = self.channel_backend.receive_many_blocking(channels)
message = Message(
content=content,
channel=channel,
channel_backend=self.channel_backend,
reply_channel=content.get("reply_channel", None),
)
# Handle the message
consumer = self.channel_backend.registry.consumer_for_channel(channel)
if self.callback:
self.callback(channel, message)
try:
consumer(channel=channel, **message)
consumer(message)
except:
traceback.print_exc()

0 comments on commit 48d6f63

Please sign in to comment.