In [2]:
import pandas as pd
from pathlib import Path
import numpy as np
import itertools
import subprocess
import os


from opt_targeted_transfers import standardize
from sklearn.linear_model import LinearRegression

from opt_targeted_transfers import Dataset, split
def get_row_from_metadata(metadata, covariate_name):
    """
    Extracts a specific row from the metadata DataFrame based on the covariate name.

    :param metadata: DataFrame containing metadata.
    :param covariate_name: Name of the covariate to extract.
    :return: Row corresponding to the specified covariate name.
    """
    return metadata.loc[metadata['variable_name'] == covariate_name].squeeze()

def all_rows_from_metadata_containing(metadata, substring):
    """
    Extracts all rows from the metadata DataFrame that contain a specific substring in the variable name.

    :param metadata: DataFrame containing metadata.
    :param substring: Substring to search for in the variable names.
    :return: DataFrame containing all rows with variable names that contain the substring.
    """
    return metadata[metadata['variable_name'].str.contains(substring, na=False)].reset_index(drop=True)

def all_column_names_containing(df, substring):
    """
    Extracts all column names from the DataFrame that contain a specific substring.

    :param df: DataFrame to search for column names.
    :param substring: Substring to search for in the column names.
    :return: List of column names containing the specified substring.
    """
    return [col for col in df.columns if substring in col]


def find_equivalent_columns(data, summary, numeric_tolerance=1e-6, categorical_threshold=0.99):
    """
    Find pairs of columns in a DataFrame that are informationally equivalent.
    
    Parameters:
    -----------
    data : pandas DataFrame
        The DataFrame to analyze
    numeric_tolerance : float, default 1e-6
        Tolerance for considering numeric columns equal or proportional
    categorical_threshold : float, default 0.99
        Threshold for considering categorical columns equivalent (percentage match)
    
    Returns:
    --------
    list of tuples
        Each tuple contains (col1, col2, relationship_type)
        where relationship_type is one of: 'identical', 'proportional', 'categorical_equivalent'
    """
    equivalent_pairs = []
    columns = data.columns
    
    # Get column types
    
    numeric_cols = summary[summary.data_type == 'numeric'].variable_name.tolist()
    categorical_cols = summary[summary.data_type == 'categorical'].variable_name.tolist()

    # Identify constant columns
    constant_cols = []
    for col in columns:
        unique_values = data[col].dropna().unique()
        if len(unique_values) <= 1:
            constant_cols.append(col)

    # Print constant columns if verbose
    if len(constant_cols) > 0:
        print("Constant columns:")
        for col in constant_cols:
            print(col)
        print()
    

    # Remove constant columns from numeric and categorical lists
    numeric_cols = [col for col in numeric_cols if col not in constant_cols]
    categorical_cols = [col for col in categorical_cols if col not in constant_cols]

    # remove missingness-indicator columns
    missingness_cols = list(
        set(all_column_names_containing(data, '_missing') + 
        all_column_names_containing(data, '_m'))
    )
    numeric_cols = [col for col in numeric_cols if col not in missingness_cols]
    categorical_cols = [col for col in categorical_cols if col not in missingness_cols]
    
    # Check numeric columns for equality or proportionality
    for col1, col2 in itertools.combinations(numeric_cols, 2):

        # Check for identical values first
        if data[col1].equals(data[col2]):
            equivalent_pairs.append((col1, col2, 'identical'))
            continue
            
        # Check for identical values where neither is zero
        valid_mask = ~data[col1].isna() & ~data[col2].isna()
        if np.allclose(data.loc[valid_mask, col1], data.loc[valid_mask, col2], 
                      rtol=numeric_tolerance, atol=numeric_tolerance):
            equivalent_pairs.append((col1, col2, 'nearly_identical'))
            continue
        
        # For rows with zeros, check if the columns are exactly equal
        zero_mask = (data[col1] == 0) | (data[col2] == 0)
        non_zero_mask = ~zero_mask & valid_mask
        
        # Check if the columns have the same values where zeros are present
        if zero_mask.any():
            zero_equality = (data.loc[zero_mask & valid_mask, col1] == 
                             data.loc[zero_mask & valid_mask, col2]).all()
        else:
            zero_equality = True
            
        # Check for proportional relationship in non-zero values
        if non_zero_mask.sum() > 10:  # Require at least some non-zero values
            ratios = data.loc[non_zero_mask, col2] / data.loc[non_zero_mask, col1]
            ratio_std = ratios.std()
            
            # If standard deviation of ratios is very small, columns are proportional
            if ratio_std < numeric_tolerance and zero_equality:
                ratio = ratios.mean()
                equivalent_pairs.append((col1, col2, f'proportional (factor: {ratio:.4f})'))
    
    # Create a list of all columns to check for categorical equivalence
    # This includes both explicit categorical columns and numeric columns
    all_potential_categorical_cols = categorical_cols + numeric_cols
    
    # Check all columns for equivalent categorical mappings
    for col1, col2 in itertools.combinations(all_potential_categorical_cols, 2):
        # Skip if identical columns or already identified as identical or proportional
        if col1 == col2 or any((col1, col2, rel) in equivalent_pairs for rel in 
                               ['identical', 'nearly_identical', 'proportional']):
            continue
            
        # Get unique values for both columns
        unique_vals1 = data[col1].dropna().unique()
        unique_vals2 = data[col2].dropna().unique()
        
        # Skip if columns have different number of unique values
        if len(unique_vals1) != len(unique_vals2):
            continue
            
        # Skip if too many unique values (likely not categorical)
        if len(unique_vals1) > 100:  # Arbitrary threshold, adjust as needed
            continue
            
        # Create a mapping table between values in both columns
        mapping_df = data[[col1, col2]].dropna().drop_duplicates()
        
        # Check if mapping is one-to-one (each value in col1 maps to exactly one value in col2)
        is_one_to_one = True
        
        # Check col1 -> col2 mapping
        for val in unique_vals1:
            corresponding_vals = data.loc[data[col1] == val, col2].dropna().unique()
            if len(corresponding_vals) != 1:
                is_one_to_one = False
                break
                
        # Check col2 -> col1 mapping
        if is_one_to_one:
            for val in unique_vals2:
                corresponding_vals = data.loc[data[col2] == val, col1].dropna().unique()
                if len(corresponding_vals) != 1:
                    is_one_to_one = False
                    break
        
        if is_one_to_one:
            # If we create a new column using the mapping, it should match the original
            val_mapping = dict(zip(mapping_df[col1], mapping_df[col2]))
            
            # Apply mapping and handle NaN values
            mapped_values = data[col1].map(val_mapping)
            
            # Count matches (ignoring NaN values)
            valid_mask = ~data[col1].isna() & ~data[col2].isna()
            if valid_mask.sum() > 0:
                match_percentage = (mapped_values == data[col2])[valid_mask].mean()
                
                if match_percentage >= categorical_threshold:
                    # Determine if both are numeric or mixed types
                    if col1 in numeric_cols and col2 in numeric_cols:
                        relationship = 'numeric_categorical_equivalent'
                    else:
                        relationship = 'categorical_equivalent'
                    equivalent_pairs.append((col1, col2, relationship))
    
    return equivalent_pairs



