## Imports

In [303]:
import pandas as pd
from pathlib import Path
import numpy as np
import itertools
import subprocess
import os
from glob import glob



from opt_targeted_transfers import standardize
from sklearn.linear_model import LinearRegression

from opt_targeted_transfers import Dataset, split
def get_row_from_metadata(metadata, covariate_name):
    """
    Extracts a specific row from the metadata DataFrame based on the covariate name.

    :param metadata: DataFrame containing metadata.
    :param covariate_name: Name of the covariate to extract.
    :return: Row corresponding to the specified covariate name.
    """
    return metadata.loc[metadata['variable_name'] == covariate_name].squeeze()

def all_rows_from_metadata_containing(metadata, substring):
    """
    Extracts all rows from the metadata DataFrame that contain a specific substring in the variable name.

    :param metadata: DataFrame containing metadata.
    :param substring: Substring to search for in the variable names.
    :return: DataFrame containing all rows with variable names that contain the substring.
    """
    return metadata[metadata['variable_name'].str.contains(substring, na=False)].reset_index(drop=True)

def all_column_names_containing(df, substring):
    """
    Extracts all column names from the DataFrame that contain a specific substring.

    :param df: DataFrame to search for column names.
    :param substring: Substring to search for in the column names.
    :return: List of column names containing the specified substring.
    """
    return [col for col in df.columns if substring in col]


def find_equivalent_columns(data, summary, numeric_tolerance=1e-6, categorical_threshold=0.99):
    """
    Find pairs of columns in a DataFrame that are informationally equivalent.
    
    Parameters:
    -----------
    data : pandas DataFrame
        The DataFrame to analyze
    numeric_tolerance : float, default 1e-6
        Tolerance for considering numeric columns equal or proportional
    categorical_threshold : float, default 0.99
        Threshold for considering categorical columns equivalent (percentage match)
    
    Returns:
    --------
    list of tuples
        Each tuple contains (col1, col2, relationship_type)
        where relationship_type is one of: 'identical', 'proportional', 'categorical_equivalent'
    """
    equivalent_pairs = []
    columns = data.columns
    
    # Get column types
    
    numeric_cols = summary[summary.data_type == 'numeric'].variable_name.tolist()
    categorical_cols = summary[summary.data_type == 'categorical'].variable_name.tolist()

    # Identify constant columns
    constant_cols = []
    for col in columns:
        unique_values = data[col].dropna().unique()
        if len(unique_values) <= 1:
            constant_cols.append(col)

    # Print constant columns if verbose
    if len(constant_cols) > 0:
        print("Constant columns:")
        for col in constant_cols:
            print(col)
        print()
    

    # Remove constant columns from numeric and categorical lists
    numeric_cols = [col for col in numeric_cols if col not in constant_cols]
    categorical_cols = [col for col in categorical_cols if col not in constant_cols]

    # remove missingness-indicator columns
    missingness_cols = list(
        set(all_column_names_containing(data, '_missing') + 
        all_column_names_containing(data, '_m'))
    )
    numeric_cols = [col for col in numeric_cols if col not in missingness_cols]
    categorical_cols = [col for col in categorical_cols if col not in missingness_cols]
    
    # Check numeric columns for equality or proportionality
    for col1, col2 in itertools.combinations(numeric_cols, 2):

        # Check for identical values first
        if data[col1].equals(data[col2]):
            equivalent_pairs.append((col1, col2, 'identical'))
            continue
            
        # Check for identical values where neither is zero
        valid_mask = ~data[col1].isna() & ~data[col2].isna()
        if np.allclose(data.loc[valid_mask, col1], data.loc[valid_mask, col2], 
                      rtol=numeric_tolerance, atol=numeric_tolerance):
            equivalent_pairs.append((col1, col2, 'nearly_identical'))
            continue
        
        # For rows with zeros, check if the columns are exactly equal
        zero_mask = (data[col1] == 0) | (data[col2] == 0)
        non_zero_mask = ~zero_mask & valid_mask
        
        # Check if the columns have the same values where zeros are present
        if zero_mask.any():
            zero_equality = (data.loc[zero_mask & valid_mask, col1] == 
                             data.loc[zero_mask & valid_mask, col2]).all()
        else:
            zero_equality = True
            
        # Check for proportional relationship in non-zero values
        if non_zero_mask.sum() > 10:  # Require at least some non-zero values
            ratios = data.loc[non_zero_mask, col2].astype(int) / data.loc[non_zero_mask, col1].astype(int)
            ratio_std = ratios.std()
            
            # If standard deviation of ratios is very small, columns are proportional
            if ratio_std < numeric_tolerance and zero_equality:
                ratio = ratios.mean()
                equivalent_pairs.append((col1, col2, f'proportional (factor: {ratio:.4f})'))
    
    # Create a list of all columns to check for categorical equivalence
    # This includes both explicit categorical columns and numeric columns
    all_potential_categorical_cols = categorical_cols + numeric_cols
    
    # Check all columns for equivalent categorical mappings
    for col1, col2 in itertools.combinations(all_potential_categorical_cols, 2):
        # Skip if identical columns or already identified as identical or proportional
        if col1 == col2 or any((col1, col2, rel) in equivalent_pairs for rel in 
                               ['identical', 'nearly_identical', 'proportional']):
            continue
            
        # Get unique values for both columns
        unique_vals1 = data[col1].dropna().unique()
        unique_vals2 = data[col2].dropna().unique()
        
        # Skip if columns have different number of unique values
        if len(unique_vals1) != len(unique_vals2):
            continue
            
        # Skip if too many unique values (likely not categorical)
        if len(unique_vals1) > 100:  # Arbitrary threshold, adjust as needed
            continue
            
        # Create a mapping table between values in both columns
        mapping_df = data[[col1, col2]].dropna().drop_duplicates()
        
        # Check if mapping is one-to-one (each value in col1 maps to exactly one value in col2)
        is_one_to_one = True
        
        # Check col1 -> col2 mapping
        for val in unique_vals1:
            corresponding_vals = data.loc[data[col1] == val, col2].dropna().unique()
            if len(corresponding_vals) != 1:
                is_one_to_one = False
                break
                
        # Check col2 -> col1 mapping
        if is_one_to_one:
            for val in unique_vals2:
                corresponding_vals = data.loc[data[col2] == val, col1].dropna().unique()
                if len(corresponding_vals) != 1:
                    is_one_to_one = False
                    break
        
        if is_one_to_one:
            # If we create a new column using the mapping, it should match the original
            val_mapping = dict(zip(mapping_df[col1], mapping_df[col2]))
            
            # Apply mapping and handle NaN values
            mapped_values = data[col1].map(val_mapping)
            
            # Count matches (ignoring NaN values)
            valid_mask = ~data[col1].isna() & ~data[col2].isna()
            if valid_mask.sum() > 0:
                match_percentage = (mapped_values == data[col2])[valid_mask].mean()
                
                if match_percentage >= categorical_threshold:
                    # Determine if both are numeric or mixed types
                    if col1 in numeric_cols and col2 in numeric_cols:
                        relationship = 'numeric_categorical_equivalent'
                    else:
                        relationship = 'categorical_equivalent'
                    equivalent_pairs.append((col1, col2, relationship))
    
    return equivalent_pairs



