Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 61 additions & 8 deletions tools/Custom/HttpMessageFormatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ internal class HttpMessageFormatter : HttpContent
private const string DefaultRequestMsgType = "request";
private const string DefaultResponseMsgType = "response";

private const string DefaultRequestMediaType = DefaultMediaType + "; " + MsgTypeParameter + "=" + DefaultRequestMsgType;
private const string DefaultResponseMediaType = DefaultMediaType + "; " + MsgTypeParameter + "=" + DefaultResponseMsgType;

// Set of header fields that only support single values such as Set-Cookie.
private static readonly HashSet<string> _singleValueHeaderFields = new HashSet<string>(StringComparer.OrdinalIgnoreCase)
{
Expand Down Expand Up @@ -101,7 +104,7 @@ private HttpContent Content

private void InitializeStreamTask()
{
_streamTask = new Lazy<Task<Stream>>(() => Content?.ReadAsStreamAsync());
_streamTask = new Lazy<Task<Stream>>(() => Content == null ? null : Content.ReadAsStreamAsync());
}

/// <summary>
Expand Down Expand Up @@ -198,14 +201,11 @@ protected override async Task SerializeToStreamAsync(Stream stream, TransportCon
byte[] header = SerializeHeader();
await stream.WriteAsync(header, 0, header.Length);

if (Content != null)
if (Content != null && Content.Headers.ContentLength > 0)
{
Stream readStream = await _streamTask.Value;
ValidateStreamForReading(readStream);
if (!_contentConsumed)
{
await Content.CopyToAsync(stream);
}
await Content.CopyToAsync(stream);
}
}

Expand All @@ -230,6 +230,19 @@ protected override bool TryComputeLength(out long length)
length = 0;

// Cases #1, #2, #3
if (hasContent)
{
Stream readStream;
if (!_streamTask.Value.TryGetResult(out readStream) // Case #1
|| readStream == null || !readStream.CanSeek) // Case #2
{
length = -1;
return false;
}

length = readStream.Length; // Case #3
}

// We serialize header to a StringBuilder so that we can determine the length
// following the pattern for HttpContent to try and determine the message length.
// The perf overhead is no larger than for the other HttpContent implementations.
Expand All @@ -238,6 +251,30 @@ protected override bool TryComputeLength(out long length)
return true;
}

/// <summary>
/// Releases unmanaged and - optionally - managed resources
/// </summary>
/// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged resources.</param>
protected override void Dispose(bool disposing)
{
if (disposing)
{
if (HttpRequestMessage != null)
{
HttpRequestMessage.Dispose();
HttpRequestMessage = null;
}

if (HttpResponseMessage != null)
{
HttpResponseMessage.Dispose();
HttpResponseMessage = null;
}
}

base.Dispose(disposing);
}

/// <summary>
/// Serializes the HTTP request line.
/// </summary>
Expand Down Expand Up @@ -310,8 +347,8 @@ private static void SerializeHeaderFields(StringBuilder message, HttpHeaders hea
private byte[] SerializeHeader()
{
StringBuilder message = new StringBuilder(DefaultHeaderAllocation);
HttpHeaders headers;
HttpContent content;
HttpHeaders headers = null;
HttpContent content = null;
if (HttpRequestMessage != null)
{
SerializeRequestLine(message, HttpRequestMessage);
Expand Down Expand Up @@ -354,5 +391,21 @@ private void ValidateStreamForReading(Stream stream)

_contentConsumed = true;
}

}

public static class TaskExtensions
{
public static bool TryGetResult<TResult>(this Task<TResult> task, out TResult result)
{
if (task.Status == TaskStatus.RanToCompletion)
{
result = task.Result;
return true;
}

result = default(TResult);
return false;
}
}
}
50 changes: 22 additions & 28 deletions tools/Custom/Module.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ partial void CustomInit()
{
this.EventListener = EventHandler;
}

/// <summary>
/// Common Module Event Listener, allows to handle emitted by CmdLets
/// </summary>
Expand All @@ -52,24 +52,21 @@ partial void CustomInit()
/// <param name="parameterSetName">The cmdlet's parameterset name</param>
/// <param name="exception">the exception that is being thrown (if available)</param>
/// <returns>
/// A <see cref="global::System.Threading.Tasks.Task" /> that will be complete when handling of the event is completed.
/// A <see cref="Task" /> that will be complete when handling of the event is completed.
/// </returns>
public async Task EventHandler(string id, CancellationToken cancellationToken, Func<EventArgs> getEventData, Func<string, CancellationToken, Func<EventArgs>, Task> signal, InvocationInfo invocationInfo, string parameterSetName, System.Exception exception)
{
if (invocationInfo.BoundParameters.ContainsKey("Debug"))
switch (id)
{
switch (id)
{
case Events.BeforeCall:
await BeforeCall(id, cancellationToken, getEventData, signal);
break;
case Events.Finally:
await Finally(id, cancellationToken, getEventData, signal);
break;
default:
getEventData.Print(signal, cancellationToken, Events.Information, id);
break;
}
case Events.BeforeCall:
await BeforeCall(id, cancellationToken, getEventData, signal);
break;
case Events.Finally:
await Finally(id, cancellationToken, getEventData, signal);
break;
default:
getEventData.Print(signal, cancellationToken, Events.Information, id);
break;
}
}