def get_data_for_geo_extrapolation(data, summary, geo_extrapolation):
    """
    Preprocess the testing data for geo-extrapolation.

    Args:
        data (pd.DataFrame): The input data.
        summary (pd.DataFrame): The summary data.
    Returns:
        pd.DataFrame: The preprocessed data without geographic identifiers
    """

    geo_cols = summary[summary["geographic_indicator"] == True][
        "variable_name"
    ].tolist()

    coarse_geo_cols = summary[summary["geographic_indicator_coarser"] == True][
        "variable_name"
    ].tolist()

    remove_for_coarse = set(geo_cols) - set(coarse_geo_cols)
    remove_for_coarse = list(remove_for_coarse)

    if geo_extrapolation:
        data = data.drop(columns=remove_for_coarse)
    else:
        1/0
    return data


def load_datasets(
    trainpath, testpath, summarypath, geo_extrapolation, outcome='consumption_per_day_per_capita', weight='headcount_adjusted_hh_wgt'
):
    """
    Load datasets.

    Args:
        trainpath (str): Path to the training data file.
        testpath (str): Path to the test data file.
        outcome (str): Outcome variable.
        weight (str): Weight variable.

    Returns:
        train_dataset (Dataset): Training dataset.
        test_dataset (Dataset): Test dataset.
    """
    data1 = _load_data(trainpath)
    data2 = _load_data(testpath)
    summary = pd.read_parquet(summarypath)

    data1 = get_data_for_geo_extrapolation(data1, summary, geo_extrapolation)
    data2 = get_data_for_geo_extrapolation(data2, summary, geo_extrapolation)

    all_data = pd.concat([data1, data2], ignore_index=True)
    all_data = convert_to_onehot(all_data, summary)

    train_data = _load_data(trainpath)
    test_data = _load_data(testpath)
    train_data = get_data_for_geo_extrapolation(train_data, summary, geo_extrapolation)
    test_data = get_data_for_geo_extrapolation(test_data, summary, geo_extrapolation)
    covs = list(train_data.columns)
    covs.remove(outcome)
    covs.remove(weight)

    train_data = convert_to_onehot(train_data, summary)
    test_data = convert_to_onehot(test_data, summary)

    train_missing_columns = set(all_data.columns) - set(train_data.columns)
    res = [train_data]
    for col in train_missing_columns:
        res.append(pd.DataFrame({col: np.zeros(len(train_data))}))
    final_train_data = pd.concat(res, axis=1)

    test_missing_columns = set(all_data.columns) - set(test_data.columns)
    res = [test_data]
    for col in test_missing_columns:
        res.append(pd.DataFrame({col: np.zeros(len(test_data))}))
    final_test_data = pd.concat(res, axis=1)

    train_dataset = Dataset(
        final_train_data.astype("float32"), outcome=outcome, covs=covs, weight=weight
    )
    test_dataset = Dataset(
        final_test_data.astype("float32"), outcome=outcome, covs=covs, weight=weight
    )
    test_covariate_dataset = Dataset(
        final_test_data.astype("float32"), outcome=None, covs=covs, weight=weight
    )

    train_dataset, validation_dataset = split(train_dataset)
    return train_dataset, validation_dataset, test_covariate_dataset, test_dataset