def get_data_for_geo_extrapolation(data, summary, geo_extrapolation):
    """
    Preprocess the testing data for geo-extrapolation.

    Args:
        data (pd.DataFrame): The input data.
        summary (pd.DataFrame): The summary data.
    Returns:
        pd.DataFrame: The preprocessed data without geographic identifiers
    """

    geo_cols = summary[summary["geographic_indicator"] == True][
        "variable_name"
    ].tolist()

    coarse_geo_cols = summary[summary["geographic_indicator_coarser"] == True][
        "variable_name"
    ].tolist()

    remove_for_coarse = set(geo_cols) - set(coarse_geo_cols)
    remove_for_coarse = list(remove_for_coarse)

    if geo_extrapolation:
        data = data.drop(columns=remove_for_coarse)
    else:
        1/0
    return data


def load_datasets(
    trainpath, testpath, summarypath, geo_extrapolation, outcome='consumption_per_day_per_capita', 
    weight='headcount_adjusted_hh_wgt'
):
    """
    Load datasets.

    Args:
        trainpath (str): Path to the training data file.
        testpath (str): Path to the test data file.
        outcome (str): Outcome variable.
        weight (str): Weight variable.

    Returns:
        train_dataset (Dataset): Training dataset.
        test_dataset (Dataset): Test dataset.
    """
    data1 = _load_data(trainpath)
    data2 = _load_data(testpath)
    summary = pd.read_parquet(summarypath)

    data1 = get_data_for_geo_extrapolation(data1, summary, geo_extrapolation)
    data2 = get_data_for_geo_extrapolation(data2, summary, geo_extrapolation)

    all_data = pd.concat([data1, data2], ignore_index=True)
    all_data = convert_to_onehot(all_data, summary)

    train_data = _load_data(trainpath)
    test_data = _load_data(testpath)

    train_data = get_data_for_geo_extrapolation(train_data, summary, geo_extrapolation)
    test_data = get_data_for_geo_extrapolation(test_data, summary, geo_extrapolation)
    covs = list(train_data.columns)
    covs.remove(outcome)
    covs.remove(weight)

    train_data = convert_to_onehot(train_data, summary)
    test_data = convert_to_onehot(test_data, summary)

    train_missing_columns = set(all_data.columns) - set(train_data.columns)
    res = [train_data]
    for col in train_missing_columns:
        res.append(pd.DataFrame({col: np.zeros(len(train_data))}))
    final_train_data = pd.concat(res, axis=1)

    test_missing_columns = set(all_data.columns) - set(test_data.columns)
    res = [test_data]
    for col in test_missing_columns:
        res.append(pd.DataFrame({col: np.zeros(len(test_data))}))
    final_test_data = pd.concat(res, axis=1)

    train_dataset = Dataset(
        final_train_data.astype("float32"), outcome=outcome, covs=covs, weight=weight
    )
    test_dataset = Dataset(
        final_test_data.astype("float32"), outcome=outcome, covs=covs, weight=weight
    )
    test_covariate_dataset = Dataset(
        final_test_data.astype("float32"), outcome=None, covs=covs, weight=weight
    )

    train_dataset, validation_dataset = split(train_dataset)

    return train_dataset, validation_dataset, test_covariate_dataset, test_dataset


