In [None]:
import pandas as pd, re, json, os, boto3, time
from webcrawler import Crawler, Phone, Laptop, Tablet, Watch, Screen, Earphones
from dataclasses import asdict
from datetime import datetime
from botocore.client import BaseClient
from botocore.exceptions import ClientError as AwsClientError

In [None]:
query = """ select screen_nits_main
            from c2dwh_silver.phones
        """

with open("/home/jh97/MyWorks/Documents/.aws_cdt.json", "r") as file:
    content = json.load(file)
    key_id = content["access_key"]
    secret_key = content["secrect_access_key"]

athena_client = boto3.client(
    "athena",
    region_name="us-east-1",
    aws_access_key_id=key_id,
    aws_secret_access_key=secret_key,
)

# execute the query
resp = athena_client.start_query_execution(
    QueryString=query,
    QueryExecutionContext={"Database": "c2dwh_silver", "Catalog": "AwsDataCatalog"},
    ResultConfiguration={
        "OutputLocation": "s3://c2dwh-athena-queries/",
        # "EncryptionConfiguration": {
        #     "EncryptionOption": "SSE_S3" | "SSE_KMS" | "CSE_KMS",
        #     "KmsKey": "string",
        # }
    },
)

print(resp["QueryExecutionId"])

In [None]:
# check execution status
resp2 = athena_client.get_query_execution(QueryExecutionId=resp["QueryExecutionId"])
print(resp2["QueryExecution"]["Status"]["State"])
print(resp2["QueryExecution"]["Status"]["AthenaError"])

In [None]:
# result = athena_client.get_query_results(QueryExecutionId=resp["QueryExecutionId"])
# keys = [i.get("VarCharValue") for i in result["ResultSet"]["Rows"][0]["Data"]]
prd = []

paginator = athena_client.get_paginator("get_query_results")

for page in paginator.paginate(QueryExecutionId=resp["QueryExecutionId"]):
    columns = [i.get("VarCharValue") for i in page["ResultSet"]["Rows"][0]["Data"]]
    for row in page["ResultSet"]["Rows"][1:]:
        prd.append(
            {
                columns[j]: row["Data"][j].get("VarCharValue")
                for j in range(len(row["Data"]))
            }
        )

prd

In [None]:
def athena_sql_executor(
    query: str,
    *,
    client: BaseClient | None = None,
    database: str | None = None,
    output_location: str | None = None,
    encrypt_config: dict | None = None,
):
    """
    Execute SQL query on AWS Athena.
    """
    data = {}

    # get aws credentials
    with open("/home/jh97/MyWorks/Documents/.aws_cdt.json", "r") as file:
        content = json.load(file)
        key_id = content["access_key"]
        secret_key = content["secrect_access_key"]

    # initialize client
    if not client:
        client = boto3.client(
            "athena",
            region_name="us-east-1",
            aws_access_key_id=key_id,
            aws_secret_access_key=secret_key,
        )

    # execute the query
    resp = client.start_query_execution(
        QueryString=query,
        QueryExecutionContext={
            "Database": "default" if not database else database,
            "Catalog": "AwsDataCatalog",
        },
        ResultConfiguration={
            "OutputLocation": (
                "s3://c2dwh-athena-queries/" if not output_location else output_location
            ),
            "EncryptionConfiguration": (
                {"EncryptionOption": "SSE_S3"} if not encrypt_config else encrypt_config
            ),
        },
    )

    # wait for executing
    while True:
        execution = client.get_query_execution(
            QueryExecutionId=resp["QueryExecutionId"]
        )

        if execution["QueryExecution"]["Status"]["State"] in [
            "FAILED",
            "CANCELLED",
        ]:
            print(
                f'Execution {execution["QueryExecution"]["Status"]["State"]} with error >>',
                execution["QueryExecution"]["Status"]
                .get("AthenaError", {})
                .get("ErrorMessage"),
            )
            return
        if execution["QueryExecution"]["Status"]["State"] == "SUCCEEDED":
            break

        time.sleep(0.2)

    # normalize result
    data["query_execution_id"] = resp["QueryExecutionId"]
    data["data"] = []
    paginator = client.get_paginator("get_query_results")

    columns = None
    for page in paginator.paginate(QueryExecutionId=resp["QueryExecutionId"]):
        rows = page["ResultSet"]["Rows"]

        if not columns:  # get list of columns and skip 1st row from 1st page only
            columns = [
                i.get("VarCharValue") for i in page["ResultSet"]["Rows"][0]["Data"]
            ]
            rows = rows[1:]

        for row in rows:
            data["data"].append(
                {
                    columns[i]: row["Data"][i].get("VarCharValue")
                    for i in range(len(columns))
                }
            )

    return data

In [None]:
test = athena_sql_executor("select* from phones limit 10;", database="c2dwh_silver")