# Merchant Recommendation System (MRv1)

### Notebook Flow

1. Data Loading (Spark -> CSV source of truth)
2. EDA
3. Listing Merchants
4. Cleaning and Extraction
5. Standardization
6. Grouping
7. Categorization (MCC + Pattern + Fallback)
8. RFM Analysis
9. Category-Level RFM
10. Final Integration & Summary


In [0]:
# Imports and setup
import pandas as pd
import numpy as np
import re
import warnings
import time
import multiprocessing
from pathlib import Path

try:
    from pandarallel import pandarallel
    PANDARALLEL_AVAILABLE = True
except ImportError:
    PANDARALLEL_AVAILABLE = False

from pyspark.sql.functions import col

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 50)
pd.set_option('display.float_format', lambda x: '%.2f' % x)
warnings.filterwarnings('ignore')

if PANDARALLEL_AVAILABLE:
    pandarallel.initialize(nb_workers=min(multiprocessing.cpu_count()-1, 8), progress_bar=False, verbose=0)

# Use this canonical path for all file operations (pandas, CSV, etc.)
VOLUME_PATH = "/Volumes/jupiter/temp/temp/"
Path(VOLUME_PATH).mkdir(parents=True, exist_ok=True)

DATA_TABLE = "jupiter.temp.backfill_test_mrinal_v2"

## Data Loading (Spark -> CSV source of truth)

We first load from the Databricks table and immediately persist to CSV under a shared volume. Subsequent steps read/write CSVs and also update the in-memory DataFrame so outputs remain visible in notebook runs.

1. `jupiter.temp.unified_transaction_table_v2` : Base table consisting of all {User - Merchant} pairs where the Amount(txn.) > 100
2. `jupiter.temp.unique_user_merchant_mapping` : Subsequent created table for all {U-M} pairs where Count(Txns.) > 2 from the base table


In [0]:
# Load from Spark, persist to CSV, and use CSV as source of truth
print(f"📋 Loading from table: {DATA_TABLE}")
sample_df = (
    spark.table(DATA_TABLE)
         .toPandas()
)

csv_path = f"{VOLUME_PATH}mrv1_backfill.csv"
sample_df.to_csv(csv_path, index=False)
print(f"✅ Saved raw snapshot to: {csv_path}")

# Re-load from CSV as source of truth
sample_df = pd.read_csv(csv_path)
print(f"✅ Reloaded from CSV: {sample_df.shape}")

# Basic info
print(f"👥 Users: {sample_df['user_id'].nunique() if 'user_id' in sample_df.columns else 'N/A'}")
print(f"🏪 Merchants (raw): {sample_df['merchant'].nunique() if 'merchant' in sample_df.columns else 'N/A'}")


## EDA & Listing Merchants

We inspect dataset shape, status, and list top merchants (raw vs extracted). All previews are saved to CSVs in the shared volume.


In [0]:
# EDA quick
print(f"Shape: {sample_df.shape}")
if 'transactionstatus' in sample_df.columns:
    print(sample_df['transactionstatus'].value_counts().to_string())

# Top merchants (raw)
if 'merchant' in sample_df.columns:
    top_raw = sample_df['merchant'].fillna('Unknown').astype(str).str.strip().str.lower().value_counts().head(50)
    raw_path = f"{VOLUME_PATH}mrv1_top_merchants_backfill.csv"
    top_raw.to_csv(raw_path)
    print(f"✅ Saved top raw merchants to {raw_path}")


##Cleaning and Extraction

We clean rows, filter successful txns, and extract merchant names from embedded JSON structures, then persist the cleaned snapshot to CSV.


## User-Merchant Transaction Analysis

**Dataset**: `jupiter.temp.unified_transactions_v2` - Transaction level data for all users including Fact MM + Stg Rewards + ULM CC table

**Technical Pipeline**:
1. Data loading and exploratory analysis
2. Merchant name standardization
3. User-merchant pair aggregation
4. Merchant categorization (hybrid approach)
5. Transaction pattern analysis

In [0]:
print("Dataset Sample Information:")
print("-" * 50)
print(f"Columns: {len(sample_df.columns)}")
print(f"Sample rows: {len(sample_df)}")

# Display column information
print("\nColumn Data Types:")
print("-" * 50)
for col, dtype in sample_df.dtypes.items():
    non_null = sample_df[col].count()
    pct_filled = non_null / len(sample_df) * 100
    print(f"{col:<25} {str(dtype):<12} {non_null:>8,} non-null ({pct_filled:.1f}%)")

# Merchant columns for analysis
merchant_cols = [col for col in sample_df.columns if any(k in col.lower() for k in ['merchant', 'payee'])]
print(f"\nMerchant columns: {merchant_cols}")

# Preview data
sample_df.head(3)

In [0]:
import pandas as pd

# Analyze key metrics
print("\nKey Dataset Metrics:")
print("-" * 50)
print(f"Unique users: {sample_df['user_id'].nunique():,}")
print(f"Unique merchants: {sample_df['merchant'].nunique():,}")

# Transaction amount stats
print(f"\nTransaction amount stats:")
print(f"  - Min: ₹{sample_df['total_spend'].min():,.2f}")
print(f"  - Max: ₹{sample_df['total_spend'].max():,.2f}")
print(f"  - Mean: ₹{sample_df['total_spend'].mean():,.2f}")
print(f"  - Median: ₹{sample_df['total_spend'].median():,.2f}")
print(f"  - Total: ₹{sample_df['total_spend'].sum():,.2f}")

# Transaction count stats
print(f"\nTransaction count stats:")
print(f"  - Min: {sample_df['total_txns'].min():,}")
print(f"  - Max: {sample_df['total_txns'].max():,}")
print(f"  - Mean: {sample_df['total_txns'].mean():.2f}")
print(f"  - Median: {sample_df['total_txns'].median():,}")
print(f"  - Total: {sample_df['total_txns'].sum():,}")

# Analyze transaction sources
print("\nTransaction Sources:")
print("-" * 50)
source_counts = sample_df['source'].value_counts()
for source, count in source_counts.items():
    print(f"{source:<10}: {count:,} records ({count/len(sample_df):.1%})")

# Analyze transaction status
print("\nTransaction Status:")
print("-" * 50)
status_counts = sample_df['transactionstatus'].value_counts()
for status, count in status_counts.items():
    print(f"{status:<10}: {count:,} records ({count/len(sample_df):.1%})")

# Check date range
print("\nTransaction Date Range:")
print("-" * 50)
sample_df['first_txn_date'] = pd.to_datetime(sample_df['first_txn_date'])
sample_df['last_txn_date'] = pd.to_datetime(sample_df['last_txn_date'])
min_date = sample_df['first_txn_date'].min()
max_date = sample_df['last_txn_date'].max()
date_range = max_date - min_date
print(f"Earliest: {min_date}")
print(f"Latest: {max_date}")
print(f"Range: {date_range.days} days")


In [0]:
import matplotlib.pyplot as plt
import seaborn as sns
# Visualize transaction distribution
plt.figure(figsize=(14, 6))

# Plot 1: Transactions per user distribution (log scale)
plt.subplot(1, 2, 1)
sns.histplot(sample_df.groupby('user_id')['total_txns'].sum(), log_scale=True)
plt.title('Transactions per User (Log Scale)')
plt.xlabel('Number of Transactions')
plt.ylabel('Count of Users')

# Plot 2: Spend per user distribution (log scale)
plt.subplot(1, 2, 2)
sns.histplot(sample_df.groupby('user_id')['total_spend'].sum(), log_scale=True)
plt.title('Total Spend per User (Log Scale)')
plt.xlabel('Total Spend (₹)')
plt.ylabel('Count of Users')

plt.tight_layout()
plt.show()

# Display dataset shape
print(f"Dataset shape: {sample_df.shape}")

## Data Prep 

**Key Steps:**
1. **Data Cleaning**: Filter invalid records and handle missing values
2. **Merchant Name Extraction**: Extract merchant names from JSON formats
3. **Merchant Name Standardization**: Remove prefixes/suffixes, normalize formatting
4. **Source Analysis**: Analyze MM vs CC transaction distribution


In [0]:
import re

print("Cleaning and preparing data...")
sample_df_clean = sample_df.dropna(subset=['merchant']).copy() # Filter out records with missing merchant names
print(f"Removed {len(sample_df) - len(sample_df_clean):,} records with missing merchant names")

# Filter out failed transactions
if 'transactionstatus' in sample_df_clean.columns:
    success_mask = sample_df_clean['transactionstatus'] == 'SUCCESS'
    sample_df_clean = sample_df_clean[success_mask].copy()
    print(f"Kept only successful transactions: {len(sample_df_clean):,} records")

# Extract merchant name from JSON format if present
def extract_merchant_name(merchant_str):
    if pd.isna(merchant_str) or not isinstance(merchant_str, str):
        return merchant_str
    
    # Extract name from JSON format if present
    json_pattern = r'"name":"([^"]+)"'
    json_match = re.search(json_pattern, merchant_str)
    if json_match:
        return json_match.group(1).strip()
    
    return merchant_str

# Apply extraction to merchant column
sample_df_clean['merchant_extracted'] = sample_df_clean['merchant'].apply(extract_merchant_name)

# Analyze source distribution
source_counts = sample_df_clean['source'].value_counts()
total_records = len(sample_df_clean)

print("\nTransaction Source Distribution:")
print("-" * 50)
for source, count in source_counts.items():
    print(f"{source:<10}: {count:,} records ({count/total_records:.1%})")

# Check MCC code availability
mcc_available = sample_df_clean['mcccode'].notna().sum()
print(f"\nMCC code availability: {mcc_available:,} records ({mcc_available/total_records:.1%})")

