# Custom Cohort Analysis

This notebook demonstrates advanced cohort analysis techniques using the AutoCLV library. You'll learn how to:

1. Create custom cohort definitions (monthly, quarterly, campaign-based)
2. Compare multiple cohorts side-by-side
3. Analyze cohort retention curves
4. Segment cohorts by geography or other attributes
5. Calculate cohort-specific CLV predictions

## Prerequisites

```bash
pip install -e .
```

## ⚠️ Data Privacy Notice

This notebook uses synthetic data. When using production customer data, ensure GDPR/CCPA compliance, anonymize identifiers, and never commit real data to version control.

## Setup: Generate Synthetic Data

We'll use the Texas CLV generator to create a realistic customer base.

In [None]:
from datetime import date, datetime
from dataclasses import asdict
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from customer_base_audit.synthetic.texas_clv_client import generate_texas_clv_client
from customer_base_audit.foundation.data_mart import CustomerDataMartBuilder, PeriodGranularity

# Generate 1500 customers over 18 months for richer cohort analysis
customers, transactions, city_map = generate_texas_clv_client(
    total_customers=1500,
    seed=123
)

# Build data mart
builder = CustomerDataMartBuilder(period_granularities=[PeriodGranularity.MONTH])
mart = builder.build([asdict(t) for t in transactions])
period_aggregations = mart.periods[PeriodGranularity.MONTH]

print(f"Generated {len(customers):,} customers")
print(f"Generated {len(transactions):,} transactions")
print(f"Date range: {min(t.event_ts for t in transactions).date()} to {max(t.event_ts for t in transactions).date()}")

## Approach 1: Monthly Cohorts

Group customers by their acquisition month.

In [None]:
from customer_base_audit.foundation.cohorts import create_monthly_cohorts, assign_cohorts

# Create monthly cohorts for 2024
monthly_cohorts = create_monthly_cohorts(
    customers=customers,
    start_date=datetime(2024, 1, 1),
    end_date=datetime(2024, 12, 31)
)

monthly_assignments = assign_cohorts(customers, monthly_cohorts)

# Analyze cohort sizes
cohort_sizes = pd.Series(monthly_assignments).value_counts().sort_index()

print("=== Monthly Cohort Sizes ===")
print(cohort_sizes)

# Visualize cohort acquisition
plt.figure(figsize=(12, 5))
cohort_sizes.plot(kind='bar')
plt.xlabel('Cohort (YYYY-MM)')
plt.ylabel('Number of Customers')
plt.title('Customer Acquisitions by Month')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

## Approach 2: Quarterly Cohorts

Group customers by quarter for longer-term analysis.

In [None]:
from customer_base_audit.foundation.cohorts import CohortDefinition

# Define quarterly cohorts manually
quarterly_cohorts = [
    CohortDefinition(cohort_id="2024-Q1", period_start=datetime(2024, 1, 1), period_end=datetime(2024, 3, 31, 23, 59, 59)),
    CohortDefinition(cohort_id="2024-Q2", period_start=datetime(2024, 4, 1), period_end=datetime(2024, 6, 30, 23, 59, 59)),
    CohortDefinition(cohort_id="2024-Q3", period_start=datetime(2024, 7, 1), period_end=datetime(2024, 9, 30, 23, 59, 59)),
    CohortDefinition(cohort_id="2024-Q4", period_start=datetime(2024, 10, 1), period_end=datetime(2024, 12, 31, 23, 59, 59)),
]

quarterly_assignments = assign_cohorts(customers, quarterly_cohorts)

print("=== Quarterly Cohort Sizes ===")
for cohort_id in ["2024-Q1", "2024-Q2", "2024-Q3", "2024-Q4"]:
    count = sum(1 for c in quarterly_assignments.values() if c == cohort_id)
    print(f"{cohort_id}: {count:,} customers")

## Approach 3: Geographic Cohorts

Group customers by their city (from the Texas CLV data).

In [None]:
# Create geographic cohort assignments
geo_assignments = {cust_id: city for cust_id, city in city_map.items()}

print("=== Geographic Cohort Sizes ===")
geo_sizes = pd.Series(geo_assignments).value_counts()
print(geo_sizes)

# Visualize
plt.figure(figsize=(10, 6))
geo_sizes.plot(kind='pie', autopct='%1.1f%%')
plt.title('Customer Distribution by City')
plt.ylabel('')
plt.show()

## Multi-Cohort Comparison: Lens 3 Analysis

Compare retention curves for the first 3 monthly cohorts.

In [None]:
from customer_base_audit.analyses.lens3 import analyze_cohort_evolution

# Select first 3 cohorts for comparison
cohorts_to_compare = ["2024-01", "2024-02", "2024-03"]
cohort_results = {}

