In [0]:

from pyspark.sql.functions import *
from pyspark.sql import Window
from pyspark.sql.types import *
import re
from pyspark.sql import functions as F
from functools import reduce

# COMMAND ----------

# MAGIC %md
# MAGIC ## 1. Advanced Text Similarity Functions

# COMMAND ----------

@udf(DoubleType())
def advanced_name_similarity(name1, name2):
    """
    Advanced similarity considering:
    - Levenshtein distance
    - Token overlap
    - Soundex matching
    """
    if not name1 or not name2:
        return 0.0
    
    name1 = name1.lower().strip()
    name2 = name2.lower().strip()
    
    # Exact match
    if name1 == name2:
        return 1.0
    
    # Token-based similarity
    tokens1 = set(name1.split())
    tokens2 = set(name2.split())
    
    if len(tokens1) == 0 or len(tokens2) == 0:
        return 0.0
    
    # Jaccard similarity
    intersection = len(tokens1 & tokens2)
    union = len(tokens1 | tokens2)
    jaccard = intersection / union if union > 0 else 0.0
    
    # Character-level similarity (simplified Levenshtein)
    max_len = max(len(name1), len(name2))
    char_sim = 1.0 - (levenshtein(name1, name2) / max_len) if max_len > 0 else 0.0
    
    # Combined score (weighted)
    final_score = (jaccard * 0.6) + (char_sim * 0.4)
    
    return final_score

# COMMAND ----------

@udf(DoubleType())
def url_domain_similarity(url1, url2):
    """
    Compare domain names from URLs
    """
    if not url1 or not url2:
        return 0.0
    
    # Extract domain
    import re
    
    def extract_domain(url):
        # Remove protocol
        url = re.sub(r'^https?://', '', url)
        # Remove www
        url = re.sub(r'^www\.', '', url)
        # Get domain (before first /)
        domain = url.split('/')[0]
        # Remove port
        domain = domain.split(':')[0]
        return domain.lower()
    
    domain1 = extract_domain(url1)
    domain2 = extract_domain(url2)
    
    if domain1 == domain2:
        return 1.0
    
    # Check if one is subdomain of other
    if domain1 in domain2 or domain2 in domain1:
        return 0.9
    
    return 0.0

# COMMAND ----------

# MAGIC %md
# MAGIC ## 2. Fuzzy Matching Pipeline for Unmatched Records

# COMMAND ----------

def find_fuzzy_matches(source_df, target_df, source_name_col, target_name_col, 
                       source_url_col=None, target_url_col=None, threshold=0.85):
    """
    Find fuzzy matches between two dataframes
    
    Parameters:
    - source_df: DataFrame to match FROM
    - target_df: DataFrame to match TO
    - source_name_col: company name column in source
    - target_name_col: company name column in target
    - source_url_col: optional URL column in source
    - target_url_col: optional URL column in target
    - threshold: minimum similarity threshold (0-1)
    
    Returns:
    - DataFrame with matched pairs and similarity scores
    """
    
    # Cross join for comparison (use broadcast for smaller df)
    if source_df.count() < target_df.count():
        matches = source_df.crossJoin(broadcast(target_df))
    else:
        matches = broadcast(source_df).crossJoin(target_df)
    
    # Calculate name similarity
    matches = matches.withColumn(
        "name_similarity",
        advanced_name_similarity(col(source_name_col), col(target_name_col))
    )
    
    # Calculate URL similarity if URLs provided
    if source_url_col and target_url_col:
        matches = matches.withColumn(
            "url_similarity",
            url_domain_similarity(col(source_url_col), col(target_url_col))
        )
        
        # Combined score: 70% name, 30% URL
        matches = matches.withColumn(
            "combined_similarity",
            (col("name_similarity") * 0.7) + (col("url_similarity") * 0.3)
        )
    else:
        matches = matches.withColumn("combined_similarity", col("name_similarity"))
    
    # Filter by threshold
    matches = matches.filter(col("combined_similarity") >= threshold)
    
    # Rank matches
    window = Window.partitionBy(source_name_col).orderBy(col("combined_similarity").desc())
    matches = matches.withColumn("match_rank", row_number().over(window))
    
    return matches

