From 81b093de379135ba5f8ab6d4bcb08f983b056074 Mon Sep 17 00:00:00 2001 From: Peter Ombwa Date: Mon, 20 Jun 2022 14:15:30 -0700 Subject: [PATCH 1/2] Add custom directive to download files when response is of type octetStreamSchemaResponse. --- src/readme.graph.md | 5 +++- tools/Custom/PSCmdletExtensions.cs | 48 ++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/src/readme.graph.md b/src/readme.graph.md index 86d2434365d..4ba85e14b70 100644 --- a/src/readme.graph.md +++ b/src/readme.graph.md @@ -617,10 +617,13 @@ directive: } else { let outFileParameterRegex = /(^\s*)public\s*global::System\.String\s*OutFile\s*/gmi let streamResponseRegex = /global::System\.Threading\.Tasks\.Task\s*response/gmi + let octetStreamSchemaResponseRegex = /global::System\.Threading\.Tasks\.Task<.*OctetStreamSchema>\s*response/gmi + let overrideOnOkCallRegex = /(^\s*)(overrideOnOk\(\s*responseMessage\s*,\s*response\s*,\s*ref\s*_returnNow\s*\);)/gmi if($.match(outFileParameterRegex) && $.match(streamResponseRegex)) { // Handle file download. - let overrideOnOkCallRegex = /(^\s*)(overrideOnOk\(\s*responseMessage\s*,\s*response\s*,\s*ref\s*_returnNow\s*\);)/gmi $ = $.replace(overrideOnOkCallRegex, '$1$2\n$1using(var stream = await response){ this.WriteToFile(responseMessage, stream, this.GetProviderPath(OutFile, false), _cancellationTokenSource.Token); _returnNow = global::System.Threading.Tasks.Task.FromResult(true);}\n$1'); + } else if ($.match(outFileParameterRegex) && $.match(octetStreamSchemaResponseRegex)){ + $ = $.replace(overrideOnOkCallRegex, '$1$2\n$1using(var stream = await responseMessage.Content.ReadAsStreamAsync()){ this.WriteToFile(responseMessage, stream, this.GetProviderPath(OutFile, false), _cancellationTokenSource.Token); _returnNow = global::System.Threading.Tasks.Task.FromResult(true);}\n$1'); } return $; } diff --git a/tools/Custom/PSCmdletExtensions.cs b/tools/Custom/PSCmdletExtensions.cs index 225f1410fce..437ec01f6d0 100644 --- a/tools/Custom/PSCmdletExtensions.cs +++ b/tools/Custom/PSCmdletExtensions.cs @@ -7,6 +7,7 @@ namespace Microsoft.Graph.PowerShell using System; using System.Collections.ObjectModel; using System.IO; + using System.Linq; using System.Management.Automation; using System.Net.Http; using System.Threading; @@ -65,6 +66,12 @@ internal static string GetProviderPath(this PSCmdlet cmdlet, string filePath, bo /// A cancellation token that will be used to cancel the operation by the user. internal static void WriteToFile(this PSCmdlet cmdlet, HttpResponseMessage response, Stream inputStream, string filePath, CancellationToken cancellationToken) { + if (IsPathDirectory(filePath)) + { + // Get file name from content disposition header is presents; otherwise throw an exception for a file name to be provided. + var fileName = GetFileName(response); + filePath = Path.Combine(filePath, fileName); + } using (var fileProvider = ProtectedFileProvider.CreateFileProvider(filePath, FileProtection.ExclusiveWrite, new DiskDataStore())) { string downloadUrl = response?.RequestMessage?.RequestUri.ToString(); @@ -105,6 +112,47 @@ private static void WriteToStream(this PSCmdlet cmdlet, Stream inputStream, Stre } } + private static bool IsPathDirectory(string path) + { + if (path == null) throw new ArgumentNullException("path"); + path = path.Trim(); + + if (Directory.Exists(path)) + return true; + + if (File.Exists(path)) + return false; + + // If path has a trailing slash then it's a directory. + if (new[] { Path.DirectorySeparatorChar, Path.AltDirectorySeparatorChar }.Any(x => path.EndsWith(x.ToString()))) + return true; + + // If path has an extension then its a file; directory otherwise. + return string.IsNullOrWhiteSpace(Path.GetExtension(path)); + } + + private static string GetFileName(HttpResponseMessage responseMessage) + { + if (responseMessage.Content.Headers.ContentDisposition != null + && !string.IsNullOrWhiteSpace(responseMessage.Content.Headers.ContentDisposition.FileName)) + { + var fileName = responseMessage.Content.Headers.ContentDisposition.FileNameStar ?? responseMessage.Content.Headers.ContentDisposition.FileName; + if (!string.IsNullOrWhiteSpace(fileName)) + return SanitizeFileName(fileName); + } + throw new ArgumentException("Count not infer file name from the response. Please specify the file name in -OutFile explicitly."); + } + + /// + /// When Inferring file names from Content disposition, ensure that only valid path characters are in the file name + /// + /// + private static string SanitizeFileName(string fileName) + { + var illegalCharacters = Path.GetInvalidFileNameChars().Concat(Path.GetInvalidPathChars()).ToArray(); + return string.Concat(fileName.Split(illegalCharacters)); + } + /// /// Calculates and updates the progress record of the provided stream. /// From 36fd8e218fbbce8ce145cf760ccbaab47ea8e6f1 Mon Sep 17 00:00:00 2001 From: Peter Ombwa Date: Mon, 20 Jun 2022 16:16:54 -0700 Subject: [PATCH 2/2] Warn customer if file name already exists. --- src/Authentication/Authentication/ErrorConstants.cs | 3 ++- src/readme.graph.md | 2 +- tools/Custom/PSCmdletExtensions.cs | 12 +++++++++--- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/Authentication/Authentication/ErrorConstants.cs b/src/Authentication/Authentication/ErrorConstants.cs index 4f76a504412..4d24c344bf6 100644 --- a/src/Authentication/Authentication/ErrorConstants.cs +++ b/src/Authentication/Authentication/ErrorConstants.cs @@ -29,13 +29,14 @@ internal static class Codes public const string InvokeGraphRequestCouldNotInferFileName = nameof(InvokeGraphRequestKeysWithDifferentCasingInJsonString); } - internal static class Message + public static class Message { internal const string CannotModifyBuiltInEnvironment = "Cannot {0} built-in environment {1}."; internal const string InvalidUrlParameter = "Parameter '{0}' has an invalid endpoint URL. Please use a valid URL with a network protocol i.e. [protocol]://[resource-name]."; internal const string InvalidEnvironment = "Unable to find environment with name '{0}'. Use Get-MgEnvironment to list available environments."; internal const string CannotAccessFile = "Could not {0} file at '{1}'. Please ensure you have access to this file and try again in a few minutes.."; internal const string InvalidCertificateThumbprint = "'{0}' must have a length of 40. Ensure you have the right certificate thumbprint then try again."; + public const string CannotInferFileName = "Could not infer file name from the response. Please specify the file name in -OutFile explicitly."; } } } diff --git a/src/readme.graph.md b/src/readme.graph.md index 4ba85e14b70..8d7b62ec83f 100644 --- a/src/readme.graph.md +++ b/src/readme.graph.md @@ -617,7 +617,7 @@ directive: } else { let outFileParameterRegex = /(^\s*)public\s*global::System\.String\s*OutFile\s*/gmi let streamResponseRegex = /global::System\.Threading\.Tasks\.Task\s*response/gmi - let octetStreamSchemaResponseRegex = /global::System\.Threading\.Tasks\.Task<.*OctetStreamSchema>\s*response/gmi + let octetStreamSchemaResponseRegex = /global::System\.Threading\.Tasks\.Task<.*(OctetStreamSchema|GraphReport)>\s*response/gmi let overrideOnOkCallRegex = /(^\s*)(overrideOnOk\(\s*responseMessage\s*,\s*response\s*,\s*ref\s*_returnNow\s*\);)/gmi if($.match(outFileParameterRegex) && $.match(streamResponseRegex)) { // Handle file download. diff --git a/tools/Custom/PSCmdletExtensions.cs b/tools/Custom/PSCmdletExtensions.cs index 437ec01f6d0..39376066a7b 100644 --- a/tools/Custom/PSCmdletExtensions.cs +++ b/tools/Custom/PSCmdletExtensions.cs @@ -3,6 +3,7 @@ // ------------------------------------------------------------------------------ namespace Microsoft.Graph.PowerShell { + using Microsoft.Graph.PowerShell.Authentication; using Microsoft.Graph.PowerShell.Authentication.Common; using System; using System.Collections.ObjectModel; @@ -68,10 +69,15 @@ internal static void WriteToFile(this PSCmdlet cmdlet, HttpResponseMessage respo { if (IsPathDirectory(filePath)) { - // Get file name from content disposition header is presents; otherwise throw an exception for a file name to be provided. + // Get file name from content disposition header if present; otherwise throw an exception for a file name to be provided. var fileName = GetFileName(response); filePath = Path.Combine(filePath, fileName); } + if (File.Exists(filePath)) + { + cmdlet.WriteWarning($"{filePath} already exists. The file will be overridden."); + File.Delete(filePath); + } using (var fileProvider = ProtectedFileProvider.CreateFileProvider(filePath, FileProtection.ExclusiveWrite, new DiskDataStore())) { string downloadUrl = response?.RequestMessage?.RequestUri.ToString(); @@ -140,11 +146,11 @@ private static string GetFileName(HttpResponseMessage responseMessage) if (!string.IsNullOrWhiteSpace(fileName)) return SanitizeFileName(fileName); } - throw new ArgumentException("Count not infer file name from the response. Please specify the file name in -OutFile explicitly."); + throw new ArgumentException(ErrorConstants.Message.CannotInferFileName, "-OutFile"); } /// - /// When Inferring file names from Content disposition, ensure that only valid path characters are in the file name + /// When Inferring file names from content disposition header, ensure that only valid path characters are in the file name /// /// private static string SanitizeFileName(string fileName)