def convert_to_onehot(df, summary):
    """
    Convert categorical columns to one-hot encoding.

    :param df: The input data.
    :type df: pandas.DataFrame
    :return new_df: The input data with one-hot encoding.
    :rtype: pandas.DataFrame
    """
    if "type" in summary.columns:
        data_type = "type"
    elif "data_type" in summary.columns:
        data_type = "data_type"
    if "covariate" in summary.columns:
        covariate = "covariate"
    elif "variable_name" in summary.columns:
        covariate = "variable_name"

    categorical_columns = summary[summary[data_type] == "categorical"][
        covariate
    ].tolist()


    categorical_columns = [col for col in categorical_columns if col in df.columns]
    # print top 5 categorical columns by number of distinct values
    counts = {col: df[col].nunique(dropna=True) for col in categorical_columns}
    if counts:
        top5 = pd.Series(counts).sort_values(ascending=False).head(5)
        print("Top 5 categorical columns by distinct values:")
        for col, cnt in top5.items():
            print(f"{col}: {cnt}")
    else:
        print("No categorical columns found.")
    one_hot = pd.get_dummies(df[categorical_columns]).astype(np.float32)
    df.drop(columns=categorical_columns, inplace=True)
    new_df = pd.concat([df, one_hot], axis=1)
    return new_df


def _load_data(path):
    """
    Load data.

    Args:
        path (str): Path to the data file.

    Returns:
        data_for_wgan (pd.DataFrame): Data for WGAN training.
        data_wrapper (wgan.DataWrapper): DataWrapper object for WGAN training.
    """
    data = pd.read_parquet(path)

    if "hhid" in data.columns:
        data = data.drop(columns=["hhid"])
    if "case_id" in data.columns:
        data = data.drop(columns=["case_id"])
    if "hh_id" in data.columns:
        data = data.drop(columns=["hh_id"])
    if "hh_wgt" in data.columns:
        data = data.drop(columns=["hh_wgt"])

    return data.reset_index(drop=True)