def convert_to_onehot(df, summary):
    """
    Convert categorical columns to one-hot encoding.

    :param df: The input data.
    :type df: pandas.DataFrame
    :return new_df: The input data with one-hot encoding.
    :rtype: pandas.DataFrame
    """
    if "type" in summary.columns:
        data_type = "type"
    elif "data_type" in summary.columns:
        data_type = "data_type"
    if "covariate" in summary.columns:
        covariate = "covariate"
    elif "variable_name" in summary.columns:
        covariate = "variable_name"

    categorical_columns = summary[summary[data_type] == "categorical"][
        covariate
    ].tolist()

    categorical_columns = [col for col in categorical_columns if col in df.columns]

    one_hot = pd.get_dummies(df[categorical_columns]).astype(np.float32)
    df.drop(columns=categorical_columns, inplace=True)
    new_df = pd.concat([df, one_hot], axis=1)
    return new_df


def _load_data(path):
    """
    Load data.

    Args:
        path (str): Path to the data file.

    Returns:
        data_for_wgan (pd.DataFrame): Data for WGAN training.
        data_wrapper (wgan.DataWrapper): DataWrapper object for WGAN training.
    """
    data = pd.read_parquet(path)

    if "hhid" in data.columns:
        data = data.drop(columns=["hhid"])
    if "case_id" in data.columns:
        data = data.drop(columns=["case_id"])
    if "hh_id" in data.columns:
        data = data.drop(columns=["hh_id"])
    if "hh_wgt" in data.columns:
        data = data.drop(columns=["hh_wgt"])

    return data.reset_index(drop=True)

"""
Done in this notebook
- Ensure that missingness-indicator columns exist.
    - You probably can't conclusively check that all are included, because the data you get will not necessarily reveal which columns had missingness, but check that there are some missingness columns, and none for categorical data.
- Ensure there are no NaNs in the data.
- Ensure column names:
    - In data: "hhid" (if household ID is included), "consumption_per_capita_per_day", "hh_wgt".
    - Consumption: Check mean and std for sanity. In a poor country, the mean should be low-mid single digits: e.g., in Uganda, the mean is $3.80/day.
- Check for columns that indicate units:
    - If they are present, the corresponding numeric field should be standardized, e.g., all area units adjusted to square meters.
- Check that metadata and the dataset itself match:
    - Every column in data is described in metadata and vice versa. It's also OK if `hhid` is not in the data at all.
- In metadata:
    - "variable_name".
    - "data_type", with permitted values "numeric" and "categorical".
    - "geographic_indicator".
- Scan datatypes:
    - In particular, make sure nothing is numeric which should be categorical.
    - Ensure categorical-type columns have the appropriate type even if the categories are encoded as integers (if a column is binary, with no missing values, it can be numeric or categorical).
    - IDs of all kinds are strings even if they appear numeric.
- Check for duplication
- Check feasibility of stratification
""";

## Read in

In [None]:
data_path = Path('/data/eop/country_data')

data, summary, proposed_stratifier = None, None, None

country = 'Ethiopia'  # Change this to the desired country

if country == 'Burkina Faso':
    country_data_path = data_path / 'burkina_faso' / 'cleaned'
    data = pd.read_parquet(
        data_path / 'burkina_faso' / 'cleaned' / 'burkinafaso_final_data.parquet'
    )
    summary = pd.read_parquet(
        data_path / 'burkina_faso' / 'cleaned' / 'summary.parquet'
    )
    proposed_stratifier = 'region'
elif country == 'Benin':
    country_data_path = data_path / 'benin' / 'cleaned'
    data = pd.read_parquet(country_data_path / 'benin_data.parquet')
    summary = pd.read_parquet(country_data_path / 'summary.parquet')
    proposed_stratifier = 'region'
