Skip to content

Commit

Permalink
fixed issues #118, #119
Browse files Browse the repository at this point in the history
  • Loading branch information
Ganesh Subramanian committed Oct 7, 2014
1 parent 0697d64 commit ee4526a
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.minnal.security.config.SecurityAware;
import org.minnal.security.config.SecurityConfiguration;
import org.minnal.security.filter.AuthenticationFilter;
import org.minnal.security.filter.AuthenticationListener;
import org.minnal.security.filter.CallbackFilter;
import org.minnal.security.filter.SecurityContextFilter;
import org.pac4j.core.client.Client;
Expand All @@ -23,6 +24,8 @@ public class SecurityPlugin implements Plugin {

private Clients clients;

private AuthenticationListener listener;

/**
* @param callbackUrl
* @param clients
Expand All @@ -31,6 +34,15 @@ public SecurityPlugin(String callbackUrl, Client... clients) {
this.clients = new Clients(callbackUrl, clients);
}

/**
* @param callbackUrl
* @param clients
*/
public SecurityPlugin(String callbackUrl, AuthenticationListener listener, Client... clients) {
this.clients = new Clients(callbackUrl, clients);
this.listener = listener;
}

/**
* @param clients
*/
Expand All @@ -46,7 +58,10 @@ public void init(Application<? extends ApplicationConfiguration> application) {
}
SecurityConfiguration configuration = ((SecurityAware) applicationConfiguration).getSecurityConfiguration();
clients.init();
application.addFilter(new CallbackFilter(clients, configuration));

CallbackFilter callbackFilter = new CallbackFilter(clients, configuration);
callbackFilter.registerListener(listener);
application.addFilter(callbackFilter);
application.addFilter(new AuthenticationFilter(clients, configuration));
application.addFilter(new SecurityContextFilter(configuration));
application.getResourceConfig().register(RolesAllowedDynamicFeature.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ public class AbstractSecurityFilter {

public static final String AUTH_COOKIE = "_session_id";

public static final String SESSION = "session";

/**
* @param configuration
*/
Expand All @@ -43,7 +45,10 @@ public SecurityConfiguration getConfiguration() {
* @return
*/
protected Session getSession(ContainerRequestContext request, boolean create) {
Session session = null;
Session session = (Session) request.getProperty(SESSION);
if (session != null) {
return session;
}
Cookie sessionCookie = request.getCookies().get(AUTH_COOKIE);

if (sessionCookie != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
*/
package org.minnal.security.filter;

import java.io.IOException;
import java.util.Map;

import javax.annotation.Priority;
import javax.ws.rs.Priorities;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerRequestFilter;
import javax.ws.rs.container.ContainerResponseContext;
import javax.ws.rs.container.ContainerResponseFilter;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.NewCookie;
import javax.ws.rs.core.Response;
Expand All @@ -35,13 +38,13 @@
*
*/
@Priority(Priorities.AUTHENTICATION)
public class AuthenticationFilter extends AbstractSecurityFilter implements ContainerRequestFilter {
public class AuthenticationFilter extends AbstractSecurityFilter implements ContainerRequestFilter, ContainerResponseFilter {

private Clients clients;

public static final String PRINCIPAL = "principal";

public static final String SESSION = "session";
protected AuthenticationListener listener;

private static final Logger logger = LoggerFactory.getLogger(AuthenticationFilter.class);

Expand All @@ -63,25 +66,21 @@ public Clients getClients() {

@Override
public void filter(ContainerRequestContext request) {
Session session = getSession(request, true);
request.setProperty(SESSION, session);

if (isWhiteListed(request)) {
logger.debug("Request path {} is in whitelisted set of urls. Skipping authentication", request.getUriInfo());
return;
}

Session session = getSession(request, true);
request.setProperty(SESSION, session);
if (isAuthenticated(session)) {
logger.debug("Session is already authenticated. Skipping authentication");
return;
}

JaxrsWebContext context = getContext(request, session);

Client client = null;
try {
client = getClient(context);
} catch (TechnicalException e) {
logger.error("Failed while getiing the client", e);
}

Client client = getClient(context);

if (client != null) {
session.addAttribute(Clients.DEFAULT_CLIENT_NAME_PARAMETER, client.getName());
getConfiguration().getSessionStore().save(session);
Expand All @@ -99,6 +98,12 @@ public void filter(ContainerRequestContext request) {
request.abortWith(context.getResponse());
}

@Override
public void filter(ContainerRequestContext request, ContainerResponseContext response) throws IOException {
Session session = getSession(request, true);
response.getHeaders().add(HttpHeaders.SET_COOKIE, new NewCookie(AUTH_COOKIE, session.getId()).toString());
}

/**
* Checks if the session is already authenticated
*
Expand Down Expand Up @@ -128,7 +133,6 @@ protected User retrieveProfile(Session session) {
if (profile == null) {
return null;
}

Client client = getClient(session);
Class<UserProfile> type = Generics.getTypeParameter(client.getClass(), UserProfile.class);
if (type.isAssignableFrom(profile.getClass())) {
Expand All @@ -145,19 +149,29 @@ protected User retrieveProfile(Session session) {
return null;
}

protected Client getClient(Session session) {
String clientName = session.getAttribute(Clients.DEFAULT_CLIENT_NAME_PARAMETER);
if (Strings.isNullOrEmpty(clientName)) {
return null;
}
return clients.findClient(clientName);
}
protected Client getClient(Session session) {
String clientName = session.getAttribute(Clients.DEFAULT_CLIENT_NAME_PARAMETER);
if (Strings.isNullOrEmpty(clientName)) {
return null;
}
return clients.findClient(clientName);
}

protected Client getClient(JaxrsWebContext context) {
String clientName = context.getRequestParameter(Clients.DEFAULT_CLIENT_NAME_PARAMETER);
if (Strings.isNullOrEmpty(clientName)) {
return null;
}
return clients.findClient(clientName);
try {
return clients.findClient(context);
} catch (TechnicalException e) {
logger.debug("Error while getting the client from the context", e);
return null;
}
}

/**
* Registers the authentication listener
*
* @param listener
*/
public void registerListener(AuthenticationListener listener) {
this.listener = listener;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/**
*
*/
package org.minnal.security.filter;

import org.minnal.security.session.Session;
import org.pac4j.core.profile.UserProfile;

/**
* @author ganeshs
*
*/
public interface AuthenticationListener {

void authSuccess(Session session, UserProfile profile);

void authFailed(Session session);
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,27 @@
*/
package org.minnal.security.filter;

import java.io.IOException;
import java.net.URI;

import javax.annotation.Priority;
import javax.ws.rs.Priorities;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerResponseContext;
import javax.ws.rs.container.PreMatching;
import javax.ws.rs.core.Response;

import org.minnal.security.auth.JaxrsWebContext;
import org.minnal.security.config.SecurityConfiguration;
import org.minnal.security.session.Session;
import org.minnal.utils.http.HttpUtil;
import org.pac4j.core.client.Client;
import org.pac4j.core.client.Clients;
import org.pac4j.core.credentials.Credentials;
import org.pac4j.core.exception.RequiresHttpAction;
import org.pac4j.core.profile.UserProfile;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* @author ganeshs
Expand All @@ -27,37 +32,58 @@
@PreMatching
@Priority(Priorities.USER)
public class CallbackFilter extends AuthenticationFilter {

private static final Logger logger = LoggerFactory.getLogger(CallbackFilter.class);

/**
* @param clients
*/
public CallbackFilter(Clients clients, SecurityConfiguration configuration) {
super(clients, configuration);
}

@Override
public void filter(ContainerRequestContext request) {
URI uri = URI.create(getClients().getCallbackUrl());
if (! request.getUriInfo().getPath().equalsIgnoreCase(uri.getPath())) {
if (! HttpUtil.structureUrl(request.getUriInfo().getPath()).equalsIgnoreCase(uri.getPath())) {
logger.debug("Request path {} doesn't match callback url. Skipping", request.getUriInfo().getPath());
return;
}

Session session = getSession(request, true);
Client client = getClient(session);
JaxrsWebContext context = getContext(request, session);
Client client = getClient(session);
if (client == null) {
client = getClient(context);
}
if (client == null) {
context.setResponseStatus(422);
if (listener != null) {
listener.authFailed(session);
}
} else {
try {
Credentials credentials = client.getCredentials(context);
UserProfile userProfile = client.getUserProfile(credentials, context);
session.addAttribute(Clients.DEFAULT_CLIENT_NAME_PARAMETER, client.getName());
session.addAttribute(PRINCIPAL, userProfile);
if (listener != null) {
listener.authSuccess(session, userProfile);
}
getConfiguration().getSessionStore().save(session);
context.setResponseStatus(Response.Status.OK.getStatusCode());
} catch (RequiresHttpAction e) {
context.setResponseStatus(e.getCode());
if (listener != null) {
listener.authFailed(session);
}
}
}
request.abortWith(context.getResponse());
}

@Override
public void filter(ContainerRequestContext request, ContainerResponseContext response) throws IOException {
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ public void shouldCreateSessionIfAuthCookieIsFoundButSessionHasExpired() {
verify(sessionStore).getSession(sessionId);
}

@Test
public void shouldReturnSessionIfAlreadySetToRequestProperty() {
Session session = mock(Session.class);
when(context.getProperty(AuthenticationFilter.SESSION)).thenReturn(session);
assertEquals(filter.getSession(context, true), session);
verify(sessionStore, never()).createSession(any(String.class));
}

@Test
public void shouldReturnSessionIfAuthCookieIsFoundAndSessionHasNotExpired() {
when(configuration.getSessionExpiryTimeInSecs()).thenReturn(100L);
Expand All @@ -160,17 +168,17 @@ public void shouldReturnSessionIfAuthCookieIsFoundAndSessionHasNotExpired() {
}

@Test
public void shouldReturnNullClientFromRequestContextIfClientNameAttributeIsNotSet() {
public void shouldReturnNullFromRequestContextIfClientNameAttributeIsNotSet() {
JaxrsWebContext context = mock(JaxrsWebContext.class);
when(context.getRequestParameter(Clients.DEFAULT_CLIENT_NAME_PARAMETER)).thenReturn(null);
assertNull(filter.getClient(context));
}

@Test(expectedExceptions=TechnicalException.class)
@Test
public void shouldThrowExceptionIfClientNameIsNotFoundInRequestContext() {
JaxrsWebContext context = mock(JaxrsWebContext.class);
when(context.getRequestParameter(Clients.DEFAULT_CLIENT_NAME_PARAMETER)).thenReturn("unknownClient");
filter.getClient(context);
assertNull(filter.getClient(context));
}

@Test
Expand Down Expand Up @@ -248,7 +256,8 @@ public void shouldReturnFalseIfNotAlreadyAuthenticated() {
public void shouldNotFilterWhiteListedUrls() {
doReturn(true).when(filter).isWhiteListed(context);
filter.filter(context);
verify(filter, never()).getSession(context, true);
verify(filter).getSession(context, true);
verify(filter, never()).getContext(eq(context), any(Session.class));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,21 @@ public class CallbackFilterTest {

private SessionStore sessionStore;

private AuthenticationListener listener;

private Client client;

@BeforeMethod
public void setup() {
client = mock(Client.class);
listener = mock(AuthenticationListener.class);
when(client.getName()).thenReturn("client1");
clients = new Clients("/callback", client);
sessionStore = mock(SessionStore.class);
configuration = mock(SecurityConfiguration.class);
when(configuration.getSessionStore()).thenReturn(sessionStore);
filter = spy(new CallbackFilter(clients, configuration));
filter.registerListener(listener);
context = mock(ContainerRequestContext.class);
uriInfo = mock(UriInfo.class);
when(uriInfo.getPath()).thenReturn("/callback");
Expand All @@ -80,6 +84,7 @@ public void shouldReturnUnAcceptableIfClientNameNotSet() {
filter.filter(context);
verify(webContext).setResponseStatus(422);
verify(context).abortWith(response);
verify(listener).authFailed(session);
}

@Test
Expand All @@ -97,8 +102,10 @@ public void shouldReturnOkIfClientNameIsSet() throws RequiresHttpAction {
when(client.getUserProfile(credentials, webContext)).thenReturn(profile);
filter.filter(context);
verify(session).addAttribute(AuthenticationFilter.PRINCIPAL, profile);
verify(session).addAttribute(Clients.DEFAULT_CLIENT_NAME_PARAMETER, "client1");
verify(sessionStore).save(session);
verify(webContext).setResponseStatus(Response.Status.OK.getStatusCode());
verify(listener).authSuccess(session, profile);
verify(context).abortWith(response);
}
}

0 comments on commit ee4526a

Please sign in to comment.