Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
namespace Microsoft.Graph.Authentication.Test.Helpers
{
using Microsoft.Graph.PowerShell.Authentication;
using System;
using Xunit;
public class GraphSessionTests
{
[Fact]
public void GraphSessionShouldBeInitilizedAfterInitializerIsCalled()
{
GraphSession.Initialize(() => new GraphSession());

Assert.NotNull(GraphSession.Instance);
Assert.Null(GraphSession.Instance.AuthContext);

// reset static instance.
GraphSession.Reset();
}

[Fact]
public void ShouldOverwriteExistingGraphSession()
{
GraphSession.Initialize(() => new GraphSession());
Guid originalSessionId = GraphSession.Instance._graphSessionId;

GraphSession.Initialize(() => new GraphSession(), true);

Assert.NotNull(GraphSession.Instance);
Assert.NotEqual(originalSessionId, GraphSession.Instance._graphSessionId);

// reset static instance.
GraphSession.Reset();
}

[Fact]
public void ShouldNotOverwriteExistingGraphSession()
{
GraphSession.Initialize(() => new GraphSession());
Guid originalSessionId = GraphSession.Instance._graphSessionId;

InvalidOperationException exception = Assert.Throws<InvalidOperationException>(() => GraphSession.Initialize(() => new GraphSession()));

Assert.Equal("An instance of GraphSession already exists. Call Initialize(Func<GraphSession>, bool) to overwrite it.", exception.Message);
Assert.NotNull(GraphSession.Instance);
Assert.Equal(originalSessionId, GraphSession.Instance._graphSessionId);

// reset static instance.
GraphSession.Reset();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
using Xunit;

[assembly: CollectionBehavior(DisableTestParallelization = true)]
14 changes: 11 additions & 3 deletions src/Authentication/Authentication/Cmdlets/ConnectGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace Microsoft.Graph.PowerShell.Authentication.Cmdlets
using System.Threading.Tasks;

[Cmdlet(VerbsCommunications.Connect, "Graph", DefaultParameterSetName = Constants.UserParameterSet)]
public class ConnectGraph : PSCmdlet
public class ConnectGraph : PSCmdlet, IModuleAssemblyInitializer
{

[Parameter(ParameterSetName = Constants.UserParameterSet, Position = 1)]
Expand Down Expand Up @@ -53,7 +53,7 @@ protected override void ProcessRecord()
{
base.ProcessRecord();

AuthConfig authConfig = new AuthConfig { TenantId = TenantId };
IAuthContext authConfig = new AuthContext { TenantId = TenantId };
CancellationToken cancellationToken = CancellationToken.None;

if (ParameterSetName == Constants.UserParameterSet)
Expand Down Expand Up @@ -117,7 +117,7 @@ protected override void ProcessRecord()
authConfig.Account = jwtPayload?.Upn ?? account?.Username;

// Save auth config to session state.
SessionState.PSVariable.Set(Constants.GraphAuthConfigId, authConfig);
GraphSession.Instance.AuthContext = authConfig;
}
catch (AuthenticationException authEx)
{
Expand Down Expand Up @@ -164,5 +164,13 @@ private void ThrowParameterError(string parameterName)
new ArgumentException($"Must specify {parameterName}"), Guid.NewGuid().ToString(), ErrorCategory.InvalidArgument, null)
);
}

/// <summary>
/// Globally initializes GraphSession.
/// </summary>
public void OnImport()
{
GraphSessionInitializer.InitializeSession();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
namespace Microsoft.Graph.PowerShell.Authentication.Cmdlets
{
using Microsoft.Graph.PowerShell.Authentication.Helpers;
using Microsoft.Graph.PowerShell.Authentication.Models;
using System;
using System.Management.Automation;
[Cmdlet(VerbsCommunications.Disconnect, "Graph")]
Expand All @@ -24,15 +23,15 @@ protected override void ProcessRecord()
{
base.ProcessRecord();

AuthConfig authConfig = SessionState.PSVariable.GetValue(Constants.GraphAuthConfigId) as AuthConfig;
IAuthContext authConfig = GraphSession.Instance.AuthContext;

if (authConfig == null)
ThrowTerminatingError(
new ErrorRecord(new System.Exception("No application to sign out from."), Guid.NewGuid().ToString(), ErrorCategory.InvalidArgument, null));

AuthenticationHelpers.Logout(authConfig);

SessionState.PSVariable.Remove(Constants.GraphAuthConfigId);
GraphSession.Instance.AuthContext = null;
}

protected override void StopProcessing()
Expand Down
17 changes: 3 additions & 14 deletions src/Authentication/Authentication/Cmdlets/GetMGContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,10 @@

namespace Microsoft.Graph.PowerShell.Authentication.Cmdlets
{
using Microsoft.Graph.Auth;
using Microsoft.Graph.PowerShell.Authentication.Helpers;
using Microsoft.Graph.PowerShell.Authentication.Models;
using System;
using System.Collections.Generic;
using System.Management.Automation;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;

[Cmdlet(VerbsCommon.Get, "MgContext", DefaultParameterSetName = Constants.UserParameterSet)]
[OutputType(typeof(AuthConfig))]
[OutputType(typeof(IAuthContext))]
public class GetMGContext: PSCmdlet
{
protected override void BeginProcessing()
Expand All @@ -26,11 +18,8 @@ protected override void BeginProcessing()
protected override void ProcessRecord()
{
base.ProcessRecord();
// Get auth config from session state.
PSVariable graphAuthVariable = SessionState.PSVariable.Get(Constants.GraphAuthConfigId);
AuthConfig authConfig = graphAuthVariable?.Value as AuthConfig;
Invoke<AuthConfig>();
WriteObject(authConfig as AuthConfig);
IAuthContext authConfig = GraphSession.Instance.AuthContext;
WriteObject(authConfig as IAuthContext);
}

protected override void EndProcessing()
Expand Down
168 changes: 168 additions & 0 deletions src/Authentication/Authentication/Common/GraphSession.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// ------------------------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All Rights Reserved. Licensed under the MIT License. See License in the project root for license information.
// ------------------------------------------------------------------------------

namespace Microsoft.Graph.PowerShell.Authentication
{
using System;
using System.Threading;
/// <summary>
/// Contains methods to create, modify or obtain a thread safe static instance of <see cref="GraphSession"/>.
/// </summary>
public class GraphSession : IGraphSession
{
static GraphSession _instance;
static bool _initialized = false;
static ReaderWriterLockSlim sessionLock = new ReaderWriterLockSlim(LockRecursionPolicy.SupportsRecursion);
internal Guid _graphSessionId;

/// <summary>
/// Gets or Sets <see cref="IAuthContext"/>.
/// </summary>
public IAuthContext AuthContext { get; set; }

/// <summary>
/// Gets an instance of <see cref="GraphSession"/>.
/// </summary>
public static GraphSession Instance
{
get
{
try
{
sessionLock.EnterReadLock();
try
{
if (null == _instance)
{
throw new InvalidOperationException(ErrorConstants.Codes.SessionNotInitialized);
}
return _instance;
}
finally
{
sessionLock.ExitReadLock();
}
}
catch (LockRecursionException lockException)
{
throw new InvalidOperationException(ErrorConstants.Codes.SessionLockReadRecursion, lockException);
}
catch (ObjectDisposedException disposedException)
{
throw new InvalidOperationException(ErrorConstants.Codes.SessionLockReadDisposed, disposedException);
}
}
}

/// <summary>
/// Creates a new GraphSession.
/// </summary>
public GraphSession()
{
_graphSessionId = Guid.NewGuid();
}

/// <summary>
/// Initialize <see cref="GraphSession"/>.
/// </summary>
/// <param name="instanceCreator">A func to create an instance.</param>
/// <param name="overwrite">If true, overwrite the current instance. Otherwise do not initialize.</param>
public static void Initialize(Func<GraphSession> instanceCreator, bool overwrite)
{
try
{
sessionLock.EnterWriteLock();
try
{
if (overwrite || !_initialized)
{
_instance = instanceCreator();
_initialized = true;
}
else
{
throw new InvalidOperationException(string.Format(ErrorConstants.Message.InstanceExists, nameof(GraphSession), "Initialize(Func<GraphSession>, bool)"));
}
}
finally
{
sessionLock.ExitWriteLock();
}
}
catch (LockRecursionException lockException)
{
throw new InvalidOperationException(ErrorConstants.Codes.SessionLockWriteRecursion, lockException);
}
catch (ObjectDisposedException disposedException)
{
throw new InvalidOperationException(ErrorConstants.Codes.SessionLockWriteDisposed, disposedException);
}
}

/// <summary>
/// Initialize the current instance if none exists.
/// </summary>
/// <param name="instanceCreator">A func to create an instance.</param>
public static void Initialize(Func<GraphSession> instanceCreator)
{
Initialize(instanceCreator, false);
}

/// <summary>
/// Modify the current instance of <see cref="GraphSession"/>.
/// </summary>
/// <param name="modifier">A func to modify the <see cref="GraphSession"/> instance.</param>
public static void Modify(Action<GraphSession> modifier)
{
try
{
sessionLock.EnterWriteLock();
try
{
modifier(_instance);
}
finally
{
sessionLock.ExitWriteLock();
}
}
catch (LockRecursionException lockException)
{
throw new InvalidOperationException(ErrorConstants.Codes.SessionLockWriteRecursion, lockException);
}
catch (ObjectDisposedException disposedException)
{
throw new InvalidOperationException(ErrorConstants.Codes.SessionLockWriteDisposed, disposedException);
}
}

/// <summary>
/// Resets the current instance of <see cref="GraphSession"/> to initial state.
/// </summary>
internal static void Reset()
{
try
{
sessionLock.EnterWriteLock();
try
{
_instance = null;
_initialized = false;
}
finally
{
sessionLock.ExitWriteLock();
}
}
catch (LockRecursionException lockException)
{
throw new InvalidOperationException(ErrorConstants.Codes.SessionLockWriteRecursion, lockException);
}
catch (ObjectDisposedException disposedException)
{
throw new InvalidOperationException(ErrorConstants.Codes.SessionLockWriteDisposed, disposedException);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// ------------------------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All Rights Reserved. Licensed under the MIT License. See License in the project root for license information.
// ------------------------------------------------------------------------------

namespace Microsoft.Graph.PowerShell.Authentication
{
public static class GraphSessionInitializer
{
/// <summary>
/// Initializes <see cref="GraphSession"/>.
/// </summary>
public static void InitializeSession()
{
GraphSession.Initialize(() => CreateInstance());
}

/// <summary>
/// Creates a new instance of a <see cref="GraphSession"/>.
/// </summary>
/// <returns><see cref="GraphSession"/></returns>
internal static GraphSession CreateInstance()
{
// This can be used to initialize GraphSession from a file in the future.
return new GraphSession();
}
}
}
7 changes: 7 additions & 0 deletions src/Authentication/Authentication/ErrorConstants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@ public static class ErrorConstants
{
internal static class Codes
{
internal const string SessionNotInitialized = "sessionNotInitialized";
internal const string SessionLockReadRecursion = "sessionLockReadRecursion";
internal const string SessionLockReadDisposed = "sessionLockReadDisposed";
internal const string SessionLockWriteDisposed = "sessionLockWriteDisposed";
internal const string SessionLockWriteRecursion = "sessionLockWriteRecursion";
internal const string InvalidJWT = "invalidJWT";
}

internal static class Message
{
internal const string InvalidJWT = "Invalid JWT access token.";
internal const string MissingAuthContext = "Authentication needed, call Connect-Graph.";
internal const string InstanceExists = "An instance of {0} already exists. Call {1} to overwrite it.";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
namespace Microsoft.Graph.PowerShell.Authentication.Helpers
{
using Microsoft.Graph.Auth;
using Microsoft.Graph.PowerShell.Authentication.Models;
using Microsoft.Graph.PowerShell.Authentication.TokenCache;
using Microsoft.Identity.Client;
using System;
Expand All @@ -16,7 +15,7 @@ internal static class AuthenticationHelpers
{
private static readonly object FileLock = new object();

internal static IAuthenticationProvider GetAuthProvider(AuthConfig authConfig)
internal static IAuthenticationProvider GetAuthProvider(IAuthContext authConfig)
{
if (authConfig.AuthType == AuthenticationType.Delegated)
{
Expand All @@ -43,7 +42,7 @@ internal static IAuthenticationProvider GetAuthProvider(AuthConfig authConfig)
}
}

internal static void Logout(AuthConfig authConfig)
internal static void Logout(IAuthContext authConfig)
{
lock (FileLock)
{
Expand Down
Loading