Skip to content

Commit

Permalink
Make it possible for Go agents to use auth tokens from HTTP header (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mshaposhnik committed Jul 3, 2018
1 parent 4cecef2 commit 29eef88
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 25 deletions.
16 changes: 15 additions & 1 deletion agents/go-agents/core/auth/auth.go
Expand Up @@ -18,13 +18,15 @@ import (
"regexp"

"fmt"
"strings"
"github.com/dgrijalva/jwt-go"
"os"
)

const (
TokenKind = "machine_token"
WorkspaceIdEnv = "CHE_WORKSPACE_ID"
BearerPrefix = "bearer "
)

var (
Expand Down Expand Up @@ -111,6 +113,12 @@ func (handler handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return
}
token := req.URL.Query().Get("token")
if token == "" {
header := req.Header.Get("Authorization")
if header != "" && strings.HasPrefix(strings.ToLower(header), BearerPrefix) {
token = header[len(BearerPrefix):]
}
}
if err := authenticate(token); err == nil {
handler.delegate.ServeHTTP(w, req)
} else {
Expand All @@ -125,6 +133,12 @@ func (handler cachingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request
return
}
token := req.URL.Query().Get("token")
if token == "" {
header := req.Header.Get("Authorization")
if header != "" && strings.HasPrefix(strings.ToLower(header), BearerPrefix) {
token = header[len(BearerPrefix):]
}
}
if handler.cache.Contains(token) {
handler.delegate.ServeHTTP(w, req)
} else if err := authenticate(token); err == nil {
Expand All @@ -137,7 +151,7 @@ func (handler cachingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request

func authenticate(token string) error {
if token == "" {
return errors.New("Authentication failed because: missing 'token' query parameter")
return errors.New("Authentication failed because: missing authentication token")
}

claims := &JWTClaims{}
Expand Down
Expand Up @@ -24,9 +24,11 @@
* @author Alexander Garagatyi
*/
public class HttpConnectionServerChecker extends ServerChecker {
private static final String AUTHORIZATION_HEADER = "Authorization";
private static final String CONNECTION_HEADER = "Connection";
private static final String CONNECTION_CLOSE = "close";
private final URL url;
private final String token;

public HttpConnectionServerChecker(
URL url,
Expand All @@ -36,9 +38,11 @@ public HttpConnectionServerChecker(
long timeout,
int successThreshold,
TimeUnit timeUnit,
Timer timer) {
Timer timer,
String token) {
super(machineName, serverRef, period, timeout, successThreshold, timeUnit, timer);
this.url = url;
this.token = token;
}

@Override
Expand All @@ -50,6 +54,9 @@ public boolean isAvailable() {
httpURLConnection.setConnectTimeout((int) TimeUnit.SECONDS.toMillis(3));
httpURLConnection.setReadTimeout((int) TimeUnit.SECONDS.toMillis(3));
httpURLConnection.setRequestProperty(CONNECTION_HEADER, CONNECTION_CLOSE);
if (token != null) {
httpURLConnection.setRequestProperty(AUTHORIZATION_HEADER, "Bearer " + token);
}
return isConnectionSuccessful(httpURLConnection);
} catch (IOException e) {
return false;
Expand Down
Expand Up @@ -174,6 +174,7 @@ private ServerChecker getChecker(String serverRef, Server server) throws Infrast
// workaround needed because we don't have server readiness check in the model
// Create server readiness endpoint URL
URL url;
String token;
try {
String serverUrl = server.getUrl();

Expand All @@ -186,32 +187,44 @@ private ServerChecker getChecker(String serverRef, Server server) throws Infrast
serverUrl = serverUrl + '/';
}

url =
UriBuilder.fromUri(serverUrl)
.queryParam(
"token",
machineTokenProvider.getToken(
runtimeIdentity.getOwnerId(), runtimeIdentity.getWorkspaceId()))
.build()
.toURL();
token =
machineTokenProvider.getToken(
runtimeIdentity.getOwnerId(), runtimeIdentity.getWorkspaceId());
url = UriBuilder.fromUri(serverUrl).build().toURL();
} catch (MalformedURLException e) {
throw new InternalInfrastructureException(
"Server " + serverRef + " URL is invalid. Error: " + e.getMessage(), e);
}

return doCreateChecker(url, serverRef);
return doCreateChecker(url, serverRef, token);
}

@VisibleForTesting
ServerChecker doCreateChecker(URL url, String serverRef) {
ServerChecker doCreateChecker(URL url, String serverRef, String token) {
// TODO add readiness endpoint to terminal and remove this
// workaround needed because terminal server doesn't have endpoint to check it readiness
if ("terminal".equals(serverRef)) {
return new TerminalHttpConnectionServerChecker(
url, machineName, serverRef, 3, 180, serverPingSuccessThreshold, TimeUnit.SECONDS, timer);
url,
machineName,
serverRef,
3,
180,
serverPingSuccessThreshold,
TimeUnit.SECONDS,
timer,
token);
}
// TODO do not hardcode timeouts, use server conf instead
return new HttpConnectionServerChecker(
url, machineName, serverRef, 3, 180, serverPingSuccessThreshold, TimeUnit.SECONDS, timer);
url,
machineName,
serverRef,
3,
180,
serverPingSuccessThreshold,
TimeUnit.SECONDS,
timer,
token);
}
}
Expand Up @@ -32,8 +32,9 @@ class TerminalHttpConnectionServerChecker extends HttpConnectionServerChecker {
long timeout,
int successThreshold,
TimeUnit timeUnit,
Timer timer) {
super(url, machineName, serverRef, period, timeout, successThreshold, timeUnit, timer);
Timer timer,
String token) {
super(url, machineName, serverRef, period, timeout, successThreshold, timeUnit, timer, token);
}

@Override
Expand Down
Expand Up @@ -50,7 +50,7 @@ public void setUp() throws Exception {
checker =
spy(
new HttpConnectionServerChecker(
SERVER_URL, MACHINE_NAME, SERVER_REF, 1, 10, 1, TimeUnit.SECONDS, timer));
SERVER_URL, MACHINE_NAME, SERVER_REF, 1, 10, 1, TimeUnit.SECONDS, timer, null));

doReturn(conn).when(checker).createConnection(nullable(URL.class));
when(conn.getResponseCode()).thenReturn(200);
Expand Down
Expand Up @@ -21,7 +21,6 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotEquals;
import static org.testng.Assert.fail;

import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -85,7 +84,8 @@ public void setUp() throws Exception {
servers,
machineTokenProvider,
SERVER_PING_SUCCESS_THRESHOLD));
when(checker.doCreateChecker(any(URL.class), anyString())).thenReturn(connectionChecker);
when(checker.doCreateChecker(any(URL.class), anyString(), anyString()))
.thenReturn(connectionChecker);
when(machineTokenProvider.getToken(anyString(), anyString())).thenReturn(MACHINE_TOKEN);
}

Expand All @@ -98,18 +98,19 @@ public void tearDown() throws Exception {
}

@Test(timeOut = 1000)
public void shouldUseMachineTokenWhenConstructionUrlToCheck() throws Exception {
public void shouldUseMachineTokenWhenCallChecker() throws Exception {
servers.clear();
servers.put("wsagent/http", new ServerImpl().withUrl("http://localhost"));

checker.startAsync(readinessHandler);
connectionChecker.getReportCompFuture().complete("wsagent/http");

verify(machineTokenProvider).getToken(USER_ID, WORKSPACE_ID);
ArgumentCaptor<URL> urlCaptor = ArgumentCaptor.forClass(URL.class);
verify(checker).doCreateChecker(urlCaptor.capture(), eq("wsagent/http"));
URL urlToCheck = urlCaptor.getValue();
assertNotEquals(urlToCheck.getQuery().indexOf("token=" + MACHINE_TOKEN), -1);
ArgumentCaptor<String> tokenCaptor = ArgumentCaptor.forClass(String.class);
verify(checker)
.doCreateChecker(
eq(new URL("http://localhost/")), eq("wsagent/http"), tokenCaptor.capture());
assertEquals(tokenCaptor.getValue(), MACHINE_TOKEN);
}

@Test(timeOut = 1000)
Expand Down
Expand Up @@ -46,7 +46,8 @@ public void setUp() throws Exception {
10,
1,
TimeUnit.SECONDS,
timer);
timer,
null);
}

@Test
Expand Down

0 comments on commit 29eef88

Please sign in to comment.