Skip to content

Commit

Permalink
[Fix] Don't dispose streams in DefaultRestClient (#2652)
Browse files Browse the repository at this point in the history
* Duplicate file streams before sending

* Other code needs to dispose their objects

* Another resource to dispose

* Stop disposing and copying streams in SendAsync

* Fix inverted boolean check

Co-authored-by: Dmitry <dimson-n@users.noreply.github.com>

* Await results for using statement to work

---------

Co-authored-by: Dmitry <dimson-n@users.noreply.github.com>
  • Loading branch information
ben-reilly and dimson-n committed Apr 14, 2023
1 parent 69cce5b commit 84431de
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 56 deletions.
6 changes: 3 additions & 3 deletions src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1068,11 +1068,11 @@ public async Task<CustomSticker> CreateStickerAsync(string name, Image image, IE
/// <returns>
/// A task that represents the asynchronous creation operation. The task result contains the created sticker.
/// </returns>
public Task<CustomSticker> CreateStickerAsync(string name, string path, IEnumerable<string> tags, string description = null,
public async Task<CustomSticker> CreateStickerAsync(string name, string path, IEnumerable<string> tags, string description = null,
RequestOptions options = null)
{
var fs = File.OpenRead(path);
return CreateStickerAsync(name, fs, Path.GetFileName(fs.Name), tags, description,options);
using var fs = File.OpenRead(path);
return await CreateStickerAsync(name, fs, Path.GetFileName(fs.Name), tags, description,options);
}
/// <summary>
/// Creates a new sticker in this guild
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,15 @@ public override string Defer(bool ephemeral = false, RequestOptions options = nu
fileName ??= Path.GetFileName(filePath);
Preconditions.NotNullOrEmpty(fileName, nameof(fileName), "File Name must not be empty or null");

using var fileStream = !string.IsNullOrEmpty(filePath) ? new MemoryStream(File.ReadAllBytes(filePath), false) : null;
var args = new API.Rest.CreateWebhookMessageParams
{
Content = text,
AllowedMentions = allowedMentions?.ToModel() ?? Optional<API.AllowedMentions>.Unspecified,
IsTTS = isTTS,
Embeds = embeds.Select(x => x.ToModel()).ToArray(),
Components = component?.Components.Select(x => new API.ActionRowComponent(x)).ToArray() ?? Optional<API.ActionRowComponent[]>.Unspecified,
File = !string.IsNullOrEmpty(filePath) ? new MultipartFile(new MemoryStream(File.ReadAllBytes(filePath), false), fileName) : Optional<MultipartFile>.Unspecified
File = fileStream != null ? new MultipartFile(fileStream, fileName) : Optional<MultipartFile>.Unspecified
};

if (ephemeral)
Expand Down
104 changes: 55 additions & 49 deletions src/Discord.Net.Rest/Net/DefaultRestClient.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Discord.Net.Converters;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Globalization;
using System.IO;
using System.Linq;
Expand Down Expand Up @@ -101,62 +101,68 @@ public void SetCancelToken(CancellationToken cancelToken)
IEnumerable<KeyValuePair<string, IEnumerable<string>>> requestHeaders = null)
{
string uri = Path.Combine(_baseUrl, endpoint);
using (var restRequest = new HttpRequestMessage(GetMethod(method), uri))

// HttpRequestMessage implements IDisposable but we do not need to dispose it as it merely disposes of its Content property,
// which we can do as needed. And regarding that, we do not want to take responsibility for disposing of content provided by
// the caller of this function, since it's possible that the caller wants to reuse it or is forced to reuse it because of a
// 429 response. Therefore, by convention, we only dispose the content objects created in this function (if any).
//
// See this comment explaining why this is safe: https://github.com/aspnet/Security/issues/886#issuecomment-229181249
// See also the source for HttpRequestMessage: https://github.com/microsoft/referencesource/blob/master/System/net/System/Net/Http/HttpRequestMessage.cs
#pragma warning disable IDISP004
var restRequest = new HttpRequestMessage(GetMethod(method), uri);
#pragma warning restore IDISP004

if (reason != null)
restRequest.Headers.Add("X-Audit-Log-Reason", Uri.EscapeDataString(reason));
if (requestHeaders != null)
foreach (var header in requestHeaders)
restRequest.Headers.Add(header.Key, header.Value);
var content = new MultipartFormDataContent("Upload----" + DateTime.Now.ToString(CultureInfo.InvariantCulture));

static StreamContent GetStreamContent(Stream stream)
{
if (reason != null)
restRequest.Headers.Add("X-Audit-Log-Reason", Uri.EscapeDataString(reason));
if (requestHeaders != null)
foreach (var header in requestHeaders)
restRequest.Headers.Add(header.Key, header.Value);
var content = new MultipartFormDataContent("Upload----" + DateTime.Now.ToString(CultureInfo.InvariantCulture));
MemoryStream memoryStream = null;
if (multipartParams != null)
if (stream.CanSeek)
{
// Reset back to the beginning; it may have been used elsewhere or in a previous request.
stream.Position = 0;
}

#pragma warning disable IDISP004
return new StreamContent(stream);
#pragma warning restore IDISP004
}

foreach (var p in multipartParams ?? ImmutableDictionary<string, object>.Empty)
{
switch (p.Value)
{
foreach (var p in multipartParams)
{
switch (p.Value)
{
#pragma warning disable IDISP004
case string stringValue:
{ content.Add(new StringContent(stringValue, Encoding.UTF8, "text/plain"), p.Key); continue; }
case byte[] byteArrayValue:
{ content.Add(new ByteArrayContent(byteArrayValue), p.Key); continue; }
case Stream streamValue:
{ content.Add(new StreamContent(streamValue), p.Key); continue; }
case MultipartFile fileValue:
{
var stream = fileValue.Stream;
if (!stream.CanSeek)
{
memoryStream = new MemoryStream();
await stream.CopyToAsync(memoryStream).ConfigureAwait(false);
memoryStream.Position = 0;
#pragma warning disable IDISP001
stream = memoryStream;
#pragma warning restore IDISP001
}

var streamContent = new StreamContent(stream);
var extension = fileValue.Filename.Split('.').Last();

if (fileValue.ContentType != null)
streamContent.Headers.ContentType = new MediaTypeHeaderValue(fileValue.ContentType);

content.Add(streamContent, p.Key, fileValue.Filename);
case string stringValue:
{ content.Add(new StringContent(stringValue, Encoding.UTF8, "text/plain"), p.Key); continue; }
case byte[] byteArrayValue:
{ content.Add(new ByteArrayContent(byteArrayValue), p.Key); continue; }
case Stream streamValue:
{ content.Add(GetStreamContent(streamValue), p.Key); continue; }
case MultipartFile fileValue:
{
var streamContent = GetStreamContent(fileValue.Stream);

if (fileValue.ContentType != null)
streamContent.Headers.ContentType = new MediaTypeHeaderValue(fileValue.ContentType);

content.Add(streamContent, p.Key, fileValue.Filename);
#pragma warning restore IDISP004

continue;
}
default:
throw new InvalidOperationException($"Unsupported param type \"{p.Value.GetType().Name}\".");
continue;
}
}
default:
throw new InvalidOperationException($"Unsupported param type \"{p.Value.GetType().Name}\".");
}
restRequest.Content = content;
var result = await SendInternalAsync(restRequest, cancelToken, headerOnly).ConfigureAwait(false);
memoryStream?.Dispose();
return result;
}

restRequest.Content = content;
return await SendInternalAsync(restRequest, cancelToken, headerOnly).ConfigureAwait(false);
}

private async Task<RestResponse> SendInternalAsync(HttpRequestMessage request, CancellationToken cancelToken, bool headerOnly)
Expand Down
6 changes: 3 additions & 3 deletions src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1558,11 +1558,11 @@ public SocketCustomSticker GetSticker(ulong id)
/// <returns>
/// A task that represents the asynchronous creation operation. The task result contains the created sticker.
/// </returns>
public Task<SocketCustomSticker> CreateStickerAsync(string name, string path, IEnumerable<string> tags, string description = null,
public async Task<SocketCustomSticker> CreateStickerAsync(string name, string path, IEnumerable<string> tags, string description = null,
RequestOptions options = null)
{
var fs = File.OpenRead(path);
return CreateStickerAsync(name, fs, Path.GetFileName(fs.Name), tags, description, options);
using var fs = File.OpenRead(path);
return await CreateStickerAsync(name, fs, Path.GetFileName(fs.Name), tags, description, options);
}
/// <summary>
/// Creates a new sticker in this guild
Expand Down

0 comments on commit 84431de

Please sign in to comment.