Skip to content

Commit

Permalink
Renamed existing Flatten method to FlattenAsync and added new Flatten…
Browse files Browse the repository at this point in the history
… method. Also fixed ClientHelper using incorrect guild batch count. (#744)
  • Loading branch information
ObsidianMinor authored and foxbot committed Jan 7, 2018
1 parent edfbd05 commit 5bbd9bb
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 11 deletions.
14 changes: 8 additions & 6 deletions src/Discord.Net.Commands/Readers/UserTypeReader.cs
Expand Up @@ -13,7 +13,7 @@ internal class UserTypeReader<T> : TypeReader
public override async Task<TypeReaderResult> ReadAsync(ICommandContext context, string input, IServiceProvider services)
{
var results = new Dictionary<ulong, TypeReaderValue>();
IReadOnlyCollection<IUser> channelUsers = (await context.Channel.GetUsersAsync(CacheMode.CacheOnly).Flatten().ConfigureAwait(false)).ToArray(); //TODO: must be a better way?
IAsyncEnumerable<IUser> channelUsers = context.Channel.GetUsersAsync(CacheMode.CacheOnly).Flatten(); // it's better
IReadOnlyCollection<IGuildUser> guildUsers = ImmutableArray.Create<IGuildUser>();
ulong id;

Expand Down Expand Up @@ -45,7 +45,7 @@ public override async Task<TypeReaderResult> ReadAsync(ICommandContext context,
string username = input.Substring(0, index);
if (ushort.TryParse(input.Substring(index + 1), out ushort discriminator))
{
var channelUser = channelUsers.FirstOrDefault(x => x.DiscriminatorValue == discriminator &&
var channelUser = await channelUsers.FirstOrDefault(x => x.DiscriminatorValue == discriminator &&
string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase));
AddResult(results, channelUser as T, channelUser?.Username == username ? 0.85f : 0.75f);

Expand All @@ -57,17 +57,19 @@ public override async Task<TypeReaderResult> ReadAsync(ICommandContext context,

//By Username (0.5-0.6)
{
foreach (var channelUser in channelUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase)))
AddResult(results, channelUser as T, channelUser.Username == input ? 0.65f : 0.55f);
await channelUsers
.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase))
.ForEachAsync(channelUser => AddResult(results, channelUser as T, channelUser.Username == input ? 0.65f : 0.55f));

foreach (var guildUser in guildUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase)))
AddResult(results, guildUser as T, guildUser.Username == input ? 0.60f : 0.50f);
}

//By Nickname (0.5-0.6)
{
foreach (var channelUser in channelUsers.Where(x => string.Equals(input, (x as IGuildUser)?.Nickname, StringComparison.OrdinalIgnoreCase)))
AddResult(results, channelUser as T, (channelUser as IGuildUser).Nickname == input ? 0.65f : 0.55f);
await channelUsers
.Where(x => string.Equals(input, (x as IGuildUser)?.Nickname, StringComparison.OrdinalIgnoreCase))
.ForEachAsync(channelUser => AddResult(results, channelUser as T, (channelUser as IGuildUser).Nickname == input ? 0.65f : 0.55f));

foreach (var guildUser in guildUsers.Where(x => string.Equals(input, (x as IGuildUser).Nickname, StringComparison.OrdinalIgnoreCase)))
AddResult(results, guildUser as T, (guildUser as IGuildUser).Nickname == input ? 0.60f : 0.50f);
Expand Down
54 changes: 52 additions & 2 deletions src/Discord.Net.Core/Extensions/AsyncEnumerableExtensions.cs
@@ -1,14 +1,64 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace Discord
{
public static class AsyncEnumerableExtensions
{
public static async Task<IEnumerable<T>> Flatten<T>(this IAsyncEnumerable<IReadOnlyCollection<T>> source)
/// <summary>
/// Flattens the specified pages into one <see cref="IEnumerable{T}"/> asynchronously
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="source"></param>
/// <returns></returns>
public static async Task<IEnumerable<T>> FlattenAsync<T>(this IAsyncEnumerable<IEnumerable<T>> source)
{
return (await source.ToArray().ConfigureAwait(false)).SelectMany(x => x);
return await source.Flatten().ToArray().ConfigureAwait(false);
}

public static IAsyncEnumerable<T> Flatten<T>(this IAsyncEnumerable<IEnumerable<T>> source)
{
return new PagedCollectionEnumerator<T>(source);
}

internal class PagedCollectionEnumerator<T> : IAsyncEnumerator<T>, IAsyncEnumerable<T>
{
readonly IAsyncEnumerator<IEnumerable<T>> _source;
IEnumerator<T> _enumerator;

public IAsyncEnumerator<T> GetEnumerator() => this;

internal PagedCollectionEnumerator(IAsyncEnumerable<IEnumerable<T>> source)
{
_source = source.GetEnumerator();
}

public T Current => _enumerator.Current;

public void Dispose()
{
_enumerator?.Dispose();
_source.Dispose();
}

public async Task<bool> MoveNext(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

if(!_enumerator?.MoveNext() ?? true)
{
if (!await _source.MoveNext(cancellationToken).ConfigureAwait(false))
return false;

_enumerator?.Dispose();
_enumerator = _source.Current.GetEnumerator();
return _enumerator.MoveNext();
}

return true;
}
}
}
}
4 changes: 2 additions & 2 deletions src/Discord.Net.Rest/ClientHelper.cs
Expand Up @@ -79,7 +79,7 @@ public static async Task<IReadOnlyCollection<RestConnection>> GetConnectionsAsyn
ulong? fromGuildId, int? limit, RequestOptions options)
{
return new PagedAsyncEnumerable<RestUserGuild>(
DiscordConfig.MaxUsersPerBatch,
DiscordConfig.MaxGuildsPerBatch,
async (info, ct) =>
{
var args = new GetGuildSummariesParams
Expand All @@ -106,7 +106,7 @@ public static async Task<IReadOnlyCollection<RestConnection>> GetConnectionsAsyn
}
public static async Task<IReadOnlyCollection<RestGuild>> GetGuildsAsync(BaseDiscordClient client, RequestOptions options)
{
var summaryModels = await GetGuildSummariesAsync(client, null, null, options).Flatten();
var summaryModels = await GetGuildSummariesAsync(client, null, null, options).FlattenAsync().ConfigureAwait(false);
var guilds = ImmutableArray.CreateBuilder<RestGuild>();
foreach (var summaryModel in summaryModels)
{
Expand Down
2 changes: 1 addition & 1 deletion src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs
Expand Up @@ -413,7 +413,7 @@ async Task<IGuildUser> IGuild.GetOwnerAsync(CacheMode mode, RequestOptions optio
async Task<IReadOnlyCollection<IGuildUser>> IGuild.GetUsersAsync(CacheMode mode, RequestOptions options)
{
if (mode == CacheMode.AllowDownload)
return (await GetUsersAsync(options).Flatten().ConfigureAwait(false)).ToImmutableArray();
return (await GetUsersAsync(options).FlattenAsync().ConfigureAwait(false)).ToImmutableArray();
else
return ImmutableArray.Create<IGuildUser>();
}
Expand Down

0 comments on commit 5bbd9bb

Please sign in to comment.