Note: На данном этапе в sources БД загружены данные. Создана структура таблиц для DWH и витрины

Note 2: У меня не завелось со структурами в DWH из задания, я их просто с нуля сделал прям свои и загрузил через DBeaver.


```
-- Создание схемы DWH если её еще нет
CREATE SCHEMA IF NOT EXISTS dwh;

-- Таблица измерений customers
CREATE TABLE dwh.d_customers (
    customer_id BIGINT PRIMARY KEY,
    customer_name VARCHAR,
    customer_address VARCHAR,
    customer_birthday DATE,
    customer_email VARCHAR NOT NULL,
    valid_from TIMESTAMP NOT NULL,
    valid_to TIMESTAMP,
    is_current BOOLEAN NOT NULL,
    CONSTRAINT d_customers_email_uk UNIQUE (customer_email)
);

COMMENT ON TABLE dwh.d_customers IS 'Таблица измерений с информацией о заказчиках';

-- Таблица измерений products
CREATE TABLE dwh.d_products (
    product_id BIGINT PRIMARY KEY,
    product_name VARCHAR NOT NULL,
    product_description VARCHAR NOT NULL,
    product_type VARCHAR NOT NULL,
    product_price BIGINT NOT NULL,
    valid_from TIMESTAMP NOT NULL,
    valid_to TIMESTAMP,
    is_current BOOLEAN NOT NULL
);

COMMENT ON TABLE dwh.d_products IS 'Таблица измерений с информацией о продуктах';

-- Таблица измерений craftsmans
CREATE TABLE dwh.d_craftsmans (
    craftsman_id BIGINT PRIMARY KEY,
    craftsman_name VARCHAR NOT NULL,
    craftsman_address VARCHAR NOT NULL,
    craftsman_birthday DATE NOT NULL,
    craftsman_email VARCHAR NOT NULL,
    valid_from TIMESTAMP NOT NULL,
    valid_to TIMESTAMP,
    is_current BOOLEAN NOT NULL,
    CONSTRAINT d_craftsmans_email_uk UNIQUE (craftsman_email)
);

COMMENT ON TABLE dwh.d_craftsmans IS 'Таблица измерений с информацией о мастерах';

-- Таблица фактов orders
CREATE TABLE dwh.f_orders (
    order_id BIGINT PRIMARY KEY,
    customer_id BIGINT NOT NULL,
    craftsman_id BIGINT NOT NULL,
    product_id BIGINT NOT NULL,
    order_created_date DATE,
    order_completion_date DATE,
    order_status VARCHAR NOT NULL,
    load_dttm TIMESTAMP NOT NULL,
    CONSTRAINT f_orders_customer_fk FOREIGN KEY (customer_id) REFERENCES dwh.d_customers(customer_id),
    CONSTRAINT f_orders_craftsman_fk FOREIGN KEY (craftsman_id) REFERENCES dwh.d_craftsmans(craftsman_id),
    CONSTRAINT f_orders_product_fk FOREIGN KEY (product_id) REFERENCES dwh.d_products(product_id)
);

COMMENT ON TABLE dwh.f_orders IS 'Таблица фактов с информацией о заказах';

-- Создание индексов для оптимизации производительности
CREATE INDEX idx_f_orders_customer ON dwh.f_orders(customer_id);
CREATE INDEX idx_f_orders_craftsman ON dwh.f_orders(craftsman_id);
CREATE INDEX idx_f_orders_product ON dwh.f_orders(product_id);
CREATE INDEX idx_f_orders_created_date ON dwh.f_orders(order_created_date);
CREATE INDEX idx_f_orders_completion_date ON dwh.f_orders(order_completion_date);
```




In [None]:
# Задание 8. Выгрузка данных в DWH

In [142]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import Window
from pyspark.sql.types import TimestampType
import datetime

def log_count(df, message):
    count = df.count()
    print(f"{message}: {count}")
    return count

def init_spark():
    """Initialize Spark session"""
    spark = SparkSession.builder \
        .appName("DWH Data Load") \
        .config("spark.jars", "postgresql-42.7.4.jar") \
        .getOrCreate()
    return spark

def read_source_tables(spark, jdbc_url, db_properties):
    source_tables = {
        'source1_wide': spark.read.jdbc(jdbc_url, "source1.craft_market_wide", properties=db_properties),
        'source2_masters': spark.read.jdbc(jdbc_url, "source2.craft_market_masters_products", properties=db_properties),
        'source2_orders': spark.read.jdbc(jdbc_url, "source2.craft_market_orders_customers", properties=db_properties),
        'source3_craftsmans': spark.read.jdbc(jdbc_url, "source3.craft_market_craftsmans", properties=db_properties),
        'source3_customers': spark.read.jdbc(jdbc_url, "source3.craft_market_customers", properties=db_properties),
        'source3_orders': spark.read.jdbc(jdbc_url, "source3.craft_market_orders", properties=db_properties)
    }
    
    for name, df in source_tables.items():
        log_count(df, f"Read {name}")
    
    return source_tables

