Skip to content

Commit

Permalink
Auth changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mythz committed Jan 17, 2012
1 parent 2c6d89e commit fb137f7
Show file tree
Hide file tree
Showing 15 changed files with 128 additions and 57 deletions.
45 changes: 39 additions & 6 deletions src/ServiceStack.ServiceInterface/Auth/AuthProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public virtual object Logout(IServiceBase service, Auth request)
/// <returns></returns>
public virtual object Authenticate(IServiceBase authService, IAuthSession session, Auth request)
{
var tokens = Init(authService, session);
var tokens = Init(authService, ref session, request);

//Default OAuth logic based on Twitter's OAuth workflow
if (!tokens.RequestToken.IsNullOrEmpty() && !request.oauth_token.IsNullOrEmpty())
Expand Down Expand Up @@ -125,16 +125,23 @@ public virtual object Authenticate(IServiceBase authService, IAuthSession sessio
/// <summary>
/// Sets the CallbackUrl and session.ReferrerUrl if not set and initializes the session tokens for this AuthProvider
/// </summary>
/// <param name="service"></param>
/// <param name="authService"></param>
/// <param name="session"></param>
/// <param name="request"> </param>
/// <returns></returns>
protected IOAuthTokens Init(IServiceBase service, IAuthSession session)
protected IOAuthTokens Init(IServiceBase authService, ref IAuthSession session, Auth request)
{
if (request != null && !LoginMatchesSession(session, request.UserName))
{
authService.RemoveSession();
session = authService.GetSession();
}

if (this.CallbackUrl.IsNullOrEmpty())
this.CallbackUrl = service.RequestContext.AbsoluteUri;
this.CallbackUrl = authService.RequestContext.AbsoluteUri;

if (session.ReferrerUrl.IsNullOrEmpty())
session.ReferrerUrl = service.RequestContext.GetHeader("Referer") ?? this.CallbackUrl;
session.ReferrerUrl = authService.RequestContext.GetHeader("Referer") ?? this.CallbackUrl;

var tokens = session.ProviderOAuthAccess.FirstOrDefault(x => x.Provider == Provider);
if (tokens == null)
Expand Down Expand Up @@ -165,6 +172,7 @@ public virtual void OnAuthenticated(IServiceBase authService, IAuthSession sessi
var userSession = session as AuthUserSession;
if (userSession != null)
{
session.ProviderOAuthAccess.ForEach(x => LoadUserOAuthProvider(userSession, x));
LoadUserAuthInfo(userSession, tokens, authInfo);
}

Expand All @@ -184,10 +192,35 @@ public virtual void OnAuthenticated(IServiceBase authService, IAuthSession sessi

protected virtual void LoadUserAuthInfo(AuthUserSession userSession, IOAuthTokens tokens, Dictionary<string, string> authInfo) { }

public virtual bool IsAuthorized(IAuthSession session, IOAuthTokens tokens)
public virtual bool IsAuthorized(IAuthSession session, IOAuthTokens tokens, Auth request=null)
{
if (request != null)
{
if (!LoginMatchesSession(session, request.UserName)) return false;
}

return tokens != null && !string.IsNullOrEmpty(tokens.AccessTokenSecret);
}

protected static bool LoginMatchesSession(IAuthSession session, string userName)
{
if (userName == null) return false;
var isEmail = userName.Contains("@");
if (isEmail)
{
if (!userName.EqualsIgnoreCase(session.Email))
return false;
}
else
{
if (!userName.EqualsIgnoreCase(session.UserName))
return false;
}
return true;
}

protected virtual void LoadUserOAuthProvider(AuthUserSession userSession, IOAuthTokens tokens){}

}

public static class AuthConfigExtensions
Expand Down
2 changes: 1 addition & 1 deletion src/ServiceStack.ServiceInterface/Auth/AuthService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public override object OnPost(Auth request)
return oAuthConfig.Logout(this, request);

var session = this.GetSession();
if (!oAuthConfig.IsAuthorized(session, session.GetOAuthTokens(provider)))
if (!oAuthConfig.IsAuthorized(session, session.GetOAuthTokens(provider), request))
{
return oAuthConfig.Authenticate(this, session, request);
}
Expand Down
2 changes: 2 additions & 0 deletions src/ServiceStack.ServiceInterface/Auth/AuthUserSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public AuthUserSession()

public string UserAuthId { get; set; }

public string UserAuthName { get; set; }

public string UserName { get; set; }

public string TwitterUserId { get; set; }
Expand Down
22 changes: 16 additions & 6 deletions src/ServiceStack.ServiceInterface/Auth/CredentialsAuthProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,21 @@ public virtual bool TryAuthenticate(IServiceBase authService, string userName, s
if (authRepo.TryAuthenticate(userName, password, out useUserName))
{
session.IsAuthenticated = true;
session.UserName = userName;
session.UserAuthName = userName;

return true;
}
return false;
}

public override bool IsAuthorized(IAuthSession session, IOAuthTokens tokens)
public override bool IsAuthorized(IAuthSession session, IOAuthTokens tokens, Auth request=null)
{
return !session.UserName.IsNullOrEmpty();
if (request != null)
{
if (!LoginMatchesSession(session, request.UserName)) return false;
}

return !session.UserAuthName.IsNullOrEmpty();
}

public override object Authenticate(IServiceBase authService, IAuthSession session, Auth request)
Expand All @@ -66,13 +71,18 @@ public override object Authenticate(IServiceBase authService, IAuthSession sessi

protected object Authenticate(IServiceBase authService, IAuthSession session, string userName, string password)
{
if (!LoginMatchesSession(session, userName))
{
authService.RemoveSession();
session = authService.GetSession();
}

if (TryAuthenticate(authService, userName, password))
{
OnAuthenticated(authService, session, null, null);
//OnAuthenticatedCredentials(authService, session, userName);

if (session.UserName == null)
session.UserName = userName;
if (session.UserAuthName == null)
session.UserAuthName = userName;

authService.SaveSession(session);

Expand Down
19 changes: 12 additions & 7 deletions src/ServiceStack.ServiceInterface/Auth/FacebookAuthProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public FacebookAuthProvider(IResourceManager appSettings)

public override object Authenticate(IServiceBase authService, IAuthSession session, Auth request)
{
var tokens = Init(authService, session);
var tokens = Init(authService, ref session, request);

var code = authService.RequestContext.Get<IHttpRequest>().QueryString["code"];
var isPreAuthCallback = !code.IsNullOrEmpty();
Expand Down Expand Up @@ -82,17 +82,22 @@ protected override void LoadUserAuthInfo(AuthUserSession userSession, IOAuthToke
tokens.LastName = obj.Get("last_name");
tokens.Email = obj.Get("email");

userSession.FacebookUserId = tokens.UserId ?? userSession.FacebookUserId;
userSession.FacebookUserName = tokens.UserName ?? userSession.FacebookUserName;
userSession.DisplayName = tokens.DisplayName ?? userSession.DisplayName;
userSession.FirstName = tokens.FirstName ?? userSession.FirstName;
userSession.LastName = tokens.LastName ?? userSession.LastName;
userSession.Email = tokens.Email ?? userSession.Email;
LoadUserOAuthProvider(userSession, tokens);
}
catch (Exception ex)
{
Log.Error("Could not retrieve facebook user info for '{0}'".Fmt(tokens.DisplayName), ex);
}
}

protected override void LoadUserOAuthProvider(AuthUserSession userSession, IOAuthTokens tokens)
{
userSession.FacebookUserId = tokens.UserId ?? userSession.FacebookUserId;
userSession.FacebookUserName = tokens.UserName ?? userSession.FacebookUserName;
userSession.DisplayName = tokens.DisplayName ?? userSession.DisplayName;
userSession.FirstName = tokens.FirstName ?? userSession.FirstName;
userSession.LastName = tokens.LastName ?? userSession.LastName;
userSession.Email = tokens.Email ?? userSession.Email;
}
}
}
2 changes: 1 addition & 1 deletion src/ServiceStack.ServiceInterface/Auth/IAuthProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ public interface IAuthProvider

void OnSaveUserAuth(IServiceBase authService, IAuthSession session);
void OnAuthenticated(IServiceBase authService, IAuthSession session, IOAuthTokens tokens, Dictionary<string, string> authInfo);
bool IsAuthorized(IAuthSession session, IOAuthTokens tokens);
bool IsAuthorized(IAuthSession session, IOAuthTokens tokens, Auth request = null);
}
}
1 change: 1 addition & 0 deletions src/ServiceStack.ServiceInterface/Auth/IAuthSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ public interface IAuthSession
string ReferrerUrl { get; set; }
string Id { get; set; }
string UserAuthId { get; set; }
string UserAuthName { get; set; }
string UserName { get; set; }
string DisplayName { get; set; }
string FirstName { get; set; }
Expand Down
13 changes: 4 additions & 9 deletions src/ServiceStack.ServiceInterface/Auth/OrmLiteAuthRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,10 @@ private void LoadUserAuth(IAuthSession session, UserAuth userAuth)
{
if (userAuth == null) return;

session.UserAuthId = userAuth.Id.ToString(CultureInfo.InvariantCulture);
session.DisplayName = userAuth.DisplayName;
session.FirstName = userAuth.FirstName;
session.LastName = userAuth.LastName;
session.Email = userAuth.Email;
session.Roles = userAuth.Roles;
session.Permissions = userAuth.Permissions;
session.PopulateWith(userAuth);
session.ProviderOAuthAccess = GetUserOAuthProviders(session.UserAuthId)
.ConvertAll(x => (IOAuthTokens)x);

}

public UserAuth GetUserAuth(string userAuthId)
Expand Down Expand Up @@ -220,9 +215,9 @@ public UserAuth GetUserAuth(IAuthSession authSession, IOAuthTokens tokens)
var userAuth = GetUserAuth(authSession.UserAuthId);
if (userAuth != null) return userAuth;
}
if (!authSession.UserName.IsNullOrEmpty())
if (!authSession.UserAuthName.IsNullOrEmpty())
{
var userAuth = GetUserAuthByUserName(authSession.UserName);
var userAuth = GetUserAuthByUserName(authSession.UserAuthName);
if (userAuth != null) return userAuth;
}

Expand Down
12 changes: 3 additions & 9 deletions src/ServiceStack.ServiceInterface/Auth/RedisAuthRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,7 @@ private void LoadUserAuth(IAuthSession session, UserAuth userAuth)
{
if (userAuth == null) return;

session.UserAuthId = userAuth.Id.ToString(CultureInfo.InvariantCulture);
session.DisplayName = userAuth.DisplayName;
session.FirstName = userAuth.FirstName;
session.LastName = userAuth.LastName;
session.Email = userAuth.Email;
session.Roles = userAuth.Roles;
session.Permissions = userAuth.Permissions;
session.PopulateWith(userAuth);
session.ProviderOAuthAccess = GetUserOAuthProviders(session.UserAuthId)
.ConvertAll(x => (IOAuthTokens)x);
}
Expand Down Expand Up @@ -281,9 +275,9 @@ private UserAuth GetUserAuth(IRedisClientFacade redis, IAuthSession authSession,
var userAuth = GetUserAuth(redis, authSession.UserAuthId);
if (userAuth != null) return userAuth;
}
if (!authSession.UserName.IsNullOrEmpty())
if (!authSession.UserAuthName.IsNullOrEmpty())
{
var userAuth = GetUserAuthByUserName(authSession.UserName);
var userAuth = GetUserAuthByUserName(authSession.UserAuthName);
if (userAuth != null) return userAuth;
}

Expand Down
10 changes: 8 additions & 2 deletions src/ServiceStack.ServiceInterface/Auth/RegistrationService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public override object OnPost(Registration request)
}

var session = this.GetSession();
var newUserAuth = request.TranslateTo<UserAuth>();
var newUserAuth = ToUserAuth(request);
var existingUser = UserAuthRepo.GetUserAuth(session, null);

var user = existingUser != null
Expand All @@ -118,6 +118,12 @@ public override object OnPost(Registration request)
};
}