"""
Done in this notebook
- Ensure that missingness-indicator columns exist.
    - You probably can't conclusively check that all are included, because 
       the data you get will not necessarily reveal which columns had missingness, but check 
       that there are some missingness columns, and none for categorical data.
- Ensure there are no NaNs in the data.
- Ensure column names:
    - In data: "hhid" (if household ID is included), "consumption_per_capita_per_day", "hh_wgt".
    - Consumption: Check mean and std for sanity. In a poor country, the mean should be low-mid single 
      digits: e.g., in Uganda, the mean is $3.80/day.
- Check for columns that indicate units:
    - If they are present, the corresponding numeric field should be standardized, e.g., all area units 
      adjusted to square meters.
- Check that metadata and the dataset itself match:
    - Every column in data is described in metadata and vice versa. It's also OK if `hhid` is not in the data at all.
- In metadata:
    - "variable_name".
    - "data_type", with permitted values "numeric" and "categorical".
    - "geographic_indicator".
- Scan datatypes:
    - In particular, make sure nothing is numeric which should be categorical.
    - Ensure categorical-type columns have the appropriate type even if the categories are encoded as integers (if a 
      column is binary, with no missing values, it can be numeric or categorical).
    - IDs of all kinds are strings even if they appear numeric.
- Check for duplication
- Check feasibility of stratification
""";


## Read in

In [304]:

data_path = Path('/data/eop/country_data')

data, summary, proposed_stratifier = None, None, None

country_code_map = {
    'bangladesh': 'BGD',
    'benin': 'BEN',
    'burkina_faso': 'BFA',
    'colombia': 'COL',
    'cote_divoire': 'CIV',
    'ethiopia': 'ETH',
    'ghana': 'GHA',
    'ghana_henry': 'GHA',
    'guatemala': 'GTM',
    'guinea-bissau': 'GNB',
    'india': 'IND',
    'indonesia': 'IDN',
    'kenya': 'KEN',
    'madagascar': 'MDG',
    'malawi': 'MWI',
    'mali': 'MLI',
    'niger': 'NER',
    'nigeria': 'NGA',
    'rwanda': 'RWA',
    'senegal': 'SEN',
    'somalia': 'SOM',
    'south_africa': 'ZAF',
    'south_sudan': 'SSD',
    'tanzania': 'TZA',
    'timor-leste': 'TLS',
    'togo': 'TGO',
    'togo_survey_and_cdr': 'TGO',
    'uganda': 'UGA'
}    
country_stratifier_map = {
    'bangladesh': 'DivCode',
    'benin': 'region', 
    'burkina_faso': 'region', 
    'colombia': 'domain', 
    'cote_divoire': 'region',
    'ethiopia': 'region', #
    'ghana': 'region', 
    'ghana_henry': 'region',
    'guatemala': 'region',
    'guinea-bissau': 'region',
    'india': 'state', 
    'indonesia': 'regency_city_code',
    'kenya': 'county', 
    'madagascar': 'region',
    'malawi': 'ea_id', 
    'mali': 'region', 
    'niger': 's00q01',
    'nigeria': 'ea_id', 
    'rwanda': 'district',
    'senegal': 'region',
    'somalia': 'region',
    'south_africa': 'province',
    'south_sudan': 'ea',
    'tanzania': 'region',
    'timor-leste': None,
    'togo': 'cluster_id',
    'togo_survey_and_cdr': 'cluster_id',
    'uganda': 'region'
}


# for the compiled data: convert countries whose directory names don't match their lower-case country name.
country_name_map = {
    'burkina_faso': 'burkina faso',
    'cote_divoire': "côte d'ivoire",
    'south_africa': 'south africa',
    'south_sudan': 'south sudan'
}
#########################
country = 'india' # Change this to the desired country
#########################
country_code = country_code_map[country]
country_data_path = data_path / country_code / 'cleaned'

data = pd.read_parquet(country_data_path / 'full.parquet')
summary = pd.read_parquet(country_data_path / 'summary.parquet')
proposed_stratifier = country_stratifier_map[country]

# Read in the most recent auxiliary data file available
aux_files = glob('/data/eop/compiled_country_data/auxiliary_data/auxiliary_data_*.csv')
latest_file = max(aux_files, key=lambda x: x.split('_')[-1].split('.')[0])

if country == 'somalia':
    print('Warning: Sidestepping conversion doc for somalia')

else:     
    compiled_data = pd.read_csv(latest_file)
    wb_poverty_rate_survey_year_2017 = compiled_data[
        compiled_data.country_code == country_code
    ].wb_poverty_rate_povertyline_2017_survey_year.values[0]
    wb_poverty_rate_survey_year_2021 = compiled_data[
        compiled_data.country_code == country_code
    ].wb_poverty_rate_povertyline_2021_survey_year.values[0]
    conversion_2021_to_2017 = compiled_data[
        compiled_data.country_code == country_code
    ].overall_conversion_factor_ratio_from_2021_to_2017.values[0]

