<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Pre-requisites" data-toc-modified-id="Pre-requisites-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Pre-requisites</a></span></li><li><span><a href="#Instructions" data-toc-modified-id="Instructions-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Instructions</a></span></li><li><span><a href="#Imports-and-Constants" data-toc-modified-id="Imports-and-Constants-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Imports and Constants</a></span></li><li><span><a href="#Validate-and-Split-Exported-TFRecords" data-toc-modified-id="Validate-and-Split-Exported-TFRecords-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Validate and Split Exported TFRecords</a></span></li><li><span><a href="#Calculate-Mean-and-Std-Dev-for-Each-Band" data-toc-modified-id="Calculate-Mean-and-Std-Dev-for-Each-Band-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Calculate Mean and Std-Dev for Each Band</a></span></li></ul></div>

## Pre-requisites

Go through the [`preprocessing/0_export_tfrecords.ipynb`](./0_export_tfrecords.ipynb) notebook.

Before running this notebook, you should have the following structure under the `data/` directory:

```
data/
    dhs_tfrecords_raw/
        angola_2011_00.tfrecord.gz
        ...
        zimbabwe_2015_XX.tfrecord.gz
    dhsnl_tfrecords_raw/
        angola_2010_00.tfrecord.gz
        ...
        zimbabwe_2016_XX.tfrecord.gz
    lsms_tfrecords_raw/
        ethiopia_2011_00.tfrecord.gz
        ...
        uganda_2013_XX.tfrecord.gz
```

## Instructions

This notebook processes the exported TFRecords as follows:
1. Verifies that the fields in the TFRecords match the original CSV files.
2. Splits each monolithic TFRecord file exported from Google Earth Engine into one file per record.

After running this notebook, you should have three new folders (`dhs_tfrecords`, `dhsnl_tfrecords`, and `lsms_tfrecords`) under `data/`:

```
data/
    dhs_tfrecords/
        angola_2011/
            00000.tfrecord.gz
            ...
            00229.tfrecord.gz
        ...
        zimbabwe_2015/
            00000.tfrecord.gz
            ...
            00399.tfrecord.gz
    dhsnl_tfrecords/
        angola_2010/
            00000.tfrecord.gz
            ...
            07734.tfrecord.gz
        zimbabwe_2016/
            00000.tfrecord.gz
            ...
            03584.tfrecord.gz
    lsms_tfrecords/
        ethiopia_2011/
            00000.tfrecord.gz
            ...
            00326.tfrecord.gz
        uganda_2013/
            00000.tfrecord.gz
            ...
            00164.tfrecord.gz
```

This notebook also calculates the mean and standard deviation of each band across each of the 3 datasets.

## Imports and Constants

In [None]:
%load_ext autoreload
%autoreload 2

# change directory to repo root, and verify
%cd '../'
!pwd

In [None]:
from __future__ import annotations

from collections.abc import Iterable
from glob import glob
from pprint import pprint
import os
from typing import Optional

import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm.auto import tqdm

from batchers import batcher, tfrecord_paths_utils
from preprocessing.helper import (
    analyze_tfrecord_batch,
    per_band_mean_std,
    print_analysis_results)

In [None]:
REQUIRED_BANDS = [
    'BLUE', 'GREEN', 'LAT', 'LON', 'NIGHTLIGHTS', 'NIR', 'RED',
    'SWIR1', 'SWIR2', 'TEMP1']

BANDS_ORDER = [
    'BLUE', 'GREEN', 'RED', 'SWIR1', 'SWIR2', 'TEMP1', 'NIR',
    'DMSP', 'VIIRS']

DHS_EXPORT_FOLDER = 'data/dhs_tfrecords_raw'
DHSNL_EXPORT_FOLDER = 'data/dhsnl_tfrecords_raw'
LSMS_EXPORT_FOLDER = 'data/lsms_tfrecords_raw'

DHS_PROCESSED_FOLDER = 'data/dhs_tfrecords'
DHSNL_PROCESSED_FOLDER = 'data/dhsnl_tfrecords'
LSMS_PROCESSED_FOLDER = 'data/lsms_tfrecords'