def process_customers(sources):
    customers_union = sources['source1_wide'].select(
        "customer_id", "customer_name", "customer_address", 
        "customer_birthday", "customer_email"
    ).union(
        sources['source2_orders'].select(
            "customer_id", "customer_name", "customer_address", 
            "customer_birthday", "customer_email"
        )
    ).union(
        sources['source3_customers'].select(
            "customer_id", "customer_name", "customer_address", 
            "customer_birthday", "customer_email"
        )
    )

    d_customers = customers_union \
        .withColumn("valid_from", current_timestamp().cast(TimestampType())) \
        .withColumn("valid_to", lit(None).cast(TimestampType())) \
        .withColumn("is_current", lit(True))
    
    return d_customers

def process_craftsmans(sources):
    craftsmans_union = sources['source1_wide'].select(
        "craftsman_id", "craftsman_name", "craftsman_address", 
        "craftsman_birthday", "craftsman_email"
    ).union(
        sources['source2_masters'].select(
            "craftsman_id", "craftsman_name", "craftsman_address", 
            "craftsman_birthday", "craftsman_email"
        )
    ).union(
        sources['source3_craftsmans'].select(
            "craftsman_id", "craftsman_name", "craftsman_address", 
            "craftsman_birthday", "craftsman_email"
        )
    )

    d_craftsmans = craftsmans_union \
        .withColumn("valid_from", current_timestamp().cast(TimestampType())) \
        .withColumn("valid_to", lit(None).cast(TimestampType())) \
        .withColumn("is_current", lit(True))
    
    return d_craftsmans

def process_products(sources):
    def add_product_business_key(df):
        return df.withColumn(
            "product_business_key",
            concat_ws('::', 
                col("product_name"),
                col("product_description"),
                col("product_type"),
                col("product_price").cast("string")
            )
        )
    
    products_source1 = add_product_business_key(
        sources['source1_wide'].select(
            "product_name", "product_description", 
            "product_type", "product_price"
        )
    )
    
    products_source2 = add_product_business_key(
        sources['source2_masters'].select(
            "product_name", "product_description", 
            "product_type", "product_price"
        )
    )
    
    products_source3 = add_product_business_key(
        sources['source3_orders'].select(
            "product_name", "product_description", 
            "product_type", "product_price"
        )
    )
    
    products_union = products_source1.union(products_source2).union(products_source3)
    
    print(f"Products before deduplication: {products_union.count()}")
    
    products_deduped = products_union.dropDuplicates(["product_business_key"])
    
    print(f"Products after deduplication: {products_deduped.count()}")
    
    window_spec = Window.orderBy("product_business_key")
    d_products = products_deduped \
        .withColumn("product_id", row_number().over(window_spec)) \
        .withColumn("valid_from", current_timestamp().cast(TimestampType())) \
        .withColumn("valid_to", lit(None).cast(TimestampType())) \
        .withColumn("is_current", lit(True)) \
        .select(
            "product_id",
            "product_name",
            "product_description",
            "product_type",
            "product_price",
            "valid_from",
            "valid_to",
            "is_current",
            "product_business_key"
        )
    
    return d_products

def create_unique_order_key(df, source):
    return df.withColumn(
        "order_key",
        concat_ws(':', 
            lit(source),
            col("customer_id"),
            col("craftsman_id"),
            col("product_id"),
            coalesce(col("order_created_date").cast("string"), lit("null")),
            coalesce(col("order_completion_date").cast("string"), lit("null")),
            col("order_status")
        )
    )

def process_orders(sources, d_customers, d_craftsmans, d_products):
    
    source1_orders = create_unique_order_key(
        sources['source1_wide'].select(
            "customer_id",
            "craftsman_id",
            "product_id",
            "order_created_date",
            "order_completion_date",
            "order_status"
        ),
        "source1"
    )
    
    source2_orders = create_unique_order_key(
        sources['source2_orders'].select(
            "customer_id",
            "craftsman_id",
            "product_id",
            "order_created_date",
            "order_completion_date",
            "order_status"
        ),
        "source2"
    )
    
    source3_orders = create_unique_order_key(
        sources['source3_orders'].select(
            "customer_id",
            "craftsman_id",
            "product_id",
            "order_created_date",
            "order_completion_date",
            "order_status"
        ),
        "source3"
    )

    orders_union = source1_orders.union(source2_orders).union(source3_orders)
    
    print(f"Total orders before deduplication: {orders_union.count()}")
    
    duplicates = orders_union.groupBy("order_key") \
        .agg(count("*").alias("count")) \
        .filter(col("count") > 1)
    
    print(f"Number of duplicated order keys: {duplicates.count()}")
    if duplicates.count() > 0:
        print("Sample of duplicates:")
        duplicates.show(5, truncate=False)
    
    orders_deduplicated = orders_union.dropDuplicates(["order_key"])
    
    print(f"Total orders after deduplication: {orders_deduplicated.count()}")
    
    window_spec = Window.orderBy("order_created_date", "order_key")
    orders_with_id = orders_deduplicated.withColumn(
        "order_id",
        row_number().over(window_spec)
    )
    
    def check_duplicates(df, key_column, name):
        dups = df.groupBy(key_column).count().filter(col("count") > 1)
        if dups.count() > 0:
            print(f"Duplicates in {name} by {key_column}:")
            dups.show()
        return dups.count()
    
    f_orders = orders_with_id \
        .join(
            d_customers.dropDuplicates(["customer_id"]),
            ["customer_id"],
            "left"
        ) \
        .join(
            d_craftsmans.dropDuplicates(["craftsman_id"]),
            ["craftsman_id"],
            "left"
        ) \
        .join(
            d_products.dropDuplicates(["product_id"]),
            ["product_id"],
            "left"
        ) \
        .select(
            orders_with_id.order_id,
            orders_with_id.customer_id,
            orders_with_id.craftsman_id,
            orders_with_id.product_id,
            orders_with_id.order_created_date,
            orders_with_id.order_completion_date,
            orders_with_id.order_status,
            current_timestamp().cast(TimestampType()).alias("load_dttm"),
            orders_with_id.order_key.alias("source_system_id")
        )
    
    print(f"Final orders count in f_orders: {f_orders.count()}")
    
    return f_orders

