diff --git a/src/cloudfront/index.spec.ts b/src/cloudfront/index.spec.ts index e6b5515..2cc70b5 100644 --- a/src/cloudfront/index.spec.ts +++ b/src/cloudfront/index.spec.ts @@ -21,7 +21,7 @@ describe("cloudfront", () => { const listDistributionMock = jest.spyOn(cloudfront, "listDistributions"); const listTagsForResourceMock = jest.spyOn( cloudfront, - "listTagsForResource" + "listTagsForResource", ); const waitForMock = jest.spyOn(cloudfront, "waitFor"); @@ -46,7 +46,7 @@ describe("cloudfront", () => { }, ], }, - }) + }), ) .mockReturnValueOnce( awsResolve({ @@ -61,7 +61,7 @@ describe("cloudfront", () => { }, ], }, - }) + }), ); listTagsForResourceMock.mockReturnValue( @@ -69,7 +69,7 @@ describe("cloudfront", () => { Tags: { Items: [identifyingTag], }, - }) + }), ); const distribution: any = @@ -92,7 +92,7 @@ describe("cloudfront", () => { }, ], }, - }) + }), ); listTagsForResourceMock.mockReturnValue( @@ -100,7 +100,7 @@ describe("cloudfront", () => { Tags: { Items: [identifyingTag], }, - }) + }), ); waitForMock.mockReturnValue(awsResolve()); @@ -126,7 +126,7 @@ describe("cloudfront", () => { const invalidationParams: any = createInvalidationMock.mock.calls[0][0]; expect(invalidationParams.DistributionId).toEqual("some-distribution-id"); expect(invalidationParams.InvalidationBatch.Paths.Items[0]).toEqual( - "index.html" + "index.html", ); }); @@ -134,7 +134,7 @@ describe("cloudfront", () => { createInvalidationMock.mockReturnValue(awsResolve({ Invalidation: {} })); await invalidateCloudfrontCache( "some-distribution-id", - "index.html, static/*" + "index.html, static/*", ); expect(createInvalidationMock).toHaveBeenCalledTimes(1); @@ -152,7 +152,7 @@ describe("cloudfront", () => { await invalidateCloudfrontCache( "some-distribution-id", "index.html", - true + true, ); expect(waitForMock).toHaveBeenCalledTimes(1); expect(waitForMock.mock.calls[0][0]).toEqual("invalidationCompleted"); @@ -174,7 +174,7 @@ describe("cloudfront", () => { await invalidateCloudfrontCacheWithRetry( "some-distribution-id", - "index.html, static/*" + "index.html, static/*", ); expect(createInvalidationMock).toHaveBeenCalledTimes(2); @@ -194,7 +194,7 @@ describe("cloudfront", () => { try { await invalidateCloudfrontCacheWithRetry( "some-distribution-id", - "index.html, static/*" + "index.html, static/*", ); } catch (error) { expect(error).toBeDefined(); @@ -218,13 +218,13 @@ describe("cloudfront", () => { it("should create a distribution and wait for it to be available", async () => { const distribution = { Id: "distribution-id" }; createDistributionMock.mockReturnValue( - awsResolve({ Distribution: distribution }) + awsResolve({ Distribution: distribution }), ); tagResourceMock.mockReturnValue(awsResolve()); waitForMock.mockReturnValue(awsResolve()); const result = await createCloudFrontDistribution( "hello.lalilo.com", - "arn:certificate" + "arn:certificate", ); expect(result).toBe(distribution); expect(tagResourceMock).toHaveBeenCalledTimes(1); @@ -232,15 +232,15 @@ describe("cloudfront", () => { const distributionParam: any = createDistributionMock.mock.calls[0][0]; const distributionConfig = distributionParam.DistributionConfig; expect(distributionConfig.Origins.Items[0].DomainName).toEqual( - "hello.lalilo.com.s3-website.eu-west-3.amazonaws.com" + "hello.lalilo.com.s3-website.eu-west-3.amazonaws.com", ); expect( - distributionConfig.DefaultCacheBehavior.ViewerProtocolPolicy + distributionConfig.DefaultCacheBehavior.ViewerProtocolPolicy, ).toEqual("redirect-to-https"); expect(distributionConfig.DefaultCacheBehavior.MinTTL).toEqual(0); expect(distributionConfig.DefaultCacheBehavior.Compress).toEqual(true); expect(distributionConfig.ViewerCertificate.ACMCertificateArn).toEqual( - "arn:certificate" + "arn:certificate", ); expect(waitForMock).toHaveBeenCalledTimes(1); @@ -253,7 +253,7 @@ describe("cloudfront", () => { describe("setSimpleAuthBehavior", () => { const getDistributionConfig = jest.spyOn( cloudfront, - "getDistributionConfig" + "getDistributionConfig", ); const updateDistribution = jest.spyOn(cloudfront, "updateDistribution"); @@ -273,7 +273,7 @@ describe("cloudfront", () => { }, }, ETag: "", - }) + }), ); await setSimpleAuthBehavior("distribution-id", null); expect(updateDistribution).not.toHaveBeenCalled(); @@ -290,14 +290,14 @@ describe("cloudfront", () => { }, }, ETag: "", - }) + }), ); updateDistribution.mockReturnValueOnce(awsResolve()); await setSimpleAuthBehavior("distribution-id", null); expect(updateDistribution).toHaveBeenCalledTimes(1); expect( (updateDistribution.mock.calls[0][0] as any).DistributionConfig - .DefaultCacheBehavior.LambdaFunctionAssociations.Items + .DefaultCacheBehavior.LambdaFunctionAssociations.Items, ).toEqual([]); }); @@ -312,11 +312,11 @@ describe("cloudfront", () => { }, }, ETag: "", - }) + }), ); await setSimpleAuthBehavior( "distribution-id", - `some-arn:${lambdaPrefix}:1` + `some-arn:${lambdaPrefix}:1`, ); expect(updateDistribution).not.toHaveBeenCalled(); }); @@ -332,14 +332,14 @@ describe("cloudfront", () => { }, }, ETag: "", - }) + }), ); updateDistribution.mockReturnValueOnce(awsResolve()); await setSimpleAuthBehavior("distribution-id", "some-arn:1"); expect(updateDistribution).toHaveBeenCalledTimes(1); expect( (updateDistribution.mock.calls[0][0] as any).DistributionConfig - .DefaultCacheBehavior.LambdaFunctionAssociations.Items + .DefaultCacheBehavior.LambdaFunctionAssociations.Items, ).toEqual([ { EventType: "viewer-request", @@ -372,7 +372,7 @@ describe("cloudfront", () => { describe("updateCloudFrontDistribution", () => { const getDistributionConfigMock = jest.spyOn( cloudfront, - "getDistributionConfig" + "getDistributionConfig", ); const updateDistribution = jest.spyOn(cloudfront, "updateDistribution"); @@ -387,7 +387,7 @@ describe("cloudfront", () => { }, { shouldBlockBucketPublicAccess: false }, ])( - "should not update the distribution if the right origin is already associated", + `should not update the distribution if the right origin is already associated %p`, async ({ shouldBlockBucketPublicAccess }) => { const domainName = "hello.lalilo.com"; const originId = shouldBlockBucketPublicAccess @@ -403,22 +403,19 @@ describe("cloudfront", () => { }; getDistributionConfigMock.mockReturnValue( - awsResolve({ DistributionConfig: distribution }) + awsResolve({ DistributionConfig: distribution }), ); - await updateCloudFrontDistribution( - distribution.Id, - domainName, + await updateCloudFrontDistribution(distribution.Id, domainName, { shouldBlockBucketPublicAccess, - null - ); + oac: null, + }); expect(updateDistribution).not.toHaveBeenCalled(); - } + }, ); it("should update the distribution with an OAC when shouldBlockBucketPublicAccess and oac is given", async () => { - const shouldBlockBucketPublicAccess = true; const domainName = "hello.lalilo.com"; const originIdForPrivateBucket = getS3DomainNameForBlockedBucket(domainName); @@ -433,16 +430,14 @@ describe("cloudfront", () => { }; getDistributionConfigMock.mockReturnValue( - awsResolve({ DistributionConfig: distribution }) + awsResolve({ DistributionConfig: distribution }), ); updateDistribution.mockReturnValueOnce(awsResolve()); - await updateCloudFrontDistribution( - distribution.Id, - domainName, - shouldBlockBucketPublicAccess, - oac - ); + await updateCloudFrontDistribution(distribution.Id, domainName, { + shouldBlockBucketPublicAccess: true, + oac, + }); expect(updateDistribution).toHaveBeenCalled(); expect(updateDistribution).toHaveBeenCalledWith( @@ -464,7 +459,7 @@ describe("cloudfront", () => { TargetOriginId: originIdForPrivateBucket, }), }), - }) + }), ); }); }); diff --git a/src/cloudfront/index.ts b/src/cloudfront/index.ts index 7e6fdd2..5aa3e2f 100644 --- a/src/cloudfront/index.ts +++ b/src/cloudfront/index.ts @@ -416,9 +416,12 @@ const updateLambdaFunctionAssociations = async ( export const updateCloudFrontDistribution = async ( distributionId: string, domainName: string, - shouldBlockBucketPublicAccess: boolean, - oac: OAC | null, + options: { + shouldBlockBucketPublicAccess: boolean; + oac: OAC | null; + }, ) => { + const { shouldBlockBucketPublicAccess, oac } = options; try { const { DistributionConfig, ETag } = await cloudfront .getDistributionConfig({ Id: distributionId }) diff --git a/src/cloudfront/origin-access.spec.ts b/src/cloudfront/origin-access.spec.ts index 2983904..64a4ccf 100644 --- a/src/cloudfront/origin-access.spec.ts +++ b/src/cloudfront/origin-access.spec.ts @@ -8,16 +8,16 @@ import { describe("upsertOriginAccessControl", () => { const listOriginAccessControlsMock = jest.spyOn( cloudfront, - "listOriginAccessControls" + "listOriginAccessControls", ); const getOriginAccessControlMock = jest.spyOn( cloudfront, - "getOriginAccessControl" + "getOriginAccessControl", ); const createOriginAccessControlMock = jest.spyOn( cloudfront, - "createOriginAccessControl" + "createOriginAccessControl", ); beforeEach(() => { @@ -26,9 +26,8 @@ describe("upsertOriginAccessControl", () => { createOriginAccessControlMock.mockReset(); }); - it("does not create OAC if already existing and required (shouldBlockBucketPublicAccess is true)", async () => { + it("does not create OAC if already existing", async () => { const domainName = "my-domain"; - const shouldBlockBucketPublicAccess = true; const distributionId = "my-distribution-id"; const oacName = getOriginAccessControlName(domainName, distributionId); @@ -42,52 +41,30 @@ describe("upsertOriginAccessControl", () => { }, ], }, - }) + }), ); getOriginAccessControlMock.mockReturnValue( - awsResolve({ OriginAccessControl: {}, ETag: "my-etag" }) + awsResolve({ OriginAccessControl: {}, ETag: "my-etag" }), ); - await upsertOriginAccessControl( - domainName, - distributionId, - shouldBlockBucketPublicAccess - ); - - expect(createOriginAccessControlMock).not.toHaveBeenCalled(); - }); + await upsertOriginAccessControl(domainName, distributionId); - it("does not create OAC if not necessary (shouldBlockBucketPublicAccess is false)", async () => { - const domainName = "my-domain"; - const shouldBlockBucketPublicAccess = false; - - await upsertOriginAccessControl( - domainName, - "my-distribution-id", - shouldBlockBucketPublicAccess - ); - expect(listOriginAccessControlsMock).not.toHaveBeenCalled(); expect(createOriginAccessControlMock).not.toHaveBeenCalled(); }); it("creates OAC if necessary and required (shouldBlockBucketPublicAccess is true)", async () => { const domainName = "my-domain"; const distributionId = "my-distribution-id"; - const shouldBlockBucketPublicAccess = true; const oacName = getOriginAccessControlName(domainName, distributionId); listOriginAccessControlsMock.mockReturnValue( awsResolve({ OriginAccessControlList: {}, - }) + }), ); createOriginAccessControlMock.mockReturnValue(awsResolve({})); - await upsertOriginAccessControl( - domainName, - distributionId, - shouldBlockBucketPublicAccess - ); + await upsertOriginAccessControl(domainName, distributionId); expect(listOriginAccessControlsMock).toHaveBeenCalled(); expect(createOriginAccessControlMock).toHaveBeenCalledTimes(1); @@ -100,7 +77,7 @@ describe("upsertOriginAccessControl", () => { SigningProtocol: "sigv4", Description: `OAC used by ${domainName} associated to distributionId: ${distributionId}`, }, - }) + }), ); }); }); diff --git a/src/cloudfront/origin-access.ts b/src/cloudfront/origin-access.ts index ce9daa8..c56cb93 100644 --- a/src/cloudfront/origin-access.ts +++ b/src/cloudfront/origin-access.ts @@ -63,11 +63,7 @@ export const getOriginAccessControlName = ( export const upsertOriginAccessControl = async ( domainName: string, distributionId: string, - shouldBlockBucketPublicAccess: boolean, ) => { - if (!shouldBlockBucketPublicAccess) { - return null; - } const originAccessControlName = getOriginAccessControlName( domainName, distributionId, @@ -80,11 +76,23 @@ export const upsertOriginAccessControl = async ( return await createOAC(originAccessControlName, domainName, distributionId); }; -export const deleteOriginAccessControl = async (oac: OAC) => { +export const cleanExistingOriginAccessControl = async ( + domainName: string, + distributionId: string, +) => { + const originAccessControlName = getOriginAccessControlName( + domainName, + distributionId, + ); + const existingOAC = await getExistingOAC(originAccessControlName); + if (existingOAC === null) { + return; + } + await cloudfront .deleteOriginAccessControl({ - Id: oac.originAccessControl.Id, - IfMatch: oac.ETag, + Id: existingOAC.originAccessControl.Id, + IfMatch: existingOAC.ETag, }) .promise(); return; diff --git a/src/deploy.ts b/src/deploy.ts index 5165a80..bf8282e 100644 --- a/src/deploy.ts +++ b/src/deploy.ts @@ -10,7 +10,7 @@ import { updateCloudFrontDistribution, } from "./cloudfront"; import { - deleteOriginAccessControl, + cleanExistingOriginAccessControl, upsertOriginAccessControl, } from "./cloudfront/origin-access"; import { deploySimpleAuthLambda } from "./lambda"; @@ -94,31 +94,24 @@ export const deploy = async ( ); } - const oac = await upsertOriginAccessControl( - domainName, - distribution.Id, - shouldBlockBucketPublicAccess, - ); - - await updateCloudFrontDistribution( - distribution.Id, - domainName, - shouldBlockBucketPublicAccess, - oac, - ); - if (shouldBlockBucketPublicAccess) { + const oac = await upsertOriginAccessControl(domainName, distribution.Id); + await updateCloudFrontDistribution(distribution.Id, domainName, { + shouldBlockBucketPublicAccess: true, + oac, + }); await removeBucketWebsite(domainName); await blockBucketPublicAccess(domainName); await setBucketPolicyForOAC(domainName, distribution.Id); } else { + await updateCloudFrontDistribution(distribution.Id, domainName, { + shouldBlockBucketPublicAccess: false, + oac: null, + }); await setBucketWebsite(domainName); await allowBucketPublicAccess(domainName); await setBucketPolicy(domainName); - - if (oac) { - await deleteOriginAccessControl(oac); - } + await cleanExistingOriginAccessControl(domainName, distribution.Id); } if (credentials) {