# COMMAND ----------

# MAGIC %md
# MAGIC ## 3. Data Quality Checks

# COMMAND ----------

def check_uen_validity(df, uen_col="uen"):
    """
    Check UEN format validity for Singapore companies
    UEN format: XXXXXXXXX + letter (9 digits + 1 letter) or similar patterns
    """
    
    return df.withColumn(
        "uen_is_valid",
        when(
            col(uen_col).rlike(r'^[0-9]{8,9}[A-Z]$') |
            col(uen_col).rlike(r'^[0-9]{10}[A-Z]$') |
            col(uen_col).rlike(r'^[STF][0-9]{8}[A-Z]$'),
            True
        ).otherwise(False)
    )

def check_url_validity(df, url_col="website"):
    """
    Check if URL is valid format
    """
    
    return df.withColumn(
        "url_is_valid",
        when(
            col(url_col).isNotNull() &
            (col(url_col).rlike(r'^[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}') |
             col(url_col).rlike(r'^https?://[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}')),
            True
        ).otherwise(False)
    )

def check_email_validity(df, email_col="contact_email"):
    """
    Check if email is valid format
    """
    
    return df.withColumn(
        "email_is_valid",
        when(
            col(email_col).isNotNull() &
            col(email_col).rlike(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'),
            True
        ).otherwise(False)
    )

def check_phone_validity(df, phone_col="contact_phone"):
    """
    Check if phone number is valid Singapore format
    Singapore: +65 XXXX XXXX or 6/8/9 followed by 7 digits
    """
    
    return df.withColumn(
        "phone_is_valid",
        when(
            col(phone_col).isNotNull() &
            (col(phone_col).rlike(r'^\+65\s*[689]\d{7}$') |
             col(phone_col).rlike(r'^[689]\d{7}$')),
            True
        ).otherwise(False)
    )

# COMMAND ----------

def run_all_quality_checks(df):
    """
    Run all quality checks on unified dataframe
    """
    
    df = check_uen_validity(df, "uen")
    df = check_url_validity(df, "website")
    df = check_email_validity(df, "contact_email")
    df = check_phone_validity(df, "contact_phone")
    
    # Add overall quality score
    df = df.withColumn(
        "data_quality_score",
        (
            when(col("uen_is_valid"), 1).otherwise(0) +
            when(col("url_is_valid"), 1).otherwise(0) +
            when(col("email_is_valid"), 1).otherwise(0) +
            when(col("phone_is_valid"), 1).otherwise(0)
        ) / 4.0 * 100
    )
    
    return df

# COMMAND ----------

# MAGIC %md
# MAGIC ## 4. Deduplication Functions

# COMMAND ----------
def calculate_data_completeness_score(df):
    """
    ✅ Calculates a completeness percentage per record.
    Works with Spark Connect (Databricks) by properly summing columns.
    """
    important_fields = [
        "website",
        "linkedin",
        "facebook",
        "instagram",
        "industry",
        "revenue",
        "contact_email",
        "contact_phone",
        "products_offered",
        "services_offered"
    ]

    total_fields = len(important_fields)

    # Build column expressions safely
    cols_to_sum = [F.when(F.col(c).isNotNull(), F.lit(1)).otherwise(F.lit(0)) for c in important_fields]

    # Handle case where only one column exists (reduce fails with 1)
    if len(cols_to_sum) > 1:
        completeness_expr = reduce(lambda a, b: a + b, cols_to_sum)
    else:
        completeness_expr = cols_to_sum[0]

    # Calculate percentage completeness
    df = df.withColumn(
        "data_completeness_score",
        (completeness_expr / F.lit(total_fields)) * 100
    )

    return df
def find_duplicates(df, match_cols, similarity_threshold=0.95):
    """
    Find potential duplicate records based on similarity
    
    Parameters:
    - df: DataFrame to check
    - match_cols: list of columns to check for duplicates
    - similarity_threshold: threshold for considering duplicates
    
    Returns:
    - DataFrame with potential duplicate groups
    """
    
    # Create self-join for comparison
    df1 = df.alias("df1")
    df2 = df.alias("df2")
    
    # Join on similarity
    duplicates = df1.join(
        df2,
        (col("df1.uen") < col("df2.uen")) &  # Avoid self-comparison and duplicates
        (
            advanced_name_similarity(col("df1.company_name"), col("df2.company_name")) >= similarity_threshold
        ),
        "inner"
    )
    
    duplicates = duplicates.select(
        col("df1.uen").alias("uen1"),
        col("df1.company_name").alias("name1"),
        col("df2.uen").alias("uen2"),
        col("df2.company_name").alias("name2"),
        advanced_name_similarity(col("df1.company_name"), col("df2.company_name")).alias("similarity")
    )
    
    return duplicates

# COMMAND ----------

# MAGIC %md
# MAGIC ## 5. Missing Data Imputation Strategies

# COMMAND ----------

def impute_company_size(df):
    """
    Impute company size based on available data
    """
    
    return df.withColumn(
        "company_size_imputed",
        when(col("company_size") == "Unknown", 
             when(col("revenue").isNotNull() & (col("revenue") > 50000000), "Large")
             .when(col("revenue").isNotNull() & (col("revenue") > 5000000), "Medium")
             .when(col("revenue").isNotNull(), "Small")
             .otherwise("Unknown")
        ).otherwise(col("company_size"))
    )

def impute_industry(df):
    """
    Impute industry from keywords or description
    """
    
    # Common industry keywords
    df = df.withColumn(
        "industry_imputed",
        when(col("industry").isNull(),
             when(lower(col("keywords")).contains("tech") | lower(col("keywords")).contains("software"), "Technology")
             .when(lower(col("keywords")).contains("finance") | lower(col("keywords")).contains("bank"), "Finance")
             .when(lower(col("keywords")).contains("retail") | lower(col("keywords")).contains("ecommerce"), "Retail")
             .when(lower(col("keywords")).contains("food") | lower(col("keywords")).contains("restaurant"), "Food & Beverage")
             .when(lower(col("keywords")).contains("health") | lower(col("keywords")).contains("medical"), "Healthcare")
             .otherwise("Unknown")
        ).otherwise(col("industry"))
    )
    
    return df

# COMMAND ----------

# MAGIC %md
# MAGIC ## 6. Export Functions

# COMMAND ----------

def export_matching_report(unified_df, output_path):
    """
    Generate comprehensive matching quality report
    """
    
    report = unified_df.select(
        "uen",
        "company_name",
        "website",
        "data_completeness_score",
        "data_quality_score",
        "entity_status_description"
    ).orderBy(col("data_completeness_score").desc())
    
    # Save as CSV for easy viewing
    report.coalesce(1).write \
        .format("csv") \
        .option("header", "true") \
        .mode("overwrite") \
        .save(f"{output_path}/matching_quality_report")
    
    print(f"✓ Report saved to: {output_path}/matching_quality_report")

def export_statistics(unified_df, output_path):
    """
    Export summary statistics
    """
    
    stats = unified_df.select(
        count("*").alias("total_records"),
        countDistinct("uen").alias("unique_companies"),
        (count("website") / count("*") * 100).alias("website_coverage"),
        (count("linkedin") / count("*") * 100).alias("linkedin_coverage"),
        (count("revenue") / count("*") * 100).alias("revenue_coverage"),
        avg("data_completeness_score").alias("avg_completeness"),
        avg("data_quality_score").alias("avg_quality")
    )
    
    stats.coalesce(1).write \
        .format("csv") \
        .option("header", "true") \
        .mode("overwrite") \
        .save(f"{output_path}/summary_statistics")
    
    print(f"✓ Statistics saved to: {output_path}/summary_statistics")