elif country == 'Cote dIvoire':
    country_data_path = data_path / 'cote_divoire' / 'cleaned'

    data = pd.read_parquet(
        data_path / 'cote_divoire' / 'cleaned' / 'cotedivoire_cleaned_data.parquet'
    )

    summary = pd.read_parquet(
        data_path / 'cote_divoire' / 'cleaned' / 'summary.parquet'
    )
    proposed_stratifier = 'region'
elif country == 'Guinea-Bissau':
    country_data_path = data_path / 'guinea-bissau' / 'cleaned'

    data = pd.read_parquet(
        data_path / 'guinea-bissau' / 'cleaned' / 'final_gb_dataset.parquet'
    )
    summary = pd.read_parquet(
        data_path / 'guinea-bissau' / 'cleaned' / 'summary.parquet'
    )
    proposed_stratifier = 'region'
elif country == 'Mali':
    country_data_path = data_path / 'mali' / 'cleaned'
    data = pd.read_parquet(
        data_path / 'mali' / 'cleaned' / 'final_mali_dataset.parquet'
    )
    summary = pd.read_parquet(
        data_path / 'mali' / 'cleaned' / 'summary.parquet'
    )
    proposed_stratifier = 'region'
elif country == 'Somalia':
    country_data_path = data_path / 'somalia' / 'cleaned'
    data = pd.read_parquet(
        data_path / 'somalia' / 'cleaned' / 'somalia_lsms_final.parquet'
    )
    summary = pd.read_parquet(
        data_path / 'somalia' / 'cleaned' / 'summary.parquet'
    )
    proposed_stratifier = 'region'
elif country == 'Albania': 
    country_data_path = data_path / 'albania' / 'cleaned'
    data = pd.read_parquet(
        data_path / 'albania' / 'cleaned' / 'albania_all.parquet'
    )
    summary = pd.read_parquet(
        data_path / 'albania' / 'cleaned' / 'summary.parquet'
    )
elif country == 'Uganda':
    country_data_path = data_path / 'uganda' / 'cleaned'
    data = pd.read_parquet(
        data_path / 'uganda' / 'cleaned' / 'uganda_full.parquet'
    )
    summary = pd.read_parquet(
        data_path / 'uganda' / 'cleaned' / 'summary.parquet'
    )
    proposed_stratifier='region'
elif country == 'Niger':
    country_data_path = data_path / 'niger' / 'cleaned'
    data = pd.read_parquet(
        data_path / 'niger' / 'cleaned' / 'niger_2018.parquet'
    )
    summary = pd.read_parquet(
        data_path / 'niger' / 'cleaned' / 'summary.parquet'
    )
    proposed_stratifier = 's00q01' # region
elif country == 'Malawi':
    country_data_path = data_path / 'malawi' / 'cleaned'
    data = pd.read_parquet(data_path / 'malawi/cleaned/malawi_2019.parquet')
    summary = pd.read_parquet(data_path / 'malawi/cleaned/summary.parquet')
    summary.rename(columns={'description': 'variable_description'}, inplace=True)
    proposed_stratifier = 'ea_id'
elif country == 'Ghana_nolan':
    country_data_path = data_path / 'ghana_nolan' / 'cleaned'
    data = pd.read_parquet('/data/eop/ghana_nolan/cleaned/ghana_data.parquet')
    summary = pd.read_parquet('/data/eop/ghana_nolan/cleaned/summary.parquet')
    proposed_stratifier = 'region'
elif country == 'Ghana_henry':
    country_data_path = data_path / 'ghana_henry' / 'cleaned'
    data = pd.read_parquet('/data/eop/ghana_henry/cleaned/ghana_data.parquet')
    summary = pd.read_parquet('/data/eop/ghana_henry/cleaned/summary.parquet')
    proposed_stratifier = 'region'
elif country == 'Togo':
    country_data_path = data_path / 'Togo 2018-19' / 'clean'
    data = pd.read_parquet('/data/eop/Togo 2018-19/clean/final_togo.parquet')
    summary = pd.read_parquet('/data/eop/Togo 2018-19/clean/summary.parquet')
    proposed_stratifier = 'cluster_id'
elif country == 'Togo_only_cdr':
    country_data_path = data_path / 'Togo 2018-19' / 'clean' / 'cdr_features'

    data = pd.read_parquet('/data/eop/Togo 2018-19/clean/cdr_features/togo.parquet')
    summary = pd.read_parquet('/data/eop/Togo 2018-19/clean/cdr_features/summary.parquet')

elif country == 'Togo_survey_and_cdr':    
    country_data_path = data_path / 'Togo 2018-19' / 'clean' / 'cdr_features_and_survey_predictors'

    data = pd.read_parquet('/data/eop/Togo 2018-19/clean/cdr_features_and_survey_predictors/togo.parquet')
    summary = pd.read_parquet('/data/eop/Togo 2018-19/clean/cdr_features_and_survey_predictors/summary.parquet')
    proposed_stratifier = 'cluster_id'

