diff --git a/src/Authentication/Authentication.Test/Helpers/GraphSessionTests.cs b/src/Authentication/Authentication.Test/Helpers/GraphSessionTests.cs new file mode 100644 index 00000000000..6f165d9d568 --- /dev/null +++ b/src/Authentication/Authentication.Test/Helpers/GraphSessionTests.cs @@ -0,0 +1,51 @@ +namespace Microsoft.Graph.Authentication.Test.Helpers +{ + using Microsoft.Graph.PowerShell.Authentication; + using System; + using Xunit; + public class GraphSessionTests + { + [Fact] + public void GraphSessionShouldBeInitilizedAfterInitializerIsCalled() + { + GraphSession.Initialize(() => new GraphSession()); + + Assert.NotNull(GraphSession.Instance); + Assert.Null(GraphSession.Instance.AuthContext); + + // reset static instance. + GraphSession.Reset(); + } + + [Fact] + public void ShouldOverwriteExistingGraphSession() + { + GraphSession.Initialize(() => new GraphSession()); + Guid originalSessionId = GraphSession.Instance._graphSessionId; + + GraphSession.Initialize(() => new GraphSession(), true); + + Assert.NotNull(GraphSession.Instance); + Assert.NotEqual(originalSessionId, GraphSession.Instance._graphSessionId); + + // reset static instance. + GraphSession.Reset(); + } + + [Fact] + public void ShouldNotOverwriteExistingGraphSession() + { + GraphSession.Initialize(() => new GraphSession()); + Guid originalSessionId = GraphSession.Instance._graphSessionId; + + InvalidOperationException exception = Assert.Throws(() => GraphSession.Initialize(() => new GraphSession())); + + Assert.Equal("An instance of GraphSession already exists. Call Initialize(Func, bool) to overwrite it.", exception.Message); + Assert.NotNull(GraphSession.Instance); + Assert.Equal(originalSessionId, GraphSession.Instance._graphSessionId); + + // reset static instance. + GraphSession.Reset(); + } + } +} diff --git a/src/Authentication/Authentication.Test/Properties/AssemblyInfo.cs b/src/Authentication/Authentication.Test/Properties/AssemblyInfo.cs new file mode 100644 index 00000000000..a4bcec543f2 --- /dev/null +++ b/src/Authentication/Authentication.Test/Properties/AssemblyInfo.cs @@ -0,0 +1,3 @@ +using Xunit; + +[assembly: CollectionBehavior(DisableTestParallelization = true)] diff --git a/src/Authentication/Authentication/Cmdlets/ConnectGraph.cs b/src/Authentication/Authentication/Cmdlets/ConnectGraph.cs index 1bf71ca478b..b1156be8063 100644 --- a/src/Authentication/Authentication/Cmdlets/ConnectGraph.cs +++ b/src/Authentication/Authentication/Cmdlets/ConnectGraph.cs @@ -16,7 +16,7 @@ namespace Microsoft.Graph.PowerShell.Authentication.Cmdlets using System.Threading.Tasks; [Cmdlet(VerbsCommunications.Connect, "Graph", DefaultParameterSetName = Constants.UserParameterSet)] - public class ConnectGraph : PSCmdlet + public class ConnectGraph : PSCmdlet, IModuleAssemblyInitializer { [Parameter(ParameterSetName = Constants.UserParameterSet, Position = 1)] @@ -53,7 +53,7 @@ protected override void ProcessRecord() { base.ProcessRecord(); - AuthConfig authConfig = new AuthConfig { TenantId = TenantId }; + IAuthContext authConfig = new AuthContext { TenantId = TenantId }; CancellationToken cancellationToken = CancellationToken.None; if (ParameterSetName == Constants.UserParameterSet) @@ -117,7 +117,7 @@ protected override void ProcessRecord() authConfig.Account = jwtPayload?.Upn ?? account?.Username; // Save auth config to session state. - SessionState.PSVariable.Set(Constants.GraphAuthConfigId, authConfig); + GraphSession.Instance.AuthContext = authConfig; } catch (AuthenticationException authEx) { @@ -164,5 +164,13 @@ private void ThrowParameterError(string parameterName) new ArgumentException($"Must specify {parameterName}"), Guid.NewGuid().ToString(), ErrorCategory.InvalidArgument, null) ); } + + /// + /// Globally initializes GraphSession. + /// + public void OnImport() + { + GraphSessionInitializer.InitializeSession(); + } } } diff --git a/src/Authentication/Authentication/Cmdlets/DisconnectGraph.cs b/src/Authentication/Authentication/Cmdlets/DisconnectGraph.cs index 019333829c6..085c9b5f34e 100644 --- a/src/Authentication/Authentication/Cmdlets/DisconnectGraph.cs +++ b/src/Authentication/Authentication/Cmdlets/DisconnectGraph.cs @@ -4,7 +4,6 @@ namespace Microsoft.Graph.PowerShell.Authentication.Cmdlets { using Microsoft.Graph.PowerShell.Authentication.Helpers; - using Microsoft.Graph.PowerShell.Authentication.Models; using System; using System.Management.Automation; [Cmdlet(VerbsCommunications.Disconnect, "Graph")] @@ -24,7 +23,7 @@ protected override void ProcessRecord() { base.ProcessRecord(); - AuthConfig authConfig = SessionState.PSVariable.GetValue(Constants.GraphAuthConfigId) as AuthConfig; + IAuthContext authConfig = GraphSession.Instance.AuthContext; if (authConfig == null) ThrowTerminatingError( @@ -32,7 +31,7 @@ protected override void ProcessRecord() AuthenticationHelpers.Logout(authConfig); - SessionState.PSVariable.Remove(Constants.GraphAuthConfigId); + GraphSession.Instance.AuthContext = null; } protected override void StopProcessing() diff --git a/src/Authentication/Authentication/Cmdlets/GetMGContext.cs b/src/Authentication/Authentication/Cmdlets/GetMGContext.cs index 4305a2cbfc7..05942a68eda 100644 --- a/src/Authentication/Authentication/Cmdlets/GetMGContext.cs +++ b/src/Authentication/Authentication/Cmdlets/GetMGContext.cs @@ -4,18 +4,10 @@ namespace Microsoft.Graph.PowerShell.Authentication.Cmdlets { - using Microsoft.Graph.Auth; - using Microsoft.Graph.PowerShell.Authentication.Helpers; - using Microsoft.Graph.PowerShell.Authentication.Models; - using System; - using System.Collections.Generic; using System.Management.Automation; - using System.Net.Http; - using System.Threading; - using System.Threading.Tasks; [Cmdlet(VerbsCommon.Get, "MgContext", DefaultParameterSetName = Constants.UserParameterSet)] - [OutputType(typeof(AuthConfig))] + [OutputType(typeof(IAuthContext))] public class GetMGContext: PSCmdlet { protected override void BeginProcessing() @@ -26,11 +18,8 @@ protected override void BeginProcessing() protected override void ProcessRecord() { base.ProcessRecord(); - // Get auth config from session state. - PSVariable graphAuthVariable = SessionState.PSVariable.Get(Constants.GraphAuthConfigId); - AuthConfig authConfig = graphAuthVariable?.Value as AuthConfig; - Invoke(); - WriteObject(authConfig as AuthConfig); + IAuthContext authConfig = GraphSession.Instance.AuthContext; + WriteObject(authConfig as IAuthContext); } protected override void EndProcessing() diff --git a/src/Authentication/Authentication/Common/GraphSession.cs b/src/Authentication/Authentication/Common/GraphSession.cs new file mode 100644 index 00000000000..31fdd654551 --- /dev/null +++ b/src/Authentication/Authentication/Common/GraphSession.cs @@ -0,0 +1,168 @@ +// ------------------------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All Rights Reserved. Licensed under the MIT License. See License in the project root for license information. +// ------------------------------------------------------------------------------ + +namespace Microsoft.Graph.PowerShell.Authentication +{ + using System; + using System.Threading; + /// + /// Contains methods to create, modify or obtain a thread safe static instance of . + /// + public class GraphSession : IGraphSession + { + static GraphSession _instance; + static bool _initialized = false; + static ReaderWriterLockSlim sessionLock = new ReaderWriterLockSlim(LockRecursionPolicy.SupportsRecursion); + internal Guid _graphSessionId; + + /// + /// Gets or Sets . + /// + public IAuthContext AuthContext { get; set; } + + /// + /// Gets an instance of . + /// + public static GraphSession Instance + { + get + { + try + { + sessionLock.EnterReadLock(); + try + { + if (null == _instance) + { + throw new InvalidOperationException(ErrorConstants.Codes.SessionNotInitialized); + } + return _instance; + } + finally + { + sessionLock.ExitReadLock(); + } + } + catch (LockRecursionException lockException) + { + throw new InvalidOperationException(ErrorConstants.Codes.SessionLockReadRecursion, lockException); + } + catch (ObjectDisposedException disposedException) + { + throw new InvalidOperationException(ErrorConstants.Codes.SessionLockReadDisposed, disposedException); + } + } + } + + /// + /// Creates a new GraphSession. + /// + public GraphSession() + { + _graphSessionId = Guid.NewGuid(); + } + + /// + /// Initialize . + /// + /// A func to create an instance. + /// If true, overwrite the current instance. Otherwise do not initialize. + public static void Initialize(Func instanceCreator, bool overwrite) + { + try + { + sessionLock.EnterWriteLock(); + try + { + if (overwrite || !_initialized) + { + _instance = instanceCreator(); + _initialized = true; + } + else + { + throw new InvalidOperationException(string.Format(ErrorConstants.Message.InstanceExists, nameof(GraphSession), "Initialize(Func, bool)")); + } + } + finally + { + sessionLock.ExitWriteLock(); + } + } + catch (LockRecursionException lockException) + { + throw new InvalidOperationException(ErrorConstants.Codes.SessionLockWriteRecursion, lockException); + } + catch (ObjectDisposedException disposedException) + { + throw new InvalidOperationException(ErrorConstants.Codes.SessionLockWriteDisposed, disposedException); + } + } + + /// + /// Initialize the current instance if none exists. + /// + /// A func to create an instance. + public static void Initialize(Func instanceCreator) + { + Initialize(instanceCreator, false); + } + + /// + /// Modify the current instance of . + /// + /// A func to modify the instance. + public static void Modify(Action modifier) + { + try + { + sessionLock.EnterWriteLock(); + try + { + modifier(_instance); + } + finally + { + sessionLock.ExitWriteLock(); + } + } + catch (LockRecursionException lockException) + { + throw new InvalidOperationException(ErrorConstants.Codes.SessionLockWriteRecursion, lockException); + } + catch (ObjectDisposedException disposedException) + { + throw new InvalidOperationException(ErrorConstants.Codes.SessionLockWriteDisposed, disposedException); + } + } + + /// + /// Resets the current instance of to initial state. + /// + internal static void Reset() + { + try + { + sessionLock.EnterWriteLock(); + try + { + _instance = null; + _initialized = false; + } + finally + { + sessionLock.ExitWriteLock(); + } + } + catch (LockRecursionException lockException) + { + throw new InvalidOperationException(ErrorConstants.Codes.SessionLockWriteRecursion, lockException); + } + catch (ObjectDisposedException disposedException) + { + throw new InvalidOperationException(ErrorConstants.Codes.SessionLockWriteDisposed, disposedException); + } + } + } +} diff --git a/src/Authentication/Authentication/Common/GraphSessionInitializer.cs b/src/Authentication/Authentication/Common/GraphSessionInitializer.cs new file mode 100644 index 00000000000..72fa61c51fd --- /dev/null +++ b/src/Authentication/Authentication/Common/GraphSessionInitializer.cs @@ -0,0 +1,27 @@ +// ------------------------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All Rights Reserved. Licensed under the MIT License. See License in the project root for license information. +// ------------------------------------------------------------------------------ + +namespace Microsoft.Graph.PowerShell.Authentication +{ + public static class GraphSessionInitializer + { + /// + /// Initializes . + /// + public static void InitializeSession() + { + GraphSession.Initialize(() => CreateInstance()); + } + + /// + /// Creates a new instance of a . + /// + /// + internal static GraphSession CreateInstance() + { + // This can be used to initialize GraphSession from a file in the future. + return new GraphSession(); + } + } +} diff --git a/src/Authentication/Authentication/ErrorConstants.cs b/src/Authentication/Authentication/ErrorConstants.cs index 747c49a1ab4..ca1b815b4cb 100644 --- a/src/Authentication/Authentication/ErrorConstants.cs +++ b/src/Authentication/Authentication/ErrorConstants.cs @@ -7,12 +7,19 @@ public static class ErrorConstants { internal static class Codes { + internal const string SessionNotInitialized = "sessionNotInitialized"; + internal const string SessionLockReadRecursion = "sessionLockReadRecursion"; + internal const string SessionLockReadDisposed = "sessionLockReadDisposed"; + internal const string SessionLockWriteDisposed = "sessionLockWriteDisposed"; + internal const string SessionLockWriteRecursion = "sessionLockWriteRecursion"; internal const string InvalidJWT = "invalidJWT"; } internal static class Message { internal const string InvalidJWT = "Invalid JWT access token."; + internal const string MissingAuthContext = "Authentication needed, call Connect-Graph."; + internal const string InstanceExists = "An instance of {0} already exists. Call {1} to overwrite it."; } } } diff --git a/src/Authentication/Authentication/Helpers/AuthenticationHelpers.cs b/src/Authentication/Authentication/Helpers/AuthenticationHelpers.cs index 5d9712486f4..0a1ae2b514a 100644 --- a/src/Authentication/Authentication/Helpers/AuthenticationHelpers.cs +++ b/src/Authentication/Authentication/Helpers/AuthenticationHelpers.cs @@ -4,7 +4,6 @@ namespace Microsoft.Graph.PowerShell.Authentication.Helpers { using Microsoft.Graph.Auth; - using Microsoft.Graph.PowerShell.Authentication.Models; using Microsoft.Graph.PowerShell.Authentication.TokenCache; using Microsoft.Identity.Client; using System; @@ -16,7 +15,7 @@ internal static class AuthenticationHelpers { private static readonly object FileLock = new object(); - internal static IAuthenticationProvider GetAuthProvider(AuthConfig authConfig) + internal static IAuthenticationProvider GetAuthProvider(IAuthContext authConfig) { if (authConfig.AuthType == AuthenticationType.Delegated) { @@ -43,7 +42,7 @@ internal static IAuthenticationProvider GetAuthProvider(AuthConfig authConfig) } } - internal static void Logout(AuthConfig authConfig) + internal static void Logout(IAuthContext authConfig) { lock (FileLock) { diff --git a/src/Authentication/Authentication/Helpers/HttpHelpers.cs b/src/Authentication/Authentication/Helpers/HttpHelpers.cs index 60ae07a57ef..55f490fdb5d 100644 --- a/src/Authentication/Authentication/Helpers/HttpHelpers.cs +++ b/src/Authentication/Authentication/Helpers/HttpHelpers.cs @@ -4,11 +4,11 @@ namespace Microsoft.Graph.PowerShell.Authentication.Helpers { using Microsoft.Graph.PowerShell.Authentication.Cmdlets; - using Microsoft.Graph.PowerShell.Authentication.Models; using System.Collections.Generic; using System.Linq; using System.Net.Http; using System.Reflection; + using System.Security.Authentication; /// /// A HTTP helper class. @@ -31,8 +31,12 @@ public static class HttpHelpers /// /// /// - public static HttpClient GetGraphHttpClient(AuthConfig authConfig) + public static HttpClient GetGraphHttpClient(IAuthContext authConfig = null) { + authConfig = authConfig ?? GraphSession.Instance.AuthContext; + if (authConfig is null) + throw new AuthenticationException(ErrorConstants.Message.MissingAuthContext); + IAuthenticationProvider authProvider = AuthenticationHelpers.GetAuthProvider(authConfig); IList defaultHandlers = GraphClientFactory.CreateDefaultHandlers(authProvider); diff --git a/src/Authentication/Authentication/Helpers/JwtHelpers.cs b/src/Authentication/Authentication/Helpers/JwtHelpers.cs index 56a123b732d..0dee2afc615 100644 --- a/src/Authentication/Authentication/Helpers/JwtHelpers.cs +++ b/src/Authentication/Authentication/Helpers/JwtHelpers.cs @@ -29,8 +29,7 @@ internal static string Decode(string jwToken) if (jwtHandler.CanReadToken(jwToken)) { JwtSecurityToken token = jwtHandler.ReadJwtToken(jwToken); - JwtPayload jwtPayload = new JwtPayload(token.Claims); - return jwtPayload.SerializeToJson(); + return token.Payload.SerializeToJson(); } else { return null; } diff --git a/src/Authentication/Authentication/Interfaces/IAuthContext.cs b/src/Authentication/Authentication/Interfaces/IAuthContext.cs new file mode 100644 index 00000000000..fb29bba4b3a --- /dev/null +++ b/src/Authentication/Authentication/Interfaces/IAuthContext.cs @@ -0,0 +1,23 @@ +// ------------------------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All Rights Reserved. Licensed under the MIT License. See License in the project root for license information. +// ------------------------------------------------------------------------------ + +namespace Microsoft.Graph.PowerShell.Authentication +{ + public enum AuthenticationType + { + Delegated, + AppOnly + } + public interface IAuthContext + { + string ClientId { get; set; } + string TenantId { get; set; } + string CertificateThumbprint { get; set; } + string[] Scopes { get; set; } + AuthenticationType AuthType { get; set; } + string CertificateName { get; set; } + string Account { get; set; } + string AppName { get; set; } + } +} diff --git a/src/Authentication/Authentication/Interfaces/IGraphSession.cs b/src/Authentication/Authentication/Interfaces/IGraphSession.cs new file mode 100644 index 00000000000..2623b77e85d --- /dev/null +++ b/src/Authentication/Authentication/Interfaces/IGraphSession.cs @@ -0,0 +1,11 @@ +// ------------------------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All Rights Reserved. Licensed under the MIT License. See License in the project root for license information. +// ------------------------------------------------------------------------------ + +namespace Microsoft.Graph.PowerShell.Authentication +{ + public interface IGraphSession + { + IAuthContext AuthContext { get; set; } + } +} \ No newline at end of file diff --git a/src/Authentication/Authentication/Models/AuthConfig.cs b/src/Authentication/Authentication/Models/AuthContext.cs similarity index 81% rename from src/Authentication/Authentication/Models/AuthConfig.cs rename to src/Authentication/Authentication/Models/AuthContext.cs index 0fb989fbbc0..71741dd2c34 100644 --- a/src/Authentication/Authentication/Models/AuthConfig.cs +++ b/src/Authentication/Authentication/Models/AuthContext.cs @@ -1,15 +1,9 @@ // ------------------------------------------------------------------------------ // Copyright (c) Microsoft Corporation. All Rights Reserved. Licensed under the MIT License. See License in the project root for license information. // ------------------------------------------------------------------------------ -namespace Microsoft.Graph.PowerShell.Authentication.Models +namespace Microsoft.Graph.PowerShell.Authentication { - public enum AuthenticationType - { - Delegated, - AppOnly - } - - public class AuthConfig + public class AuthContext: IAuthContext { private const string PowerShellClientId = "14d82eec-204b-4c2f-b7e8-296a70dab67e"; public string ClientId { get; set; } @@ -21,7 +15,7 @@ public class AuthConfig public string Account { get; set; } public string AppName { get; set; } - public AuthConfig() + public AuthContext() { ClientId = PowerShellClientId; } diff --git a/src/Authentication/Authentication/Properties/AssemblyInfo.cs b/src/Authentication/Authentication/Properties/AssemblyInfo.cs new file mode 100644 index 00000000000..47dc954fb32 --- /dev/null +++ b/src/Authentication/Authentication/Properties/AssemblyInfo.cs @@ -0,0 +1,5 @@ +using System.Runtime.CompilerServices; + +#if DEBUG +[assembly: InternalsVisibleTo("Microsoft.Graph.Authentication.Test")] +#endif diff --git a/src/Authentication/Authentication/Properties/launchSettings.json b/src/Authentication/Authentication/Properties/launchSettings.json index 2d57baad318..4153aac9d29 100644 --- a/src/Authentication/Authentication/Properties/launchSettings.json +++ b/src/Authentication/Authentication/Properties/launchSettings.json @@ -2,7 +2,7 @@ "profiles": { "Graph.Authentication": { "commandName": "Executable", - "executablePath": "C:\\Program Files\\PowerShell\\7-preview\\pwsh.exe", + "executablePath": "C:\\Program Files\\PowerShell\\7\\pwsh.exe", "commandLineArgs": "-NoProfile -NoExit" } } diff --git a/tools/Custom/Module.cs b/tools/Custom/Module.cs index d048630b695..4a0e0f6ebcc 100644 --- a/tools/Custom/Module.cs +++ b/tools/Custom/Module.cs @@ -15,17 +15,7 @@ public partial class Module { partial void BeforeCreatePipeline(System.Management.Automation.InvocationInfo invocationInfo, ref Runtime.HttpPipeline pipeline) { - using (var powershell = PowerShell.Create(RunspaceMode.CurrentRunspace)) - { - powershell.Commands.AddCommand(new Command($"$executioncontext.SessionState.PSVariable.GetValue('{Constants.GraphAuthConfigId}')", true)); - - AuthConfig authConfig = powershell.Invoke().FirstOrDefault(); - - if (authConfig == null) - throw new Exception("Authentication needed, call Connect-Graph."); - - pipeline = new Runtime.HttpPipeline(new Runtime.HttpClientFactory(HttpHelpers.GetGraphHttpClient(authConfig))); - } + pipeline = new Runtime.HttpPipeline(new Runtime.HttpClientFactory(HttpHelpers.GetGraphHttpClient())); } } }