# Show sample of extracted merchant names
print("\nSample of extracted merchant names:")
display(sample_df_clean[['merchant', 'merchant_extracted']].sample(5))

## Merchant Analysis

**Key Metrics:**
1. **User Reach**: Number of unique users per merchant
2. **Transaction Frequency**: Total transaction count per merchant
3. **Revenue Impact**: Total spend amount per merchant
4. **Transaction Patterns**: Average spend, transactions per user
5. **Merchant Popularity**: Ranking based on combined metrics


# Analyze top merchants by different metrics


In [0]:
# Group by merchant and calculate key metrics
merchant_metrics = sample_df_clean.groupby('merchant_extracted').agg({
    'user_id': 'nunique',      # Distinct users
    'total_txns': 'sum',      # Total transactions
    'total_spend': 'sum',     # Total spend
    'mcccode': lambda x: x.mode().iloc[0] if not x.isna().all() else np.nan,  # Most common MCC code
    'source': lambda x: x.mode().iloc[0]  # Most common source
}).reset_index()

# Rename columns for clarity
merchant_metrics.columns = ['merchant', 'users', 'transactions', 'spend', 'mcc_code', 'primary_source']

# Calculate derived metrics
merchant_metrics['avg_txn_amount'] = (merchant_metrics['spend'] / merchant_metrics['transactions']).round(2)
merchant_metrics['txn_per_user'] = (merchant_metrics['transactions'] / merchant_metrics['users']).round(2)
merchant_metrics['spend_per_user'] = (merchant_metrics['spend'] / merchant_metrics['users']).round(2)

# Calculate popularity score (weighted combination of users, transactions, and spend)
merchant_metrics['user_rank'] = merchant_metrics['users'].rank(ascending=False)
merchant_metrics['txn_rank'] = merchant_metrics['transactions'].rank(ascending=False)
merchant_metrics['spend_rank'] = merchant_metrics['spend'].rank(ascending=False)
merchant_metrics['popularity_score'] = (0.4 * merchant_metrics['user_rank'] + 
                                      0.3 * merchant_metrics['txn_rank'] + 
                                      0.3 * merchant_metrics['spend_rank'])

# Display top merchants by different metrics
print("Top 20 Merchants by User Reach:")
print("-" * 80)
print(f"{'Rank':<5}{'Merchant':<30}{'Users':<10}{'Transactions':<15}{'Spend (₹)':<15}{'Avg Txn (₹)':<15}")
print("-" * 80)

top_by_users = merchant_metrics.sort_values('users', ascending=False).head(20)
for i, (_, row) in enumerate(top_by_users.iterrows(), 1): 
    print(f"{i:<5}{row['merchant'][:29]:<30}{row['users']:<10,}{row['transactions']:<15,}{row['spend']:<15,.2f}{row['avg_txn_amount']:<15,.2f}")

print("\nTop 20 Merchants by Transaction Volume:")
print("-" * 80)
print(f"{'Rank':<5}{'Merchant':<30}{'Transactions':<15}{'Users':<10}{'Spend (₹)':<15}{'Txn/User':<10}")
print("-" * 80)

top_by_txns = merchant_metrics.sort_values('transactions', ascending=False).head(20)
for i, (_, row) in enumerate(top_by_txns.iterrows(), 1):
    print(f"{i:<5}{row['merchant'][:29]:<30}{row['transactions']:<15,}{row['users']:<10,}{row['spend']:<15,.2f}{row['txn_per_user']:<10.2f}")

print("\nTop 20 Merchants by Total Spend:")
print("-" * 80)
print(f"{'Rank':<5}{'Merchant':<30}{'Spend (₹)':<15}{'Users':<10}{'Transactions':<15}{'Spend/User (₹)':<15}")
print("-" * 80)

top_by_spend = merchant_metrics.sort_values('spend', ascending=False).head(20)
for i, (_, row) in enumerate(top_by_spend.iterrows(), 1):
    print(f"{i:<5}{row['merchant'][:29]:<30}{row['spend']:<15,.2f}{row['users']:<10,}{row['transactions']:<15,}{row['spend_per_user']:<15,.2f}")


## Top 50 Merchants by User Reach

In [0]:
top_by_users = (
    merchant_metrics.sort_values('users', ascending=False)
    .head(50)
    .reset_index(drop=True)
)

print(f"{'Rank':<5}{'Merchant':<30}{'Users':<12}{'Transactions':<15}{'Spend (₹)':<15}{'Avg Txn (₹)':<15}")
print("-" * 100)

for i, row in top_by_users.iterrows():
    print(f"{i+1:<5}{row['merchant'][:29]:<30}{row['users']:<12,}{row['transactions']:<15,}{row['spend']:<15,.2f}{row['avg_txn_amount']:<15,.2f}")


## Top 50 Merchants by Transaction Volume

In [0]:
top_by_txns = (
    merchant_metrics.sort_values('transactions', ascending=False)
    .head(50)
    .reset_index(drop=True)
)

print(f"{'Rank':<5}{'Merchant':<30}{'Transactions':<15}{'Users':<12}{'Spend (₹)':<15}{'Txn/User':<12}")
print("-" * 100)

for i, row in top_by_txns.iterrows():
    print(f"{i+1:<5}{row['merchant'][:29]:<30}{row['transactions']:<15,}{row['users']:<12,}{row['spend']:<15,.2f}{row['txn_per_user']:<12.2f}")


## Top 50 Merchants by Total Spend


In [0]:
top_by_spend = (
    merchant_metrics.sort_values('spend', ascending=False)
    .head(50)
    .reset_index(drop=True)
)

print(f"{'Rank':<5}{'Merchant':<30}{'Spend (₹)':<15}{'Users':<12}{'Transactions':<15}{'Spend/User (₹)':<15}")
print("-" * 100)

for i, row in top_by_spend.iterrows():
    print(f"{i+1:<5}{row['merchant'][:29]:<30}{row['spend']:<15,.2f}{row['users']:<12,}{row['transactions']:<15,}{row['spend_per_user']:<15,.2f}")


## Merchant Name Standardization and Grouping

The merchant standardization process consolidates different variations of the same merchant into a single canonical name. This is crucial for accurate merchant analysis as it:

1. **Reduces data fragmentation** - Combines multiple variants of the same merchant (e.g., "Swiggy Ltd", "SWIGGY", "Swiggy Online Order" → "Swiggy")
2. **Improves data quality** - Removes inconsistencies in merchant names due to different sources (MM vs CC)
3. **Enables accurate aggregation** - Allows proper grouping of transactions by merchant for reliable metrics
4. **Facilitates categorization** - Makes it easier to categorize merchants by business type

The standardization uses multiple techniques:
- JSON extraction for structured merchant data
- Regex pattern matching for common prefixes/suffixes
- Explicit mappings for well-known merchants
- Special character and formatting normalization


In [0]:
import re

# First, let's examine merchant name variations for common merchants
merchant_keywords = ['swiggy', 'amazon', 'flipkart', 'zomato', 'phonepe']
    
for keyword in merchant_keywords:
    # Use merchant_extracted instead of merchant_unified since we're working with the new dataset structure
    variants = sample_df_clean[sample_df_clean['merchant_extracted'].str.contains(keyword, case=False, na=False)]
    variant_counts = variants['merchant_extracted'].value_counts().head(5)
    
    print(f"\n{keyword.upper()} variants:")
    for name, count in variant_counts.items():
        print(f"- {name}: {count} records")


In [0]:
%pip install python-Levenshtein

## Standardization function 
`standardize_merchant_name`

In [0]:
%pip install pandarallel

In [0]:
import re
import pandas as pd
from pandarallel import pandarallel

# Initialize pandarallel
pandarallel.initialize(progress_bar=True, nb_workers=4)