if 'variable_description' not in summary.columns:
    summary['variable_description'] = summary['variable_name']

print(f'Read in: {country}')
print('nullity: ')
display(data.isna().mean().sort_values(ascending=False).head(2))
# Empty string may or may not be a problem.

print('empty string')
display(data.isin(['']).mean().sort_values(ascending=False).head(2))
print('Number of samples:')
print(data.shape[0])

Read in: india
nullity: 


hhid      0.0
sector    0.0
dtype: float64

empty string


latrine_type     0.100070
roof_material    0.000057
dtype: float64

Number of samples:
261746


## Summary correctness: Matches data, format

In [305]:
print(f'country: {country}')

# check that metadata and data match
data_columns = set(data.columns)

summary_variable_names = set(summary['variable_name'])
missing_in_data = summary_variable_names - data_columns
missing_in_summary = data_columns - summary_variable_names

print("Variables in summary but not in data:", missing_in_data)
print("Columns in data but not in summary:", missing_in_summary)

country: india
Variables in summary but not in data: set()
Columns in data but not in summary: set()


## Missingness columns

In [306]:
# Missingness columns (assumes _missing suffix)
print(f'country: {country}')

missingness_columns_missing = [
    c for c in data.columns if ('missing' in c) 
]
missingness_columns_m = [
    c for c in data.columns if ('_m' in c) 
]
with_missingness = [
    c[:-8] for c in missingness_columns_missing
] + [
    c[:-2] for c in missingness_columns_m
]
missingness_columns = missingness_columns_missing + missingness_columns_m
for c in missingness_columns:
    if not (c in summary.variable_name.values):
        print(f"Missingness column {c} not in summary")
    
relevant_summary = summary[summary.variable_name.isin(with_missingness)]
print('categorical columns with missingness indicators:')

print(relevant_summary.data_type.value_counts())
display(relevant_summary[relevant_summary.data_type == 'categorical'])

# print numerical columns with no missingness indicators
print('numerical columns with no missingness indicators:')
print(summary[
    (summary.data_type == 'numeric') 
    & (~summary.variable_name.isin(with_missingness))
    & (~summary.variable_name.str.endswith('_missing'))
    & (~summary.variable_name.str.endswith('_m'))
].variable_name)

country: india
categorical columns with missingness indicators:
Series([], Name: count, dtype: int64)


Unnamed: 0,variable_name,data_type,geographic_indicator,geographic_indicator_coarser,variable_description


numerical columns with no missingness indicators:
5                             hh_wgt
6                           head_age
10                           hh_size
11                          num_male
12                        num_female
13                         num_child
14                         num_adult
15                       num_elderly
19                  area_owned_acres
42    consumption_per_capita_per_day
43         headcount_adjusted_hh_wgt
Name: variable_name, dtype: object


## Consumption, weights, hh size, poverty rate

In [307]:
print(f'country: {country}')
assert 'consumption_per_capita_per_day' in data.columns
if 'hhid' in data.columns:
    assert data.hhid.is_unique, "'hhid' is not unique"
assert 'headcount_adjusted_hh_wgt' in data.columns
assert pd.api.types.is_numeric_dtype(data['consumption_per_capita_per_day']), "'consumption_per_capita_per_day' is not numeric"
assert pd.api.types.is_numeric_dtype(data['headcount_adjusted_hh_wgt']), "'headcount_adjusted_hh_wgt' is not numeric"
if not 'hh_size' in data.columns:
    print('Warning: Missing hh_size')
else:
    if not np.isclose(data.hh_size * data.hh_wgt, data.headcount_adjusted_hh_wgt).all():
        print('Warning: hh_size * hh_wgt does not equal headcount_adjusted_hh_wgt for all rows')
    assert pd.api.types.is_numeric_dtype(data['hh_size']), "'hh_size' is not numeric"

for col in ['headcount_adjusted_hh_wgt_missing', 'consumption_per_capita_per_day_missing', 'hh_wgt_missing']:
    if col in data.columns:
        assert data[col].sum() == 0, f"{col} has missing values"


print('mean:', data.consumption_per_capita_per_day.mean())
print('std:', data.consumption_per_capita_per_day.std())

consumption_adjusted = data.consumption_per_capita_per_day * conversion_2021_to_2017

count_poor = (
    data[consumption_adjusted < 2.15].headcount_adjusted_hh_wgt
).sum()