def check_for_changes(new_df, table_name, jdbc_url, db_properties):
    try:
        spark = new_df.sparkSession
        existing_df = spark.read.jdbc(
            jdbc_url,
            f"dwh.{table_name}",
            properties=db_properties
        )
        
        if table_name.startswith('d_'):
            business_columns = [
                c for c in new_df.columns
                if c not in ['valid_from', 'valid_to', 'is_current']
            ]
            
            new_business_data = new_df.select(business_columns)
            existing_business_data = existing_df.select(business_columns)
            
            diff_count = new_business_data.exceptAll(existing_business_data).count()
            
        else:  # For fact table
            compare_columns = [c for c in new_df.columns if c != 'load_dttm']
            
            new_compare_data = new_df.select(compare_columns)
            existing_compare_data = existing_df.select(compare_columns)
            
            diff_count = new_compare_data.exceptAll(existing_compare_data).count()
        
        print(f"Found {diff_count} different records for {table_name}")
        return diff_count > 0
        
    except Exception as e:
        print(f"No existing data found for {table_name} or error occurred: {str(e)}")
        return True

def write_to_dwh(df, table_name, jdbc_url, db_properties):
    """Write dataframe to DWH"""
    df.write \
        .jdbc(
            jdbc_url,
            f"dwh.{table_name}",
            mode="overwrite",
            properties=db_properties
        )
    print(f"Written {df.count()} records to dwh.{table_name}")

def upsert_dimension_table(new_df, table_name, business_key_cols, jdbc_url, db_properties):
    spark = new_df.sparkSession
    
    try:
        existing_df = spark.read.jdbc(
            jdbc_url,
            f"dwh.{table_name}",
            properties=db_properties
        )
    except:
        print(f"Dimension table dwh.{table_name} not found. Creating a new one.")
        new_df.write.jdbc(
            jdbc_url,
            f"dwh.{table_name}",
            mode="append",
            properties=db_properties
        )
        print(f"Inserted {new_df.count()} rows into dwh.{table_name}")
        return
    
    existing_current = existing_df.filter(col("is_current") == True)

    exclude_cols = {"valid_from", "valid_to", "is_current"}
    compare_cols = [c for c in new_df.columns if c not in exclude_cols]
    
    print(f"\nComparing columns: {compare_cols}")
    print(f"Business key columns: {business_key_cols}")
    
    join_cond = [new_df[k].eqNullSafe(existing_current[k]) for k in business_key_cols]
    joined = new_df.alias("n").join(
        existing_current.alias("e"),
        join_cond,
        "left"
    )
    
    change_condition = " OR ".join([
        f"NOT (n.{c} IS NOT DISTINCT FROM e.{c})"
        for c in compare_cols
    ])
    
    changed_or_new = joined.filter(expr(change_condition))
    print(f"\nFound {changed_or_new.count()} changed or new records")
    
    if changed_or_new.count() > 0:
        print("Sample of changes:")
        select_cols = []
        for c in compare_cols:
            select_cols.append(col(f"n.{c}").alias(f"new_{c}"))
            select_cols.append(col(f"e.{c}").alias(f"old_{c}"))
        
        changed_or_new.select(select_cols).show(5, truncate=False)
    
        to_close = changed_or_new.filter(col("e."+business_key_cols[0]).isNotNull()) \
            .selectExpr("e.*")
        
        if to_close.count() > 0:
            print(f"\nClosing {to_close.count()} existing records")
            
            to_close_updates = to_close \
                .withColumn("valid_to", current_timestamp()) \
                .withColumn("is_current", lit(False))
            
            new_rows = changed_or_new.selectExpr("n.*") \
                .withColumn("valid_from", current_timestamp()) \
                .withColumn("valid_to", lit(None).cast(TimestampType())) \
                .withColumn("is_current", lit(True))
            
            closed_keys = to_close.select(*business_key_cols).distinct()
            
            unchanged_current = existing_current.join(
                closed_keys,
                on=business_key_cols,
                how="leftanti"
            )
            
            existing_history = existing_df.filter(col("is_current") == False)
            
            final_dim = unchanged_current.unionByName(existing_history) \
                .unionByName(to_close_updates) \
                .unionByName(new_rows)
            
            # Write back to database
            final_dim.write.jdbc(
                jdbc_url,
                f"dwh.{table_name}",
                mode="overwrite",
                properties=db_properties
            )
            
            print(f"Upserted (closed + new) {to_close_updates.count() + new_rows.count()} rows in dwh.{table_name}")
            
        else:
            new_inserts = changed_or_new.selectExpr("n.*") \
                .withColumn("valid_from", current_timestamp()) \
                .withColumn("valid_to", lit(None).cast(TimestampType())) \
                .withColumn("is_current", lit(True))
            
            if new_inserts.count() > 0:
                final_dim = existing_df.unionByName(new_inserts)
                final_dim.write.jdbc(
                    jdbc_url,
                    f"dwh.{table_name}",
                    mode="overwrite",
                    properties=db_properties
                )
                print(f"Inserted {new_inserts.count()} new rows in dwh.{table_name}")
    else:
        print(f"No new or changed rows for dwh.{table_name}. No action taken.")