## Validate and Split Exported TFRecords

In [None]:
def process_dataset(csv_path: str, input_dir: str, processed_dir: str) -> None:
    '''
    Args
    - csv_path: str, path to CSV of DHS or LSMS clusters
    - input_dir: str, path to TFRecords exported from Google Earth Engine
    - processed_dir: str, folder where to save processed TFRecords
    '''
    df = pd.read_csv(csv_path, float_precision='high', index_col=False)
    surveys = list(df.groupby(['country', 'year']).groups.keys())  # (country, year) tuples

    for country, year in surveys:
        country_year = f'{country}_{year}'
        print('Processing:', country_year)

        tfrecord_paths = glob(os.path.join(input_dir, country_year + '*'))
        out_dir = os.path.join(processed_dir, country_year)
        os.makedirs(out_dir, exist_ok=True)
        subset_df = df[(df['country'] == country) & (df['year'] == year)].reset_index(drop=True)
        validate_and_split_tfrecords(
            tfrecord_paths=tfrecord_paths, out_dir=out_dir, df=subset_df)


def validate_and_split_tfrecords(
        tfrecord_paths: Iterable[str],
        out_dir: str,
        df: pd.DataFrame
        ) -> None:
    '''Validates and splits a list of exported TFRecord files (for a
    given country-year survey) into individual TFrecords, one per cluster.

    "Validating" a TFRecord comprises of 2 parts
    1) verifying that it contains the required bands
    2) verifying that its other features match the values from the dataset CSV

    Args
    - tfrecord_paths: list of str, paths to exported TFRecords files
    - out_dir: str, path to dir to save processed individual TFRecords
    - df: pd.DataFrame, index is sequential and starts at 0
    '''
    # Create an iterator over the TFRecords file. The iterator yields
    # the binary representations of Example messages as strings.
    options = tf.io.TFRecordOptions(tf.io.TFRecordCompressionType.GZIP)

    # cast float64 => float32 and str => bytes
    for col in df.columns:
        if df[col].dtype == np.float64:
            df[col] = df[col].astype(np.float32)
        elif df[col].dtype == object:  # pandas uses 'object' type for str
            df[col] = df[col].astype(bytes)

    i = 0
    progbar = tqdm(total=len(df))

    for tfrecord_path in tfrecord_paths:
        iterator = tf.io.tf_record_iterator(tfrecord_path, options=options)
        for record_str in iterator:
            # parse into an actual Example message
            ex = tf.train.Example.FromString(record_str)
            feature_map = ex.features.feature

            # verify required bands exist
            for band in REQUIRED_BANDS:
                assert band in feature_map, f'Band "{band}" not in record {i} of {tfrecord_path}'

            # compare feature map values against CSV values
            csv_feats = df.loc[i, :].to_dict()
            for col, val in csv_feats.items():
                ft_type = feature_map[col].WhichOneof('kind')
                ex_val = feature_map[col].__getattribute__(ft_type).value[0]
                assert val == ex_val, f'Expected {col}={val}, but found {ex_val} instead'

            # serialize to string and write to file
            out_path = os.path.join(out_dir, f'{i:05d}.tfrecord.gz')  # all surveys have < 1e6 clusters
            with tf.io.TFRecordWriter(out_path, options=options) as writer:
                writer.write(ex.SerializeToString())

            i += 1
            progbar.update(1)
    progbar.close()

In [None]:
process_dataset(
    csv_path='data/dhs_clusters.csv',
    input_dir=DHS_EXPORT_FOLDER,
    processed_dir=DHS_PROCESSED_FOLDER)

In [None]:
process_dataset(
    csv_path='data/dhsnl_locs.csv',
    input_dir=DHSNL_EXPORT_FOLDER,
    processed_dir=DHSNL_PROCESSED_FOLDER)

In [None]:
process_dataset(
    csv_path='data/lsms_clusters.csv',
    input_dir=LSMS_EXPORT_FOLDER,
    processed_dir=LSMS_PROCESSED_FOLDER)