total = (
    data.headcount_adjusted_hh_wgt
).sum()
rate = count_poor / total

print('survey rate (2017 PPP, 2.15 line):',rate)
print('wb rate (2017 PPP, 2.15 line):', wb_poverty_rate_survey_year_2017)
print('discrepancy:', rate - wb_poverty_rate_survey_year_2017)

count_poor = (
    data[data.consumption_per_capita_per_day < 3].headcount_adjusted_hh_wgt
).sum()

total = (
    data.headcount_adjusted_hh_wgt
).sum()
rate = count_poor / total

print('survey rate (2021 PPP, 3.00 line):',rate)
print('wb rate (2021 PPP, 3.00 line):', wb_poverty_rate_survey_year_2021)
print('discrepancy:', rate - wb_poverty_rate_survey_year_2021)


country: india
mean: 8.803151
std: 7.090850830078125
survey rate (2017 PPP, 2.15 line): 0.03343365990658686
wb rate (2017 PPP, 2.15 line): 0.023491840878
discrepancy: 0.00994181902858686
survey rate (2021 PPP, 3.00 line): 0.06564702310030764
wb rate (2021 PPP, 3.00 line): 0.052516401
discrepancy: 0.013130622100307639


## Suspiciously named columns

In [308]:
print(f'country: {country}')

# Suspicious data
print('containing the word "unit":')
display(
    summary[
        (
            summary.variable_name.str.contains('unit')
            | summary.variable_description.str.contains('unit')
        ) & (
            ~summary.variable_name.str.contains('community')
        )
    ]
)

print('containing the word "consumption":')
display(
    summary[
        summary.variable_name.str.contains('consumption')
        | summary.variable_description.str.contains('consumption')
    ]
)

print('containing suspicious demographic words')
display(
    summary[
        summary.variable_name.str.contains('relig|ethn|nationality')
        | summary.variable_description.str.contains('relig|ethn|nationality')
    ]
)

# Print variables whose name contains "id" or "code" and are listed as numeric in the summary
print('variables with "id" or "code" and listed numeric:')

filtered_variables = summary[
    (summary["variable_name"].str.contains("id|code", case=False, na=False)) &
    (summary["data_type"] == "numeric")
]

# Print the name and description of the filtered variables
for _, row in filtered_variables.iterrows():
    print(f"Name: {row['variable_name']}, Description: {row['variable_description']}")

country: india
containing the word "unit":


Unnamed: 0,variable_name,data_type,geographic_indicator,geographic_indicator_coarser,variable_description
20,type_dwelling_unit,categorical,False,False,type_dwelling_unit


containing the word "consumption":


Unnamed: 0,variable_name,data_type,geographic_indicator,geographic_indicator_coarser,variable_description
42,consumption_per_capita_per_day,numeric,False,False,consumption_per_capita_per_day


containing suspicious demographic words


Unnamed: 0,variable_name,data_type,geographic_indicator,geographic_indicator_coarser,variable_description


variables with "id" or "code" and listed numeric:


In [309]:
print(f'country: {country}')
# Check that "summary" fits the required format
required_columns = {
    "variable_name", "data_type", "geographic_indicator", "geographic_indicator_coarser"
    }
summary_columns = set(summary.columns)

missing_columns = required_columns - summary_columns
if missing_columns:
    raise ValueError(f"Missing required columns in summary: {missing_columns}")

# Ensure "data_type" has only permitted values
permitted_data_types = {"numeric", "categorical"}
found_errors = False
for _, row in summary.iterrows():
    if row["data_type"] not in permitted_data_types:
        print(
            f"Invalid data_type '{row['data_type']}' for variable '{row['variable_name']}'. "
            f"Description: {row['variable_description']}"
        )
        found_errors = True

# Ensure "geographic_indicator_coarser", "geographic_indicator_finer" is boolean or 0-1
for _, row in summary.iterrows():
    for c in ["geographic_indicator", "geographic_indicator_coarser", "geographic_indicator_finer"]:
        if c not in row:
            continue
        if row[c] not in [0, 1, True, False, None]:
            print(
                f"Invalid {c} '{row[c]}' for variable '{row['variable_name']}'. "
                f"Description: {row['variable_description']}"
            )
            found_errors = True
if found_errors:
    raise ValueError("Errors found in summary metadata. Please fix them before proceeding.")

country: india


