# Generate Datasets

In [None]:
# If running on a local IDE
from databricks.connect import DatabricksSession

spark = DatabricksSession.builder.remote(serverless=True).getOrCreate()

In [None]:
# Set variables
CATALOG = "users"  # TODO: your catalog name
SCHEMA = "david_huang"  # TODO: your schema name

## Create master merchant entity table

In [None]:
# Create comprehensive master entity table
import pandas as pd
import json
from datetime import datetime, timedelta
import random

# Load merchant attributes from JSON file
with open("../data/merchant_attributes.json", "r") as f:
    merchant_attributes = json.load(f)

# Enhanced merchant data with additional attributes for entity resolution
master_entities_data = []

# Create master entity records directly from JSON
for i, (merchant_name, attributes) in enumerate(merchant_attributes.items()):
    # Generate random date between today and 1 year ago
    end_date = datetime.now()
    start_date = end_date - timedelta(days=365)
    random_date = start_date + timedelta(
        days=random.randint(0, (end_date - start_date).days)
    )

    entity = {
        "entity_id": f"ENT_{i+1:03d}",
        "merchant_name": merchant_name,
        "category": attributes["category"],
        "industry": attributes["industry"],
        "is_chain": attributes["chain"],
        "status": "active",
        "created_date": random_date.strftime("%Y-%m-%d"),
        "data_source": "master_reference",
    }

    master_entities_data.append(entity)

# Create DataFrame
df_master_entities = pd.DataFrame(master_entities_data)

print(f"Created master entity table with {len(df_master_entities)} records")
display(df_master_entities.head(5))

In [None]:
# Save table to Unity Catalog
spark_df_master_entities = spark.createDataFrame(df_master_entities)
spark_df_master_entities.write.mode("overwrite").saveAsTable(
    f"{CATALOG}.{SCHEMA}.ner_demo_merchant_entities"
)

## Generate variations of merchant entities with `ai_query()`

In [None]:
spark.sql(
    f"""
    CREATE OR REPLACE TABLE {CATALOG}.{SCHEMA}.ner_demo_generated_entities
    WITH query_results AS (
        SELECT
            merchant_name,
            ai_query(
                "databricks-gpt-oss-120b",
                concat(
                    "Your job is to create merchant names that look just like the short, "
                    "often‑abbreviated entries you’d see on a credit‑card statement. ",
                    "For each given merchant name, create 10 variations from this merchant: ",
                    merchant_name
                ),
                responseFormat => 'STRUCT<generated_names:STRUCT<name_variations:ARRAY<STRING>>>'
            ) as generated_name
        FROM
            {CATALOG}.{SCHEMA}.ner_demo_merchant_entities
    )
    SELECT
        merchant_name,
        from_json(generated_name, 'STRUCT<name_variations: ARRAY<STRING>>') as name_variations
    FROM
        query_results;
    """
)

## Create transaction dataset with Faker

In [None]:
spark_df_master_entities = spark.sql(
    f"select * from {CATALOG}.{SCHEMA}.ner_demo_merchant_entities"
)
spark_df_generated_entities = spark.sql(
    f"select merchant_name, name_variations.name_variations from {CATALOG}.{SCHEMA}.ner_demo_generated_entities"
)

display(spark_df_generated_entities.limit(5))

In [None]:
from faker import Faker
import random
import decimal
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, explode, rand
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    DecimalType,
    TimestampType,
    IntegerType,
)
from datetime import datetime, timedelta


def generate_credit_card_transactions(num_rows=1000):
    """
    Generate credit card transaction data using Faker and randomly selecting
    merchant names from the name_variations column of spark_df_generated_entities.

    Args:
        num_rows (int): Number of transaction rows to generate

    Returns:
        DataFrame: Spark DataFrame with generated transaction data
    """
    fake = Faker()

    # Collect merchant name variations from the existing table
    merchant_variations = spark_df_generated_entities.select(
        explode(col("name_variations")).alias("merchant_variation")
    ).collect()

    # Extract just the merchant names into a list
    merchant_names = [row["merchant_variation"] for row in merchant_variations]

    # Generate transaction data
    transactions = []

    for i in range(num_rows):
        # Random merchant from our variations
        merchant = random.choice(merchant_names)

        # Generate realistic transaction data
        transaction = {
            "transaction_id": fake.uuid4(),
            "card_number": fake.credit_card_number(card_type="visa"),
            "merchant_name": merchant,
            "amount": decimal.Decimal(str(round(random.uniform(5.0, 200.0), 2))),
            "transaction_date": fake.date_time_between(
                start_date=datetime.now() - timedelta(days=1), end_date=datetime.now()
            ),
            "transaction_type": random.choice(["PURCHASE", "REFUND", "PAYMENT"]),
            "currency": "USD",
            "card_holder_name": fake.name(),
            "location": f"{fake.city()}, {fake.state_abbr()}",
        }
        transactions.append(transaction)

    # Define schema for the DataFrame
    schema = StructType(
        [
            StructField("transaction_id", StringType(), True),
            StructField("card_number", StringType(), True),
            StructField("merchant_name", StringType(), True),
            StructField("amount", DecimalType(10, 2), True),
            StructField("transaction_date", TimestampType(), True),
            StructField("transaction_type", StringType(), True),
            StructField("currency", StringType(), True),
            StructField("card_holder_name", StringType(), True),
            StructField("location", StringType(), True),
        ]
    )

    # Create DataFrame
    transactions_df = spark.createDataFrame(transactions, schema)

    return transactions_df

In [None]:
# Generate sample transaction data
sample_transactions = generate_credit_card_transactions(num_rows=1000)

# Display the generated data
display(sample_transactions.limit(5))

In [None]:
# Save to table if needed
sample_transactions.write.mode("overwrite").saveAsTable(
    f"{CATALOG}.{SCHEMA}.ner_demo_generated_transactions"
)