Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/main/java/land/oras/Registry.java
Original file line number Diff line number Diff line change
Expand Up @@ -901,14 +901,14 @@ private OrasHttpClient.ResponseWrapper<String> getManifestResponse(ContainerRef
private boolean switchTokenAuth(ContainerRef containerRef, OrasHttpClient.ResponseWrapper<String> response) {
if (response.statusCode() == 401 && !(authProvider instanceof BearerTokenProvider)) {
LOG.debug("Requesting token with token flow");
setAuthProvider(new BearerTokenProvider(authProvider).refreshToken(containerRef, response));
setAuthProvider(new BearerTokenProvider(authProvider).refreshToken(containerRef, client, response));
return true;
}
// Need token refresh (expired or wrong scope)
if ((response.statusCode() == 401 || response.statusCode() == 403)
&& authProvider instanceof BearerTokenProvider) {
LOG.debug("Requesting new token with username password flow");
setAuthProvider(((BearerTokenProvider) authProvider).refreshToken(containerRef, response));
setAuthProvider(((BearerTokenProvider) authProvider).refreshToken(containerRef, client, response));
return true;
}
return false;
Expand Down
11 changes: 5 additions & 6 deletions src/main/java/land/oras/auth/BearerTokenProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,12 @@ public BearerTokenProvider(AuthProvider provider) {
/**
* Retrieve
* @param response The response
* @param client The original client
* @param containerRef The container reference
* @return The token
*/
public BearerTokenProvider refreshToken(
ContainerRef containerRef, OrasHttpClient.ResponseWrapper<String> response) {
ContainerRef containerRef, OrasHttpClient client, OrasHttpClient.ResponseWrapper<String> response) {

String wwwAuthHeader = response.headers().getOrDefault(Const.WWW_AUTHENTICATE_HEADER.toLowerCase(), "");
LOG.debug("WWW-Authenticate header: {}", wwwAuthHeader);
Expand All @@ -103,14 +104,12 @@ public BearerTokenProvider refreshToken(
URI uri = URI.create(realm + "?scope=" + scope + "&service=" + service);

// Perform the request to get the token
OrasHttpClient httpClient =
OrasHttpClient.Builder.builder().withAuthentication(provider).build();
Map<String, String> headers = new HashMap<>();
String authHeader = provider.getAuthHeader(containerRef);
if (authHeader != null) {
headers.put(Const.AUTHORIZATION_HEADER, authHeader);
}
OrasHttpClient.ResponseWrapper<String> responseWrapper = httpClient.get(uri, headers);
OrasHttpClient.ResponseWrapper<String> responseWrapper = client.get(uri, headers);

// Log the response
LOG.debug(
Expand Down Expand Up @@ -143,9 +142,9 @@ public BearerTokenProvider refreshToken(
}

@Override
public String getAuthHeader(ContainerRef registry) {
public @Nullable String getAuthHeader(ContainerRef registry) {
if (token == null) {
throw new OrasException("No token available");
return null;
}
return "Bearer " + token.token;
}
Expand Down
80 changes: 80 additions & 0 deletions src/test/java/land/oras/RegistryWireMockTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.ZonedDateTime;
import java.util.List;
import land.oras.auth.BearerTokenProvider;
import land.oras.auth.FileStoreAuthenticationProvider;
import land.oras.auth.UsernamePasswordProvider;
import land.oras.credentials.FileStore;
Expand Down Expand Up @@ -301,4 +303,82 @@ void shouldRetryBlobUpload(WireMockRuntimeInfo wmRuntimeInfo) throws IOException
assertNotNull(layer.getDigest());
}
}

@Test
void shouldGetToken(WireMockRuntimeInfo wmRuntimeInfo) {

// Return data from wiremock
WireMock wireMock = wmRuntimeInfo.getWireMock();
wireMock.register(WireMock.any(WireMock.urlEqualTo("/v2/library/get-token/blobs/sha256:one"))
.inScenario("get token")
.willReturn(WireMock.unauthorized()
.withHeader(
Const.WWW_AUTHENTICATE_HEADER,
"Bearer realm=\"http://localhost:%d/token\",service=\"localhost\",scope=\"repository:library/get-token:pull\""
.formatted(wmRuntimeInfo.getHttpPort()))));

// Return token
wireMock.register(
WireMock.any(WireMock.urlEqualTo("/token?scope=repository:library/get-token:pull&service=localhost"))
.inScenario("get token")
.willSetStateTo("get")
.willReturn(WireMock.okJson(JsonUtils.toJson(new BearerTokenProvider.TokenResponse(
"fake-token", "access-token", 300, ZonedDateTime.now())))));

// On the second call we return ok
wireMock.register(WireMock.any(WireMock.urlEqualTo("/v2/library/get-token/blobs/sha256:one"))
.inScenario("get token")
.whenScenarioStateIs("get")
.willReturn(WireMock.ok().withBody("blob-data")));

// Insecure registry
Registry registry = Registry.Builder.builder()
.withAuthProvider(authProvider)
.withInsecure(true)
.build();

ContainerRef containerRef =
ContainerRef.parse("localhost:%d/library/get-token".formatted(wmRuntimeInfo.getHttpPort()));
byte[] blob = registry.getBlob(containerRef.withDigest("sha256:one"));
assertEquals("blob-data", new String(blob));
}

@Test
void shouldRefreshExpiredToken(WireMockRuntimeInfo wmRuntimeInfo) {

// Return data from wiremock
WireMock wireMock = wmRuntimeInfo.getWireMock();
wireMock.register(WireMock.any(WireMock.urlEqualTo("/v2/library/refresh-token/blobs/sha256:one"))
.inScenario("get token")
.willReturn(WireMock.forbidden()
.withHeader(
Const.WWW_AUTHENTICATE_HEADER,
"Bearer realm=\"http://localhost:%d/token\",service=\"localhost\",scope=\"repository:library/refresh-token:pull\""
.formatted(wmRuntimeInfo.getHttpPort()))));

// Return token
wireMock.register(WireMock.any(
WireMock.urlEqualTo("/token?scope=repository:library/refresh-token:pull&service=localhost"))
.inScenario("get token")
.willSetStateTo("get")
.willReturn(WireMock.okJson(JsonUtils.toJson(new BearerTokenProvider.TokenResponse(
"fake-token", "access-token", 300, ZonedDateTime.now())))));

// On the second call we return ok
wireMock.register(WireMock.any(WireMock.urlEqualTo("/v2/library/refresh-token/blobs/sha256:one"))
.inScenario("get token")
.whenScenarioStateIs("get")
.willReturn(WireMock.ok().withBody("blob-data")));

// Insecure registry
Registry registry = Registry.Builder.builder()
.withAuthProvider(new BearerTokenProvider(authProvider)) // Already bearer token
.withInsecure(true)
.build();

ContainerRef containerRef =
ContainerRef.parse("localhost:%d/library/refresh-token".formatted(wmRuntimeInfo.getHttpPort()));
byte[] blob = registry.getBlob(containerRef.withDigest("sha256:one"));
assertEquals("blob-data", new String(blob));
}
}
55 changes: 45 additions & 10 deletions src/test/java/land/oras/auth/BearerTokenProviderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package land.oras.auth;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.doReturn;

Expand Down Expand Up @@ -54,7 +55,7 @@ public class BearerTokenProviderTest {

@Test
@SuppressWarnings({"unchecked", "rawtypes"})
void shouldTestToken(WireMockRuntimeInfo wmRuntimeInfo) {
void shouldRefreshToken(WireMockRuntimeInfo wmRuntimeInfo) {

// Mock responses
OrasHttpClient.ResponseWrapper mockResponse = Mockito.mock(OrasHttpClient.ResponseWrapper.class);
Expand All @@ -75,7 +76,43 @@ void shouldTestToken(WireMockRuntimeInfo wmRuntimeInfo) {

// Test
BearerTokenProvider provider = new BearerTokenProvider(new UsernamePasswordProvider("user", "password"));
provider.refreshToken(containerRef, mockResponse);
provider.refreshToken(containerRef, OrasHttpClient.Builder.builder().build(), mockResponse);
BearerTokenProvider.TokenResponse token = provider.getToken();

// Assert tokens
assertEquals("fake-token", token.token());
assertEquals("access-token", token.access_token());
assertEquals(300, token.expire_in());
assertEquals(tokenResponse.issued_at(), token.issued_at());

// Check the token header is set
assertEquals("Bearer fake-token", provider.getAuthHeader(containerRef));
}

@Test
@SuppressWarnings({"unchecked", "rawtypes"})
void shouldTestTokenWithNoAuth(WireMockRuntimeInfo wmRuntimeInfo) {

// Mock responses
OrasHttpClient.ResponseWrapper mockResponse = Mockito.mock(OrasHttpClient.ResponseWrapper.class);
BearerTokenProvider.TokenResponse tokenResponse =
new BearerTokenProvider.TokenResponse("fake-token", "access-token", 300, ZonedDateTime.now());
WireMock wireMock = wmRuntimeInfo.getWireMock();
wireMock.register(WireMock.get(WireMock.urlEqualTo(
"/token?scope=repository:library/test:pull&service=%s".formatted(registry.getRegistry())))
.willReturn(WireMock.okJson(JsonUtils.toJson(tokenResponse))));

// Return WWW-Authenticate header from registry
Mockito.when(mockResponse.headers())
.thenReturn(Map.of(
Const.WWW_AUTHENTICATE_HEADER.toLowerCase(),
String.format(
"Bearer realm=\"%s/token\",service=\"%s\",scope=\"repository:library/test:pull\"",
wmRuntimeInfo.getHttpBaseUrl(), registry.getRegistry())));

// Test
BearerTokenProvider provider = new BearerTokenProvider(new NoAuthProvider());
provider.refreshToken(containerRef, OrasHttpClient.Builder.builder().build(), mockResponse);
BearerTokenProvider.TokenResponse token = provider.getToken();

// Assert tokens
Expand All @@ -91,9 +128,7 @@ void shouldTestToken(WireMockRuntimeInfo wmRuntimeInfo) {
@Test
void testNoRefreshedToken() {
BearerTokenProvider provider = new BearerTokenProvider(new UsernamePasswordProvider("user", "password"));
assertThrows(OrasException.class, () -> {
provider.getAuthHeader(containerRef);
});
assertNull(provider.getAuthHeader(containerRef), "No token should be returned");
}

@Test
Expand All @@ -103,13 +138,13 @@ void testInvalidWwwAuthentication() {
BearerTokenProvider provider = new BearerTokenProvider(new UsernamePasswordProvider("user", "password"));
ContainerRef containerRef = ContainerRef.parse("localhost:5000/library/test:latest");
assertThrows(OrasException.class, () -> {
provider.refreshToken(containerRef, mockResponse);
provider.refreshToken(containerRef, OrasHttpClient.Builder.builder().build(), mockResponse);
});
doReturn(Map.of(Const.WWW_AUTHENTICATE_HEADER.toLowerCase(), "invalid"))
.when(mockResponse)
.headers();
assertThrows(OrasException.class, () -> {
provider.refreshToken(containerRef, mockResponse);
provider.refreshToken(containerRef, OrasHttpClient.Builder.builder().build(), mockResponse);
});
}

Expand All @@ -126,7 +161,7 @@ void testWWWAuthenticateFormat(WireMockRuntimeInfo wmRuntimeInfo) {

BearerTokenProvider provider = new BearerTokenProvider(new UsernamePasswordProvider("user", "password"));
assertThrows(OrasException.class, () -> {
provider.refreshToken(containerRef, mockResponse);
provider.refreshToken(containerRef, OrasHttpClient.Builder.builder().build(), mockResponse);
});

// Without error
Expand All @@ -137,7 +172,7 @@ void testWWWAuthenticateFormat(WireMockRuntimeInfo wmRuntimeInfo) {
.when(mockResponse)
.headers();
// No exception should be thrown
provider.refreshToken(containerRef, mockResponse);
provider.refreshToken(containerRef, OrasHttpClient.Builder.builder().build(), mockResponse);

// With error
doReturn(Map.of(
Expand All @@ -147,6 +182,6 @@ void testWWWAuthenticateFormat(WireMockRuntimeInfo wmRuntimeInfo) {
.when(mockResponse)
.headers();
// No exception should be thrown
provider.refreshToken(containerRef, mockResponse);
provider.refreshToken(containerRef, OrasHttpClient.Builder.builder().build(), mockResponse);
}
}