for cohort_name in cohorts_to_compare:
    # Get customer IDs for this cohort
    cohort_customer_ids = [
        cust_id for cust_id, coh_id in monthly_assignments.items()
        if coh_id == cohort_name
    ]
    
    if not cohort_customer_ids:
        continue
    
    # Find cohort definition
    cohort_def = next(c for c in monthly_cohorts if c.cohort_id == cohort_name)
    
    # Run Lens 3 analysis
    result = analyze_cohort_evolution(
        cohort_name=cohort_name,
        acquisition_date=cohort_def.period_start,
        period_aggregations=period_aggregations,
        cohort_customer_ids=cohort_customer_ids
    )
    
    cohort_results[cohort_name] = result
    print(f"✓ Analyzed {cohort_name}: {result.cohort_size} customers, {len(result.periods)} periods")

## Visualize Retention Curves

Compare how different cohorts retain customers over time.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Plot 1: Retention Rate Over Time
for cohort_name, result in cohort_results.items():
    periods = [p.period_number for p in result.periods]
    retention_rates = [p.retention_rate * 100 for p in result.periods]
    axes[0].plot(periods, retention_rates, marker='o', label=cohort_name)

axes[0].set_xlabel('Periods Since Acquisition (months)')
axes[0].set_ylabel('Retention Rate (%)')
axes[0].set_title('Cohort Retention Curves')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot 2: Revenue Per Active Customer Over Time
for cohort_name, result in cohort_results.items():
    periods = [p.period_number for p in result.periods]
    revenue_per_customer = [p.avg_revenue_per_customer for p in result.periods]
    axes[1].plot(periods, revenue_per_customer, marker='o', label=cohort_name)

axes[1].set_xlabel('Periods Since Acquisition (months)')
axes[1].set_ylabel('Avg Revenue per Active Customer ($)')
axes[1].set_title('Revenue Evolution by Cohort')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Cohort Comparison Table

Create a summary table comparing key metrics across cohorts.

In [None]:
comparison_data = []

for cohort_name, result in cohort_results.items():
    # Calculate key metrics
    initial_size = result.cohort_size
    
    # Period 0 (acquisition month)
    p0 = result.periods[0]
    
    # Period 1 (first full month after acquisition)
    p1 = result.periods[1] if len(result.periods) > 1 else None
    
    # Latest period
    latest = result.periods[-1]
    
    comparison_data.append({
        'Cohort': cohort_name,
        'Initial Size': initial_size,
        'P0 Revenue': f"${p0.total_revenue:,.0f}",
        'P0 Avg Order Value': f"${p0.avg_revenue_per_customer:,.0f}",
        'P1 Retention': f"{p1.retention_rate:.1%}" if p1 else "N/A",
        'Latest Period': latest.period_number,
        'Latest Retention': f"{latest.retention_rate:.1%}",
        'Cumulative Revenue': f"${sum(p.total_revenue for p in result.periods):,.0f}"
    })

comparison_df = pd.DataFrame(comparison_data)
print("=== Cohort Comparison Summary ===")
print(comparison_df.to_string(index=False))

## Geographic Cohort Analysis

Compare customer behavior across different cities.

In [None]:
from customer_base_audit.foundation.rfm import calculate_rfm

# Calculate RFM for all customers
observation_end = datetime(2024, 12, 31, 23, 59, 59)
rfm_metrics = calculate_rfm(period_aggregations, observation_end)
rfm_df = pd.DataFrame([asdict(rfm) for rfm in rfm_metrics])

# Add city information
rfm_df['city'] = rfm_df['customer_id'].map(city_map)

# Group by city
city_stats = rfm_df.groupby('city').agg({
    'customer_id': 'count',
    'frequency': 'mean',
    'monetary': 'mean',
    'recency_days': 'mean'
}).round(2)

city_stats.columns = ['Customers', 'Avg Frequency', 'Avg Monetary', 'Avg Recency (days)']

print("=== Geographic Cohort Comparison ===")
print(city_stats)

# Visualize city comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

city_stats['Avg Frequency'].plot(kind='bar', ax=axes[0], color='skyblue')
axes[0].set_title('Average Purchase Frequency by City')
axes[0].set_ylabel('Avg Frequency')
axes[0].set_xlabel('City')

city_stats['Avg Monetary'].plot(kind='bar', ax=axes[1], color='lightgreen')
axes[1].set_title('Average Monetary Value by City')
axes[1].set_ylabel('Avg Monetary ($)')
axes[1].set_xlabel('City')

city_stats['Avg Recency (days)'].plot(kind='bar', ax=axes[2], color='salmon')
axes[2].set_title('Average Recency by City')
axes[2].set_ylabel('Avg Recency (days)')
axes[2].set_xlabel('City')

plt.tight_layout()
plt.show()

## Advanced: Cohort-Specific CLV Predictions

Train separate models for each quarterly cohort to see if behavior differs.

