Skip to content

Commit

Permalink
Add asynchronous PubSub
Browse files Browse the repository at this point in the history
  • Loading branch information
evdokimovs committed Feb 14, 2020
1 parent 246152d commit cf93995
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 43 deletions.
217 changes: 214 additions & 3 deletions src/aio.rs
Expand Up @@ -18,6 +18,7 @@ use tokio_util::codec::Decoder;

#[cfg(unix)]
use futures_util::future::Either;
use futures_util::TryStreamExt;
use futures_util::{
future::{Future, FutureExt, TryFutureExt},
ready,
Expand All @@ -26,13 +27,13 @@ use futures_util::{
};

use pin_project_lite::pin_project;
use tokio_util::codec::FramedRead;

use crate::cmd::{cmd, Cmd};
use crate::types::{ErrorKind, RedisError, RedisFuture, RedisResult, Value};

use crate::connection::{ConnectionAddr, ConnectionInfo};

use crate::parser::ValueCodec;
use crate::types::{ErrorKind, RedisError, RedisFuture, RedisResult, Value};
use crate::{from_redis_value, ToRedisArgs};

enum ActualConnection {
Tcp(Buffered<TcpStream>),
Expand Down Expand Up @@ -104,16 +105,217 @@ impl AsyncBufRead for ActualConnection {
}
}

/// Represents a pubsub connection.
pub struct PubSub<'a> {
con: &'a mut Connection,
closed: bool,
}

impl<'a> PubSub<'a> {
fn new(con: &'a mut Connection) -> Self {
Self { con, closed: false }
}

/// Subscribes to a new channel.
pub async fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
Ok(cmd("SUBSCRIBE").arg(channel).query_async(self.con).await?)
}

/// Subscribes to a new channel with a pattern.
pub async fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
Ok(cmd("PSUBSCRIBE")
.arg(pchannel)
.query_async(self.con)
.await?)
}

/// Unsubscribes from a channel.
pub async fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
Ok(cmd("UNSUBSCRIBE")
.arg(channel)
.query_async(self.con)
.await?)
}

/// Unsubscribes from a channel with a pattern.
pub async fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
Ok(cmd("PUNSUBSCRIBE")
.arg(pchannel)
.query_async(self.con)
.await?)
}

/// Returns [`Stream`] into which will be sent all [`Msg`]s to which this [`PubSub`] subscribed.
pub fn get_messages<'b>(&'b mut self) -> impl Stream<Item = Msg> + 'b {
FramedRead::new(&mut self.con.con, ValueCodec::default())
.into_stream()
.filter_map(|msg| {
Box::pin(async move {
let msg = msg.ok()?;
let raw_msg: Vec<Value> = from_redis_value(&msg).ok()?;
let mut iter = raw_msg.into_iter();
let msg_type: String =
from_redis_value(&unwrap_or!(iter.next(), return None)).ok()?;
let mut pattern = None;
let payload;
let channel;

if msg_type == "message" {
channel = iter.next()?;
payload = iter.next()?;
} else if msg_type == "pmessage" {
pattern = Some(iter.next()?);
channel = iter.next()?;
payload = iter.next()?;
} else {
return None;
}

return Some(Msg {
payload,
channel,
pattern,
});
})
})
}

/// Closes this [`PubSub`].
pub async fn close(mut self) {
let res = self.con.clear_active_subscriptions().await;
if res.is_ok() {
self.closed = true;
} else {
// Raise the pubsub flag to indicate the connection is "stuck" in that state.
self.closed = false;
}
}
}

impl<'a> Drop for PubSub<'a> {
fn drop(&mut self) {
if !self.closed {
self.con.pubsub = true;
}
}
}

/// Represents a pubsub message.
#[derive(Debug)]
pub struct Msg {
payload: Value,
channel: Value,
pattern: Option<Value>,
}

/// Represents a stateful redis TCP connection.
pub struct Connection {
con: ActualConnection,
db: i64,
pubsub: bool,
}

impl Connection {
/// Fetches a single response from the connection. This is useful
/// if used in combination with `send_packed_command`.
pub async fn recv_response(&mut self) -> RedisResult<Value> {
self.con.read_response().await
}

async fn exit_pubsub(&mut self) -> RedisResult<()> {
let res = self.clear_active_subscriptions().await;
if res.is_ok() {
self.pubsub = false;
} else {
// Raise the pubsub flag to indicate the connection is "stuck" in that state.
self.pubsub = true;
}

res
}

/// Get the inner connection out of a PubSub
///
/// Any active subscriptions are unsubscribed. In the event of an error, the connection is
/// dropped.
async fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
// Responses to unsubscribe commands return in a 3-tuple with values
// ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs).
// The "count of remaining subs" includes both pattern subscriptions and non pattern
// subscriptions. Thus, to accurately drain all unsubscribe messages received from the
// server, both commands need to be executed at once.
{
// Prepare both unsubscribe commands
let unsubscribe = cmd("UNSUBSCRIBE").get_packed_command();
let punsubscribe = cmd("PUNSUBSCRIBE").get_packed_command();

// Grab a reference to the underlying connection so that we may send
// the commands without immediately blocking for a response.
let con = &mut self.con;

// Execute commands
con.send_bytes(&unsubscribe).await?;
con.send_bytes(&punsubscribe).await?;
}

// Receive responses
//
// There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe
// commands. There may be more responses if there are active subscriptions. In this case,
// messages are received until the _subscription count_ in the responses reach zero.
let mut received_unsub = false;
let mut received_punsub = false;
loop {
let res: (Vec<u8>, (), isize) = from_redis_value(&self.recv_response().await?)?;

match res.0.first() {
Some(&b'u') => received_unsub = true,
Some(&b'p') => received_punsub = true,
_ => (),
}

if received_unsub && received_punsub && res.2 == 0 {
break;
}
}

// Finally, the connection is back in its normal state since all subscriptions were
// cancelled *and* all unsubscribe messages were received.
Ok(())
}

