Skip to content

Commit

Permalink
Added Support for performing Assume role with Athena Connections (#2471)
Browse files Browse the repository at this point in the history
Co-authored-by: Tom Thornton <tom.thornton@sony.com>
  • Loading branch information
thatguyfig and Tom Thornton committed May 13, 2024
1 parent 8e56283 commit b7a759b
Show file tree
Hide file tree
Showing 5 changed files with 488 additions and 57 deletions.
2 changes: 2 additions & 0 deletions packages/back-end/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
"generate-api-types": "yarn generate-api-models && swagger-cli bundle -t yaml src/api/openapi/openapi.tmp.yaml -o generated/spec.yaml && node src/scripts/generate-openapi.mjs"
},
"dependencies": {
"@aws-sdk/client-sts": "^3.567.0",
"@aws-sdk/client-athena": "^3.564.0",
"@clickhouse/client": "^1.0.1",
"@databricks/sql": "^1.8.1",
"@dqbd/tiktoken": "^1.0.7",
Expand Down
88 changes: 57 additions & 31 deletions packages/back-end/src/services/athena.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,52 @@
import { Athena } from "aws-sdk";
import { ResultSet } from "aws-sdk/clients/athena";
import { STSClient, AssumeRoleCommand } from "@aws-sdk/client-sts";
import { Athena, ResultSet } from "@aws-sdk/client-athena";
import { AthenaConnectionParams } from "../../types/integrations/athena";
import { logger } from "../util/logger";
import { IS_CLOUD } from "../util/secrets";
import { ExternalIdCallback, QueryResponse } from "../types/Integration";

function getAthenaInstance(params: AthenaConnectionParams) {
async function assumeRole(params: AthenaConnectionParams) {
// build sts client
const client = new STSClient();
const command = new AssumeRoleCommand({
RoleArn: params.assumeRoleARN,
RoleSessionName: params.roleSessionName,
ExternalId: params.externalId,
DurationSeconds: params.durationSeconds,
});

return await client.send(command);
}

async function getAthenaInstance(params: AthenaConnectionParams) {
// handle the instance profile
if (!IS_CLOUD && params.authType === "auto") {
return new Athena({
region: params.region,
});
}

// handle assuming a role first
if (!IS_CLOUD && params.authType === "assumeRole") {
// use client to assume another role
const credentials = await assumeRole(params);

return new Athena({
credentials: {
accessKeyId: credentials?.Credentials?.AccessKeyId || "",
secretAccessKey: credentials?.Credentials?.SecretAccessKey || "",
sessionToken: credentials?.Credentials?.SessionToken || "",
},
region: params.region,
});
}

// handle access key + secret key
return new Athena({
accessKeyId: params.accessKeyId,
secretAccessKey: params.secretAccessKey,
credentials: {
accessKeyId: params.accessKeyId || "",
secretAccessKey: params.secretAccessKey || "",
},
region: params.region,
});
}
Expand All @@ -23,42 +55,38 @@ export async function cancelAthenaQuery(
conn: AthenaConnectionParams,
id: string
) {
const athena = getAthenaInstance(conn);
await athena
.stopQueryExecution({
QueryExecutionId: id,
})
.promise();
const athena = await getAthenaInstance(conn);
await athena.stopQueryExecution({
QueryExecutionId: id,
});
}

export async function runAthenaQuery(
conn: AthenaConnectionParams,
sql: string,
setExternalId: ExternalIdCallback
): Promise<QueryResponse> {
const athena = getAthenaInstance(conn);
const athena = await getAthenaInstance(conn);

const { database, bucketUri, workGroup, catalog } = conn;

const retryWaitTime =
(parseInt(process.env.ATHENA_RETRY_WAIT_TIME || "60") || 60) * 1000;

const { QueryExecutionId } = await athena
.startQueryExecution({
QueryString: sql,
QueryExecutionContext: {
Database: database || undefined,
Catalog: catalog || undefined,
},
ResultConfiguration: {
EncryptionConfiguration: {
EncryptionOption: "SSE_S3",
},
OutputLocation: bucketUri,
const { QueryExecutionId } = await athena.startQueryExecution({
QueryString: sql,
QueryExecutionContext: {
Database: database || undefined,
Catalog: catalog || undefined,
},
ResultConfiguration: {
EncryptionConfiguration: {
EncryptionOption: "SSE_S3",
},
WorkGroup: workGroup || "primary",
})
.promise();
OutputLocation: bucketUri,
},
WorkGroup: workGroup || "primary",
});

if (!QueryExecutionId) {
throw new Error("Failed to start query");
Expand All @@ -74,7 +102,6 @@ export async function runAthenaQuery(
setTimeout(() => {
athena
.getQueryExecution({ QueryExecutionId })
.promise()
.then((resp) => {
const State = resp.QueryExecution?.Status?.State;
const StateChangeReason =
Expand Down Expand Up @@ -118,7 +145,6 @@ export async function runAthenaQuery(
} else {
athena
.getQueryResults({ QueryExecutionId })
.promise()
.then(({ ResultSet }) => {
if (ResultSet) {
resolve(ResultSet);
Expand Down Expand Up @@ -152,7 +178,7 @@ export async function runAthenaQuery(
const obj: any = {};
if (row.Data) {
row.Data.forEach((value, i) => {
obj[keys[i]] = value.VarCharValue || null;
obj[keys[i] as string] = value.VarCharValue || null;
});
}
return obj;
Expand All @@ -162,6 +188,6 @@ export async function runAthenaQuery(
}

// Cancel the query if it reaches this point
await athena.stopQueryExecution({ QueryExecutionId }).promise();
await athena.stopQueryExecution({ QueryExecutionId });
throw new Error("Query timed out after 30 minutes");
}
6 changes: 5 additions & 1 deletion packages/back-end/types/integrations/athena.d.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
export interface AthenaConnectionParams {
authType?: "auto" | "accessKey";
authType?: "auto" | "accessKey" | "assumeRole";
accessKeyId?: string;
secretAccessKey?: string;
assumeRoleARN?: string;
roleSessionName?: string;
durationSeconds?: number;
externalId?: string;
region: string;
database?: string;
bucketUri: string;
Expand Down
107 changes: 82 additions & 25 deletions packages/front-end/components/Settings/AthenaForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@ const AthenaForm: FC<{
value: "auto",
display: "Auto-discovery",
},
{
value: "assumeRole",
display: "Assume IAM Role",
},
]}
helpText="'Auto-discovery' will look for credentials in environment variables and instance metadata."
helpText="'Auto-discovery' will look for credentials in environment variables and instance metadata. 'Assume IAM Role' uses the current role to assume another role and execute Athena with temporary credentials."
value={params.authType || "accessKey"}
onChange={(e) => {
setParams({
Expand All @@ -35,29 +39,8 @@ const AthenaForm: FC<{
/>
</div>
)}
<div className="form-group col-md-12">
<label>AWS Region</label>
<input
type="text"
className="form-control"
name="region"
required
value={params.region || ""}
onChange={onParamChange}
/>
</div>
<div className="form-group col-md-12">
<label>Workgroup (optional)</label>
<input
type="text"
className="form-control"
name="workGroup"
placeholder="primary"
value={params.workGroup || ""}
onChange={onParamChange}
/>
</div>
{(isCloud() || params.authType !== "auto") && (
{(isCloud() ||
(params.authType !== "assumeRole" && params.authType !== "auto")) && (
<>
<div className="form-group col-md-12">
<label>AWS Access Key</label>
Expand Down Expand Up @@ -86,13 +69,87 @@ const AthenaForm: FC<{
</div>
</>
)}
{!isCloud() && params.authType === "assumeRole" && (
<>
<div className="form-group col-md-12">
<label>AWS IAM Role ARN</label>
<input
type="text"
className="form-control"
name="assumeRoleARN"
required={!existing}
value={params.assumeRoleARN || ""}
onChange={onParamChange}
placeholder={existing ? "(Keep existing)" : ""}
/>
</div>
<div className="form-group col-md-12">
<label>Role Session Name</label>
<input
type="text"
className="form-control"
name="roleSessionName"
required={!existing}
value={params.roleSessionName || ""}
onChange={onParamChange}
placeholder={existing ? "(Keep existing)" : ""}
/>
</div>
<div className="form-group col-md-12">
<label>External ID</label>
<input
type="text"
className="form-control"
name="externalId"
required={!existing}
value={params.externalId || ""}
onChange={onParamChange}
placeholder={existing ? "(Keep existing)" : ""}
/>
</div>
<div className="form-group col-md-12">
<label>Session Duration</label>
<input
type="number"
className="form-control"
name="durationSeconds"
required={!existing}
value={params.durationSeconds || 900}
onChange={onParamChange}
placeholder={existing ? "(Keep existing)" : ""}
/>
</div>
</>
)}
<div className="form-group col-md-12">
<label>AWS Region</label>
<input
type="text"
className="form-control"
name="region"
required
value={params.region || ""}
onChange={onParamChange}
/>
</div>
<div className="form-group col-md-12">
<label>Workgroup (optional)</label>
<input
type="text"
className="form-control"
name="workGroup"
placeholder="primary"
value={params.workGroup || ""}
onChange={onParamChange}
/>
</div>
<div className="form-group col-md-12">
<label>Default Catalog (optional)</label>
<input
type="text"
className="form-control"
name="catalog"
value={params.catalog || ""}
value={params.catalog || "AwsDataCatalog"}
onChange={onParamChange}
/>
</div>
Expand Down

0 comments on commit b7a759b

Please sign in to comment.