def upsert_fact_table(new_facts, table_name, unique_key_col, jdbc_url, db_properties):
    spark = new_facts.sparkSession
    
    try:
        existing_facts = spark.read.jdbc(
            jdbc_url,
            f"dwh.{table_name}",
            properties=db_properties
        )
    except:
        print(f"Fact table dwh.{table_name} not found. Creating a new one.")
        new_facts.write.jdbc(
            jdbc_url,
            f"dwh.{table_name}",
            mode="append",
            properties=db_properties
        )
        print(f"Inserted {new_facts.count()} rows into dwh.{table_name}")
        return
    
    join_cond = [new_facts[unique_key_col] == existing_facts[unique_key_col]]
    joined = new_facts.alias("n").join(existing_facts.alias("e"), on=join_cond, how="left")
    
    new_only = joined.filter(col(f"e.{unique_key_col}").isNull()).selectExpr("n.*")
    
    if new_only.count() > 0:
        new_only.write.jdbc(
            jdbc_url,
            f"dwh.{table_name}",
            mode="append",
            properties=db_properties
        )
        print(f"Inserted {new_only.count()} new fact rows into dwh.{table_name}")
    else:
        print(f"No new rows found for dwh.{table_name}. No action taken.")

def load_data_into_warehouse():
    spark = init_spark()
    
    jdbc_url = "jdbc:postgresql://spark_db:5432/db"
    db_properties = {
        "user": "user",
        "password": "password",
        "driver": "org.postgresql.Driver"
    }
    
    try:
        sources = read_source_tables(spark, jdbc_url, db_properties)
        
        print("Processing dimension tables...")
        d_customers = process_customers(sources)
        d_craftsmans = process_craftsmans(sources)
        d_products = process_products(sources)
        
        print("Processing fact table...")
        f_orders = process_orders(sources, d_customers, d_craftsmans, d_products)
        
        print("Upserting dimension tables to DWH...")
        upsert_dimension_table(
            d_customers,
            "d_customers",
            business_key_cols=["customer_email"],
            jdbc_url=jdbc_url,
            db_properties=db_properties
        )
        
        upsert_dimension_table(
            d_craftsmans,
            "d_craftsmans",
            business_key_cols=["craftsman_email"],
            jdbc_url=jdbc_url,
            db_properties=db_properties
        )
        
        upsert_dimension_table(
            d_products,
            "d_products",
            business_key_cols=["product_business_key"],
            jdbc_url=jdbc_url,
            db_properties=db_properties
        )
        
        print("Upserting fact table to DWH...")
        upsert_fact_table(
            f_orders,
            "f_orders",
            unique_key_col="source_system_id", 
            jdbc_url=jdbc_url,
            db_properties=db_properties
        )
        
        print("DWH load completed successfully!")
        
    except Exception as e:
        print(f"Error during DWH load: {str(e)}")
        raise
    finally:
        spark.stop()

load_data_into_warehouse()


Read source1_wide: 999
Read source2_masters: 999
Read source2_orders: 999
Read source3_craftsmans: 999
Read source3_customers: 999
Read source3_orders: 999
Processing dimension tables...
Products before deduplication: 2997
Products after deduplication: 2994
Processing fact table...
Total orders before deduplication: 2997
Number of duplicated order keys: 0
Total orders after deduplication: 2997
Final orders count in f_orders: 2997
Upserting dimension tables to DWH...
Dimension table dwh.d_customers not found. Creating a new one.
Inserted 2997 rows into dwh.d_customers
Dimension table dwh.d_craftsmans not found. Creating a new one.
Inserted 2997 rows into dwh.d_craftsmans
Dimension table dwh.d_products not found. Creating a new one.
Inserted 2994 rows into dwh.d_products
Upserting fact table to DWH...
Fact table dwh.f_orders not found. Creating a new one.
Inserted 2997 rows into dwh.f_orders
DWH load completed successfully (incremental)!