/// Creates a [`PubSub`] instance for this connection.
pub fn as_pubsub(&mut self) -> PubSub {
self.pubsub = true;
PubSub::new(self)
}
}

impl ActualConnection {
/// Fetches a single response from the connection.
pub async fn read_response(&mut self) -> RedisResult<Value> {
crate::parser::parse_redis_value_async(self).await
}

async fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
match *self {
ActualConnection::Tcp(ref mut connection) => {
let res = connection.write_all(bytes).await.map_err(RedisError::from);
match res {
Err(e) => Err(e),
Ok(_) => Ok(Value::Okay),
}
}
#[cfg(unix)]
ActualConnection::Unix(ref mut connection) => {
let result = connection.write_all(bytes).await.map_err(RedisError::from);
match result {
Err(e) => Err(e),
Ok(_) => Ok(Value::Okay),
}
}
}
}
}

/// Opens a connection.
Expand Down Expand Up @@ -158,6 +360,7 @@ pub async fn connect(connection_info: &ConnectionInfo) -> RedisResult<Connection
let mut rv = Connection {
con,
db: connection_info.db,
pubsub: false,
};

if let Some(passwd) = &connection_info.passwd {
Expand Down Expand Up @@ -219,6 +422,10 @@ pub trait ConnectionLike: Sized {
impl ConnectionLike for Connection {
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
(async move {
if self.pubsub {
self.exit_pubsub().await?;
}

cmd.write_command_async(Pin::new(&mut self.con)).await?;
self.con.flush().await?;
self.con.read_response().await
Expand All @@ -233,6 +440,10 @@ impl ConnectionLike for Connection {
count: usize,
) -> RedisFuture<'a, Vec<Value>> {
(async move {
if self.pubsub {
self.exit_pubsub().await?;
}

cmd.write_pipeline_async(Pin::new(&mut self.con)).await?;
self.con.flush().await?;

Expand Down
2 changes: 1 addition & 1 deletion src/connection.rs
Expand Up @@ -511,7 +511,7 @@ impl Connection {
self.con.set_read_timeout(dur)
}

/// Creats a pubsub instance.for this connection.
/// Creates a [`PubSub`] instance for this connection.
pub fn as_pubsub(&mut self) -> PubSub<'_> {
// NOTE: The pubsub flag is intentionally not raised at this time since
// running commands within the pubsub state should not try and exit from
Expand Down
74 changes: 35 additions & 39 deletions tests/test_async.rs
Expand Up @@ -15,24 +15,22 @@ fn test_args() {
let ctx = TestContext::new();
let connect = ctx.async_connection();

block_on_all(connect.and_then(|mut con| {
async move {
redis::cmd("SET")
.arg("key1")
.arg(b"foo")
.query_async(&mut con)
.await?;
redis::cmd("SET")
.arg(&["key2", "bar"])
.query_async(&mut con)
.await?;
let result = redis::cmd("MGET")
.arg(&["key1", "key2"])
.query_async(&mut con)
.await;
assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec())));
result
}
block_on_all(connect.and_then(|mut con| async move {
redis::cmd("SET")
.arg("key1")
.arg(b"foo")
.query_async(&mut con)
.await?;
redis::cmd("SET")
.arg(&["key2", "bar"])
.query_async(&mut con)
.await?;
let result = redis::cmd("MGET")
.arg(&["key1", "key2"])
.query_async(&mut con)
.await;
assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec())));
result
}))
.unwrap();
}
Expand All @@ -45,27 +43,25 @@ fn dont_panic_on_closed_multiplexed_connection() {

block_on_all(async move {
connect
.and_then(|con| {
async move {
let cmd = move || {
let mut con = con.clone();
async move {
redis::cmd("SET")
.arg("key1")
.arg(b"foo")
.query_async(&mut con)
.await
}
};
let result: RedisResult<()> = cmd().await;
assert_eq!(
result.as_ref().unwrap_err().kind(),
redis::ErrorKind::IoError,
"{}",
result.as_ref().unwrap_err()
);
cmd().await
}
.and_then(|con| async move {
let cmd = move || {
let mut con = con.clone();
async move {
redis::cmd("SET")
.arg("key1")
.arg(b"foo")
.query_async(&mut con)
.await
}
};
let result: RedisResult<()> = cmd().await;
assert_eq!(
result.as_ref().unwrap_err().kind(),
redis::ErrorKind::IoError,
"{}",
result.as_ref().unwrap_err()
);
cmd().await
})
.map(|result| {
assert_eq!(
Expand Down

0 comments on commit cf93995

Please sign in to comment.