elif country == 'Ethiopia':
    country_data_path = data_path / 'Ethiopia 2018-19' / 'clean'
    data = pd.read_parquet('/data/eop/Ethiopia 2018-19/clean/final_ethiopia.parquet')
    summary = pd.read_parquet('/data/eop/Ethiopia 2018-19/clean/summary.parquet')
    proposed_stratifier = 'region_zone'
elif country == 'Nigeria':
    country_data_path = data_path / 'Nigeria 2018-19' / 'clean'
    data = pd.read_parquet('/data/eop/Nigeria 2018-19/clean/final_nigeria.parquet')
    summary = pd.read_parquet('/data/eop/Nigeria 2018-19/clean/summary.parquet')
    proposed_stratifier = 'ea_id'

elif country == 'Kenya':
    country_data_path = data_path / 'kenya' / 'cleaned'
    data = pd.read_parquet('/data/eop/kenya/cleaned/kenya.parquet')
    summary = pd.read_parquet('/data/eop/kenya/cleaned/summary.parquet')
    proposed_stratifier = 'county'
elif country == 'Tanzania':
    country_data_path = data_path / 'Tanzania_2020-21' / 'cleaned'
    data = pd.read_parquet('/data/eop/Tanzania_2020-21/cleaned/tanzania_data.parquet')
    summary = pd.read_parquet('/data/eop/Tanzania_2020-21/cleaned/summary.parquet')
    proposed_stratifier = 'region'
elif country == 'Madagascar':
    country_data_path = data_path / 'Madagascar 2010-11' / 'cleaned'
    data = pd.read_parquet('/data/eop/Madagascar 2010-11/cleaned/madagascar_data.parquet')
    summary = pd.read_parquet('/data/eop/Madagascar 2010-11/cleaned/summary.parquet')
    proposed_stratifier = 'REGION'
elif country == 'South Sudan':
    country_data_path = data_path / 'south_sudan' / 'cleaned'
    data = pd.read_parquet('/data/eop/south_sudan/cleaned/south_sudan_data.parquet')
    summary = pd.read_parquet('/data/eop/south_sudan/cleaned/summary.parquet')
    proposed_stratifier = 'ea'
elif country == 'South Africa':
    country_data_path = data_path / 'south_africa' / 'cleaned'
    data = pd.read_parquet('/data/eop/south_africa/cleaned/south_africa_output.parquet')
    summary = pd.read_parquet('/data/eop/south_africa/cleaned/summary.parquet')
    proposed_stratifier = 'province'
else:
    raise ValueError('Invalid country name')


if 'variable_description' not in summary.columns:
    summary['variable_description'] = summary['variable_name']

print(f'Read in: {country}')

Read in: Ethiopia


In [5]:
print(f'country: {country}')
print('nullity: ')
display(data.isna().mean().sort_values(ascending=False).head(2))
# Empty string may or may not be a problem.

print('empty string')
display(data.isin(['']).mean().sort_values(ascending=False).head(2))
print('Number of samples:')
print(data.shape[0])

country: Benin
nullity: 


hhid                        0.0
instrument_count_missing    0.0
dtype: float64

empty string


hhid                        0.0
instrument_count_missing    0.0
dtype: float64

Number of samples:
8012


## Missingness columns

In [6]:
# Missingness columns (assumes _missing suffix)
print(f'country: {country}')

missingness_columns_missing = [
    c for c in data.columns if ('missing' in c) 
]
missingness_columns_m = [
    c for c in data.columns if ('_m' in c) 
]
with_missingness = [
    c[:-8] for c in missingness_columns_missing
] + [
    c[:-2] for c in missingness_columns_m
]
missingness_columns = missingness_columns_missing + missingness_columns_m
for c in missingness_columns:
    if not (c in summary.variable_name.values):
        print(f"Missingness column {c} not in summary")
    
relevant_summary = summary[summary.variable_name.isin(with_missingness)]
print('categorical columns with missingness indicators:')

print(relevant_summary.data_type.value_counts())
display(relevant_summary[relevant_summary.data_type == 'categorical'])

# print numerical columns with no missingness indicators
print('numerical columns with no missingness indicators:')
print(summary[
    (summary.data_type == 'numeric') 
    & (~summary.variable_name.isin(with_missingness))
    & (~summary.variable_name.str.endswith('_missing'))
    & (~summary.variable_name.str.endswith('_m'))
].variable_name)