Expand All @@ -81,18 +78,16 @@ public async Task EventHandler(string id, CancellationToken cancellationToken, F
/// <param name="getEventData">A delegate to get the detailed event data</param>
/// <param name="signal">The callback for the event dispatcher</param>
/// <returns>
/// A <see cref="global::System.Threading.Tasks.Task" /> that will be complete when handling of the event is completed.
/// A <see cref="Task" /> that will be complete when handling of the event is completed.
/// </returns>
private async Task Finally(string id, CancellationToken cancellationToken, Func<EventArgs> getEventData, Func<string, CancellationToken, Func<EventArgs>, Task> signal)
{
using (Extensions.NoSynchronizationContext)
{
var eventData = EventDataConverter.ConvertFrom(getEventData());
using (var responseFormatter = new HttpMessageFormatter(eventData.ResponseMessage as HttpResponseMessage))
{
var responseString = await responseFormatter.ReadAsStringAsync();
await signal(Events.Debug, cancellationToken, () => EventFactory.CreateLogEvent(responseString));
}
var responseFormatter = new HttpMessageFormatter(eventData.ResponseMessage as HttpResponseMessage);
var responseString = await responseFormatter.ReadAsStringAsync();
await signal(Events.Debug, cancellationToken, () => EventFactory.CreateLogEvent(responseString));
}
}

Expand All @@ -104,19 +99,18 @@ private async Task Finally(string id, CancellationToken cancellationToken, Func<
/// <param name="getEventData">A delegate to get the detailed event data</param>
/// <param name="signal">The callback for the event dispatcher</param>
/// <returns>
/// A <see cref="global::System.Threading.Tasks.Task" /> that will be complete when handling of the event is completed.
/// A <see cref="Task" /> that will be complete when handling of the event is completed.
/// </returns>
private async Task BeforeCall(string id, CancellationToken cancellationToken, Func<EventArgs> getEventData, Func<string, CancellationToken, Func<EventArgs>, Task> signal)
{
using (Extensions.NoSynchronizationContext)
{
var eventData = EventDataConverter.ConvertFrom(getEventData());
using (var requestFormatter = new HttpMessageFormatter(eventData.RequestMessage as HttpRequestMessage))
{
var requestString = await requestFormatter.ReadAsStringAsync();
await signal(Events.Debug, cancellationToken, () => EventFactory.CreateLogEvent(requestString));
}
var requestFormatter = new HttpMessageFormatter(eventData.RequestMessage as HttpRequestMessage);
var requestString = await requestFormatter.ReadAsStringAsync();
await signal(Events.Debug, cancellationToken, () => EventFactory.CreateLogEvent(requestString));
}

}
}
}
}
23 changes: 23 additions & 0 deletions tools/Tests/DebugTests.ps1
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
BeforeAll {
$ModulePrefix = "Microsoft.Graph"
$AuthModuleName = "Authentication"
$AuthModulePath = Join-Path $PSScriptRoot "..\..\src\$AuthModuleName\$AuthModuleName\artifacts\$ModulePrefix.$AuthModuleName.psd1"
$TestModuleName = "DirectoryObjects"
$TestModulePath = Join-Path $PSScriptRoot "..\..\src\$TestModuleName\$TestModuleName\$ModulePrefix.$TestModuleName.psd1"
Import-Module $AuthModulePath -Force
Import-Module $TestModulePath -Force

Connect-MgGraph
Select-MgProfile beta
}
Describe 'Cmdlets Streams' {
It 'Should Not Throw Exception when Debug Preference is Set'{
$ps = [powershell]::Create()
$ps.AddScript(@'
$DebugPreference = 'Continue'
Test-MgDirectoryObjectProperty -DisplayName "New Name" -EntityType "Group"
'@).Invoke()
$ps.Streams.Debug | Should -notLike -BeLike "*Exception*"
$ps.Streams.Debug -like "*HTTP/1.1 200 OK*"
}
}