def standardize_merchant_name(name):
    """
    Standardize merchant names and filter out banks, competitors, and specified merchants.
    
    Args:
        name (str): Raw merchant name
        
    Returns:
        str: Standardized merchant name or None if merchant should be filtered out
    """
    if pd.isna(name) or not isinstance(name, str):
        return "Unknown"
    
    # Extract name from JSON format if present
    json_pattern = r'"name":"([^"]+)"'
    json_match = re.search(json_pattern, str(name))
    if json_match:
        name = json_match.group(1)
    
    # Basic cleaning
    name = str(name).strip().lower()
    name = re.sub(r'[^\w\s]', ' ', name)  # Replace special chars with space
    name = re.sub(r'\s+', ' ', name).strip()  # Normalize whitespace
    
    # Merchants to filter out (existing + new list merged)
    merchants_to_remove = {
        # Existing
        'vi', 'jio', 'airtel', 'irctc', 'indian oil', 'bsnl', 'bpcl', 'hpcl',
        'phonepe', 'paytm', 'paytm wallet', 'paytm cash', 'paytm cashback', 'paytm m', 'paytm m-w',
        'cred', 'groww', 'mobikwik', 'kiwi', 'supermoney', 'google pay',
        'bharat connect utilities', 'mpokket', 'kreditbee', 'branch', 'on demand salary',
        're cash', 'wdl', 'atw', 'mbk', 'ccbp', 'googlepay', 'bharatpe', 'capitalfloat',
        'slice', 'navi', 'rblmycard'
        
        # Newly added merchants to filter
        'debit card annual fee',
        'google india digital',
        'google i',
        'vodafone idea',
        'onecard',
        'bank acc',
        'sbi card',
        'cheq',
        'indian railways',
        'bajaj finance',
        'google india di',
        'dummy name',
        'zerodha',
        'sbimops',
        'indian railways catering and tourism',
        'zerodha broking',
        'cheq1 yesbank',
        'atm cash',
        'indmoney',
        'snapmint',
        'bank'
    }
    
    if any(remove in name for remove in merchants_to_remove):
        return None
    
    # Merchant name standardization mappings (unchanged)
    merchant_mappings = {
        # Food & Dining
        'swiggy': 'Swiggy',
        'instamart': 'Swiggy',
        'zomato': 'Zomato',
        'district': 'Zomato',
        'dominos': 'Dominos Pizza',
        'pizza hut': 'Pizza Hut',
        'mcdonald': 'McDonalds',
        'kfc': 'KFC',
        'subway': 'Subway',
        'burger king': 'Burger King',
        'dunkin': 'Dunkin Donuts',
        'starbucks': 'Starbucks',
        'cafe coffee day': 'Cafe Coffee Day',
        'ccd': 'Cafe Coffee Day',
        
        # E-commerce
        'amazon': 'Amazon',
        'flipkart': 'Flipkart',
        'ekart': 'Flipkart',
        'myntra': 'Myntra',
        'ajio': 'AJIO',
        'tatacliq': 'Tata CLiQ',
        'meesho': 'Meesho',
        'nykaa': 'Nykaa',
        'firstcry': 'FirstCry',
        
        # Grocery & Supermarkets
        'dmart': 'DMart',
        'avenue supermarts': 'DMart',
        'blinkit': 'Blinkit',
        'grofers': 'Blinkit',
        'bigbasket': 'BigBasket',
        'bb now': 'BigBasket',
        'zepto': 'Zepto',
        'jiomart': 'JioMart',
        'reliance smart': 'JioMart',
        'reliance fresh': 'JioMart',
        'more retail': 'More',
        'spencers': 'Spencers',
        'nature basket': 'Natures Basket',
        
        # Entertainment
        'netflix': 'Netflix',
        'hotstar': 'Disney+ Hotstar',
        'disney': 'Disney+ Hotstar',
        'amazon prime': 'Amazon Prime',
        'prime video': 'Amazon Prime',
        'bookmyshow': 'BookMyShow',
        'pvr': 'PVR Cinemas',
        'inox': 'INOX',
        'sony liv': 'Sony LIV',
        'zee5': 'ZEE5',
        'voot': 'Voot',
        
        # Travel
        'makemytrip': 'MakeMyTrip',
        'mmt': 'MakeMyTrip',
        'oyo': 'OYO Rooms',
        'uber': 'Uber',
        'ola': 'Ola Cabs',
        'rapido': 'Rapido',
        'goibibo': 'Goibibo',
        'cleartrip': 'Cleartrip',
        'easemytrip': 'EaseMyTrip',
        'yatra': 'Yatra',
        
        # Utilities
        'tata play': 'Tata Play',
        'dish tv': 'Dish TV',
        'd2h': 'D2H',
        
        # Fuel
        'shell': 'Shell',
        'reliance petroleum': 'Reliance Petroleum',
        'essar': 'Essar Oil',
        
        # Healthcare
        'apollo': 'Apollo Pharmacy',
        'medplus': 'MedPlus',
        'pharmeasy': 'PharmEasy',
        '1mg': '1MG',
        'netmeds': 'Netmeds',
        'wellness forever': 'Wellness Forever',
        'frank ross': 'Frank Ross',
        'guardian pharmacy': 'Guardian Pharmacy',
        
        # Fashion & Beauty
        'lifestyle': 'Lifestyle',
        'westside': 'Westside',
        'shoppers stop': 'Shoppers Stop',
        'central': 'Central',
        'pantaloons': 'Pantaloons',
        'max fashion': 'Max Fashion',
        'h&m': 'H&M',
        'zara': 'Zara',
        'marks spencer': 'Marks & Spencer',
        'fabindia': 'FabIndia',
        'biba': 'Biba',
        'w for woman': 'W',
        
        # Electronics
        'croma': 'Croma',
        'reliance digital': 'Reliance Digital',
        'vijay sales': 'Vijay Sales',
        'apple': 'Apple',
        'samsung': 'Samsung',
        'oneplus': 'OnePlus',
        'mi store': 'Mi Store',
        'lenovo': 'Lenovo',
        'dell': 'Dell',
        'hp world': 'HP'
    }
    
    # Common prefixes to remove
    prefixes_to_remove = [
        'payment to', 'payment from', 'paid to', 'upi', 'transfer to', 'transfer from',
        'payment at', 'payment for', 'purchase from', 'purchase at', 'bill payment',
        'recharge', 'subscription', 'order from', 'order at'
    ]
    
    # Common suffixes to remove
    suffixes_to_remove = [
        'pvt ltd', 'private limited', 'limited', 'ltd', 'india', 'retail', 'online',
        'services', 'service', 'solutions', 'solution', 'technologies', 'technology',
        'payments', 'payment', 'store', 'stores', 'shop', 'shopping', 'enterprise',
        'enterprises', 'corporation', 'corp', 'inc', 'llp', 'company', 'co', 'private',
        'pvt', 'order', 'delivery', 'bill', 'recharge', 'subscription'
    ]
    
    # Remove prefixes
    for prefix in prefixes_to_remove:
        if name.startswith(prefix):
            name = name[len(prefix):].strip()
    
    # Remove suffixes
    for suffix in suffixes_to_remove:
        if name.endswith(suffix):
            name = name[:-len(suffix)].strip()
    
    # Check mappings
    for key, value in merchant_mappings.items():
        if key in name:
            return value
    
    # If no mapping found, just capitalize words
    return ' '.join(word.capitalize() for word in name.split())
    

In [0]:
# Ensure merchant_standardized and merchant_extracted exist
if 'merchant_standardized' not in sample_df_clean.columns:
    sample_df_clean['merchant_standardized'] = sample_df_clean['merchant'].apply(standardize_merchant_name)

if 'merchant_extracted' not in sample_df_clean.columns:
    # Fallback: use raw merchant if no separate extraction logic
    sample_df_clean['merchant_extracted'] = sample_df_clean['merchant']

# Find merchants with multiple name variants
variant_counts = {}
for std_name, group in sample_df_clean.groupby('merchant_standardized'):
    variants = group['merchant_extracted'].dropna().unique()
    if len(variants) > 1:  # Only include merchants with multiple variants
        variant_counts[std_name] = {
            'variant_count': len(variants),
            'variants': variants,
            'record_count': len(group)
        }

# Sort by number of variants
top_variants = sorted(variant_counts.items(), key=lambda x: x[1]['variant_count'], reverse=True)

print("Merchants with Most Name Variations:")
print(f"{'Standardized Name':<20} {'Variants':<10} {'Records':<12}")
print("-" * 50)

for std_name, data in top_variants[:10]:
    print(f"{std_name[:20]:<20} {data['variant_count']:<10} {data['record_count']:<12,}")
    example_variants = ', '.join([str(v) for v in list(data['variants'])[:3]])
    print(f"  Example variants: {example_variants}{'...' if len(data['variants']) > 3 else ''}")
    print()


## Apply merchant name standardization


In [0]:
# First, extract merchant names from JSON format
print("1. Extracting merchant names from JSON...")
sample_df_clean['merchant_extracted'] = sample_df_clean['merchant'].apply(lambda x: re.search(r'"name":"([^"]+)"', str(x)).group(1) if re.search(r'"name":"([^"]+)"', str(x)) else x)

# Then apply standardization
print("2. Standardizing merchant names...")
sample_df_clean['merchant_standardized'] = sample_df_clean['merchant_extracted'].apply(standardize_merchant_name)

# Remove rows where standardization returned None (banks and competitors)
print("3. Filtering out banks and competitors...")
sample_df_clean = sample_df_clean[sample_df_clean['merchant_standardized'].notna()].copy()

# Calculate reduction in unique merchants
before_count = sample_df_clean['merchant_extracted'].nunique()
after_count = sample_df_clean['merchant_standardized'].nunique()
reduction = before_count - after_count
reduction_pct = (reduction / before_count * 100) if before_count > 0 else 0

print("\nStandardization Results:")
print("-" * 50)
print(f"Unique merchants before: {before_count:,}")
print(f"Unique merchants after: {after_count:,}")
print(f"Reduction: {reduction:,} ({reduction_pct:.1f}%)")

# Show top standardized merchants by record count
print("\nTop 50 Merchants by Record Count:")
print("-" * 50)
top_by_records = sample_df_clean['merchant_standardized'].value_counts().head(50)
for merchant, count in top_by_records.items():
    print(f"- {merchant}: {count:,} records")

# Show top standardized merchants by user count
print("\nTop 50 Merchants by Unique Users:")
print("-" * 50)
top_by_users = sample_df_clean.groupby('merchant_standardized')['user_id'].nunique().sort_values(ascending=False).head(50)
for merchant, count in top_by_users.items():
    print(f"- {merchant}: {count:,} users")

# Show merchants with most name variations
print("\nMerchants with Most Name Variations:")
print("-" * 50)
variant_counts = sample_df_clean.groupby('merchant_standardized')['merchant_extracted'].nunique().sort_values(ascending=False).head(10)
for merchant, count in variant_counts.items():
    variants = sample_df_clean[sample_df_clean['merchant_standardized'] == merchant]['merchant_extracted'].unique()[:3]
    print(f"\n{merchant}: {count} variants")
    print(f"Example variants: {', '.join(variants)}{'...' if len(variants) > 3 else ''}")

# Create a dataframe with top merchants data for further analysis
top_merchants_df = sample_df_clean.groupby('merchant_standardized').agg({
    'user_id': 'nunique',
    'total_txns': 'sum',
    'total_spend': 'sum',
    'merchant_extracted': lambda x: len(set(x))  # Number of variations
}).reset_index()

