diff --git a/src/Authentication/Authentication/ErrorConstants.cs b/src/Authentication/Authentication/ErrorConstants.cs index 4f76a504412..4d24c344bf6 100644 --- a/src/Authentication/Authentication/ErrorConstants.cs +++ b/src/Authentication/Authentication/ErrorConstants.cs @@ -29,13 +29,14 @@ internal static class Codes public const string InvokeGraphRequestCouldNotInferFileName = nameof(InvokeGraphRequestKeysWithDifferentCasingInJsonString); } - internal static class Message + public static class Message { internal const string CannotModifyBuiltInEnvironment = "Cannot {0} built-in environment {1}."; internal const string InvalidUrlParameter = "Parameter '{0}' has an invalid endpoint URL. Please use a valid URL with a network protocol i.e. [protocol]://[resource-name]."; internal const string InvalidEnvironment = "Unable to find environment with name '{0}'. Use Get-MgEnvironment to list available environments."; internal const string CannotAccessFile = "Could not {0} file at '{1}'. Please ensure you have access to this file and try again in a few minutes.."; internal const string InvalidCertificateThumbprint = "'{0}' must have a length of 40. Ensure you have the right certificate thumbprint then try again."; + public const string CannotInferFileName = "Could not infer file name from the response. Please specify the file name in -OutFile explicitly."; } } } diff --git a/src/readme.graph.md b/src/readme.graph.md index 86d2434365d..8d7b62ec83f 100644 --- a/src/readme.graph.md +++ b/src/readme.graph.md @@ -617,10 +617,13 @@ directive: } else { let outFileParameterRegex = /(^\s*)public\s*global::System\.String\s*OutFile\s*/gmi let streamResponseRegex = /global::System\.Threading\.Tasks\.Task\s*response/gmi + let octetStreamSchemaResponseRegex = /global::System\.Threading\.Tasks\.Task<.*(OctetStreamSchema|GraphReport)>\s*response/gmi + let overrideOnOkCallRegex = /(^\s*)(overrideOnOk\(\s*responseMessage\s*,\s*response\s*,\s*ref\s*_returnNow\s*\);)/gmi if($.match(outFileParameterRegex) && $.match(streamResponseRegex)) { // Handle file download. - let overrideOnOkCallRegex = /(^\s*)(overrideOnOk\(\s*responseMessage\s*,\s*response\s*,\s*ref\s*_returnNow\s*\);)/gmi $ = $.replace(overrideOnOkCallRegex, '$1$2\n$1using(var stream = await response){ this.WriteToFile(responseMessage, stream, this.GetProviderPath(OutFile, false), _cancellationTokenSource.Token); _returnNow = global::System.Threading.Tasks.Task.FromResult(true);}\n$1'); + } else if ($.match(outFileParameterRegex) && $.match(octetStreamSchemaResponseRegex)){ + $ = $.replace(overrideOnOkCallRegex, '$1$2\n$1using(var stream = await responseMessage.Content.ReadAsStreamAsync()){ this.WriteToFile(responseMessage, stream, this.GetProviderPath(OutFile, false), _cancellationTokenSource.Token); _returnNow = global::System.Threading.Tasks.Task.FromResult(true);}\n$1'); } return $; } diff --git a/tools/Custom/PSCmdletExtensions.cs b/tools/Custom/PSCmdletExtensions.cs index 225f1410fce..39376066a7b 100644 --- a/tools/Custom/PSCmdletExtensions.cs +++ b/tools/Custom/PSCmdletExtensions.cs @@ -3,10 +3,12 @@ // ------------------------------------------------------------------------------ namespace Microsoft.Graph.PowerShell { + using Microsoft.Graph.PowerShell.Authentication; using Microsoft.Graph.PowerShell.Authentication.Common; using System; using System.Collections.ObjectModel; using System.IO; + using System.Linq; using System.Management.Automation; using System.Net.Http; using System.Threading; @@ -65,6 +67,17 @@ internal static string GetProviderPath(this PSCmdlet cmdlet, string filePath, bo /// A cancellation token that will be used to cancel the operation by the user. internal static void WriteToFile(this PSCmdlet cmdlet, HttpResponseMessage response, Stream inputStream, string filePath, CancellationToken cancellationToken) { + if (IsPathDirectory(filePath)) + { + // Get file name from content disposition header if present; otherwise throw an exception for a file name to be provided. + var fileName = GetFileName(response); + filePath = Path.Combine(filePath, fileName); + } + if (File.Exists(filePath)) + { + cmdlet.WriteWarning($"{filePath} already exists. The file will be overridden."); + File.Delete(filePath); + } using (var fileProvider = ProtectedFileProvider.CreateFileProvider(filePath, FileProtection.ExclusiveWrite, new DiskDataStore())) { string downloadUrl = response?.RequestMessage?.RequestUri.ToString(); @@ -105,6 +118,47 @@ private static void WriteToStream(this PSCmdlet cmdlet, Stream inputStream, Stre } } + private static bool IsPathDirectory(string path) + { + if (path == null) throw new ArgumentNullException("path"); + path = path.Trim(); + + if (Directory.Exists(path)) + return true; + + if (File.Exists(path)) + return false; + + // If path has a trailing slash then it's a directory. + if (new[] { Path.DirectorySeparatorChar, Path.AltDirectorySeparatorChar }.Any(x => path.EndsWith(x.ToString()))) + return true; + + // If path has an extension then its a file; directory otherwise. + return string.IsNullOrWhiteSpace(Path.GetExtension(path)); + } + + private static string GetFileName(HttpResponseMessage responseMessage) + { + if (responseMessage.Content.Headers.ContentDisposition != null + && !string.IsNullOrWhiteSpace(responseMessage.Content.Headers.ContentDisposition.FileName)) + { + var fileName = responseMessage.Content.Headers.ContentDisposition.FileNameStar ?? responseMessage.Content.Headers.ContentDisposition.FileName; + if (!string.IsNullOrWhiteSpace(fileName)) + return SanitizeFileName(fileName); + } + throw new ArgumentException(ErrorConstants.Message.CannotInferFileName, "-OutFile"); + } + + /// + /// When Inferring file names from content disposition header, ensure that only valid path characters are in the file name + /// + /// + private static string SanitizeFileName(string fileName) + { + var illegalCharacters = Path.GetInvalidFileNameChars().Concat(Path.GetInvalidPathChars()).ToArray(); + return string.Concat(fileName.Split(illegalCharacters)); + } + /// /// Calculates and updates the progress record of the provided stream. ///