public UserAuth ToUserAuth(Registration request)
{
var to = request.TranslateTo<UserAuth>();
return to;
}

/// <summary>
/// Logic to update UserAuth from Registration info, not enabled on OnPut because of security.
/// </summary>
Expand All @@ -139,7 +145,7 @@ public object UpdateUserAuth(Registration request)
throw HttpError.NotFound("User does not exist");
}

var newUserAuth = request.TranslateTo<UserAuth>();
var newUserAuth = ToUserAuth(request);
UserAuthRepo.UpdateUserAuth(newUserAuth, existingUser, request.Password);

return new RegistrationResponse {
Expand Down
14 changes: 11 additions & 3 deletions src/ServiceStack.ServiceInterface/Auth/TwitterAuthProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,30 @@ public TwitterAuthProvider(IResourceManager appSettings)
protected override void LoadUserAuthInfo(AuthUserSession userSession, IOAuthTokens tokens, Dictionary<string, string> authInfo)
{
if (authInfo.ContainsKey("user_id"))
tokens.UserId = userSession.TwitterUserId = authInfo.GetValueOrDefault("user_id");
tokens.UserId = authInfo.GetValueOrDefault("user_id");

if (authInfo.ContainsKey("screen_name"))
tokens.UserName = userSession.TwitterScreenName = authInfo.GetValueOrDefault("screen_name");
tokens.UserName = authInfo.GetValueOrDefault("screen_name");

try
{
var json = AuthHttpGateway.DownloadTwitterUserInfo(userSession.TwitterUserId);
var obj = JsonObject.Parse(json);
tokens.DisplayName = obj.Get("name");
userSession.DisplayName = tokens.DisplayName ?? userSession.DisplayName;

LoadUserOAuthProvider(userSession, tokens);
}
catch (Exception ex)
{
Log.Error("Could not retrieve twitter user info for '{0}'".Fmt(userSession.TwitterUserId), ex);
}
}

protected override void LoadUserOAuthProvider(AuthUserSession userSession, IOAuthTokens tokens)
{
userSession.TwitterUserId = tokens.UserId ?? userSession.TwitterUserId;
userSession.TwitterScreenName = tokens.UserName ?? userSession.TwitterScreenName;
userSession.DisplayName = tokens.DisplayName ?? userSession.DisplayName;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ public void Logging_in_pulls_all_AuthInfo_from_repo_after_logging_in_all_AuthPro
Password = request.Password,
});