top_merchants_df.columns = ['merchant', 'unique_users', 'total_transactions', 'total_spend', 'name_variations']
top_merchants_df['avg_spend_per_user'] = top_merchants_df['total_spend'] / top_merchants_df['unique_users']
top_merchants_df['avg_txn_value'] = top_merchants_df['total_spend'] / top_merchants_df['total_transactions']
top_merchants_df.sort_values('unique_users', ascending=False, inplace=True)

# Save top merchants data and standardized data
volume_path = "/Volumes/jupiter/temp/temp/"

# Save top merchants data and standardized data to the Volume
top_merchants_df.to_csv(f"{volume_path}top_merchants_data_backfill.csv", index=False)
sample_df_clean.to_csv(f"{volume_path}data-standardized_backfill.csv", index=False)

print("\nData saved to Databricks Volume:")
print(f"- {volume_path}top_merchants_data.csv: Top merchants analysis")
print(f"- {volume_path}data-standardized.csv: Full standardized dataset")


## Save top merchants to CSV

In [0]:
csv_path = "/Volumes/jupiter/temp/temp/top_merchants_data_backfill.csv"
top_merchants_df = pd.read_csv(csv_path)
print(f"Shape: {top_merchants_df.shape}")
print(f"Columns: {list(top_merchants_df.columns)}")
display(top_merchants_df.head(50))

## Save standardised data to CSV

In [0]:
csv_path = "/Volumes/jupiter/temp/temp/data-standardized_backfill.csv"
stand_data_df = pd.read_csv(csv_path)
print(f"Shape: {stand_data_df.shape}")
print(f"Columns: {list(stand_data_df.columns)}")
display(stand_data_df.head(20))

## Analyze impact of merchant standardization


In [0]:
# Before standardization metrics
before_merchants = sample_df_clean['merchant_extracted'].nunique()
before_stats = sample_df_clean.groupby('merchant_extracted').agg({
    'user_id': 'nunique',
    'total_txns': 'sum',
    'total_spend': 'sum'
}).sort_values('user_id', ascending=False)

# After standardization metrics
after_merchants = sample_df_clean['merchant_standardized'].nunique()
after_stats = sample_df_clean.groupby('merchant_standardized').agg({
    'user_id': 'nunique',
    'total_txns': 'sum',
    'total_spend': 'sum'
}).sort_values('user_id', ascending=False)

# Print overall impact
print("\n📈 Overall Impact:")
print("-" * 50)
print(f"Unique merchants before standardization: {before_merchants:,}")
print(f"Unique merchants after standardization:  {after_merchants:,}")
print(f"Reduction in merchant variations: {before_merchants - after_merchants:,} ({(before_merchants - after_merchants)/before_merchants:.1%})")

# Print top 20 merchants by user count
print("\n👥 Top 20 Merchants by User Count:")
print("-" * 80)
print(f"{'Merchant':<30} {'Users':>10} {'Transactions':>15} {'Total Spend':>20}")
print("-" * 80)
for merchant, stats in after_stats.head(20).iterrows():
    print(f"{merchant:<30} {stats['user_id']:>10,} {stats['total_txns']:>15,} {stats['total_spend']:>20,.2f}")


## Enhanced Merchant Categorization (NPCI Standards + Pattern Matching)

We implement a hybrid merchant categorization approach that combines:

1. **MCC Code-Based Categorization** (Primary)
   - Uses NPCI RuPay Merchant Category Codes (June 2025)
   - Covers both POS and ECOM channels
   - Prioritizes high-transacting categories

2. **Pattern-Based Categorization** (Secondary)
   - Extensive keyword matching for each category
   - Handles merchants without MCC codes
   - Reduces 'Others' category significantly

3. **Fallback to Jupiter Defined Categories** (Tertiary)
   - Uses columns from `fact-mm-transactions` such as `jupiter-coarsegrain-category` , `usercategory` and `appcategory`
   - Handles rows which don't have MCC-Code or Pattern defined

**Business Categories:**
1. **Food & Dining** - Restaurants, cafes, food delivery
2. **Grocery & Supermarkets** - Grocery stores, supermarkets, convenience stores
3. **E-Commerce & Retail** - Online shopping, retail stores
4. **Financial Services** - Banks, payment apps, insurance
5. **Fuel & Transportation** - Petrol, ride-hailing, travel
6. **Entertainment & Media** - Streaming, movies, gaming
7. **Utilities & Services** - Telecom, electricity, bills
8. **Healthcare & Wellness** - Hospitals, pharmacies, fitness
9. **Shopping & Fashion** - Clothing, accessories, beauty
10. **Education & Learning** - Schools, courses, training
11. **Travel & Hospitality** - Hotels, airlines, tourism
12. **Technology & Electronics** - Gadgets, computers, software
13. **Others** - Unclassified merchants

This enhanced categorization provides better granularity and accuracy for merchant analysis.


In [0]:
import re
import pandas as pd
from pandarallel import pandarallel
from tqdm import tqdm
import multiprocessing

# ================================
# ⚡ Initialize Parallelization
# ================================
pandarallel.initialize(
    nb_workers=min(multiprocessing.cpu_count()-1, 8),
    progress_bar=True,
    verbose=1
)
tqdm.pandas()

# ================================
# 📌 MCC-based Categorization
# ================================
def categorize_merchant_by_mcc(mcc_code):
    """Categorize merchants using NPCI MCC standards. Handles int, float64, str, JSON."""
    if pd.isna(mcc_code):
        return None
    
    # Normalize MCC string
    mcc_str = str(mcc_code).strip()
    
    # Extract if JSON-like
    mcc_match = re.search(r'"mccCode":"(\d+)"', mcc_str)
    if mcc_match:
        mcc_str = mcc_match.group(1)
    
    try:
        mcc = int(float(mcc_str))  # robust parsing for 5812, 5812.0, "05812"
    except (ValueError, TypeError):
        return None
    
    # === Category mapping ===
    if mcc in [5811, 5812, 5813, 5814, 5462, 5499, 5921]:
        return "Food & Dining"
    elif mcc in [5411, 5422, 5441, 5451, 5310]:
        return "Grocery & Supermarkets"
    elif mcc in [5311, 5399, 5999, 5945, 5734, 5964, 5969, 5970, 5309, 5732, 5722, 4215]:
        return "E-Commerce & Retail"
    elif mcc in [5541, 5542, 4111, 4112, 4121, 4131, 4784, 4789, 7523, 5983]:
        return "Fuel & Transportation"
    elif mcc in [7832, 7922, 7841, 7829, 7932, 7933, 7941, 7991, 7994, 7995,
                 7996, 7999, 5815, 5816, 5817, 5818, 4722]:
        return "Entertainment & Media"
    elif mcc in [4814, 4815, 4821, 4899, 4900, 6513, 7299, 7311, 7399, 8111, 
                 8999, 9399, 9311, 7349]:
        return "Utilities & Services"
    elif mcc in [8062, 8071, 8099, 5912, 5975, 5976, 5977, 8011, 8021, 8031,
                 8041, 8042, 8043, 8049, 8050, 5047]:
        return "Healthcare & Wellness"
    elif mcc in [5611, 5621, 5631, 5641, 5651, 5655, 5661, 5681, 5691, 5697,
                 5698, 5699, 5931, 5944, 5949, 5950, 5947]:
        return "Shopping & Fashion"
    elif mcc in [8211, 8220, 8241, 8244, 8249, 8299]:
        return "Education & Learning"
    elif mcc in [6012, 6540, 6300]:
        return "Financial Services"
    elif mcc in [7011]:
        return "Travel & Hospitality"
    elif mcc in [5511]:
        return "Automotive & Dealerships"
    
    return None

# ================================
# 📌 Pattern-based Categorization
# ================================
def categorize_merchant_by_pattern(merchant_name):
    """Categorize merchants using pattern matching."""
    if not isinstance(merchant_name, str):
        return None
    merchant_lower = merchant_name.lower()
    patterns = {
        "Food & Dining": ["restaurant","food","swiggy","zomato","pizza","burger","cafe","dining","hotel","bar","pub","dhaba"],
        "Grocery & Supermarkets": ["grocery","supermarket","mart","zepto","blinkit","bigbasket","kirana","provision","store","vegetable","fruit"],
        "E-Commerce & Retail": ["amazon","flipkart","myntra","shopping","retail","store","mall","shop","outlet","ecommerce","online"],
        "Fuel & Transportation": ["petrol","fuel","uber","ola","taxi","transport","ride","cab","auto","bus","train","metro","rail"],
        "Entertainment & Media": ["netflix","prime","hotstar","movie","cinema","theatre","concert","ticket","streaming","subscription","music"],
        "Utilities & Services": ["recharge","bill","utility","airtel","jio","vi","bsnl","telecom","mobile","wifi","electricity","power"],
        "Healthcare & Wellness": ["hospital","pharmacy","medical","apollo","doctor","clinic","checkup","medicine","drug","surgery","gym","fitness"],
        "Shopping & Fashion": ["fashion","clothing","apparel","lifestyle","accessories","wear","jeans","saree","kurta","lehenga","footwear","jewelry"],
        "Education & Learning": ["education","school","college","university","academy","coaching","tuition","tutorial","class","course","training"]
    }
    for category, keywords in patterns.items():
        if any(k in merchant_lower for k in keywords):
            return category
    return None

# ================================
# 📌 Explicit Merchant Mapping
# ================================
merchant_category_mapping = {
    "Swiggy": "Food & Dining", "Zomato": "Food & Dining", "Dominos Pizza": "Food & Dining",
    "Amazon": "E-Commerce & Retail", "Flipkart": "E-Commerce & Retail", "Myntra": "E-Commerce & Retail",
    "DMart": "Grocery & Supermarkets", "Blinkit": "Grocery & Supermarkets", "BigBasket": "Grocery & Supermarkets",
    "Netflix": "Entertainment & Media", "Disney+ Hotstar": "Entertainment & Media", "BookMyShow": "Entertainment & Media",
    "Jio": "Utilities & Services", "Airtel": "Utilities & Services", "Apollo Pharmacy": "Healthcare & Wellness",
    "Lifestyle": "Shopping & Fashion", "H&M": "Shopping & Fashion", "Zara": "Shopping & Fashion"
}