## Verify the Individual TFRecord Files (Optional)

Check that the label, location, and year values in each individual TFRecord file match the original CSV.

In [None]:
def validate_individual_tfrecords(tfrecord_paths: Iterable[str],
                                  csv_path: str,
                                  label_name: Optional[str] = None) -> None:
    '''
    Args
    - tfrecord_paths: list of str, paths to individual TFRecord files
        in the same order as in the CSV
    - csv_path: str, path to CSV file with columns ['lat', 'lon', 'wealthpooled', 'year']
    '''
    df = pd.read_csv(csv_path, float_precision='high', index_col=False)
    iter_init, batch_op = batcher.Batcher(
        tfrecord_files=tfrecord_paths,
        label_name=label_name,
        ls_bands=None,
        nl_band=None,
        batch_size=128,
        shuffle=False,
        augment=False,
        clipneg=False,
        normalize=None).get_batch()

    locs, years = [], []
    if label_name is not None:
        labels = []

    num_processed = 0
    with tf.Session() as sess:
        sess.run(iter_init)
        while True:
            try:
                if label_name is not None:
                    batch_np = sess.run((batch_op['locs'], batch_op['years'], batch_op['labels']))
                    labels.append(batch_np[2])
                else:
                    batch_np = sess.run((batch_op['locs'], batch_op['years']))
                locs.append(batch_np[0])
                years.append(batch_np[1])
                num_processed += len(batch_np[0])
                print(f'\rProcessed {num_processed} images', end='')
            except tf.errors.OutOfRangeError:
                break
    print()

    locs = np.concatenate(locs)
    years = np.concatenate(years)
    assert (locs == df[['lat', 'lon']].to_numpy(dtype=np.float32)).all()
    assert (years == df['year'].to_numpy(dtype=np.float32)).all()
    if label_name is not None:
        labels = np.concatenate(labels)
        assert (labels == df['wealthpooled'].to_numpy(dtype=np.float32)).all()

In [None]:
validate_individual_tfrecords(
    tfrecord_paths=tfrecord_paths_utils.dhs(),
    csv_path='data/dhs_clusters.csv',
    label_name='wealthpooled')

In [None]:
validate_individual_tfrecords(
    tfrecord_paths=tfrecord_paths_utils.dhsnl(),
    csv_path='data/dhsnl_locs.csv')

In [None]:
validate_individual_tfrecords(
    tfrecord_paths=tfrecord_paths_utils.lsms(),
    csv_path='data/lsms_clusters.csv')

## Calculate Mean and Std-Dev for Each Band

The means and standard deviations calculated here are saved as constants in `batchers/dataset_constants.py` for `_MEANS_DHS`, `_STD_DEVS_DHS`, `_MEANS_LSMS`, and `_STD_DEVS_LSMS`.

In [None]:
def calculate_mean_std(tfrecord_paths):
    '''Calculates and prints the per-band means and std-devs'''
    iter_init, batch_op = batcher.Batcher(
        tfrecord_files=tfrecord_paths,
        label_name=None,
        ls_bands='ms',
        nl_band='merge',
        batch_size=128,
        shuffle=False,
        augment=False,
        clipneg=False,
        normalize=None).get_batch()

    stats = analyze_tfrecord_batch(
        iter_init, batch_op, total_num_images=len(tfrecord_paths),
        nbands=len(BANDS_ORDER), k=10)
    means, stds = per_band_mean_std(stats=stats, band_order=BANDS_ORDER)

    print('Means:')
    pprint(means)
    print()

    print('Std Devs:')
    pprint(stds)

    print('\n========== Additional Per-band Statistics ==========\n')
    print_analysis_results(stats, BANDS_ORDER)

In [None]:
calculate_mean_std(tfrecord_paths_utils.dhs())

In [None]:
calculate_mean_std(tfrecord_paths_utils.dhsnl())

In [None]:
calculate_mean_std(tfrecord_paths_utils.lsms())