diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java b/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java index 727eda83..806d48d0 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java @@ -9,13 +9,17 @@ import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; import io.grpc.ManagedChannel; +import io.grpc.StatusRuntimeException; +import io.grpc.Status; import io.opentdf.platform.kas.AccessServiceGrpc; import io.opentdf.platform.kas.PublicKeyRequest; import io.opentdf.platform.kas.PublicKeyResponse; import io.opentdf.platform.kas.RewrapRequest; +import io.opentdf.platform.kas.RewrapResponse; import io.opentdf.platform.sdk.Config.KASInfo; import io.opentdf.platform.sdk.nanotdf.ECKeyPair; import io.opentdf.platform.sdk.nanotdf.NanoTDFType; +import io.opentdf.platform.sdk.TDF.KasBadRequestException; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; @@ -182,9 +186,19 @@ public byte[] unwrap(Manifest.KeyAccess keyAccess, String policy) { .newBuilder() .setSignedRequestToken(jwt.serialize()) .build(); - var response = getStub(keyAccess.url).rewrap(request); - var wrappedKey = response.getEntityWrappedKey().toByteArray(); - return decryptor.decrypt(wrappedKey); + RewrapResponse response; + try { + response = getStub(keyAccess.url).rewrap(request); + var wrappedKey = response.getEntityWrappedKey().toByteArray(); + return decryptor.decrypt(wrappedKey); + } catch (StatusRuntimeException e) { + if (e.getStatus().getCode() == Status.Code.INVALID_ARGUMENT) { + // 400 Bad Request + throw new KasBadRequestException("rewrap request 400: " + e.toString()); + } + throw e; + } + } public byte[] unwrapNanoTDF(NanoTDFType.ECCurve curve, String header, String kasURL) { diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/Manifest.java b/sdk/src/main/java/io/opentdf/platform/sdk/Manifest.java index bc4c4c3c..aca6d1cd 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/Manifest.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/Manifest.java @@ -10,6 +10,9 @@ import com.nimbusds.jose.crypto.RSASSAVerifier; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; + +import io.opentdf.platform.sdk.TDF.AssertionException; + import org.apache.commons.codec.binary.Hex; import org.erdtman.jcs.JsonCanonicalizer; @@ -381,7 +384,7 @@ public void sign(final HashValues hashValues, final AssertionConfig.AssertionKey public Assertion.HashValues verify(AssertionConfig.AssertionKey assertionKey) throws ParseException, JOSEException { if (binding == null) { - throw new SDKException("Binding is null in assertion"); + throw new AssertionException("Binding is null in assertion", this.id); } String signatureString = binding.signature; diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java index d694cea7..4c6eb6da 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java @@ -119,30 +119,48 @@ public FailedToCreateGMAC(String errorMessage) { } } - public static class NotValidateRootSignature extends RuntimeException { - public NotValidateRootSignature(String errorMessage) { + public static class TDFReadFailed extends RuntimeException { + public TDFReadFailed(String errorMessage) { + super(errorMessage); + } + } + + public static class TamperException extends SDKException { + public TamperException(String errorMessage) { + super("[tamper detected] "+errorMessage); + } + } + + public static class RootSignatureValidationException extends TamperException { + public RootSignatureValidationException(String errorMessage) { super(errorMessage); } } - public static class SegmentSizeMismatch extends RuntimeException { + public static class SegmentSizeMismatch extends TamperException { public SegmentSizeMismatch(String errorMessage) { super(errorMessage); } } - public static class SegmentSignatureMismatch extends RuntimeException { + public static class SegmentSignatureMismatch extends TamperException { public SegmentSignatureMismatch(String errorMessage) { super(errorMessage); } } - public static class TDFReadFailed extends RuntimeException { - public TDFReadFailed(String errorMessage) { + public static class KasBadRequestException extends TamperException { + public KasBadRequestException(String errorMessage) { super(errorMessage); } } + public static class AssertionException extends TamperException { + public AssertionException(String errorMessage, String id) { + super("assertion id: "+ id + "; " + errorMessage); + } + } + public static class EncryptedMetadata { private String ciphertext; private String iv; @@ -558,7 +576,7 @@ public Reader loadTDF(SeekableByteChannel tdf, SDK.KAS kas) public Reader loadTDF(SeekableByteChannel tdf, SDK.KAS kas, Config.TDFReaderConfig tdfReaderConfig) - throws NotValidateRootSignature, SegmentSizeMismatch, + throws RootSignatureValidationException, SegmentSizeMismatch, IOException, FailedToCreateGMAC, JOSEException, ParseException, NoSuchAlgorithmException, DecoderException { TDFReader tdfReader = new TDFReader(tdf); @@ -666,7 +684,7 @@ public Reader loadTDF(SeekableByteChannel tdf, SDK.KAS kas, } if (rootSignature.compareTo(rootSigValue) != 0) { - throw new NotValidateRootSignature("root signature validation failed"); + throw new RootSignatureValidationException("root signature validation failed"); } int segmentSize = manifest.encryptionInformation.integrityInformation.segmentSizeDefault; @@ -701,11 +719,11 @@ public Reader loadTDF(SeekableByteChannel tdf, SDK.KAS kas, var encodeSignature = Base64.getEncoder().encodeToString(signature.getBytes()); if (!Objects.equals(hashOfAssertion, hashValues.getAssertionHash())) { - throw new SDKException("assertion hash mismatch"); + throw new AssertionException("assertion hash mismatch", assertion.id); } if (!Objects.equals(encodeSignature, hashValues.getSignature())) { - throw new SDKException("failed integrity check on assertion signature"); + throw new AssertionException("failed integrity check on assertion signature", assertion.id); } } diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/TDFTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/TDFTest.java index 8afb5545..acd1fd15 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/TDFTest.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/TDFTest.java @@ -7,6 +7,7 @@ import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsResponse; import io.opentdf.platform.policy.attributes.AttributesServiceGrpc; import io.opentdf.platform.sdk.Config.KASInfo; +import io.opentdf.platform.sdk.TDF.Reader; import io.opentdf.platform.sdk.nanotdf.NanoTDFType; import org.apache.commons.compress.utils.SeekableInMemoryByteChannel; import org.junit.jupiter.api.BeforeAll; @@ -30,7 +31,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -333,6 +333,62 @@ void testSimpleTDFWithAssertionWithHS256() throws Exception { } } + @Test + void testSimpleTDFWithAssertionWithHS256Failure() throws Exception { + + ListenableFuture resp1 = mock(ListenableFuture.class); + lenient().when(resp1.get()).thenReturn(GetAttributeValuesByFqnsResponse.newBuilder().build()); + lenient().when(attributeGrpcStub.getAttributeValuesByFqns(any(GetAttributeValuesByFqnsRequest.class))) + .thenReturn(resp1); + + // var keypair = CryptoUtils.generateRSAKeypair(); + SecureRandom secureRandom = new SecureRandom(); + byte[] key = new byte[32]; + secureRandom.nextBytes(key); + + String assertion1Id = "assertion1"; + var assertionConfig1 = new AssertionConfig(); + assertionConfig1.id = assertion1Id; + assertionConfig1.type = AssertionConfig.Type.BaseAssertion; + assertionConfig1.scope = AssertionConfig.Scope.TrustedDataObj; + assertionConfig1.appliesToState = AssertionConfig.AppliesToState.Unencrypted; + assertionConfig1.statement = new AssertionConfig.Statement(); + assertionConfig1.statement.format = "base64binary"; + assertionConfig1.statement.schema = "text"; + assertionConfig1.statement.value = "ICAgIDxlZGoOkVkaD4="; + assertionConfig1.assertionKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.HS256, key); + + Config.TDFConfig config = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(getKASInfos()), + Config.withAssertionConfig(assertionConfig1)); + + String plainText = "this is extremely sensitive stuff!!!"; + InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); + ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); + + TDF tdf = new TDF(); + tdf.createTDF(plainTextInputStream, tdfOutputStream, config, kas, attributeGrpcStub); + + byte[] notkey = new byte[32]; + secureRandom.nextBytes(notkey); + var assertionVerificationKeys = new Config.AssertionVerificationKeys(); + assertionVerificationKeys.defaultKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.HS256, + notkey); + Config.TDFReaderConfig readerConfig = Config.newTDFReaderConfig( + Config.withAssertionVerificationKeys(assertionVerificationKeys)); + + var unwrappedData = new ByteArrayOutputStream(); + Reader reader; + try { + reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), kas, readerConfig); + throw new RuntimeException("assertion verify key error thrown"); + + } catch (SDKException e) { + assertThat(e).hasMessageContaining("verify"); + } + } + @Test public void testCreatingTDFWithMultipleSegments() throws Exception {