# ================================
# 📌 Hybrid Categorization
# ================================
def categorize_merchant_hybrid(row):
    """Hybrid categorization with MCC, explicit mapping, pattern, and fallbacks."""
    merchant_name = row.get("merchant_standardized", None)
    
    # 1. Explicit
    if merchant_name in merchant_category_mapping:
        return merchant_category_mapping[merchant_name]
    
    # 2. MCC
    if "mcccode" in row and pd.notna(row["mcccode"]):
        mcc_category = categorize_merchant_by_mcc(row["mcccode"])
        if mcc_category:
            return mcc_category
    
    # 3. Pattern
    pattern_category = categorize_merchant_by_pattern(merchant_name)
    if pattern_category:
        return pattern_category
    
    # 4. Fallbacks: Jupiter, usercategory, appcategory (if present)
    for col in ["jupiter_coarsegrain_category", "usercategory", "appcategory"]:
        if col in row and pd.notna(row[col]):
            cat_val = str(row[col]).strip().upper()
            if "FOOD" in cat_val: return "Food & Dining"
            if "GROCERY" in cat_val or "SUPERMARKET" in cat_val: return "Grocery & Supermarkets"
            if "SHOPPING" in cat_val or "RETAIL" in cat_val or "ECOM" in cat_val: return "E-Commerce & Retail"
            if "TRAVEL" in cat_val or "FUEL" in cat_val or "TRANSPORT" in cat_val: return "Fuel & Transportation"
            if "ENTERTAIN" in cat_val or "MOVIE" in cat_val or "MEDIA" in cat_val: return "Entertainment & Media"
            if "UTILITY" in cat_val or "BILL" in cat_val or "SERVICE" in cat_val: return "Utilities & Services"
            if "HEALTH" in cat_val or "MEDICAL" in cat_val or "PHARM" in cat_val: return "Healthcare & Wellness"
            if "FASHION" in cat_val or "CLOTH" in cat_val or "APPAREL" in cat_val: return "Shopping & Fashion"
            if "EDUCATION" in cat_val or "LEARN" in cat_val or "SCHOOL" in cat_val or "COLLEGE" in cat_val: return "Education & Learning"
    
    # 5. Default
    return "Others"

print("⚡ Applying hybrid categorization with pandarallel...")
sample_df_clean["merchant_category"] = sample_df_clean.parallel_apply(categorize_merchant_hybrid, axis=1)

# Distribution
category_counts = sample_df_clean["merchant_category"].value_counts()
print("\nCategory Distribution After Hybrid Categorization:")
for cat, cnt in category_counts.items():
    print(f"{cat:<25} {cnt:>8,} ({cnt/len(sample_df_clean):.2%})")

In [0]:
# ================================
# 📊 MCC vs Merchant Category Analysis (exclude 0 + null)
# ================================

# 1. Filter rows with valid MCC codes (not null and not 0)
mcc_df = sample_df_clean[
    (sample_df_clean["mcccode"].notna()) & (sample_df_clean["mcccode"] != 0)
].copy()

# 2. Distribution of MCC → Merchant Category
mcc_category_distribution = (
    mcc_df.groupby("mcccode")["merchant_category"]
    .value_counts(normalize=False)
    .unstack(fill_value=0)
)

# 3. Percentage of MCC-coded rows labeled as "Others"
total_mcc_rows = len(mcc_df)
others_rows = len(mcc_df[mcc_df["merchant_category"] == "Others"])
others_pct = (others_rows / total_mcc_rows) * 100

print(f"Total rows with valid MCC code: {total_mcc_rows:,}")
print(f"Rows with MCC code & category='Others': {others_rows:,} ({others_pct:.2f}%)")

# 4. Show the pivot table (MCC → categories)
display(mcc_category_distribution.head(20))  # show top 20 for readability

# 5. Visualization: Heatmap of MCC vs Categories
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(12,6))
sns.heatmap(mcc_category_distribution.T, cmap="Blues", cbar=True)
plt.title("MCC vs Merchant Category Distribution", fontsize=14, weight="bold")
plt.xlabel("MCC Code")
plt.ylabel("Merchant Category")
plt.show()

# 6. Visualization: Focus on 'Others' share by MCC
others_share = (
    mcc_df.groupby("mcccode")["merchant_category"]
    .apply(lambda x: (x == "Others").mean())
    .reset_index(name="others_share")
)

plt.figure(figsize=(12,5))
sns.barplot(data=others_share.sort_values("others_share", ascending=False).head(20),
            x="mcccode", y="others_share", palette="Reds_r")
plt.title("Top 20 MCC Codes with Highest 'Others' Share", fontsize=14, weight="bold")
plt.ylabel("Share of 'Others'")
plt.xlabel("MCC Code")
plt.xticks(rotation=45)
plt.show()

# RFM Analysis for User-Merchant Pairs

**RFM Framework**: Analyze customer behavior patterns using three key dimensions:
- **Recency (R)**: How recently did a user transact with a merchant?
- **Frequency (F)**: How often does a user transact with a merchant?  
- **Monetary (M)**: How much does a user spend with a merchant?

**Implementation for 50M+ Records**:
- **Optimized Grouping**: Efficient aggregation by user-merchant pairs
- **Quintile Scoring**: 1-5 scale for each RFM dimension
- **Customer Segmentation**: 8 distinct behavioral segments
- **Performance Tracking**: Monitor processing time and memory usage

**Customer Segments**:
- **Champions** (R5,F5,M5): Best customers - high value, frequent, recent
- **Loyal Customers** (R4-5,F4-5,M3-5): Consistent high-value customers
- **Potential Loyalists** (R4-5,F2-3,M1-3): Recent customers with growth potential  
- **New Customers** (R4-5,F1,M1-2): Recently acquired users
- **At Risk** (R1-2,F4-5,M3-5): Valuable customers becoming inactive
- **Cannot Lose Them** (R1-2,F1,M4-5): Inactive but historically high-value
- **Lost Customers** (R1,F1,M1): Churned low-value customers
- **Others**: All other combinations

**Output**: User-merchant pair level RFM scores and segments for recommendation engine


In [0]:
import pandas as pd
import numpy as np
import time

# -------------------------
# 1. Helper: Customer Segmentation
# -------------------------
def get_customer_segment(row):
    """Assign customer segments based on RFM scores"""
    r, f, m = row['R_score'], row['F_score'], row['M_score']

    if r >= 4 and f >= 4 and m >= 4:
        return 'Champions'
    elif r >= 3 and f >= 3 and m >= 3:
        return 'Loyal Customers'
    elif r >= 4 and f <= 2:
        return 'New Customers'
    elif r >= 3 and f >= 3 and m <= 2:
        return 'Potential Loyalists'
    elif r <= 2 and f >= 3:
        return 'At Risk'
    elif r <= 2 and f <= 2 and m >= 3:
        return 'Cannot Lose Them'
    elif r <= 2 and f <= 2 and m <= 2:
        return 'Lost Customers'
    else:
        return 'Others'

In [0]:
# -------------------------
# 2. Generic RFM Analysis Function
# -------------------------
def run_rfm_analysis(df, groupby_cols, analysis_name="Merchant-Level"):
    """
    Run RFM analysis on given grouping (merchant or category)

    Args:
        df (pd.DataFrame): input data
        groupby_cols (list): grouping cols (['user_id','merchant_standardized'] or ['user_id','merchant_category'])
        analysis_name (str): Name of analysis
    """
    print(f"\n📊 RFM Analysis: {analysis_name}")
    print("=" * 60)

    start_time = time.time()

    # --- Step 1: Pick recency column
    date_columns = ["last_txn_date", "first_txn_date", "transactiondatetime"]
    available_date_cols = [c for c in date_columns if c in df.columns]
    date_col = available_date_cols[0] if available_date_cols else None
    current_date = df[date_col].max() if date_col else pd.Timestamp.now()

    print(f"📅 Using {date_col if date_col else 'synthetic recency'} as reference")
    print(f"   Reference date: {current_date}\n")

        # --- Step 2: Aggregate for R, F, M
    rfm = df.groupby(groupby_cols).agg({
            'total_txns': 'sum',      # Frequency
            'total_spend': 'sum',     # Monetary
            date_col if date_col else groupby_cols[0]: 'max' if date_col else 'count'
        }).reset_index()

    if date_col:
            rfm.columns = groupby_cols + ['frequency', 'monetary', 'last_transaction_date']
            rfm['recency'] = (current_date - rfm['last_transaction_date']).dt.days
    else:
            rfm.columns = groupby_cols + ['frequency', 'monetary', 'temp_col']
            max_freq = rfm['frequency'].max()
            rfm['recency'] = ((max_freq - rfm['frequency']) / max_freq * 365).astype(int)

    # --- Step 3: RFM Scores (1–5)
    rfm['R_score'] = pd.qcut(rfm['recency'].rank(method='first'), q=5, labels=[5,4,3,2,1]).astype(int)
    rfm['F_score'] = pd.qcut(rfm['frequency'].rank(method='first'), q=5, labels=[1,2,3,4,5]).astype(int)
    rfm['M_score'] = pd.qcut(rfm['monetary'].rank(method='first'), q=5, labels=[1,2,3,4,5]).astype(int)
    rfm['RFM_score'] = rfm['R_score'].astype(str) + rfm['F_score'].astype(str) + rfm['M_score'].astype(str)

    # --- Step 4: Segmentation
    rfm['customer_segment'] = rfm.apply(get_customer_segment, axis=1)

    processing_time = time.time() - start_time
    print(f"✅ RFM metrics & scores calculated in {processing_time:.2f}s")
    print(f"   Pairs analyzed: {len(rfm):,}\n")

    # --- Step 5: Reporting
    segment_counts = rfm['customer_segment'].value_counts()
    print("👥 Customer Segment Distribution:")
    for seg, count in segment_counts.items():
        print(f"   {seg:<20}: {count:>8,} ({count/len(rfm)*100:5.1f}%)")

    print("\n🔢 RFM Score Distributions:")
    print("   Recency :", dict(rfm['R_score'].value_counts().sort_index()))
    print("   Frequency:", dict(rfm['F_score'].value_counts().sort_index()))
    print("   Monetary :", dict(rfm['M_score'].value_counts().sort_index()))

    return rfm