country: Benin
Missingness column salon_count_missing not in summary
Missingness column dining_count_missing not in summary
Missingness column bed_count_missing not in summary
Missingness column mattress_count_missing not in summary
Missingness column furniture_count_missing not in summary
Missingness column carpet_count_missing not in summary
Missingness column e_iron_count_missing not in summary
Missingness column charcoal_iron_count_missing not in summary
Missingness column stove_count_missing not in summary
Missingness column gas_cylinder_count_missing not in summary
Missingness column hotplate_count_missing not in summary
Missingness column microwave_count_missing not in summary
Missingness column improved_stove_count_missing not in summary
Missingness column food_processor_count_missing not in summary
Missingness column manual_juicer_count_missing not in summary
Missingness column fridge_count_missing not in summary
Missingness column freezer_count_missing not in summary
Missingn

Unnamed: 0,module_name,module_description,variable_name,variable_description,include,data_type,geographic_indicator,create,impute,notes


numerical columns with no missingness indicators:
Series([], Name: variable_name, dtype: object)


## Consumption, weights, hh size, poverty rate

In [5]:
print(f'country: {country}')
assert 'consumption_per_capita_per_day' in data.columns
assert 'headcount_adjusted_hh_wgt' in data.columns
assert pd.api.types.is_numeric_dtype(data['consumption_per_capita_per_day']), "'consumption_per_capita_per_day' is not numeric"
assert pd.api.types.is_numeric_dtype(data['headcount_adjusted_hh_wgt']), "'headcount_adjusted_hh_wgt' is not numeric"
if not 'hh_size' in data.columns:
    print('Warning: Missing hh_size')
else:
    assert np.isclose(data.hh_size * data.hh_wgt, data.headcount_adjusted_hh_wgt).all()
    assert pd.api.types.is_numeric_dtype(data['hh_size']), "'hh_size' is not numeric"

for col in ['headcount_adjusted_hh_wgt_missing', 'consumption_per_capita_per_day_missing', 'hh_wgt_missing']:
    if col in data.columns:
        assert data[col].sum() == 0, f"{col} has missing values"

print('mean:', data.consumption_per_capita_per_day.mean())
print('std:', data.consumption_per_capita_per_day.std())

count_poor = (
    data[data.consumption_per_capita_per_day < 2.15].headcount_adjusted_hh_wgt
).sum()

total = (
    data.headcount_adjusted_hh_wgt
).sum()
rate = count_poor / total

print('rate:',rate)
# To crosscheck: https://docs.google.com/spreadsheets/d/11wGVZadIZMvR2oXoDtSfjJVvixyv3ievuUOF4k_1HNY/edit?gid=0#gid=0

country: Ethiopia
mean: 4.681175409246381
std: 4.812618484963736
rate: 0.4560640131118927


## Suspiciously named columns

In [89]:
print(f'country: {country}')

# Suspicious data
print('containing the word "unit":')
display(
    summary[
        (
            summary.variable_name.str.contains('unit')
            | summary.variable_description.str.contains('unit')
        ) & (
            ~summary.variable_name.str.contains('community')
        )
    ]
)

print('containing the word "consumption":')
display(
    summary[
        summary.variable_name.str.contains('consumption')
        | summary.variable_description.str.contains('consumption')
    ]
)

# Print variables whose name contains "id" or "code" and are listed as numeric in the summary
print('variables with "id" or "code" and listed numeric:')

filtered_variables = summary[
    (summary["variable_name"].str.contains("id|code", case=False, na=False)) &
    (summary["data_type"] == "numeric")
]

# Print the name and description of the filtered variables
for _, row in filtered_variables.iterrows():
    print(f"Name: {row['variable_name']}, Description: {row['variable_description']}")

country: Ghana_nolan
containing the word "unit":


Unnamed: 0,variable_name,variable_description,module_name,module_description,data_type,geographic_indicator,geographic_indicator_coarser


containing the word "consumption":


Unnamed: 0,variable_name,variable_description,module_name,module_description,data_type,geographic_indicator,geographic_indicator_coarser
165,consumption_per_capita_per_day,Daily consumption per capita by the household.,percapita_expenditure_df,Expenditure per capita,numeric,False,False
304,consumption_per_capita_per_day_missing,Missingness indicator for consumption_per_capi...,Missingness Indicators,,categorical,False,False


variables with "id" or "code" and listed numeric:
Name: sample_hh_id, Description: Unique identifier for the sample household.
Name: num_camcorder/video_camera_owned, Description: Number of camcorders/video cameras owned by the household.
Name: num_video_player_owned, Description: Number of video players owned by the household.


## Summary correctness: Matches data, format

In [90]:
print(f'country: {country}')

# check that metadata and data match
data_columns = set(data.columns)

summary_variable_names = set(summary['variable_name'])
missing_in_data = summary_variable_names - data_columns
missing_in_summary = data_columns - summary_variable_names

print("Variables in summary but not in data:", missing_in_data)
print("Columns in data but not in summary:", missing_in_summary)

