In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import ks_2samp
import os
import random

# — ensure output directory
os.makedirs('plots', exist_ok=True)

# — load with low_memory to reduce memory footprint
df_alite  = pd.read_csv('alite_merge.csv',  low_memory=True)
df_manual = pd.read_csv('manual_merge.csv', low_memory=True)

# — 1. Summary metrics
def compute_summary(df):
    # basic shape & missingness
    rows, cols = df.shape
    avg_missing = df.isna().mean().mean()
    # country-year uniqueness
    ccol = next((c for c in df if 'country' in c.lower()), None)
    ycol = next((c for c in df if 'year'    in c.lower()), None)
    if ccol and ycol:
        unique_ky = df[[ccol,ycol]].drop_duplicates().shape[0]
    else:
        unique_ky = np.nan
    # dtype counts
    dt = df.dtypes.value_counts().to_dict()
    # flatten dtype counts into the record
    dtype_flat = {f"dtype_{k}": v for k,v in dt.items()}
    return {
        'Rows': rows,
        'Columns': cols,
        'AvgMissing': avg_missing,
        'UniqueCountryYear': unique_ky,
        **dtype_flat
    }

summary = pd.DataFrame({
    'ALITE_merge':  compute_summary(df_alite),
    'Manual_merge': compute_summary(df_manual)
}).T

# — 2. Approximate Country-Year Jaccard
def approx_jaccard(df1, df2, cols, max_sample=50000):
    # build key tuples
    s1 = set(df1[cols].drop_duplicates().apply(tuple, axis=1))
    s2 = set(df2[cols].drop_duplicates().apply(tuple, axis=1))
    # sample if too large
    if len(s1)>max_sample:
        s1 = set(random.sample(s1, max_sample))
    if len(s2)>max_sample:
        s2 = set(random.sample(s2, max_sample))
    inter = len(s1 & s2)
    uni   = len(s1 | s2)
    return inter, uni, inter/uni if uni else np.nan

cc1 = next((c for c in df_alite  if 'country' in c.lower()), None)
yy1 = next((c for c in df_alite  if 'year'    in c.lower()), None)
if cc1 and yy1:
    inter, union, jacc = approx_jaccard(df_alite, df_manual, [cc1, yy1])
else:
    inter = union = jacc = np.nan

align = pd.DataFrame([{
    'KeyIntersection': inter,
    'KeyUnion':        union,
    'KeyJaccard':      jacc
}])

# — 3. Column overlap counts
cols_a = set(df_alite.columns)
cols_m = set(df_manual.columns)
common = cols_a & cols_m
overlap = {
    'ALITE_only': len(cols_a - common),
    'Common':     len(common),
    'Manual_only':len(cols_m - common)
}

# — 4. Top-10 missingness
missing_alite  = df_alite .isna().mean().nlargest(10)
missing_manual = df_manual.isna().mean().nlargest(10)

# — 5. Dtype consistency
dtype_df = pd.DataFrame([
    {
      'Column':       col,
      'ALITE_dtype':  str(df_alite[col].dtype),
      'Manual_dtype': str(df_manual[col].dtype),
      'Consistent':   df_alite[col].dtype == df_manual[col].dtype
    }
    for col in sorted(common)
])

# — 6. KS-tests on numeric common columns (sampling to 1 000 points)
numeric_common = [
    col for col in common
    if pd.api.types.is_numeric_dtype(df_alite[col])
    and pd.api.types.is_numeric_dtype(df_manual[col])
]
stats = []
for col in numeric_common:
    a = df_alite[col].dropna()
    m = df_manual[col].dropna()
    if len(a)>1000: a = a.sample(1000, random_state=1)
    if len(m)>1000: m = m.sample(1000, random_state=1)
    if len(a)>0 and len(m)>0:
        ks_stat, pval = ks_2samp(a, m)
        stats.append({
          'Column':      col,
          'ALITE_mean':  a.mean(),
          'Manual_mean': m.mean(),
          'KS_stat':     ks_stat,
          'p_value':     pval
        })
stats_df = pd.DataFrame(stats)

# — 7. Visualizations with Matplotlib only
# (a) Missingness histograms
plt.hist(df_alite.isna().mean(),  bins=20, alpha=0.5, label='ALITE')
plt.hist(df_manual.isna().mean(), bins=20, alpha=0.5, label='Manual')
plt.title('Missingness Distribution by Column')
plt.xlabel('Missing Ratio')
plt.ylabel('Count')
plt.legend()
plt.savefig('plots/missingness_hist.png')
plt.clf()

# (b) Column overlap bar
plt.bar(overlap.keys(), overlap.values())
plt.title('Column Overlap')
plt.ylabel('Number of Columns')
plt.savefig('plots/column_overlap.png')
plt.clf()

# (c) Empirical CDFs for up to 3 numeric columns
for col in numeric_common[:3]:
    a = df_alite[col].dropna()
    m = df_manual[col].dropna()
    # sample if needed
    if len(a)>2000: a = a.sample(2000, random_state=1)
    if len(m)>2000: m = m.sample(2000, random_state=1)
    a_sorted = np.sort(a)
    m_sorted = np.sort(m)
    plt.plot(a_sorted, np.linspace(0,1,len(a_sorted)), label='ALITE')
    plt.plot(m_sorted, np.linspace(0,1,len(m_sorted)), label='Manual')
    plt.title(f'Empirical CDF — {col}')
    plt.xlabel(col)
    plt.ylabel('Proportion ≤ x')
    plt.legend()
    plt.savefig(f'plots/cdf_{col}.png')
    plt.clf()

# — 8. Print / inspect
print("=== Summary ===")
print(summary)
print("\n=== Key Alignment ===")
print(align)
print("\n=== Top Missingness (ALITE) ===")
print(missing_alite)
print("\n=== Top Missingness (Manual) ===")
print(missing_manual)
print("\n=== Dtype Consistency ===")
print(dtype_df.head(10).to_string(index=False))
print("\n=== KS-Test Results ===")
print(stats_df.to_string(index=False))
print("\nPlots saved under ./plots/")