In [0]:
# -------------------------
# 3. Run for Merchant-Level and Category-Level
# -------------------------
merchant_rfm = run_rfm_analysis(sample_df_clean, ['user_id','merchant_standardized'], "User-Merchant")
category_rfm = run_rfm_analysis(sample_df_clean, ['user_id','merchant_category'], "User-Category")

# -------------------------
# 4. Merge Back to Original for Visibility (per row tagging)
# -------------------------
sample_df_clean = sample_df_clean.merge(
    merchant_rfm[['user_id','merchant_standardized','R_score','F_score','M_score','RFM_score','customer_segment']],
    on=['user_id','merchant_standardized'],
    how='left'
)

sample_df_clean = sample_df_clean.merge(
    category_rfm[['user_id','merchant_category','R_score','F_score','M_score','RFM_score','customer_segment']],
    on=['user_id','merchant_category'],
    how='left',
    suffixes=('_merchant','_category')
)

print("\n✅ Final dataframe now includes RFM scores & segments at both Merchant and Category level.")
print(f"Shape: {sample_df_clean.shape}")
print(sample_df_clean.head)

In [0]:
display(sample_df_clean)

In [0]:
import matplotlib.pyplot as plt
import seaborn as sns

# Make plots prettier
sns.set(style="whitegrid", palette="pastel", font_scale=1.2)

# ===============================
# 1. Customer Segment Distribution (Merchant-level)
# ===============================
plt.figure(figsize=(10,6))
segment_counts = merchant_rfm['customer_segment'].value_counts().sort_values(ascending=False)
sns.barplot(x=segment_counts.index, y=segment_counts.values)
plt.title("Merchant-Level Customer Segment Distribution", fontsize=16)
plt.ylabel("Number of User-Merchant Pairs")
plt.xlabel("Customer Segment")
plt.xticks(rotation=30)
plt.show()

# ===============================
# 2. Customer Segment Distribution (Category-level)
# ===============================
plt.figure(figsize=(10,6))
cat_segment_counts = category_rfm['customer_segment'].value_counts().sort_values(ascending=False)
sns.barplot(x=cat_segment_counts.index, y=cat_segment_counts.values)
plt.title("Category-Level Customer Segment Distribution", fontsize=16)
plt.ylabel("Number of User-Category Pairs")
plt.xlabel("Customer Segment")
plt.xticks(rotation=30)
plt.show()

# ===============================
# 3. RFM Score Heatmap (Merchant-level)
# ===============================
rfm_heatmap = merchant_rfm.groupby(['R_score','F_score']).size().unstack(fill_value=0)
plt.figure(figsize=(8,6))
sns.heatmap(rfm_heatmap, annot=True, fmt="d", cmap="Blues")
plt.title("Heatmap of Recency vs Frequency (Merchant-Level)")
plt.ylabel("Recency Score")
plt.xlabel("Frequency Score")
plt.show()

# ===============================
# 4. Top Categories by Champions
# ===============================
top_champion_cats = category_rfm[category_rfm['customer_segment']=="Champions"] \
                        .groupby('merchant_category')['user_id'].nunique() \
                        .sort_values(ascending=False).head(10)

plt.figure(figsize=(10,6))
sns.barplot(x=top_champion_cats.values, y=top_champion_cats.index)
plt.title("Top 10 Categories by Champion Users")
plt.xlabel("Unique Champion Users")
plt.ylabel("Merchant Category")
plt.show()

# ===============================
# 5. Segment Distribution by Category (stacked bar)
# ===============================
cat_segment_dist = category_rfm.groupby(['merchant_category','customer_segment']).size().reset_index(name="count")
cat_segment_pivot = cat_segment_dist.pivot(index="merchant_category", columns="customer_segment", values="count").fillna(0)

cat_segment_pivot_pct = cat_segment_pivot.div(cat_segment_pivot.sum(axis=1), axis=0) * 100
cat_segment_pivot_pct.sort_values("Champions", ascending=False, inplace=True)

cat_segment_pivot_pct.head(10).plot(kind="bar", stacked=True, figsize=(12,6), colormap="tab20")
plt.title("Segment Composition of Top 10 Categories")
plt.ylabel("Percentage of Users")
plt.xlabel("Merchant Category")
plt.legend(title="Segment", bbox_to_anchor=(1.05,1), loc='upper left')
plt.tight_layout()
plt.show()

# ===============================
# 6. Average Spend & Transactions by Segment (Merchant-level)
# ===============================
plt.figure(figsize=(10,6))
avg_metrics = merchant_rfm.groupby('customer_segment')[['frequency','monetary']].mean().sort_values('monetary', ascending=False)
avg_metrics.plot(kind="bar", figsize=(10,6))
plt.title("Average Frequency & Monetary by Segment (Merchant-level)")
plt.ylabel("Average Value")
plt.xticks(rotation=30)
plt.show()

# ===============================
# 7. Category Engagement Summary
# ===============================
cat_stats = category_rfm.groupby('merchant_category').agg({
    'user_id':'nunique',
    'frequency':'sum',
    'monetary':'sum'
}).sort_values('user_id', ascending=False).head(10)

cat_stats['avg_spend_per_user'] = cat_stats['monetary'] / cat_stats['user_id']
cat_stats['avg_txns_per_user'] = cat_stats['frequency'] / cat_stats['user_id']

print("📊 Top 10 Categories Engagement Stats:")
display(cat_stats[['user_id','frequency','monetary','avg_spend_per_user','avg_txns_per_user']])


In [0]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

sns.set(style="whitegrid", palette="Set2", font_scale=1.1)

# Ensure RFM segmentation columns exist
if 'customer_segment' not in sample_df_clean.columns:
    print("⚠️ Running RFM analysis to add missing segmentation columns...")

    # Calculate RFM scores again (minimal version)
    current_date = pd.to_datetime(sample_df_clean['last_txn_date']).max()

    rfm_data = sample_df_clean.groupby(['user_id', 'merchant_standardized']).agg({
        'total_txns': 'sum',
        'total_spend': 'sum',
        'last_txn_date': 'max'
    }).reset_index()

    rfm_data['recency'] = (current_date - rfm_data['last_txn_date']).dt.days

    # R, F, M scores
    rfm_data['R_score'] = pd.qcut(rfm_data['recency'].rank(method='first'), 5, labels=[5,4,3,2,1]).astype(int)
    rfm_data['F_score'] = pd.qcut(rfm_data['total_txns'].rank(method='first'), 5, labels=[1,2,3,4,5]).astype(int)
    rfm_data['M_score'] = pd.qcut(rfm_data['total_spend'].rank(method='first'), 5, labels=[1,2,3,4,5]).astype(int)

    rfm_data['RFM_score'] = (
        rfm_data['R_score'].astype(str) +
        rfm_data['F_score'].astype(str) +
        rfm_data['M_score'].astype(str)
    )

    def get_segment(row):
        r, f, m = row['R_score'], row['F_score'], row['M_score']
        if r >= 4 and f >= 4 and m >= 4: return 'Champions'
        elif r >= 3 and f >= 3 and m >= 3: return 'Loyal Customers'
        elif r >= 4 and f <= 2: return 'New Customers'
        elif r >= 3 and f >= 3 and m <= 2: return 'Potential Loyalists'
        elif r <= 2 and f >= 3: return 'At Risk'
        elif r <= 2 and f <= 2 and m >= 3: return 'Cannot Lose Them'
        elif r <= 2 and f <= 2 and m <= 2: return 'Lost Customers'
        else: return 'Others'

    rfm_data['customer_segment'] = rfm_data.apply(get_segment, axis=1)

    # Merge back into sample_df_clean
    sample_df_clean = sample_df_clean.merge(
        rfm_data[['user_id','merchant_standardized','R_score','F_score','M_score','RFM_score','customer_segment']],
        on=['user_id','merchant_standardized'],
        how='left'
    )

    print("✅ Added RFM scores & segments to sample_df_clean")


# ===================================
# 1. Top 30 Merchants by User Count
# ===================================
top_merchants = (
    sample_df_clean.groupby("merchant_standardized")["user_id"]
    .nunique()
    .sort_values(ascending=False)
    .head(30)
)

plt.figure(figsize=(12,6))
sns.barplot(x=top_merchants.values, y=top_merchants.index)
plt.title("🏆 Top 30 Merchants by Unique Users")
plt.xlabel("Unique Users")
plt.ylabel("Merchant")
plt.show()


# ===================================
# 2. Pie Chart of Categories (by unique users)
# ===================================
category_users = (
    sample_df_clean.groupby("merchant_category")["user_id"]
    .nunique()
    .sort_values(ascending=False)
)

plt.figure(figsize=(8,8))
plt.pie(category_users.values, labels=category_users.index, autopct="%1.1f%%", startangle=140)
plt.title("👥 Users Interacting in Categories")
plt.show()