country: Ghana_nolan
Variables in summary but not in data: set()
Columns in data but not in summary: set()


In [91]:
print(f'country: {country}')
# Check that "summary" fits the required format
required_columns = {
    "variable_name", "data_type", "geographic_indicator", "geographic_indicator_coarser"
    }
summary_columns = set(summary.columns)

missing_columns = required_columns - summary_columns
if missing_columns:
    raise ValueError(f"Missing required columns in summary: {missing_columns}")

# Ensure "data_type" has only permitted values
permitted_data_types = {"numeric", "categorical"}
found_errors = False
for _, row in summary.iterrows():
    if row["data_type"] not in permitted_data_types:
        print(
            f"Invalid data_type '{row['data_type']}' for variable '{row['variable_name']}'. "
            f"Description: {row['variable_description']}"
        )
        found_errors = True

# Ensure "geographic_indicator_coarser", "geographic_indicator_finer" is boolean or 0-1
for _, row in summary.iterrows():
    for c in ["geographic_indicator", "geographic_indicator_coarser", "geographic_indicator_finer"]:
        if c not in row:
            continue
        if row[c] not in [0, 1, True, False, None]:
            print(
                f"Invalid {c} '{row[c]}' for variable '{row['variable_name']}'. "
                f"Description: {row['variable_description']}"
            )
            found_errors = True
if found_errors:
    raise ValueError("Errors found in summary metadata. Please fix them before proceeding.")

country: Ghana_nolan


## Data types

In [92]:
print(f'country: {country}')

# Check that numeric columns in summary are actually numeric in data
numeric_columns = summary[summary["data_type"] == "numeric"]["variable_name"]
found_error = False
for col in numeric_columns:
    if col in data.columns and not pd.api.types.is_numeric_dtype(data[col]):
        description = summary.loc[summary["variable_name"] == col, "variable_description"].values[0]
        print(f"BAD: numeric in summary, non-numeric in data: '{col}'; {description}")
        found_error = True
if found_error:
    raise ValueError('Found numeric columns in summary that are not numeric in data.')

# Check that categorical columns in summary are actually categorical in data (less important)
categorical_columns = summary[summary["data_type"] == "categorical"]["variable_name"]
for col in categorical_columns:
    if (
        col in data.columns 
        and not pd.api.types.is_categorical_dtype(data[col])
        and not (col.endswith('_missing') or col.endswith('_m'))
    ):
        description = summary.loc[summary["variable_name"] == col, "variable_description"].values[0]
        print(f"categorical in summary, numeric in data: '{col}'; {description}")

country: Ghana_nolan
categorical in summary, numeric in data: 'region'; Region where the household is located.
categorical in summary, numeric in data: 'hhid'; Unique identifier for the household.
categorical in summary, numeric in data: 'urbrur'; Urban-rural classification indicating whether the household is in an urban or rural area.
categorical in summary, numeric in data: 'locality_zone'; Specific zone or area within the locality where the household is located.
categorical in summary, numeric in data: 'head_gender'; Gender of the head of the household.
categorical in summary, numeric in data: 'head_marital_status'; Marital status of the head of the household.
categorical in summary, numeric in data: 'head_religion'; Religion of the head of the household.
categorical in summary, numeric in data: 'head_ethnic_group'; Ethnic group to which the head of the household belongs.
categorical in summary, numeric in data: 'head_industry_or_trade'; Industry or trade in which the head of the ho

  and not pd.api.types.is_categorical_dtype(data[col])


## Duplicate columns

In [93]:
# Check for duplicate information. Don't do if there are too many columns.
print(f'country: {country}')

find_equivalent_columns(data, summary)

country: Ghana_nolan
Constant columns:
num_child_missing
num_adult_missing
num_elder_missing
businesses_owned_missing
hh_size_missing
consumption_per_capita_per_day_missing



[('num_adult', 'num_elder', 'identical'),
 ('num_pincers_owned', 'num_pinch_bar_owned', 'identical'),
 ('num_pincers_owned',
  'num_screw_driver_owned',
  'numeric_categorical_equivalent'),
 ('num_pinch_bar_owned',
  'num_screw_driver_owned',
  'numeric_categorical_equivalent')]

## Geography and stratification

In [94]:
print(f'country: {country}')

print('geographic indicators:')
display(summary[summary.geographic_indicator])
for _, row in summary[summary.geographic_indicator].iterrows():
    print(row.variable_name)
    print(data[row.variable_name].nunique())

if False:
    partially_represented = summary[
        (summary.geographic_indicator_finer) & ~(summary.geographic_indicator_coarser)
    ]
    if len(partially_represented) == 0:
        print('No partially represented geo level')
    else:
        assert len(partially_represented) == 1
        print('partially represented:')
        print(partially_represented.variable_name.values[0])
        print('partially represented value counts:')
        print(data[partially_represented.variable_name.values[0]].value_counts())