## Data types

In [310]:
print(f'country: {country}')

# Check that numeric columns in summary are actually numeric in data
numeric_columns = summary[summary["data_type"] == "numeric"]["variable_name"]
found_error = False
for col in numeric_columns:
    if col in data.columns and (not pd.api.types.is_numeric_dtype(data[col])):
        description = summary.loc[summary["variable_name"] == col, "variable_description"].values[0]
        print(f"BAD: numeric in summary, non-numeric in data: '{col}'; {description}")
        found_error = True
if found_error:
    raise ValueError('Found numeric columns in summary that are not numeric in data.')

# Check that categorical columns in summary are actually categorical in data (less important)
categorical_columns = summary[summary["data_type"] == "categorical"]["variable_name"]
for col in categorical_columns:
    if (
        col in data.columns 
        and pd.api.types.is_numeric_dtype(data[col])
        and not (col.endswith('_missing') or col.endswith('_m'))
    ):
        description = summary.loc[summary["variable_name"] == col, "variable_description"].values[0]
        print(f"categorical in summary, numeric in data: '{col}'; {description}")

country: india


## Duplicate columns

In [311]:
# Check for duplicate information. Don't do if there are too many columns (i.e. remote sensing/CDR).
print(f'country: {country}')

if False:
    find_equivalent_columns(data, summary)
else:
    print('Warning: Skipping equivalent-column check')

country: india


## Geography and stratification

In [312]:
print(f'country: {country}')

print('geographic indicators:')
display(summary[summary.geographic_indicator])
for _, row in summary[summary.geographic_indicator].iterrows():
    print(row.variable_name)
    print(data[row.variable_name].nunique())

if False:  # Holdover from previous partially-represented step
    partially_represented = summary[
        (summary.geographic_indicator_finer) & ~(summary.geographic_indicator_coarser)
    ]
    if len(partially_represented) == 0:
        print('No partially represented geo level')
    else:
        assert len(partially_represented) == 1
        print('partially represented:')
        print(partially_represented.variable_name.values[0])
        print('partially represented value counts:')
        print(data[partially_represented.variable_name.values[0]].value_counts())

country: india
geographic indicators:


Unnamed: 0,variable_name,data_type,geographic_indicator,geographic_indicator_coarser,variable_description
2,state,categorical,True,True,state
3,nss_region,categorical,True,False,nss_region
4,district,categorical,True,False,district


state
36
nss_region
87
district
71


In [313]:
# Stratification
print(f'country: {country}')

if country == 'colombia':
    data['strat'] = data.domain.astype(str) + '_' + data.region.astype(str)
    proposed_stratifier = 'strat'
    print('Colombia: Creating stratifier "strat" = domain + region')

print('proposed stratifier:', proposed_stratifier)
print('count per unit')
display(data.groupby(proposed_stratifier, observed=True).size().reset_index(name='count').sort_values('count', ascending=False))

print('weights per unit')
display(data.groupby(proposed_stratifier,  observed=True).hh_wgt.nunique().sort_values())

print('regions per weight class')
display(data.groupby('hh_wgt', observed=True)[proposed_stratifier].nunique().sort_values())

print('count per unit x weight')
display(
    data.groupby(['hh_wgt', proposed_stratifier],  observed=True)
    .size()
    .reset_index(name='count')
    .sort_values('count', ascending=False)
)

country: india
proposed stratifier: state
count per unit


Unnamed: 0,state,count
8,09. Uttar Pradesh,30239
25,27. Maharashtra,22759
18,19. West Bengal,18136
9,10. Bihar,17184
31,33. Tamil Nadu,14364
22,23. Madhya Pradesh,14197
7,08. Rajasthan,13162
27,29. Karnataka,12389
23,24. Gujarat,11286
26,28. Andhra Pradesh,10283


weights per unit


state
31. Lakshadweep (U.T.)                        36
25. Dadra & Nagar Haveli and Daman & Diu      51
37. Ladakh (U.T.)                             58
30. Goa                                       72
04. Chandigarh (U.T.)                         77
35. Andaman & Nicobar Islands (U.T.)         104
34. Puducherry (U.T.)                        112
11. Sikkim                                   202
17. Meghalaya                                250
05. Uttarakhand                              258
02. Himachal Pradesh                         265
15. Mizoram                                  337
07. Delhi                                    338
01. Jammu & Kashmir                          349
13. Nagaland                                 384
12. Arunachal Pradesh                        412
16. Tripura                                  430
14. Manipur                                  442
22. Chhattisgarh                             507
06. Haryana                                  526
20. Jharkhand 