# ===================================
# 3. Segment Distribution within Each Category
# ===================================
cat_seg = (
    sample_df_clean.groupby(["merchant_category","customer_segment"])["user_id"]
    .nunique()
    .reset_index()
)

plt.figure(figsize=(14,6))
sns.barplot(data=cat_seg, x="merchant_category", y="user_id", hue="customer_segment")
plt.title("📊 Segment Distribution Across Categories")
plt.xticks(rotation=45)
plt.ylabel("Unique Users")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


# ===================================
# 4. Top Categories by Champions
# ===================================
top_champion_cats = (
    sample_df_clean[sample_df_clean["customer_segment"]=="Champions"]
    .groupby("merchant_category")["user_id"]
    .nunique()
    .sort_values(ascending=False)
    .head(10)
)

plt.figure(figsize=(10,6))
sns.barplot(x=top_champion_cats.values, y=top_champion_cats.index)
plt.title("🏆 Top 10 Categories by Champion Users")
plt.xlabel("Champion Users")
plt.ylabel("Category")
plt.show()


# ===================================
# 5. Average Spend per Segment
# ===================================
avg_spend = (
    sample_df_clean.groupby("customer_segment")["total_spend"]
    .mean()
    .sort_values(ascending=False)
)

plt.figure(figsize=(10,6))
sns.barplot(x=avg_spend.index, y=avg_spend.values)
plt.title("💰 Average Spend by Segment")
plt.ylabel("Avg Spend (₹)")
plt.xticks(rotation=30)
plt.show()


# ===================================
# 6. Heatmap of RFM Scores (Recency vs Frequency)
# ===================================
rfm_heatmap = (
    sample_df_clean.groupby(["R_score","F_score"])
    .size()
    .unstack(fill_value=0)
)

plt.figure(figsize=(8,6))
sns.heatmap(rfm_heatmap, annot=True, fmt="d", cmap="YlGnBu")
plt.title("📈 Heatmap: Recency vs Frequency (All Merchants)")
plt.xlabel("Frequency Score")
plt.ylabel("Recency Score")
plt.show()


In [0]:
# =====================================
# 🔑 Extra RFM Metrics & Enrichment
# =====================================

print("🔧 Adding extra RFM metrics...")

# --------------------------
# 1. Weighted RFM Score
# --------------------------
# Adjust weights depending on business need (recency usually matters more)
merchant_rfm['RFM_weighted'] = (
    merchant_rfm['R_score']*0.4 + 
    merchant_rfm['F_score']*0.3 + 
    merchant_rfm['M_score']*0.3
).round(2)

category_rfm['RFM_weighted'] = (
    category_rfm['R_score']*0.4 + 
    category_rfm['F_score']*0.3 + 
    category_rfm['M_score']*0.3
).round(2)

# --------------------------
# 2. Per-user Percentiles (normalize within user)
# --------------------------
# Merchant-level
merchant_rfm['F_percentile'] = merchant_rfm.groupby('user_id')['frequency'].rank(pct=True)
merchant_rfm['M_percentile'] = merchant_rfm.groupby('user_id')['monetary'].rank(pct=True)

# Category-level
category_rfm['F_percentile'] = category_rfm.groupby('user_id')['frequency'].rank(pct=True)
category_rfm['M_percentile'] = category_rfm.groupby('user_id')['monetary'].rank(pct=True)

# --------------------------
# 3. Engagement Index
# --------------------------
merchant_rfm['engagement_index'] = (
    merchant_rfm['F_percentile']*0.5 + merchant_rfm['M_percentile']*0.5
).round(3)

category_rfm['engagement_index'] = (
    category_rfm['F_percentile']*0.5 + category_rfm['M_percentile']*0.5
).round(3)

print("✅ Extra RFM metrics added: RFM_weighted, F_percentile, M_percentile, engagement_index")

# --------------------------
# 4. Merge back into sample_df_clean
# --------------------------
# Merchant-level enrich
sample_df_clean = sample_df_clean.merge(
    merchant_rfm[['user_id','merchant_standardized','R_score','F_score','M_score',
                  'RFM_score','RFM_weighted','customer_segment',
                  'F_percentile','M_percentile','engagement_index']],
    how='left',
    left_on=['user_id','merchant_standardized'],
    right_on=['user_id','merchant_standardized']
)

# Category-level enrich
sample_df_clean = sample_df_clean.merge(
    category_rfm[['user_id','merchant_category','R_score','F_score','M_score',
                  'RFM_score','RFM_weighted','customer_segment',
                  'F_percentile','M_percentile','engagement_index']],
    how='left',
    left_on=['user_id','merchant_category'],
    right_on=['user_id','merchant_category'],
    suffixes=('_merchant','_category')
)

print("✅ RFM enrichment merged back into sample_df_clean")
print(f"📊 Final enriched shape: {sample_df_clean.shape}")
display(sample_df_clean)


In [0]:
# Convert all boolean columns in one go
bool_cols = sample_df_clean.select_dtypes(include=['bool']).columns
sample_df_clean[bool_cols] = sample_df_clean[bool_cols].astype('float')


In [0]:
# Convert the pandas DataFrame to a Spark DataFrame
spark_df = spark.createDataFrame(sample_df_clean)

# Use the Databricks display() command
display(spark_df)

In [0]:
display(sample_df_clean)

## 📊 Visualization of Enriched RFM


In [0]:
import matplotlib.pyplot as plt
import seaborn as sns

# -------------------------------
# 1. Distribution of Weighted RFM Scores
# -------------------------------
plt.figure(figsize=(10,5))
sns.histplot(sample_df_clean['RFM_weighted_merchant'].dropna(), bins=20, kde=True)
plt.title("Distribution of Weighted RFM Scores (Merchant-level)")
plt.xlabel("RFM Weighted Score")
plt.ylabel("Count")
plt.show()

plt.figure(figsize=(10,5))
sns.histplot(sample_df_clean['RFM_weighted_category'].dropna(), bins=20, kde=True, color="orange")
plt.title("Distribution of Weighted RFM Scores (Category-level)")
plt.xlabel("RFM Weighted Score")
plt.ylabel("Count")
plt.show()

In [0]:
# -------------------------------
# 3. Category Segment Distribution
# -------------------------------
cat_seg = (
    sample_df_clean.groupby(["merchant_category","customer_segment_category"])["user_id"]
    .nunique()
    .reset_index()
)

plt.figure(figsize=(14,6))
sns.barplot(data=cat_seg, x="merchant_category", y="user_id", hue="customer_segment_category")
plt.title("User Distribution Across Segments per Category")
plt.xticks(rotation=45, ha="right")
plt.ylabel("Unique Users")
plt.show()

In [0]:
# ================================
# 💾 Save Enriched RFM Data (Training Set)
# ================================

output_path = "/Volumes/jupiter/temp/temp/rfm_analysis_test.csv"

sample_df_clean.to_csv(output_path, index=False)

print(f"✅ Enriched RFM training dataset saved to: {output_path}")
print(f"   Shape: {sample_df_clean.shape}")
print(f"   Columns: {list(sample_df_clean.columns)}")

In [0]:
import pandas as pd
import base64
from IPython.display import HTML, display

# 1. Convert the DataFrame to a CSV string.
# `index=False` prevents pandas from writing row indices into the CSV.
csv_string = sample_df_clean.to_csv(index=False)

# 2. Encode the CSV string into base64.
# This is necessary to embed the file data directly into the HTML link.
b64 = base64.b64encode(csv_string.encode()).decode()

# 3. Define the filename for the download.
file_name = "rfm_test.csv"

# 4. Create an HTML anchor tag (`<a>`) with the download link.
# The `href` attribute contains the base64-encoded data with the correct MIME type.
# The `download` attribute tells the browser to download the file with the specified name.
# --- FIX IS ON THIS LINE ---
download_link = f'<a href="data:text/csv;base64,{b64}" download="{file_name}">Click here to download {file_name}</a>'

# 5. Display the link in the notebook output.
# When you run this cell, a clickable link will appear.
display(HTML(download_link))

## Backfill Evals

In [0]:
pip install pandarallel

In [0]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pandarallel import pandarallel
import multiprocessing

# Display settings
pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", 50)
sns.set_style("whitegrid")

# Parameters
K = 3  # Top-K merchants/categories for evaluation
pandarallel.initialize(nb_workers=min(multiprocessing.cpu_count()-1, 8), progress_bar=True, verbose=1)


In [0]:
import pandas as pd

# train_spark = spark.read.csv("/Volumes/jupiter/temp/temp/rfm_analysis_train.csv", header=True, inferSchema=True)
# train_df = train_spark.toPandas()  # convert only if memory allows

# test_spark = spark.read.csv("/Volumes/jupiter/temp/temp/rfm_analysis_test.csv", header=True, inferSchema=True)
# test_df = test_spark.toPandas()  # convert only if memory allows

train_df = pd.read_csv("/Volumes/jupiter/temp/temp/rfm_analysis_train.csv")
test_df = pd.read_csv("/Volumes/jupiter/temp/temp/rfm_analysis_test.csv")

print(f"Train shape: {train_df.shape}, Test shape: {test_df.shape}")
print("Train sample:")
display(train_df.head(3))
print("Test sample:")
display(test_df.head(3))


In [0]:
# Get common users
common_users = set(train_df["user_id"]).intersection(set(test_df["user_id"]))
print(f"Common users: {len(common_users)}")

# Filter both datasets
train_df = train_df[train_df["user_id"].isin(common_users)].reset_index(drop=True)
test_df = test_df[test_df["user_id"].isin(common_users)].reset_index(drop=True)

print(f"Filtered Train: {train_df.shape}, Filtered Test: {test_df.shape}")