Ура, загрузили все данные. По их числу видно что загрузились правильно. Теперь можно проверить что без изменений оно ничего добавлять не будет: 

In [143]:
load_data_into_warehouse()

Read source1_wide: 999
Read source2_masters: 999
Read source2_orders: 999
Read source3_craftsmans: 999
Read source3_customers: 999
Read source3_orders: 999
Processing dimension tables...
Products before deduplication: 2997
Products after deduplication: 2994
Processing fact table...
Total orders before deduplication: 2997
Number of duplicated order keys: 0
Total orders after deduplication: 2997
Final orders count in f_orders: 2997
Upserting dimension tables to DWH...

Comparing columns: ['customer_id', 'customer_name', 'customer_address', 'customer_birthday', 'customer_email']
Business key columns: ['customer_email']

Found 0 changed or new records
No new or changed rows for dwh.d_customers. No action taken.

Comparing columns: ['craftsman_id', 'craftsman_name', 'craftsman_address', 'craftsman_birthday', 'craftsman_email']
Business key columns: ['craftsman_email']

Found 0 changed or new records
No new or changed rows for dwh.d_craftsmans. No action taken.

Comparing columns: ['product_

Теперь обновим данные в витрине

In [147]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import Window
from pyspark.sql.types import *
from datetime import datetime, timedelta

def init_spark():
    return SparkSession.builder \
        .appName("Craftsman Report Datamart") \
        .config("spark.sql.legacy.timeParserPolicy", "LEGACY") \
        .getOrCreate()

def get_last_load_date(spark, jdbc_url, properties):
    try:
        control_table_query = """
        (SELECT load_dttm 
         FROM dwh.load_dates_craftsman_report_datamart 
         ORDER BY load_dttm DESC 
         LIMIT 1) AS last_load
        """
        
        df = spark.read.jdbc(
            jdbc_url,
            control_table_query,
            properties=properties
        )
        
        if df.count() > 0:
            last_load = df.collect()[0]['load_dttm']
            print(f"Found last load date in control table: {last_load}")
            return last_load
        else:
            print("No previous load date found in control table")
            return None
            
    except Exception as e:
        print(f"Error getting last load date: {str(e)}")
        return None

def read_dwh_data(spark, jdbc_url, properties, last_load_date):
    
    d_craftsmans = spark.read.jdbc(
        jdbc_url, 
        "(SELECT * FROM dwh.d_craftsmans WHERE is_current = true) AS d_craftsmans", 
        properties=properties
    )
    
    d_customers = spark.read.jdbc(
        jdbc_url, 
        "(SELECT * FROM dwh.d_customers WHERE is_current = true) AS d_customers", 
        properties=properties
    )
    
    d_products = spark.read.jdbc(
        jdbc_url, 
        "(SELECT * FROM dwh.d_products WHERE is_current = true) AS d_products", 
        properties=properties
    )
    
    if last_load_date:
        orders_query = f"""
        (SELECT * FROM dwh.f_orders 
         WHERE DATE(load_dttm) >= DATE('{last_load_date}')
         OR DATE(order_created_date) >= DATE('{last_load_date}')
         OR DATE(order_completion_date) >= DATE('{last_load_date}')
        ) AS f_orders
        """
    else:
        orders_query = "dwh.f_orders"
    
    f_orders = spark.read.jdbc(jdbc_url, orders_query, properties=properties)
    
    print(f"""Read from DWH:
    - {d_craftsmans.count()} craftsmen records
    - {d_customers.count()} customer records
    - {d_products.count()} product records
    - {f_orders.count()} order records""")
    
    return d_craftsmans, d_customers, d_products, f_orders

def get_existing_periods(spark, jdbc_url, properties, last_load_date):
    if last_load_date:
        query = f"""
        (SELECT DISTINCT report_period 
         FROM dwh.craftsman_report_datamart 
         WHERE report_period >= '{last_load_date.strftime('%Y-%m')}') AS periods
        """
        return spark.read.jdbc(jdbc_url, query, properties=properties) \
            .select("report_period") \
            .collect()
    return []

def calculate_metrics(d_craftsmans, d_customers, d_products, f_orders):
    
    orders_with_period = f_orders.withColumn(
        "report_period",
        date_format(col("order_created_date"), "yyyy-MM")
    )
    
    enriched_orders = orders_with_period \
        .join(broadcast(d_craftsmans), "craftsman_id") \
        .join(broadcast(d_customers), "customer_id") \
        .join(broadcast(d_products), "product_id")
    
    enriched_orders = enriched_orders.withColumn(
        "customer_age",
        floor(datediff(col("order_created_date"), col("customer_birthday")) / 365.25)
    )
    
    enriched_orders = enriched_orders.withColumn(
        "completion_time",
        when(
            col("order_completion_date").isNotNull(),
            datediff(col("order_completion_date"), col("order_created_date"))
        ).otherwise(None)
    )
    
    window_spec = Window.partitionBy("craftsman_id", "report_period")
    
    metrics = enriched_orders.groupBy("craftsman_id", "report_period") \
        .agg(
            first("craftsman_name").alias("craftsman_name"),
            first("craftsman_address").alias("craftsman_address"),
            first("craftsman_birthday").alias("craftsman_birthday"),
            first("craftsman_email").alias("craftsman_email"),
            
            round(sum("product_price") * 0.9, 2).alias("craftsman_money"),
            (sum("product_price") * 0.1).cast("bigint").alias("platform_money"),
            
            count("order_id").alias("count_order"),
            round(avg("product_price"), 2).alias("avg_price_order"),
            
            round(avg("customer_age"), 1).alias("avg_age_customer"),
            
            expr("percentile_approx(completion_time, 0.5)").alias("median_time_order_completed"),
            
            sum(when(col("order_status") == "created", 1).otherwise(0)).alias("count_order_created"),
            sum(when(col("order_status") == "in progress", 1).otherwise(0)).alias("count_order_in_progress"),
            sum(when(col("order_status") == "delivery", 1).otherwise(0)).alias("count_order_delivery"),
            sum(when(col("order_status") == "done", 1).otherwise(0)).alias("count_order_done"),
            sum(when(col("order_status") != "done", 1).otherwise(0)).alias("count_order_not_done")
        )
    
    top_categories = enriched_orders \
        .groupBy("craftsman_id", "report_period", "product_type") \
        .count() \
        .withColumn(
            "rank",
            row_number().over(
                Window.partitionBy("craftsman_id", "report_period")
                .orderBy(desc("count"))
            )
        ) \
        .filter(col("rank") == 1) \
        .select("craftsman_id", "report_period", "product_type")
    
    # Объединяем все метрики
    final_metrics = metrics.join(
        top_categories,
        ["craftsman_id", "report_period"]
    ).withColumnRenamed("product_type", "top_product_category")
    
    return final_metrics

def get_periods_to_recalculate(spark, jdbc_url, properties, last_load_date):
    affected_periods = set()  

    if not last_load_date:
        all_periods_query = """
        SELECT DISTINCT DATE_TRUNC('month', order_created_date)::date as report_period
        FROM dwh.f_orders
        ORDER BY report_period
        """
        
        try:
            all_periods = spark.read.jdbc(
                jdbc_url,
                f"({all_periods_query}) AS periods",
                properties=properties
            ).collect()
            
            affected_periods.update(row.report_period for row in all_periods)
            print("First load - will process all periods:")
            print(f"Found {len(affected_periods)} periods to process")
            
        except Exception as e:
            print(f"Error getting all periods: {str(e)}")
            return []
    else:
        dim_tables = [
            {
                'table': 'dwh.d_customers',
                'key': 'customer_id',
                'name': 'customers'
            },
            {
                'table': 'dwh.d_craftsmans',
                'key': 'craftsman_id',
                'name': 'craftsmen'
            },
            {
                'table': 'dwh.d_products',
                'key': 'product_id',
                'name': 'products'
            }
        ]

        for dim in dim_tables:
            dim_changes_query = f"""
            WITH changed_dim AS (
                SELECT {dim['key']}
                FROM {dim['table']}
                WHERE valid_from >= '{last_load_date}'
                    OR valid_to >= '{last_load_date}'
            )
            SELECT DISTINCT 
                DATE_TRUNC('month', f.order_created_date)::date as report_period
            FROM dwh.f_orders f
            INNER JOIN changed_dim cd ON f.{dim['key']} = cd.{dim['key']}
            """
            
            try:
                changed_periods = spark.read.jdbc(
                    jdbc_url,
                    f"({dim_changes_query}) AS periods",
                    properties=properties
                ).collect()
                
                if changed_periods:
                    periods = {row.report_period for row in changed_periods}
                    affected_periods.update(periods)
                    print(f"Found periods to recalculate due to {dim['name']} changes: {sorted(periods)}")
            except Exception as e:
                print(f"Error checking {dim['name']} changes: {str(e)}")

        if not affected_periods:
            new_facts_query = f"""
            SELECT DISTINCT DATE_TRUNC('month', order_created_date)::date as report_period
            FROM dwh.f_orders
            WHERE load_dttm >= '{last_load_date}'
            """
            
            try:
                new_fact_periods = spark.read.jdbc(
                    jdbc_url,
                    f"({new_facts_query}) AS periods",
                    properties=properties
                ).collect()
                
                if new_fact_periods:
                    fact_periods = {row.report_period for row in new_fact_periods}
                    affected_periods.update(fact_periods)
                    print(f"Found periods to recalculate due to new fact records: {sorted(fact_periods)}")
            except Exception as e:
                print(f"Error checking new fact records: {str(e)}")

    periods_list = sorted(list(affected_periods))
    
    if periods_list:
        print(f"Total unique periods to recalculate: {periods_list}")
        print(f"Number of periods to recalculate: {len(periods_list)}")
    else:
        print("No periods need recalculation")
    
    return periods_list

def update_datamart(spark, df, jdbc_url, properties, periods_to_recalc):
    if not periods_to_recalc:
        print("No periods need recalculation, skipping update")
        return

    try:
        current_data = spark.read.jdbc(
            jdbc_url,
            "dwh.craftsman_report_datamart",
            properties=properties
        )

        period_strings = [period.strftime('%Y-%m') for period in periods_to_recalc]
        print(f"Updating periods: {period_strings}")

        current_filtered = current_data.filter(
            ~col("report_period").isin(period_strings)
        )

        print(f"Records retained after filtering: {current_filtered.count()}")

        if current_filtered.count() > 0:
            new_data = df.filter(col("report_period").isin(period_strings))
            print(f"New records for affected periods: {new_data.count()}")
            
            final_data = current_filtered.unionByName(new_data)
        else:
            final_data = df

        print(f"Total records to write: {final_data.count()}")

        final_data.write.jdbc(
            jdbc_url,
            "dwh.craftsman_report_datamart",
            mode="overwrite",
            properties=properties
        )
        print(f"Written {final_data.count()} records to datamart")

        current_timestamp = datetime.now()
        spark.createDataFrame(
            [(current_timestamp,)],
            ["load_dttm"]
        ).write.jdbc(
            jdbc_url,
            "dwh.load_dates_craftsman_report_datamart",
            mode="append",
            properties=properties
        )
        print(f"Updated control table with load timestamp: {current_timestamp}")

    except Exception as e:
        print(f"Error updating datamart: {str(e)}")
        raise

    except Exception as e:
        print(f"Error updating datamart: {str(e)}")
        raise

def load_data_into_datamart():
    spark = init_spark()
    
    jdbc_url = "jdbc:postgresql://spark_db:5432/db"
    properties = {
        "user": "user",
        "password": "password",
        "driver": "org.postgresql.Driver"
    }
    
    try:
        last_load_date = get_last_load_date(spark, jdbc_url, properties)
        print(f"Last load date: {last_load_date}")
        
        periods_to_recalc = get_periods_to_recalculate(
            spark, jdbc_url, properties, last_load_date
        )
        
        d_craftsmans, d_customers, d_products, f_orders = read_dwh_data(
            spark, jdbc_url, properties, last_load_date
        )
        
        print("Calculating metrics...")
        final_metrics = calculate_metrics(
            d_craftsmans, d_customers, d_products, f_orders
        )
        
        print("Updating datamart...")
        update_datamart(
            spark, final_metrics, jdbc_url, properties, periods_to_recalc
        )
        
        print("Datamart update completed successfully!")
        
    except Exception as e:
        print(f"Error during datamart update: {str(e)}")
        raise
    finally:
        spark.stop()

load_data_into_datamart()

Found last load date in control table: 2024-12-26 04:20:54.469924
Last load date: 2024-12-26 04:20:54.469924
No periods need recalculation
Read from DWH:
    - 2997 craftsmen records
    - 2997 customer records
    - 2994 product records
    - 2997 order records
Calculating metrics...
Updating datamart...
No periods need recalculation, skipping update
Datamart update completed successfully!


In [148]:
load_data_into_datamart()

Found last load date in control table: 2024-12-26 04:20:54.469924
Last load date: 2024-12-26 04:20:54.469924
No periods need recalculation
Read from DWH:
    - 2997 craftsmen records
    - 2997 customer records
    - 2994 product records
    - 2997 order records
Calculating metrics...
Updating datamart...
No periods need recalculation, skipping update
Datamart update completed successfully!


Да, не обновляется без изменений. Теперь, нужно проверить, обновляется если изменение таки есть. При этом нам нужно проверить весь флоу - что оно обновится в DWH, а потом и в витрине. То есть нам нужно сделать изменение прям в оригиналах - в sources. 

In [94]:
!pip3 install psycopg2-binary

Collecting psycopg2-binary
  Downloading psycopg2_binary-2.9.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.metadata (4.9 kB)
Downloading psycopg2_binary-2.9.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (2.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.9/2.9 MB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: psycopg2-binary
Successfully installed psycopg2-binary-2.9.10


In [149]:
import uuid
from datetime import datetime
from random import randint

from pyspark.sql import SparkSession
from pyspark.sql.functions import rand, col
from pyspark.sql.types import *

import psycopg2

def init_spark():
    """Инициализация Spark-сессии"""
    return SparkSession.builder \
        .appName("Update Customer with psycopg2") \
        .getOrCreate()

def choose_random_customer(spark, jdbc_url, db_properties):
    customers_df = spark.read.jdbc(
        url=jdbc_url,
        table="source3.craft_market_customers",
        properties=db_properties
    )
    total_count = customers_df.count()
    print(f"Всего customer'ов в таблице: {total_count}")

    if total_count == 0:
        print("Таблица пустая, нет записей для обновления.")
        return None, None 

    random_row = customers_df.orderBy(rand()).limit(1).collect()[0]

    if "customer_id" in random_row:
        customer_id = random_row["customer_id"]
        old_name = random_row["customer_name"]
        print(f"Случайно выбран customer_id = {customer_id}")
        print(f"Старое имя = '{old_name}'")
        return customer_id, old_name
    else:
        customer_email = random_row["customer_email"]
        old_name = random_row["customer_name"]
        print(f"Случайно выбран customer_email = {customer_email}")
        return customer_email, old_name

def update_customer_name_psycopg2(customer_id, old_name, new_name, db_host, db_port, db_name, db_user, db_password):
    if customer_id is None:
        print("Нет customer_id — нечего обновлять через psycopg2.")
        return

    conn = psycopg2.connect(
        database=db_name,
        user=db_user,
        password=db_password,
        host=db_host,
        port=db_port
    )

    try:
        cur = conn.cursor()
        sql_update = """UPDATE source3.craft_market_customers
                        SET customer_name = %s
                        WHERE customer_id = %s"""

        print(f"Обновляем имя '{old_name}' -> '{new_name}' у customer_id = {customer_id}")

        cur.execute(sql_update, (new_name, customer_id))
        conn.commit()

        rows_updated = cur.rowcount
        print(f"Обновлено строк: {rows_updated}")

        cur.close()
    except Exception as e:
        print(f"Ошибка в процессе UPDATE через psycopg2: {str(e)}")
        conn.rollback()
    finally:
        conn.close()

def main():
    spark = init_spark()

    jdbc_url = "jdbc:postgresql://spark_db:5432/db"
    db_properties = {
        "user": "user",
        "password": "password",
        "driver": "org.postgresql.Driver"
    }

    db_host = "spark_db"
    db_port = "5432"
    db_name = "db"
    db_user = "user"
    db_password = "password"

    try:
        customer_id, old_name = choose_random_customer(spark, jdbc_url, db_properties)
        if customer_id is None:
            print("Нет данных для обновления (таблица пустая). Завершение.")
            return

        new_name = old_name + "_CHANGED"
        
        update_customer_name_psycopg2(
            customer_id=customer_id,
            old_name=old_name,
            new_name=new_name,
            db_host=db_host,
            db_port=db_port,
            db_name=db_name,
            db_user=db_user,
            db_password=db_password
        )

    except Exception as e:
        print(f"Ошибка во время операции: {str(e)}")
    finally:
        spark.stop()

if __name__ == "__main__":
    main()


Всего customer'ов в таблице: 999
Случайно выбран customer_id = 994
Старое имя = 'Selia Longcake'
Обновляем имя 'Selia Longcake' -> 'Selia Longcake_CHANGED' у customer_id = 994
Обновлено строк: 1


In [150]:
load_data_into_warehouse()

Read source1_wide: 999
Read source2_masters: 999
Read source2_orders: 999
Read source3_craftsmans: 999
Read source3_customers: 999
Read source3_orders: 999
Processing dimension tables...
Products before deduplication: 2997
Products after deduplication: 2994
Processing fact table...
Total orders before deduplication: 2997
Number of duplicated order keys: 0
Total orders after deduplication: 2997
Final orders count in f_orders: 2997
Upserting dimension tables to DWH...

Comparing columns: ['customer_id', 'customer_name', 'customer_address', 'customer_birthday', 'customer_email']
Business key columns: ['customer_email']

Found 1 changed or new records
Sample of changes:
+---------------+---------------+----------------------+-----------------+-------------------------+-------------------------+---------------------+---------------------+-----------------------+-----------------------+
|new_customer_id|old_customer_id|new_customer_name     |old_customer_name|new_customer_address     |old_cu

In [151]:
load_data_into_datamart()

Found last load date in control table: 2024-12-26 04:20:54.469924
Last load date: 2024-12-26 04:20:54.469924
Found periods to recalculate due to customers changes: [datetime.date(2018, 1, 1), datetime.date(2018, 2, 1), datetime.date(2018, 3, 1), datetime.date(2018, 4, 1), datetime.date(2018, 5, 1), datetime.date(2018, 6, 1), datetime.date(2018, 7, 1), datetime.date(2018, 8, 1), datetime.date(2018, 9, 1), datetime.date(2018, 10, 1), datetime.date(2018, 11, 1), datetime.date(2018, 12, 1), datetime.date(2019, 1, 1), datetime.date(2019, 2, 1), datetime.date(2019, 3, 1), datetime.date(2019, 4, 1), datetime.date(2019, 5, 1), datetime.date(2019, 6, 1), datetime.date(2019, 7, 1), datetime.date(2019, 8, 1), datetime.date(2019, 9, 1), datetime.date(2019, 10, 1), datetime.date(2019, 11, 1), datetime.date(2019, 12, 1), datetime.date(2020, 1, 1), datetime.date(2020, 2, 1), datetime.date(2020, 3, 1), datetime.date(2020, 4, 1), datetime.date(2020, 5, 1), datetime.date(2020, 6, 1), datetime.date(2020,

In [None]:
Вывод: Где-то есть баг, который не позволяет обновить витрину одной записью. При этом в DWH пишется корректно