In [None]:
from customer_base_audit.models.model_prep import prepare_clv_model_inputs
from customer_base_audit.models.bg_nbd import BGNBDModelWrapper, BGNBDConfig

cohort_clv_predictions = {}

for cohort_id in ["2024-Q1", "2024-Q2"]:
    # Get customers in this cohort
    cohort_customer_ids = [
        cust_id for cust_id, coh_id in quarterly_assignments.items()
        if coh_id == cohort_id
    ]
    
    # Filter transactions for this cohort
    cohort_transactions = [
        t for t in transactions
        if t.customer_id in cohort_customer_ids
    ]
    
    if len(cohort_transactions) < 100:
        print(f"Skipping {cohort_id}: insufficient data")
        continue
    
    # Prepare model input
    model_data = prepare_clv_model_inputs(
        transactions=[asdict(t) for t in cohort_transactions],
        observation_start=datetime(2024, 1, 1),
        observation_end=datetime(2024, 12, 31, 23, 59, 59),
        customer_id_field='customer_id',
        timestamp_field='event_ts',
        monetary_field='unit_price'
    )
    
    # Train BG/NBD model
    config = BGNBDConfig(method="map")
    model = BGNBDModelWrapper(config)
    model.fit(model_data)
    
    # Predict
    predictions = model.predict_purchases(model_data, time_periods=90.0)
    
    cohort_clv_predictions[cohort_id] = predictions['predicted_purchases'].mean()
    
    print(f"✓ {cohort_id}: Avg predicted purchases (90d) = {predictions['predicted_purchases'].mean():.2f}")

print("\n=== Cohort CLV Comparison ===")
for cohort_id, avg_pred in cohort_clv_predictions.items():
    print(f"{cohort_id}: {avg_pred:.2f} predicted purchases per customer")

## Cohort Heatmap

Visualize retention as a heatmap across all monthly cohorts.

In [None]:
# Build retention matrix for all monthly cohorts
cohort_names = sorted([c.cohort_id for c in monthly_cohorts])
max_periods = 6  # Look at first 6 months

retention_matrix = []

for cohort_name in cohort_names:
    cohort_customer_ids = [
        cust_id for cust_id, coh_id in monthly_assignments.items()
        if coh_id == cohort_name
    ]
    
    if not cohort_customer_ids:
        print(f"Warning: Skipping cohort {cohort_id} due to error: {e}")
        retention_matrix.append([np.nan] * max_periods)
        continue
    
    cohort_def = next(c for c in monthly_cohorts if c.cohort_id == cohort_name)
    
    try:
        result = analyze_cohort_evolution(
            cohort_name=cohort_name,
            acquisition_date=cohort_def.period_start,
            period_aggregations=period_aggregations,
            cohort_customer_ids=cohort_customer_ids
        )
        
        retention_row = []
        for period_num in range(max_periods):
            if period_num < len(result.periods):
                retention_row.append(result.periods[period_num].retention_rate * 100)
            else:
                retention_row.append(np.nan)
        
        retention_matrix.append(retention_row)
    except Exception as e:
        print(f"Warning: Skipping cohort {cohort_id} due to error: {e}")
        retention_matrix.append([np.nan] * max_periods)

# Create heatmap
retention_df = pd.DataFrame(
    retention_matrix,
    index=cohort_names,
    columns=[f"Period {i}" for i in range(max_periods)]
)

plt.figure(figsize=(12, 8))
import seaborn as sns
sns.heatmap(retention_df, annot=True, fmt='.1f', cmap='RdYlGn', vmin=0, vmax=100, cbar_kws={'label': 'Retention Rate (%)'})
plt.title('Cohort Retention Heatmap')
plt.xlabel('Periods Since Acquisition')
plt.ylabel('Acquisition Cohort')
plt.tight_layout()
plt.show()

print("📊 Heatmap shows retention rates (%) for each cohort over time.")
print("   Darker green = higher retention, darker red = lower retention")

## Summary and Insights

### What We Learned

1. **Cohort Definition Flexibility**: Monthly, quarterly, and geographic cohorts each reveal different patterns
2. **Retention Patterns**: Early cohorts may show different retention than later ones
3. **Geographic Differences**: Customer behavior varies by city/region
4. **Cohort-Specific Modeling**: Different cohorts may require different CLV predictions

### Key Metrics to Watch

- **P0 → P1 Retention Drop**: The most critical transition period
- **Revenue Per Active Customer**: Should stabilize or increase over time
- **Cohort Size Trends**: Are you acquiring more or fewer customers over time?
- **Geographic Performance**: Which regions have the best retention?

### Next Steps

- Define cohorts based on marketing campaigns or channels
- Build predictive models to forecast cohort retention
- Compare cohort performance against business KPIs
- Use cohort insights to optimize acquisition strategy