In [0]:
def top_k_recs(df, user_col, item_col, score_col, K=3):
    """Top-K items by RFM score for each user, plus confidence score."""
    def _get_top(x):
        x_sorted = x.sort_values(score_col, ascending=False)
        top_items = x_sorted.head(K)[item_col].tolist()
        conf = x_sorted.head(K)[score_col].sum() / max(1, x_sorted[score_col].sum())
        return pd.Series({"top_items": top_items, "confidence": conf})
    return df.groupby(user_col).apply(_get_top).reset_index()

def compute_hits(recs, actuals, item_label):
    """Compute hit score: overlap fraction (parallelized)."""
    merged = pd.merge(recs, actuals, on="user_id", how="inner")

    def _hit(row):
        return len(set(row["top_items"]).intersection(row["actual_items"])) / max(1, len(row["top_items"]))

    merged[f"{item_label}_hit_score"] = merged.parallel_apply(_hit, axis=1)
    return merged[["user_id", "confidence", f"{item_label}_hit_score"]]

def engagement_lift(test_df, recs, user_col, item_col, spend_col="total_spend"):
    """Spend in recommended vs non-recommended (parallelized)."""
    test_grouped = test_df.groupby(user_col)

    def _calc(row):
        uid = row[user_col]
        recs_set = set(row["top_items"])
        if uid not in test_grouped.groups:
            return (uid, 0, 0)
        user_txns = test_grouped.get_group(uid)
        spend_rec = user_txns[user_txns[item_col].isin(recs_set)][spend_col].sum()
        spend_nonrec = user_txns[~user_txns[item_col].isin(recs_set)][spend_col].sum()
        return (uid, spend_rec, spend_nonrec)

    results = recs.parallel_apply(_calc, axis=1)
    return pd.DataFrame(results.tolist(), columns=[user_col, f"{item_col}_spend_rec", f"{item_col}_spend_nonrec"])


In [0]:
# Training top merchants
train_merchants = top_k_recs(train_df, "user_id", "merchant_standardized", "RFM_score_merchant", K)

# Actual merchants in test
actual_merchants = test_df.groupby("user_id")["merchant_standardized"].apply(set).reset_index(name="actual_items")

# Evaluate hits + engagement
merchant_eval = compute_hits(train_merchants, actual_merchants, "merchant")
merchant_spend = engagement_lift(test_df, train_merchants, "user_id", "merchant_standardized")

print("Merchant evaluation sample:")
display(merchant_eval.head(3))


In [0]:
# Training top categories
train_categories = top_k_recs(train_df, "user_id", "appcategory", "RFM_score_category", K)

# Actual categories in test
actual_categories = test_df.groupby("user_id")["appcategory"].apply(set).reset_index(name="actual_items")

# Evaluate hits + engagement
category_eval = compute_hits(train_categories, actual_categories, "category")
category_spend = engagement_lift(test_df, train_categories, "user_id", "appcategory")

print("Category evaluation sample:")
display(category_eval.head(3))


In [0]:
print("✅ Overall Metrics")

print(f"Merchant Hit Rate: {merchant_eval['merchant_hit_score'].mean():.2%}")
print(f"Category Hit Rate: {category_eval['category_hit_score'].mean():.2%}")

print(f"Avg Merchant Confidence: {merchant_eval['confidence'].mean():.2f}")
print(f"Avg Category Confidence: {category_eval['confidence'].mean():.2f}")

print("\nEngagement Lift (Merchant):")
print(merchant_spend.mean())

print("\nEngagement Lift (Category):")
print(category_spend.mean())


In [0]:
# Hit rate bar chart
hit_summary = pd.DataFrame({
    "Metric": ["Merchant", "Category"],
    "Hit Rate": [merchant_eval["merchant_hit_score"].mean(), category_eval["category_hit_score"].mean()]
})
plt.figure(figsize=(6,5))
sns.barplot(x="Metric", y="Hit Rate", data=hit_summary, palette="viridis")
plt.title("Hit Rate Comparison")
plt.ylim(0, 1)
for i, v in enumerate(hit_summary["Hit Rate"]):
    plt.text(i, v+0.02, f"{v:.2%}", ha="center")
plt.show()

# Confidence distribution
plt.figure(figsize=(8,5))
sns.histplot(merchant_eval["confidence"], bins=20, kde=True, color="blue", label="Merchant")
sns.histplot(category_eval["confidence"], bins=20, kde=True, color="green", label="Category")
plt.legend()
plt.title("Confidence Score Distribution")
plt.xlabel("Confidence")
plt.show()

# Engagement lift boxplots
plt.figure(figsize=(10,5))
sns.boxplot(data=merchant_spend.melt(id_vars="user_id", value_name="Spend", var_name="Type"), x="Type", y="Spend", palette="Set2")
plt.title("Engagement Lift – Merchants")
plt.xticks(rotation=20)
plt.show()

plt.figure(figsize=(10,5))
sns.boxplot(data=category_spend.melt(id_vars="user_id", value_name="Spend", var_name="Type"), x="Type", y="Spend", palette="Set1")
plt.title("Engagement Lift – Categories")
plt.xticks(rotation=20)
plt.show()


In [0]:
def user_drilldown(user_id):
    print("="*60)
    print(f"🔎 User: {user_id}")
    
    # Train affinities
    tm = train_merchants[train_merchants["user_id"] == user_id]
    tc = train_categories[train_categories["user_id"] == user_id]
    if not tm.empty:
        print("\nTrain Merchants:", tm["top_items"].values[0], "Confidence:", round(tm["confidence"].values[0], 2))
    if not tc.empty:
        print("Train Categories:", tc["top_items"].values[0], "Confidence:", round(tc["confidence"].values[0], 2))

    # Test actuals
    tmerch = test_df[test_df["user_id"] == user_id]["merchant_standardized"].unique()
    tcat = test_df[test_df["user_id"] == user_id]["appcategory"].unique()
    print("\nTest Merchants:", list(tmerch))
    print("Test Categories:", list(tcat))

    # Scores
    mh = merchant_eval[merchant_eval["user_id"] == user_id]
    ch = category_eval[category_eval["user_id"] == user_id]
    if not mh.empty:
        print("\nMerchant Hit Score:", round(mh["merchant_hit_score"].values[0], 2))
    if not ch.empty:
        print("Category Hit Score:", round(ch["category_hit_score"].values[0], 2))


In [0]:
user_drilldown("0000b6d5-f969-4996-ac9c-0635f1eed680")

In [0]:
# Merge segments from train data into eval results
seg_merchants = train_df.groupby("user_id")["customer_segment_merchant"].first().reset_index()
seg_categories = train_df.groupby("user_id")["customer_segment_category"].first().reset_index()

merchant_eval_seg = pd.merge(merchant_eval, seg_merchants, on="user_id", how="left")
category_eval_seg = pd.merge(category_eval, seg_categories, on="user_id", how="left")

# Compute segment-wise hit rates
merchant_seg_hr = merchant_eval_seg.groupby("customer_segment_merchant")["merchant_hit_score"].mean().reset_index()
category_seg_hr = category_eval_seg.groupby("customer_segment_category")["category_hit_score"].mean().reset_index()

# ================================
# 📊 Plot Merchant Segment Hit Rates
# ================================
plt.figure(figsize=(10,6))
sns.barplot(data=merchant_seg_hr.sort_values("merchant_hit_score", ascending=False),
            x="merchant_hit_score", y="customer_segment_merchant", palette="Blues_r")
plt.title("Merchant Hit Rate by Customer Segment (Train → Test)", fontsize=14, weight="bold")
plt.xlabel("Average Hit Rate")
plt.ylabel("Customer Segment")
for i, v in enumerate(merchant_seg_hr.sort_values("merchant_hit_score", ascending=False)["merchant_hit_score"]):
    plt.text(v + 0.01, i, f"{v:.2%}", va="center", fontsize=10)
plt.xlim(0, 1)
plt.show()

# ================================
# 📊 Plot Category Segment Hit Rates
# ================================
plt.figure(figsize=(10,6))
sns.barplot(data=category_seg_hr.sort_values("category_hit_score", ascending=False),
            x="category_hit_score", y="customer_segment_category", palette="Greens_r")
plt.title("Category Hit Rate by Customer Segment (Train → Test)", fontsize=14, weight="bold")
plt.xlabel("Average Hit Rate")
plt.ylabel("Customer Segment")
for i, v in enumerate(category_seg_hr.sort_values("category_hit_score", ascending=False)["category_hit_score"]):
    plt.text(v + 0.01, i, f"{v:.2%}", va="center", fontsize=10)
plt.xlim(0, 1)
plt.show()


In [0]:
# Merge merchant & category evals
user_report = (
    merchant_eval.rename(columns={"confidence": "merchant_confidence", "merchant_hit_score": "merchant_hit_score"})
    .merge(category_eval.rename(columns={"confidence": "category_confidence", "category_hit_score": "category_hit_score"}), 
           on="user_id", how="outer")
)

# Add segments from training
user_report = (
    user_report
    .merge(seg_merchants.rename(columns={"customer_segment_merchant": "merchant_segment"}), on="user_id", how="left")
    .merge(seg_categories.rename(columns={"customer_segment_category": "category_segment"}), on="user_id", how="left")
)

# Add engagement lift
user_report = (
    user_report
    .merge(merchant_spend, on="user_id", how="left")
    .merge(category_spend, on="user_id", how="left")
)

print("User report sample:")
display(user_report.head())

print(f"\nFinal report shape: {user_report.shape}")

output_path = "/Volumes/jupiter/temp/temp/rfm_backfill_user_report.csv"
user_report.to_csv(output_path, index=False)

print(f"✅ Combined report saved to: {output_path}")