regions per weight class


hh_wgt
472       1
142439    1
142428    1
142411    1
142402    1
         ..
136500    5
7920      5
117000    5
81900     5
161700    5
Name: state, Length: 23607, dtype: int64

count per unit x weight


Unnamed: 0,hh_wgt,state,count
2423,11830,12. Arunachal Pradesh,165
23,1130,12. Arunachal Pradesh,90
716,4805,12. Arunachal Pradesh,73
6056,52880,18. Assam,71
5438,44370,09. Uttar Pradesh,69
...,...,...,...
22737,200994,08. Rajasthan,1
22739,201041,19. West Bengal,1
22742,201094,08. Rajasthan,1
22744,201106,09. Uttar Pradesh,1


## Look at the columns

In [314]:
print(f'country: {country}')
if True:
    with pd.option_context('display.max_rows', 300):
        display(summary[~summary.variable_name.str.contains('_m')])

country: india


Unnamed: 0,variable_name,data_type,geographic_indicator,geographic_indicator_coarser,variable_description
0,hhid,categorical,False,False,hhid
1,sector,categorical,False,False,sector
2,state,categorical,True,True,state
3,nss_region,categorical,True,False,nss_region
4,district,categorical,True,False,district
5,hh_wgt,numeric,False,False,hh_wgt
6,head_age,numeric,False,False,head_age
7,head_gender,categorical,False,False,head_gender
9,head_education,categorical,False,False,head_education
10,hh_size,numeric,False,False,hh_size


## Split data, check with roshni's code

In [315]:
print(f'Country: {country}, country data path: {country_data_path}')
os.chdir(country_data_path)
subprocess.run(['bash', str(country_data_path / 'split.sh')], check=True)

Country: india, country data path: /data/eop/country_data/IND/cleaned


CompletedProcess(args=['bash', '/data/eop/country_data/IND/cleaned/split.sh'], returncode=0)

In [316]:
print(f'Country: {country}, country data path: {country_data_path}')

train_path = country_data_path / 'train.parquet'
test_path = country_data_path / 'test.parquet'
full_data_path = country_data_path / 'full.parquet'
assert len(pd.read_parquet(train_path)) + len(pd.read_parquet(test_path)) == len(pd.read_parquet(full_data_path))
summary_path = country_data_path / 'summary.parquet'
# Assert that train.parquet and test.parquet are newer than any other parquet file under country_data_path
all_parquet_files = list(country_data_path.glob('*.parquet'))
for f in all_parquet_files:
    if f not in [train_path, test_path]:
        assert train_path.stat().st_mtime > f.stat().st_mtime, f"{train_path.name} is not newer than {f.name}"
        assert test_path.stat().st_mtime > f.stat().st_mtime, f"{test_path.name} is not newer than {f.name}"

train_dataset, validation_dataset, test_covariate_dataset, test_dataset = load_datasets(
    trainpath = train_path,
    testpath = test_path,
    summarypath = summary_path,
    geo_extrapolation = True,
    outcome = 'consumption_per_capita_per_day',
    weight = 'headcount_adjusted_hh_wgt'
)
X, y, r = train_dataset.get_data()
X, X_mean, X_std = standardize(X)
y, y_mean, y_std = standardize(y)
model = LinearRegression(fit_intercept=True)
model.fit(X, y, sample_weight=r)
# Report in-sample R^2
r2_weighted = model.score(X, y, sample_weight=r)
print(f"In-sample weighted R^2: {r2_weighted:.4f}")
if r2_weighted > 0.9:
    raise AssertionError("In-sample weighted R^2 is suspiciously high (>0.9), please check for data leakage.")


Country: india, country data path: /data/eop/country_data/IND/cleaned
Top 5 categorical columns by distinct values:
state: 36
water_source: 17
max_adult_education: 13
max_female_education: 13
energy_cooking: 12
Top 5 categorical columns by distinct values:
state: 36
water_source: 17
max_adult_education: 13
max_female_education: 13
energy_cooking: 12
Top 5 categorical columns by distinct values:
state: 36
water_source: 17
max_adult_education: 13
max_female_education: 13
energy_cooking: 12
In-sample weighted R^2: 0.4545
