diff --git a/src/main/java/land/oras/Registry.java b/src/main/java/land/oras/Registry.java index a845db20..fcc4363d 100644 --- a/src/main/java/land/oras/Registry.java +++ b/src/main/java/land/oras/Registry.java @@ -901,14 +901,14 @@ private OrasHttpClient.ResponseWrapper getManifestResponse(ContainerRef private boolean switchTokenAuth(ContainerRef containerRef, OrasHttpClient.ResponseWrapper response) { if (response.statusCode() == 401 && !(authProvider instanceof BearerTokenProvider)) { LOG.debug("Requesting token with token flow"); - setAuthProvider(new BearerTokenProvider(authProvider).refreshToken(containerRef, response)); + setAuthProvider(new BearerTokenProvider(authProvider).refreshToken(containerRef, client, response)); return true; } // Need token refresh (expired or wrong scope) if ((response.statusCode() == 401 || response.statusCode() == 403) && authProvider instanceof BearerTokenProvider) { LOG.debug("Requesting new token with username password flow"); - setAuthProvider(((BearerTokenProvider) authProvider).refreshToken(containerRef, response)); + setAuthProvider(((BearerTokenProvider) authProvider).refreshToken(containerRef, client, response)); return true; } return false; diff --git a/src/main/java/land/oras/auth/BearerTokenProvider.java b/src/main/java/land/oras/auth/BearerTokenProvider.java index 44244d24..2e131ce6 100644 --- a/src/main/java/land/oras/auth/BearerTokenProvider.java +++ b/src/main/java/land/oras/auth/BearerTokenProvider.java @@ -75,11 +75,12 @@ public BearerTokenProvider(AuthProvider provider) { /** * Retrieve * @param response The response + * @param client The original client * @param containerRef The container reference * @return The token */ public BearerTokenProvider refreshToken( - ContainerRef containerRef, OrasHttpClient.ResponseWrapper response) { + ContainerRef containerRef, OrasHttpClient client, OrasHttpClient.ResponseWrapper response) { String wwwAuthHeader = response.headers().getOrDefault(Const.WWW_AUTHENTICATE_HEADER.toLowerCase(), ""); LOG.debug("WWW-Authenticate header: {}", wwwAuthHeader); @@ -103,14 +104,12 @@ public BearerTokenProvider refreshToken( URI uri = URI.create(realm + "?scope=" + scope + "&service=" + service); // Perform the request to get the token - OrasHttpClient httpClient = - OrasHttpClient.Builder.builder().withAuthentication(provider).build(); Map headers = new HashMap<>(); String authHeader = provider.getAuthHeader(containerRef); if (authHeader != null) { headers.put(Const.AUTHORIZATION_HEADER, authHeader); } - OrasHttpClient.ResponseWrapper responseWrapper = httpClient.get(uri, headers); + OrasHttpClient.ResponseWrapper responseWrapper = client.get(uri, headers); // Log the response LOG.debug( @@ -143,9 +142,9 @@ public BearerTokenProvider refreshToken( } @Override - public String getAuthHeader(ContainerRef registry) { + public @Nullable String getAuthHeader(ContainerRef registry) { if (token == null) { - throw new OrasException("No token available"); + return null; } return "Bearer " + token.token; } diff --git a/src/test/java/land/oras/RegistryWireMockTest.java b/src/test/java/land/oras/RegistryWireMockTest.java index dca6aa29..1af3a0bf 100644 --- a/src/test/java/land/oras/RegistryWireMockTest.java +++ b/src/test/java/land/oras/RegistryWireMockTest.java @@ -33,7 +33,9 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; +import java.time.ZonedDateTime; import java.util.List; +import land.oras.auth.BearerTokenProvider; import land.oras.auth.FileStoreAuthenticationProvider; import land.oras.auth.UsernamePasswordProvider; import land.oras.credentials.FileStore; @@ -301,4 +303,82 @@ void shouldRetryBlobUpload(WireMockRuntimeInfo wmRuntimeInfo) throws IOException assertNotNull(layer.getDigest()); } } + + @Test + void shouldGetToken(WireMockRuntimeInfo wmRuntimeInfo) { + + // Return data from wiremock + WireMock wireMock = wmRuntimeInfo.getWireMock(); + wireMock.register(WireMock.any(WireMock.urlEqualTo("/v2/library/get-token/blobs/sha256:one")) + .inScenario("get token") + .willReturn(WireMock.unauthorized() + .withHeader( + Const.WWW_AUTHENTICATE_HEADER, + "Bearer realm=\"http://localhost:%d/token\",service=\"localhost\",scope=\"repository:library/get-token:pull\"" + .formatted(wmRuntimeInfo.getHttpPort())))); + + // Return token + wireMock.register( + WireMock.any(WireMock.urlEqualTo("/token?scope=repository:library/get-token:pull&service=localhost")) + .inScenario("get token") + .willSetStateTo("get") + .willReturn(WireMock.okJson(JsonUtils.toJson(new BearerTokenProvider.TokenResponse( + "fake-token", "access-token", 300, ZonedDateTime.now()))))); + + // On the second call we return ok + wireMock.register(WireMock.any(WireMock.urlEqualTo("/v2/library/get-token/blobs/sha256:one")) + .inScenario("get token") + .whenScenarioStateIs("get") + .willReturn(WireMock.ok().withBody("blob-data"))); + + // Insecure registry + Registry registry = Registry.Builder.builder() + .withAuthProvider(authProvider) + .withInsecure(true) + .build(); + + ContainerRef containerRef = + ContainerRef.parse("localhost:%d/library/get-token".formatted(wmRuntimeInfo.getHttpPort())); + byte[] blob = registry.getBlob(containerRef.withDigest("sha256:one")); + assertEquals("blob-data", new String(blob)); + } + + @Test + void shouldRefreshExpiredToken(WireMockRuntimeInfo wmRuntimeInfo) { + + // Return data from wiremock + WireMock wireMock = wmRuntimeInfo.getWireMock(); + wireMock.register(WireMock.any(WireMock.urlEqualTo("/v2/library/refresh-token/blobs/sha256:one")) + .inScenario("get token") + .willReturn(WireMock.forbidden() + .withHeader( + Const.WWW_AUTHENTICATE_HEADER, + "Bearer realm=\"http://localhost:%d/token\",service=\"localhost\",scope=\"repository:library/refresh-token:pull\"" + .formatted(wmRuntimeInfo.getHttpPort())))); + + // Return token + wireMock.register(WireMock.any( + WireMock.urlEqualTo("/token?scope=repository:library/refresh-token:pull&service=localhost")) + .inScenario("get token") + .willSetStateTo("get") + .willReturn(WireMock.okJson(JsonUtils.toJson(new BearerTokenProvider.TokenResponse( + "fake-token", "access-token", 300, ZonedDateTime.now()))))); + + // On the second call we return ok + wireMock.register(WireMock.any(WireMock.urlEqualTo("/v2/library/refresh-token/blobs/sha256:one")) + .inScenario("get token") + .whenScenarioStateIs("get") + .willReturn(WireMock.ok().withBody("blob-data"))); + + // Insecure registry + Registry registry = Registry.Builder.builder() + .withAuthProvider(new BearerTokenProvider(authProvider)) // Already bearer token + .withInsecure(true) + .build(); + + ContainerRef containerRef = + ContainerRef.parse("localhost:%d/library/refresh-token".formatted(wmRuntimeInfo.getHttpPort())); + byte[] blob = registry.getBlob(containerRef.withDigest("sha256:one")); + assertEquals("blob-data", new String(blob)); + } } diff --git a/src/test/java/land/oras/auth/BearerTokenProviderTest.java b/src/test/java/land/oras/auth/BearerTokenProviderTest.java index 0cbc496f..b6618130 100644 --- a/src/test/java/land/oras/auth/BearerTokenProviderTest.java +++ b/src/test/java/land/oras/auth/BearerTokenProviderTest.java @@ -21,6 +21,7 @@ package land.oras.auth; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.doReturn; @@ -54,7 +55,7 @@ public class BearerTokenProviderTest { @Test @SuppressWarnings({"unchecked", "rawtypes"}) - void shouldTestToken(WireMockRuntimeInfo wmRuntimeInfo) { + void shouldRefreshToken(WireMockRuntimeInfo wmRuntimeInfo) { // Mock responses OrasHttpClient.ResponseWrapper mockResponse = Mockito.mock(OrasHttpClient.ResponseWrapper.class); @@ -75,7 +76,43 @@ void shouldTestToken(WireMockRuntimeInfo wmRuntimeInfo) { // Test BearerTokenProvider provider = new BearerTokenProvider(new UsernamePasswordProvider("user", "password")); - provider.refreshToken(containerRef, mockResponse); + provider.refreshToken(containerRef, OrasHttpClient.Builder.builder().build(), mockResponse); + BearerTokenProvider.TokenResponse token = provider.getToken(); + + // Assert tokens + assertEquals("fake-token", token.token()); + assertEquals("access-token", token.access_token()); + assertEquals(300, token.expire_in()); + assertEquals(tokenResponse.issued_at(), token.issued_at()); + + // Check the token header is set + assertEquals("Bearer fake-token", provider.getAuthHeader(containerRef)); + } + + @Test + @SuppressWarnings({"unchecked", "rawtypes"}) + void shouldTestTokenWithNoAuth(WireMockRuntimeInfo wmRuntimeInfo) { + + // Mock responses + OrasHttpClient.ResponseWrapper mockResponse = Mockito.mock(OrasHttpClient.ResponseWrapper.class); + BearerTokenProvider.TokenResponse tokenResponse = + new BearerTokenProvider.TokenResponse("fake-token", "access-token", 300, ZonedDateTime.now()); + WireMock wireMock = wmRuntimeInfo.getWireMock(); + wireMock.register(WireMock.get(WireMock.urlEqualTo( + "/token?scope=repository:library/test:pull&service=%s".formatted(registry.getRegistry()))) + .willReturn(WireMock.okJson(JsonUtils.toJson(tokenResponse)))); + + // Return WWW-Authenticate header from registry + Mockito.when(mockResponse.headers()) + .thenReturn(Map.of( + Const.WWW_AUTHENTICATE_HEADER.toLowerCase(), + String.format( + "Bearer realm=\"%s/token\",service=\"%s\",scope=\"repository:library/test:pull\"", + wmRuntimeInfo.getHttpBaseUrl(), registry.getRegistry()))); + + // Test + BearerTokenProvider provider = new BearerTokenProvider(new NoAuthProvider()); + provider.refreshToken(containerRef, OrasHttpClient.Builder.builder().build(), mockResponse); BearerTokenProvider.TokenResponse token = provider.getToken(); // Assert tokens @@ -91,9 +128,7 @@ void shouldTestToken(WireMockRuntimeInfo wmRuntimeInfo) { @Test void testNoRefreshedToken() { BearerTokenProvider provider = new BearerTokenProvider(new UsernamePasswordProvider("user", "password")); - assertThrows(OrasException.class, () -> { - provider.getAuthHeader(containerRef); - }); + assertNull(provider.getAuthHeader(containerRef), "No token should be returned"); } @Test @@ -103,13 +138,13 @@ void testInvalidWwwAuthentication() { BearerTokenProvider provider = new BearerTokenProvider(new UsernamePasswordProvider("user", "password")); ContainerRef containerRef = ContainerRef.parse("localhost:5000/library/test:latest"); assertThrows(OrasException.class, () -> { - provider.refreshToken(containerRef, mockResponse); + provider.refreshToken(containerRef, OrasHttpClient.Builder.builder().build(), mockResponse); }); doReturn(Map.of(Const.WWW_AUTHENTICATE_HEADER.toLowerCase(), "invalid")) .when(mockResponse) .headers(); assertThrows(OrasException.class, () -> { - provider.refreshToken(containerRef, mockResponse); + provider.refreshToken(containerRef, OrasHttpClient.Builder.builder().build(), mockResponse); }); } @@ -126,7 +161,7 @@ void testWWWAuthenticateFormat(WireMockRuntimeInfo wmRuntimeInfo) { BearerTokenProvider provider = new BearerTokenProvider(new UsernamePasswordProvider("user", "password")); assertThrows(OrasException.class, () -> { - provider.refreshToken(containerRef, mockResponse); + provider.refreshToken(containerRef, OrasHttpClient.Builder.builder().build(), mockResponse); }); // Without error @@ -137,7 +172,7 @@ void testWWWAuthenticateFormat(WireMockRuntimeInfo wmRuntimeInfo) { .when(mockResponse) .headers(); // No exception should be thrown - provider.refreshToken(containerRef, mockResponse); + provider.refreshToken(containerRef, OrasHttpClient.Builder.builder().build(), mockResponse); // With error doReturn(Map.of( @@ -147,6 +182,6 @@ void testWWWAuthenticateFormat(WireMockRuntimeInfo wmRuntimeInfo) { .when(mockResponse) .headers(); // No exception should be thrown - provider.refreshToken(containerRef, mockResponse); + provider.refreshToken(containerRef, OrasHttpClient.Builder.builder().build(), mockResponse); } }