Skip to content

Commit

Permalink
Rewrite collectors (serenity-rs#2278)
Browse files Browse the repository at this point in the history
  • Loading branch information
kangalio authored and mkrasnitski committed Oct 24, 2023
1 parent 96e2981 commit e5eece2
Show file tree
Hide file tree
Showing 30 changed files with 303 additions and 1,212 deletions.
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ futures = { version = "0.3", default-features = false, features = ["std"] }
dep_time = { version = "0.3.6", package = "time", features = ["formatting", "parsing", "serde-well-known"] }
# Optional dependencies
fxhash = { version = "0.2.1", optional = true }
derivative = { version = "2.2.0", optional = true }
simd-json = { version = "0.6", optional = true }
uwl = { version = "0.6.0", optional = true }
base64 = { version = "0.21", optional = true }
Expand Down Expand Up @@ -82,7 +81,7 @@ builder = ["base64"]
cache = ["fxhash", "dashmap", "parking_lot"]
# Enables collectors, a utility feature that lets you await interaction events in code with
# zero setup, without needing to setup an InteractionCreate event listener.
collector = ["gateway", "model", "derivative"]
collector = ["gateway", "model"]
# Wraps the gateway and http functionality into a single interface
# TODO: should this require "gateway"?
client = ["http", "typemap_rev"]
Expand Down
42 changes: 17 additions & 25 deletions examples/e10_collectors/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::env;
use std::time::Duration;

use serenity::async_trait;
use serenity::collector::{EventCollectorBuilder, MessageCollectorBuilder};
use serenity::collector::MessageCollector;
use serenity::framework::standard::macros::{command, group, help};
use serenity::framework::standard::{
help_commands,
Expand Down Expand Up @@ -115,11 +115,7 @@ async fn challenge(ctx: &Context, msg: &Message, _: Args) -> CommandResult {
.author_id(msg.author.id);

if let Some(reaction) = collector.collect_single().await {
// By default, the collector will collect only added reactions.
// We could also pattern-match the reaction in case we want
// to handle added or removed reactions.
// In this case we will just get the inner reaction.
let _ = if reaction.as_inner_ref().emoji.as_data() == "1️⃣" {
let _ = if reaction.emoji.as_data() == "1️⃣" {
score += 1;
msg.reply(ctx, "That's correct!").await
} else {
Expand All @@ -132,14 +128,14 @@ async fn challenge(ctx: &Context, msg: &Message, _: Args) -> CommandResult {
let _ = msg.reply(ctx, "Write 5 messages in 10 seconds").await;

// We can create a collector from scratch too using this builder future.
let collector = MessageCollectorBuilder::new(&ctx.shard)
let collector = MessageCollector::new(&ctx.shard)
// Only collect messages by this user.
.author_id(msg.author.id)
.channel_id(msg.channel_id)
.collect_limit(5u32)
.timeout(Duration::from_secs(10))
// Build the collector.
.build();
// Build the collector.
.collect_stream()
.take(5);

// Let's acquire borrow HTTP to send a message inside the `async move`.
let http = &ctx.http;
Expand All @@ -164,26 +160,22 @@ async fn challenge(ctx: &Context, msg: &Message, _: Args) -> CommandResult {
score += 1;
}

// We can also collect arbitrary events using the generic EventCollector. For example, here we
// We can also collect arbitrary events using the collect() function. For example, here we
// collect updates to the messages that the user sent above and check for them updating all 5 of
// them.
let builder = EventCollectorBuilder::new(&ctx.shard)
.add_event_type(EventType::MessageUpdate)
.timeout(Duration::from_secs(20));

// Only collect MessageUpdate events for the 5 MessageIds we're interested in.
let mut collector =
collected.iter().try_fold(builder, |b, msg| b.add_message_id(msg.id))?.build();
let mut collector = serenity::collector::collect(&ctx.shard, move |event| match event {
// Only collect MessageUpdate events for the 5 MessageIds we're interested in.
Event::MessageUpdate(event) if collected.iter().any(|msg| event.id == msg.id) => {
Some(event.id)
},
_ => None,
})
.take_until(Box::pin(tokio::time::sleep(Duration::from_secs(20))));

let _ = msg.reply(ctx, "Edit each of those 5 messages in 20 seconds").await;
let mut edited = HashSet::new();
while let Some(event) = collector.next().await {
match event.as_ref() {
Event::MessageUpdate(e) => {
edited.insert(e.id);
},
e => panic!("Unexpected event type received: {:?}", e.event_type()),
}
while let Some(edited_message_id) = collector.next().await {
edited.insert(edited_message_id);
if edited.len() >= 5 {
break;
}
Expand Down
2 changes: 1 addition & 1 deletion examples/e17_message_components/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl EventHandler for Handler {
let mut interaction_stream = m
.component_interaction_collector(&ctx.shard)
.timeout(Duration::from_secs(60 * 3))
.build();
.collect_stream();

while let Some(interaction) = interaction_stream.next().await {
let sound = &interaction.data.custom_id;
Expand Down
22 changes: 22 additions & 0 deletions examples/testing/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,28 @@ async fn message(ctx: &Context, msg: Message) -> Result<(), serenity::Error> {
})),
)
.await?;
} else if msg.content == "manybuttons" {
let mut custom_id = msg.id.to_string();
loop {
let msg = channel_id
.send_message(
ctx,
CreateMessage::new()
.button(CreateButton::new(custom_id.clone()).label(custom_id)),
)
.await?;
let button_press = msg
.component_interaction_collector(&ctx.shard)
.timeout(std::time::Duration::from_secs(10))
.collect_single()
.await;
match button_press {
Some(x) => x.defer(ctx).await?,
None => break,
}

custom_id = msg.id.to_string();
}
} else if msg.content == "reactionremoveemoji" {
// Test new ReactionRemoveEmoji gateway event: https://github.com/serenity-rs/serenity/issues/2248
msg.react(ctx, '👍').await?;
Expand Down
6 changes: 3 additions & 3 deletions src/builder/quick_modal.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use super::{CreateActionRow, CreateInputText, CreateInteractionResponse, CreateModal};
use crate::client::Context;
use crate::collector::ModalInteractionCollectorBuilder;
use crate::collector::ModalInteractionCollector;
use crate::model::id::InteractionId;
use crate::model::prelude::component::{ActionRowComponent, InputTextStyle};
use crate::model::prelude::ModalInteraction;

#[cfg(feature = "collector")]
pub struct QuickModalResponse {
pub interaction: std::sync::Arc<ModalInteraction>,
pub interaction: ModalInteraction,
pub inputs: Vec<String>,
}

Expand Down Expand Up @@ -99,7 +99,7 @@ impl CreateQuickModal {
);
builder.execute(ctx, interaction_id, token).await?;

let modal_interaction = ModalInteractionCollectorBuilder::new(&ctx.shard)
let modal_interaction = ModalInteractionCollector::new(&ctx.shard)
.custom_ids(vec![modal_custom_id])
.collect_single()
.await;
Expand Down
13 changes: 12 additions & 1 deletion src/client/bridge/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ pub use self::shard_queuer::ShardQueuer;
pub use self::shard_runner::{ShardRunner, ShardRunnerOptions};
pub use self::shard_runner_message::{ChunkGuildFilter, ShardRunnerMessage};
use crate::gateway::ConnectionStage;
use crate::model::event::Event;

/// A message either for a [`ShardManager`] or a [`ShardRunner`].
#[derive(Clone, Debug)]
#[derive(Debug)]
pub enum ShardClientMessage {
/// A message intended to be worked with by a [`ShardManager`].
Manager(ShardManagerMessage),
Expand Down Expand Up @@ -155,3 +156,13 @@ impl AsRef<ShardMessenger> for ShardRunnerInfo {
&self.runner_tx
}
}

/// Newtype around a callback that will be called on every incoming request. As long as this
/// collector should still receive events, it should return `true`. Once it returns `false`, it is
/// removed.
pub struct CollectorCallback(pub Box<dyn Fn(&Event) -> bool + Send + Sync>);
impl std::fmt::Debug for CollectorCallback {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("CollectorCallback").finish()
}
}
42 changes: 5 additions & 37 deletions src/client/bridge/gateway/shard_messenger.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
use futures::channel::mpsc::{TrySendError, UnboundedSender as Sender};
use tokio_tungstenite::tungstenite::Message;

use super::{ChunkGuildFilter, ShardClientMessage, ShardRunnerMessage};
#[cfg(feature = "collector")]
use crate::collector::{
ComponentInteractionFilter,
EventFilter,
MessageFilter,
ModalInteractionFilter,
ReactionFilter,
};
use super::CollectorCallback;
use super::{ChunkGuildFilter, ShardClientMessage, ShardRunnerMessage};
use crate::gateway::{ActivityData, InterMessage};
use crate::model::prelude::*;

Expand Down Expand Up @@ -268,39 +262,13 @@ impl ShardMessenger {
/// Returns a [`TrySendError`] if the shard's receiver was closed.
#[inline]
pub fn send_to_shard(&self, msg: ShardRunnerMessage) -> Result<(), TrySendError<InterMessage>> {
// TODO: don't propagate send error but handle here directly via a tracing::warn
self.tx.unbounded_send(InterMessage::Client(ShardClientMessage::Runner(Box::new(msg))))
}

/// Sets a new filter for an event collector.
#[inline]
#[cfg(feature = "collector")]
pub fn set_event_filter(&self, collector: EventFilter) {
drop(self.send_to_shard(ShardRunnerMessage::SetEventFilter(collector)));
}

/// Sets a new filter for a message collector.
#[inline]
#[cfg(feature = "collector")]
pub fn set_message_filter(&self, collector: MessageFilter) {
drop(self.send_to_shard(ShardRunnerMessage::SetMessageFilter(collector)));
}

/// Sets a new filter for a reaction collector.
#[cfg(feature = "collector")]
pub fn set_reaction_filter(&self, collector: ReactionFilter) {
drop(self.send_to_shard(ShardRunnerMessage::SetReactionFilter(collector)));
}

/// Sets a new filter for a component interaction collector.
#[cfg(feature = "collector")]
pub fn set_component_interaction_filter(&self, collector: ComponentInteractionFilter) {
drop(self.send_to_shard(ShardRunnerMessage::SetComponentInteractionFilter(collector)));
}

/// Sets a new filter for a modal interaction collector.
#[cfg(feature = "collector")]
pub fn set_modal_interaction_filter(&self, collector: ModalInteractionFilter) {
drop(self.send_to_shard(ShardRunnerMessage::SetModalInteractionFilter(collector)));
pub fn add_collector(&self, collector: CollectorCallback) {
drop(self.send_to_shard(ShardRunnerMessage::AddCollector(collector)));
}
}

Expand Down
Loading

0 comments on commit e5eece2

Please sign in to comment.