oAuthUserSession = requestContext.Get<IHttpRequest>().GetSession() as AuthUserSession;

Assert.That(oAuthUserSession.TwitterUserId, Is.EqualTo(authInfo["user_id"]));
Assert.That(oAuthUserSession.TwitterScreenName, Is.EqualTo(authInfo["screen_name"]));

Expand Down
27 changes: 19 additions & 8 deletions tests/ServiceStack.WebHost.IntegrationTests/Default.aspx
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,28 @@
<%= UserSession.Dump() %>
</pre>

<div id="userauths"></div>
<div id="oAuthProviders"></div>

<script type="text/javascript">
$.getJSON("api/userauths", function(r) {
$("#userauths").html(_.jsonreport(r.results));
$("#oAuthProviders").html(_.jsonreport(r.oAuthProviders));
});
</script>


<script type="text/javascript">
_.each({
UserName: 'as@if.com',
DisplayName: 'mythz',
Email: 'as@if.com',
Password: 'test',
ConfirmPassword: 'test'
}, function (val, name) {
$("[name=" + name + "]").val(val);
});
UserName: 'as@if.com',
DisplayName: 'mythz',
Email: 'as@if.com',
Password: 'test',
ConfirmPassword: 'test'
}, function (val, name) {
$("[name=" + name + "]").val(val);
});
var clear = function () {
$(".success, .error-summary").hide();
Expand Down
5 changes: 4 additions & 1 deletion tests/ServiceStack.WebHost.IntegrationTests/Global.asax.cs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ private void ConfigureAuth(Funq.Container container)
new OrmLiteAuthRepository(c.Resolve<IDbConnectionFactory>()));

var authRepo = (OrmLiteAuthRepository)container.Resolve<IUserAuthRepository>();
authRepo.CreateMissingTables();
if (new AppSettings().Get("Recr eateTables", true))
authRepo.DropAndReCreateTables();
else
authRepo.CreateMissingTables();
}
}

Expand Down
Loading

0 comments on commit fb137f7

Please sign in to comment.