In [None]:
import datetime
import os
import random

import numpy as np
import pyspark.sql.functions as F
from faker import Faker
from pyspark.sql import SparkSession

os.environ["AWS_PROFILE"] = "blueriver"

CATALOG = "glue_catalog"
ICEBERG_S3_ROOT_PATH = "s3a://blueriver-datalake/iceberg"

spark = (
    SparkSession.builder.appName("1")
    .config("spark.sql.defaultCatalog", CATALOG)
    .config(f"spark.sql.catalog.{CATALOG}", "org.apache.iceberg.spark.SparkCatalog")
    .config(f"spark.sql.catalog.{CATALOG}.catalog-impl", "org.apache.iceberg.aws.glue.GlueCatalog")
    .config(f"spark.sql.catalog.{CATALOG}.io-impl", "org.apache.iceberg.aws.s3.S3FileIO")
    .config(f"spark.sql.catalog.{CATALOG}.warehouse", ICEBERG_S3_ROOT_PATH)
    .config(f"spark.sql.catalog.{CATALOG}.s3.path-style-access", True)
    .config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions")
    .config(
        "spark.hadoop.fs.s3a.aws.credentials.provider",
        "software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider",
    )
    .config("spark.sql.caseSensitive", True)
    .config("spark.sql.session.timeZone", "UTC")
    .getOrCreate()
)

In [None]:
# TODO. 데이터 생성 수 및 비율
TRANSACTIONS_PER_SECOND = 10  # 초당 트랜잭션 수
TOTAL_DURATION_SECONDS = 1  # 총 실행 시간 (초)
INSERT_RATIO = 0.60  # 삽입 비율
UPDATE_RATIO = 0.20  # 갱신 비율
DELETE_RATIO = 0.20  # 삭제 비율

# TODO. 테이블 목록 (<schema>.<table>)
TABLE_LIST = """store_bronze.tb_lower""".split(",")

# Faker 생성
fake = Faker()


# Iceberg 테이블에서 스키마 가져오기 함수
def get_table_schema(table_name: str):
    df = spark.read.format("iceberg").load(f"{table_name}")
    schema = df.schema
    return schema


# 데이터 생성 함수
def generate_data(schema):
    data = {}
    for field in schema.fields:
        type_name = field.dataType.simpleString()

        # 10% 확률로 NULL 값 생성
        if field.nullable and random.random() < 0.1:
            data[field.name] = None
            continue

        # 데이터 타입에 따른 값 생성
        if "int" in type_name:
            data[field.name] = fake.random_int(min=0, max=2147483647)  # 32-bit 정수
        elif "bigint" in type_name:
            data[field.name] = fake.random_int(min=0, max=9223372036854775807)  # 64-bit 정수
        elif "string" in type_name:
            max_length = 255  # 문자열 길이 제한
            data[field.name] = fake.text(max_nb_chars=max_length).strip()
        elif "double" in type_name:
            data[field.name] = fake.pyfloat(left_digits=6, right_digits=3, positive=True)
        elif "decimal" in type_name:
            # decimal(precision, scale) 처리
            precision, scale = map(int, type_name.replace("decimal(", "").replace(")", "").split(","))
            data[field.name] = fake.pydecimal(left_digits=precision - scale, right_digits=scale, positive=True)
        elif "boolean" in type_name:
            data[field.name] = fake.boolean()
        elif "binary" in type_name:
            max_length = 65535  # 기본 바이너리 크기 (64 KB)
            data[field.name] = fake.binary(length=random.randint(1, max_length))
        elif "date" in type_name:
            data[field.name] = fake.date_this_decade()
        elif "timestamp" in type_name:
            data[field.name] = fake.date_time_this_decade()
        else:
            # 알 수 없는 타입일 경우 기본값 처리
            data[field.name] = None

    data["last_applied_date"] = datetime.datetime.now(datetime.UTC)
    return data


# 데이터 삽입 함수
def insert_data(table_name, schema, num_records):
    records = [generate_data(schema) for _ in range(num_records)]
    df = spark.createDataFrame(records, schema=schema)
    df.write.format("iceberg").mode("append").save(f"{table_name}")


# 데이터 갱신 함수
def update_data(table_name, schema, num_records):
    df = spark.read.format("iceberg").load(f"{table_name}")
    updated_data = []
    for row in df.take(num_records):
        record = generate_data(schema)
        record["id_iceberg"] = row.id_iceberg
        updated_data.append(record)
    updated_df = spark.createDataFrame(updated_data, schema=schema)
    updated_df.write.format("iceberg").mode("overwrite").save(f"{table_name}")


# 데이터 삭제 함수
def delete_data(table_name, num_records):
    df = spark.read.format("iceberg").load(f"{table_name}")
    to_delete = df.take(num_records)
    for row in to_delete:
        df = df.filter(~F.col("id_iceberg").isin(row["id_iceberg"]))
    df.write.format("iceberg").mode("overwrite").save(f"{table_name}")


# 주기적인 트랜잭션 수행
def schedule_transactions(
    table_name, transactions_per_second, total_duration_seconds, insert_ratio, update_ratio, delete_ratio
):
    schema = get_table_schema(table_name)
    for _ in range(total_duration_seconds):
        transactions: np.ndarray = np.random.choice(
            ["insert", "update", "delete"],
            size=transactions_per_second,
            p=[insert_ratio, update_ratio, delete_ratio],
        )
        for _, action in enumerate(transactions):
            if action == "insert":
                insert_data(table_name, schema, num_records=1)
            elif action == "update":
                update_data(table_name, schema, num_records=1)
            elif action == "delete":
                delete_data(table_name, num_records=1)
        print(f"{[t[0] for t in transactions]}")


# 트랜잭션 실행
for table in TABLE_LIST:
    schedule_transactions(
        table, TRANSACTIONS_PER_SECOND, TOTAL_DURATION_SECONDS, INSERT_RATIO, UPDATE_RATIO, DELETE_RATIO
    )

In [None]:
# spark.stop()