In [9]:
import json
import pyreadstat
import pandas as pd

from typing import Any, TypedDict
from ipywidgets import interact, interact_manual
from collections import OrderedDict
from collections.abc import Mapping
from tqdm import tqdm
from colorama import Fore, Style
from enum import Enum
from copy import deepcopy


class VarType(Enum):
    INTEGER = "integer"
    CONTINUOUS = "continuous"
    CATEGORICAL = "categorical"


class VarData(TypedDict):
    question: str | None
    resp_type: VarType
    resp_list: list[float | None]
    resp_to_text: OrderedDict[int, str] | tuple[int | float, int | float] | None
    min_max: tuple[int | float, int | float] | None

## Loading GSS *Simulation* Data

In [10]:
num_people = 1000

In [11]:
import pickle
from collections import OrderedDict
from pathlib import Path

def load_gss_data(filename_prefix: str = "gss_analysis"):
    """Load previously saved GSS data."""
    save_dir = Path("gss_saved_data")
    
    # Load predicted data
    pred_path = save_dir / f"{filename_prefix}_pred.pkl"
    with open(pred_path, 'rb') as f:
        gss_pred = pickle.load(f)
    
    # Load actual data
    data_path = save_dir / f"{filename_prefix}_data.pkl"
    with open(data_path, 'rb') as f:
        gss_data = pickle.load(f)
    
    return gss_pred, gss_data

# Load later for analysis
gss_pred, gss_data = load_gss_data(filename_prefix=f"gss_2024_{num_people}_people")

## Basic Inspect GSS Simulation Results

In [12]:
def inspect_pred(idx: int, code: str) -> None:
    print(f"Question: {gss_data[code]['question']}")
    print(f"Target: {gss_data[code]['resp_list'][idx]}, i.e., {gss_data[code]['resp_to_text'].get(gss_data[code]['resp_list'][idx], 'N/A')} ({gss_data[code]['resp_type']})")
    print(f"Pred: {gss_pred[code]['resp_list'][idx]}, i.e., {gss_pred[code]['resp_to_text'].get(gss_pred[code]['resp_list'][idx], 'N/A')} ({gss_pred[code]['resp_type']})")

interact(inspect_pred, idx=(0, len(next(iter(gss_data.values()))["resp_list"]), 1), code=gss_data.keys());