country: Ghana_nolan
geographic indicators:


Unnamed: 0,variable_name,variable_description,module_name,module_description,data_type,geographic_indicator,geographic_indicator_coarser
0,region,Region where the household is located.,key_hhld_info_df,Key Household Information,categorical,True,True
1,district,District where the household is located.,key_hhld_info_df,Key Household Information,numeric,True,False
2,ea_number,Enumeration area identifier for the household'...,key_hhld_info_df,Key Household Information,numeric,True,False
163,ghana_zone,Ghana zone where the household is located.,sec0_df,Checklist,numeric,True,False


region
10
district
115
ea_number
332
ghana_zone
3


In [95]:
# Stratification
print(f'country: {country}')
print('proposed stratifier:', proposed_stratifier)
print('count per unit')
display(data.groupby(proposed_stratifier, observed=True).size().reset_index(name='count').sort_values('count', ascending=False))

print('weights per unit')
display(data.groupby(proposed_stratifier,  observed=True).hh_wgt.nunique().sort_values())

print('regions per weight class')
display(data.groupby('hh_wgt', observed=True)[proposed_stratifier].nunique().sort_values())

print('count per unit x weight')
display(
    data.groupby(['hh_wgt', proposed_stratifier],  observed=True)
    .size()
    .reset_index(name='count')
    .sort_values('count', ascending=False)
)

country: Ghana_nolan
proposed stratifier: region
count per unit


Unnamed: 0,region,count
6,6. Ashanti Region,891
5,5. Eastern Region,625
3,3. Greater Accra Region,583
8,8. Northern Region,566
7,7. Brong Ahafo Region,508
4,4. Volta Region,480
0,1. Western Region,464
2,2. Central Region,419
9,9. Upper East Region,240
1,10. Upper West Region,177


weights per unit


region
10. Upper West Region      12
9. Upper East Region       16
2. Central Region          28
1. Western Region          31
4. Volta Region            32
7. Brong Ahafo Region      34
3. Greater Accra Region    38
8. Northern Region         38
5. Eastern Region          42
6. Ashanti Region          59
Name: hh_wgt, dtype: int64

regions per weight class


hh_wgt
214.980820     1
1528.199341    1
1516.917603    1
1514.871582    1
1503.002930    1
              ..
878.241089     1
874.731934     1
874.081055     1
864.869019     1
4751.101562    1
Name: region, Length: 330, dtype: int64

count per unit x weight


Unnamed: 0,hh_wgt,region,count
139,1089.019043,6. Ashanti Region,30
326,4502.678711,3. Greater Accra Region,30
0,214.980820,8. Northern Region,15
216,1458.206665,2. Central Region,15
224,1533.486450,7. Brong Ahafo Region,15
...,...,...,...
200,1395.010376,5. Eastern Region,14
207,1422.286011,6. Ashanti Region,14
136,1060.170776,6. Ashanti Region,13
162,1196.175293,8. Northern Region,13


## Split data, check with roshni's code

In [96]:
print(f'Country: {country}, country data path: {country_data_path}')
os.chdir(country_data_path)
subprocess.run(['bash', str(country_data_path / 'split.sh')], check=True)


Country: Ghana_nolan, country data path: /data/eop/ghana_nolan/cleaned


CompletedProcess(args=['bash', '/data/eop/ghana_nolan/cleaned/split.sh'], returncode=0)

In [97]:
print(f'Country: {country}, country data path: {country_data_path}')

train_path = country_data_path / 'train.parquet'
test_path = country_data_path / 'test.parquet'
summary_path = country_data_path / 'summary.parquet'
# Assert that train.parquet and test.parquet are newer than any other parquet file under country_data_path
all_parquet_files = list(country_data_path.glob('*.parquet'))
for f in all_parquet_files:
    if f not in [train_path, test_path]:
        assert train_path.stat().st_mtime > f.stat().st_mtime, f"{train_path.name} is not newer than {f.name}"
        assert test_path.stat().st_mtime > f.stat().st_mtime, f"{test_path.name} is not newer than {f.name}"

train_dataset, validation_dataset, test_covariate_dataset, test_dataset = load_datasets(
    trainpath = train_path,
    testpath = test_path,
    summarypath = summary_path,
    geo_extrapolation = True,
    outcome = 'consumption_per_capita_per_day',
    weight = 'headcount_adjusted_hh_wgt'
)
X, y, r = train_dataset.get_data()
X, X_mean, X_std = standardize(X)
y, y_mean, y_std = standardize(y)

model = LinearRegression(fit_intercept=True)
model.fit(X, y, sample_weight=r)

Country: Ghana_nolan, country data path: /data/eop/ghana_nolan/cleaned
