In [0]:
#Importar librerías
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.types import StructType, StructField, StringType, BooleanType
from pyspark.sql.functions import col, when, lower
from typing import Dict, Optional, List

In [0]:
# --- Clase para conexión y lectura ---
class PostgresReader:
    def __init__(
        self,
        scope: str,
        username_key: str,
        password_key: str,
        hostname: str,
        port: int,
        database: str,
        driver: str = "org.postgresql.Driver",
    ):
        self.jdbc_username = dbutils.secrets.get(scope=scope, key=username_key)
        self.jdbc_password = dbutils.secrets.get(scope=scope, key=password_key)
        self.jdbc_url = f"jdbc:postgresql://{hostname}:{port}/{database}?sslmode=require"
        self.connection_properties: Dict[str, str] = {
            "user": self.jdbc_username,
            "password": self.jdbc_password,
            "driver": driver,
        }

    def read_table(self, table_name: str, schema: Optional[StructType] = None) -> DataFrame:
        df = spark.read.jdbc(
            url=self.jdbc_url, table=table_name, properties=self.connection_properties
        )

        if schema is not None:
            for field in schema.fields:
                df = df.withColumn(field.name, df[field.name].cast(field.dataType))

        return df

In [0]:
# --- Clase para validación de esquemas ---
class SchemaValidator:
    def __init__(self, spark):
        self.spark = spark

    def validate(self, jdbc_url, table_name, connection_props, schema_expected, preview=True):
        df_actual = self.spark.read.jdbc(url=jdbc_url, table=table_name, properties=connection_props)
        actual_schema = df_actual.schema

        expected_fields = {f.name: type(f.dataType) for f in schema_expected.fields}
        actual_fields = {f.name: type(f.dataType) for f in actual_schema.fields}

        errores = False
        for col in expected_fields:
            if col not in actual_fields:
                print(f"⚠️ Columna faltante: '{col}'")
                errores = True
            elif expected_fields[col] != actual_fields[col]:
                print(f"⚠️ Tipo cambiado en '{col}': esperado {expected_fields[col]}, recibido {actual_fields[col]}")
                errores = True

        for col in actual_fields:
            if col not in expected_fields:
                print(f"⚠️ Columna adicional no esperada: '{col}'")
                errores = True

        if errores:
            print("❌ El esquema actual NO coincide con el esperado.")
        else:
            print("✅ El esquema actual coincide con el esperado.")
            if preview:
                display(df_actual.limit(10))

        return not errores, df_actual

In [0]:
# --- Clase para escritura en Delta ---
class DeltaWriter:
    def __init__(self, base_path: str = "/mnt/bronze"):
        self.base_path = base_path

    def write(
        self,
        df: DataFrame,
        table_name: str,
        database: str = "bronze",
        partition_cols: Optional[List[str]] = None,
        mode: str = "overwrite",
    ):
        write_path = f"{self.base_path}/{table_name}"
        full_table_name = f"{database}.{table_name}"

        writer = df.write.format("delta").option("mergeSchema", "true").mode(mode)

        if partition_cols:
            writer = writer.partitionBy(*partition_cols)

        writer.save(write_path)

        spark.sql(f"CREATE DATABASE IF NOT EXISTS {database}")
        spark.sql(
            f"""
            CREATE TABLE IF NOT EXISTS {full_table_name}
            USING DELTA
            LOCATION '{write_path}'
        """
        )
        spark.sql(f"REFRESH TABLE {full_table_name}")

        print(f"✅ Datos escritos y tabla registrada como: {full_table_name}")

In [0]:
# --- Clase de transformaciones ---
class UserTransformer:
    def __init__(self, df: DataFrame):
        self.df = df

    def add_solution_column(self) -> "UserTransformer":
        self.df = self.df.withColumn(
            "solucion",
            when(lower(col("grupo")).contains("sl"), "Storelive")
            .when(lower(col("grupo")).contains("sv"), "StoreView")
            .when(lower(col("grupo")).contains("sc"), "StoreConnect")
            .when(lower(col("grupo")).contains("ml"), "Marketlink")
            .otherwise("Otros"),
        )
        return self

    def add_user_type_column(self) -> "UserTransformer":
        self.df = self.df.withColumn(
            "tipo_de_usuario",
            when(col("email").contains("dichter"), "Interno").otherwise("Externo"),
        )
        return self

    def clean_column_names(self) -> "UserTransformer":
        self.df = self.df.toDF(*(c.replace(" ", "_") for c in self.df.columns))
        return self

    def get_df(self) -> DataFrame:
        return self.df