interactive(children=(IntSlider(value=500, description='idx', max=1000), Dropdown(description='code', options=…

## In-Depth Comparison of GSS Simulation Results vs GSS Data

In [13]:
to_predict = [
    'natspacy',
    'natenviy',
    'nathealy',
    'natcityy',
    'natdrugy',
    'nateducy',
    'natracey',
    'natarmsy',
    'nataidy',
    'natfarey',
    'natroad',
    'natsoc',
    'natspac',
    'natenvir',
    'natheal',
    'natcity',
    'natdrug',
    'nateduc',
    'natrace',
    'natarms',
    'nataid',
    'natfare',
    'natchld',
    'natsci',
    'natenrgy',
    'prayer',
    'courts',
    'discaffw',
    'discaffm',
    'fehire',
    'fechld',
    'fepresch',
    'fefam',
    'fepol',
    'reg16',
    'mobile16',
    'famdif16',
    'incom16',
    'dwelown16',
    'paeduc',
    'padeg',
    'maeduc',
    'madeg',
    'mawrkgrw',
    'marital',
    'widowed',
    'divorce',
    'martype',
    'posslqy',
    'wrkstat',
    'evwork',
    'wrkgovt1',
    'wrkgovt2',
    'partfull',
    'wksub1',
    'wksup1',
    'conarmy',
    'conbus',
    'conclerg',
    'coneduc',
    'confed',
    'confinan',
    'conjudge',
    'conlabor',
    'conlegis',
    'conmedic',
    'conpress',
    'consci',
    'contv',
    'vetyears',
    'joblose',
    'jobfind',
    'happy',
    'hapmar',
    'satjob',
    'speduc',
    'spdeg',
    'spwrksta',
    'spfund',
    'unemp',
    'union1',
    'spkathy',
    'libathy',
    'colath',
    'spkracy',
    'libracy',
    'spkcomy',
    'libcomy',
    'colcomy',
    'colrac',
    'spkmslmy',
    'libmslmy',
    'cappun',
    'polhitoky',
    'polabusey',
    'polattaky',
    'grass',
    'gunlaw',
    'owngun',
    'hunt1',
    'class',
    'satfin',
    'finalter',
    'finrela',
    'race',
    'racdif1',
    'racdif2',
    'racdif3',
    'racdif4',
    'wlthwhts',
    'wlthblks',
    'wlthhsps',
    'racwork',
    'letin1a',
    'getahead',
    'parsol',
    'kidssol',
    'spanking',
    'divlaw',
    'sexeduc',
    'pillok',
    'xmarsex',
    'homosex',
    'discaff',
    'abdefect',
    'abnomore',
    'abhlth',
    'abpoor',
    'abrape',
    'absingle',
    'abany',
    'letdie1',
    'suicide1',
    'suicide2',
    'suicide4',
    'pornlaw',
    'fair',
    'helpful',
    'trust',
    'tax',
    'vote16',
    'pres16',
    'if16who',
    'polviews',
    'partyid',
    'news',
    'relig',
    'relig16',
    'attend',
    'pray',
    'postlife',
    'bible',
    'reborn',
    'relpersn',
    'sprtprsn',
    'born',
    'granborn',
    'uscitzn',
    'educ',
    'degree',
    'income',
    'visitors',
    'dwelown',
    'othlang',
    'sex',
    'hispanic',
    'health',
    'compuse',
    'webmob',
    'xmovie',
    'life',
    'richwork'
]

print(f"% of vars to predict also in GSS: {int(100. * len({k for k in to_predict if k in gss_data}) / len(to_predict))}%")

% of vars to predict also in GSS: 100%


In [14]:
import numpy as np
from sklearn.metrics import accuracy_score, mean_absolute_error
from collections import OrderedDict
import pandas as pd
import warnings

def compare(gss_pred: OrderedDict[str, VarData], 
            gss_data: OrderedDict[str, VarData]) -> dict[str, dict[str, float]]:
    """
    Compare predicted GSS responses with actual responses using metrics from the paper.
    
    Categorical variables: accuracy and correlation
    Numerical variables: MAE and correlation
    """
    results = {}
    
    for var_code in to_predict:
        if var_code not in gss_pred or var_code not in gss_data:
            continue
            
        pred_data = gss_pred[var_code]
        true_data = gss_data[var_code]
        
        # Get response lists and convert to numpy arrays
        pred_array = np.array([float(x) if x is not None else np.nan for x in pred_data["resp_list"]])
        true_array = np.array([float(x) if x is not None else np.nan for x in true_data["resp_list"]])
        
        # Filter out invalid pairs (where either is NaN)
        valid_mask = ~(np.isnan(pred_array) | np.isnan(true_array))
        if not np.any(valid_mask):
            continue
            
        pred_valid = pred_array[valid_mask]
        true_valid = true_array[valid_mask]
        
        metrics = {}
        resp_type = pred_data["resp_type"]
        
        if resp_type == VarType.CATEGORICAL:
            # Categorical variables: accuracy and correlation
            pred_int = pred_valid.astype(int)
            true_int = true_valid.astype(int)
            
            # Accuracy
            metrics["accuracy"] = accuracy_score(true_int, pred_int)
            
            # Correlation (if possible)
            if len(np.unique(true_int)) > 1 and len(np.unique(pred_int)) > 1:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    metrics["correlation"] = np.corrcoef(true_int, pred_int)[0, 1]
            else:
                metrics["correlation"] = np.nan
                
        else:  # INTEGER or CONTINUOUS
            # Numerical variables: MAE and correlation
            
            # Mean Absolute Error
            metrics["mae"] = mean_absolute_error(true_valid, pred_valid)
            
            # Correlation (if possible)
            if len(np.unique(true_valid)) > 1 and len(np.unique(pred_valid)) > 1:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    metrics["correlation"] = np.corrcoef(true_valid, pred_valid)[0, 1]
            else:
                metrics["correlation"] = np.nan
        
        results[var_code] = metrics
    
    return results


def display_results_table(results: dict[str, dict[str, float]], 
                         decimal_places: int = 4,
                         show_summary: bool = True) -> pd.DataFrame:
    """
    Display comparison results in a clean table format.
    
    Args:
        results: Output from compare() function
        decimal_places: Number of decimal places to show
        show_summary: Whether to print summary statistics
    
    Returns:
        DataFrame with the results
    """
    if not results:
        print("No results to display")
        return pd.DataFrame()
    
    # Collect all metrics and create DataFrame
    all_metrics = set()
    for metrics in results.values():
        all_metrics.update(metrics.keys())
    all_metrics = sorted(list(all_metrics))
    
    # Build DataFrame
    df_data = {}
    for var_code, metrics in results.items():
        df_data[var_code] = [metrics.get(metric, np.nan) for metric in all_metrics]
    
    df = pd.DataFrame(df_data, index=all_metrics)
    
    # Separate categorical and numerical variables based on metrics
    cat_vars = [col for col in df.columns if not np.isnan(df.loc['accuracy', col]) if 'accuracy' in df.index]
    num_vars = [col for col in df.columns if not np.isnan(df.loc['mae', col]) if 'mae' in df.index]
    
    # Format display
    pd.set_option('display.precision', decimal_places)
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', None)
    
    # Print categorical variables
    if cat_vars:
        print("\n" + "="*80)
        print("CATEGORICAL VARIABLES")
        print("="*80)
        df_cat = df[cat_vars].dropna(how='all')
        print(df_cat.round(decimal_places))
        
        if show_summary:
            print("\n" + "-"*40)
            print("Summary (Categorical):")
            for metric in df_cat.index:
                values = df_cat.loc[metric].dropna()
                if len(values) > 0:
                    print(f"  {metric:12s} - Mean: {values.mean():.4f}, Median: {values.median():.4f}, Std: {values.std():.4f}, "
                          f"Min: {values.min():.4f}, Max: {values.max():.4f}")
    
    # Print numerical variables
    if num_vars:
        print("\n" + "="*80)
        print("NUMERICAL VARIABLES")
        print("="*80)
        df_num = df[num_vars].dropna(how='all')
        print(df_num.round(decimal_places))
        
        if show_summary:
            print("\n" + "-"*40)
            print("Summary (Numerical):")
            for metric in df_num.index:
                values = df_num.loc[metric].dropna()
                if len(values) > 0:
                    print(f"  {metric:12s} - Mean: {values.mean():.4f}, Median: {values.median():.4f}, Std: {values.std():.4f}, "
                          f"Min: {values.min():.4f}, Max: {values.max():.4f}")
    
    return df


# Run comparison
comp_res = compare(gss_pred, gss_data)

# Display results
df_comp = display_results_table(comp_res)


CATEGORICAL VARIABLES
             natspacy  natenviy  nathealy  natcityy  natdrugy  nateducy  \
accuracy       0.4908    0.4000    0.1947    0.3836    0.4136    0.2212   
correlation    0.0861    0.3396    0.0482    0.1713    0.1620    0.1838   

             natracey  natarmsy  nataidy  natfarey  natroad  natsoc  natspac  \
accuracy       0.5789    0.4279   0.4009     0.473   0.3991  0.3925   0.4741   
correlation    0.3803    0.1592   0.1760     0.371   0.0277 -0.0191   0.0706   

             natenvir  natheal  natcity  natdrug  nateduc  natrace  natarms  \
accuracy       0.4634   0.3933   0.4286   0.4274   0.5145   0.5022   0.4421   
correlation    0.3972   0.2504   0.2469   0.1548   0.1863   0.3600   0.3101   

             nataid  natfare  natchld  natsci  natenrgy  prayer  courts  \
accuracy     0.4093   0.4430   0.4728  0.4447    0.4677  0.4215  0.4175   
correlation  0.1612   0.3552   0.1986  0.1276    0.3309 -0.1685 -0.0628   

             discaffw  discaffm  fehire  fechl

In [15]:
# Identify which variables work well vs poorly
def categorize_performance(comp_res):
    excellent = []  # accuracy > 0.8
    good = []       # accuracy 0.6-0.8
    poor = []       # accuracy 0.3-0.6
    failing = []    # accuracy < 0.3
    
    for var, metrics in comp_res.items():
        if 'accuracy' in metrics:
            acc = metrics['accuracy']
            if acc > 0.8:
                excellent.append((var, acc))
            elif acc > 0.6:
                good.append((var, acc))
            elif acc > 0.3:
                poor.append((var, acc))
            else:
                failing.append((var, acc))
    
    print(f"Excellent (>80% acc): {len(excellent)} variables")
    print(f"Good (60-80% acc): {len(good)} variables")
    print(f"Poor (30-60% acc): {len(poor)} variables")
    print(f"Failing (<30% acc): {len(failing)} variables")
    print(f"Of {len(comp_res)} total predicted variables.")
    
    return excellent, good, poor, failing

excellent, good, poor, failing = categorize_performance(comp_res)

Excellent (>80% acc): 21 variables
Good (60-80% acc): 36 variables
Poor (30-60% acc): 97 variables
Failing (<30% acc): 14 variables
Of 171 total predicted variables.


In [17]:
import numpy as np

def inspect_variable_metrics(code: str) -> None:
    """
    Display metrics for a selected variable.
    Shows response type, correlation, accuracy (if categorical), and MAE (if numerical).
    """
    # Header
    print("=" * 80)
    print(f"VARIABLE: {code}")
    print("=" * 80)
    
    # Basic information
    if code in gss_data:
        var_data = gss_data[code]
        print(f"\nQuestion: {var_data['question']}")
        print(f"Response Type: {var_data['resp_type'].value}")
        
        # Response options (for categorical)
        if var_data['resp_type'] == VarType.CATEGORICAL and var_data['resp_to_text']:
            print(f"\nResponse Options:")
            for val, text in var_data['resp_to_text'].items():
                print(f"  {val}: {text}")
        
        # Min/Max (for numerical)
        if var_data['resp_type'] in [VarType.INTEGER, VarType.CONTINUOUS] and var_data['min_max']:
            print(f"\nRange: {var_data['min_max'][0]} to {var_data['min_max'][1]}")
    
    # Metrics from comparison
    print("\n" + "-" * 40)
    print("METRICS")
    print("-" * 40)
    
    if code in comp_res:
        metrics = comp_res[code]
        
        # Display metrics based on type
        if 'accuracy' in metrics:
            # Categorical variable
            print(f"Accuracy:    {metrics['accuracy']:.4f}")
            if not np.isnan(metrics.get('correlation', np.nan)):
                print(f"Correlation: {metrics['correlation']:.4f}")
            else:
                print(f"Correlation: N/A (insufficient variation)")
                
        elif 'mae' in metrics:
            # Numerical variable
            print(f"MAE:         {metrics['mae']:.4f}")
            if not np.isnan(metrics.get('correlation', np.nan)):
                print(f"Correlation: {metrics['correlation']:.4f}")
            else:
                print(f"Correlation: N/A (insufficient variation)")
    else:
        print("No metrics available for this variable")
    
    # Sample predictions vs actual
    print("\n" + "-" * 40)
    print("SAMPLE PREDICTIONS (first 10 valid responses)")
    print("-" * 40)
    
    if code in gss_pred and code in gss_data:
        pred_list = gss_pred[code]['resp_list']
        true_list = gss_data[code]['resp_list']
        resp_to_text = gss_data[code].get('resp_to_text', {})
        
        count = 0
        print(f"{'Idx':<5} {'True':<15} {'Predicted':<15} {'Match':<7}")
        print("-" * 50)
        
        for idx in range(min(len(pred_list), len(true_list))):
            if True or (pred_list[idx] is not None and true_list[idx] is not None):
                true_val = true_list[idx]
                pred_val = pred_list[idx]
                
                # Format display based on type
                if resp_to_text:
                    true_display = f"{true_val} ({resp_to_text.get(true_val, 'unknown')[:10]})"
                    pred_display = f"{pred_val} ({resp_to_text.get(pred_val, 'unknown')[:10]})"
                else:
                    true_display = f"{true_val:.2f}" if isinstance(true_val, float) else str(true_val)
                    pred_display = f"{pred_val:.2f}" if isinstance(pred_val, float) else str(pred_val)
                
                match = "✓" if true_val == pred_val else "✗"
                print(f"{idx:<5} {true_display:<15} {pred_display:<15} {match:<7}")
                
                count += 1
                if count >= 10:
                    break

# Create interactive widget
interact(inspect_variable_metrics, code=sorted(comp_res.keys()));

interactive(children=(Dropdown(description='code', options=('abany', 'abdefect', 'abhlth', 'abnomore', 'abpoor…