## Table of Contents

1. [Pre-requisites](#Pre-requisites)
2. [Instructions](#Instructions)
3. [Imports and Constants](#Imports-and-Constants)
4. [Validate and Split Exported TFRecords](#Validate-and-Split-Exported-TFRecords)
5. [Verify images](#Verify-images)
6. [Create final labels CSV file](#Create-final-labels-CSV-file)
7. [Tar and gzip the npz files](#Tar-and-gzip-the-npz-files)
8. [Calculate Mean and Std-Dev for Each Band](#Calculate-Mean-and-Std-Dev-for-Each-Band)

## Pre-requisites

Go through the [`2_export_tfrecords.ipynb`](./2_export_tfrecords.ipynb) notebook.

Before running this notebook, you should have the following directory structure:

```
dhs/dhs_tfrecords_raw/
    {DHSEA_ID}__to__{DHSEA_ID}.tfrecord.gz
    ...
```

## Instructions

This notebook processes the exported TFRecords as follows:
1. Verifies that the fields in the TFRecords match the original CSV files.
2. Splits each monolithic TFRecord file exported from Google Earth Engine into one numpy file per record.
3. Tar+Gzips the numpy files into sharded `.tar.gz` files roughly ~20GiB each.
4. Calculates the mean and standard deviation of each band for the DHS images.

After running this notebook, you should have the following directory structure:

```
dhs/dhs_npzs/
    {survey_name}/
        {DHSID_EA}.npz
```

- Storage space needed for processed `.npz` files: ~104 GiB
- Storage space needed for the `.tar.gz` files: ~93.5 GiB
- Expected processing time: ~8h

It may be convenient to directly run this notebook on Google Colab, especially if the TFRecords were exported to Google Drive instead of Google Cloud Storage. When doing so, uncomment the cell below which starts with

```python
from google.colab import drive
...
```

## Imports and Constants

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# from google.colab import drive
# drive.mount('/drive', force_remount=True)
# %cd '/drive/MyDrive/sustainbench'

In [None]:
!pwd

In [None]:
from __future__ import annotations

from collections.abc import Iterable
from collections import namedtuple
from glob import glob
import os
import shutil

import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm.auto import tqdm

print(tf.__version__)

In [None]:
REQUIRED_BANDS = [
    'BLUE', 'GREEN', 'RED', 'SWIR1', 'SWIR2', 'TEMP1', 'NIR', 'NIGHTLIGHTS']

DHS_EXPORT_FOLDER = 'dhs_tfrecords_raw'
DHS_PROCESSED_FOLDER = 'dhs_npzs'
DHS_INPUT_CSV_PATH = 'output_labels/merged.csv'  # CSV mapping DHSID_EA to ['lat', 'lon', labels]
DHS_FINAL_CSV_PATH = 'output_labels/dhs_final_labels.csv'

## Validate and Split Exported TFRecords

In [None]:
def process_dataset(csv_path: str, index_col: str, input_dir: str, processed_dir: str,
                    log_path: str) -> None:
    '''
    Args
    - csv_path: str, path to CSV of DHS or LSMS clusters
    - index_col: str, name of column in CSV to use as unique index
    - input_dir: str, path to TFRecords exported from Google Earth Engine
    - processed_dir: str, folder where to save processed TFRecords
    - log_path: str, path to log file
    '''
    df = pd.read_csv(csv_path, float_precision='high', index_col=index_col)

    # cast float64 => float32 and str => bytes
    for col in df.columns:
        if df[col].dtype == np.float64:
            df[col] = df[col].astype(np.float32)
        elif df[col].dtype == object:  # pandas uses 'object' type for str
            df[col] = df[col].astype(bytes)

    df['survey'] = df.index.str[:10]  # TODO: check if this works with LSMS
    surveys = list(df.groupby('survey').groups.keys())

    if os.path.exists(log_path):
        log = pd.read_csv(log_path, index_col=index_col)
    else:
        log = pd.DataFrame(index=pd.Index([], name=index_col),
                           columns=['status'])

    # use this list to track any surveys that have already been processed
    # (this is useful for processing surveys in batches)
    PROCESSED_SURVEYS = []

    pbar = tqdm()
    for i, survey in enumerate(surveys):
        tqdm.write(f'Processing: {survey}, ({i+1} / {len(surveys)})')

        if survey in PROCESSED_SURVEYS:
            tqdm.write(f'- Already processed')
            continue

        tfrecord_paths = glob(os.path.join(input_dir, survey + '*'))
        if len(tfrecord_paths) == 0:
            tqdm.write(f'- No TFRecords found')
            continue

        out_dir = os.path.join(processed_dir, survey)
        subset_df = df[df['survey'] == survey].sort_index()
        log_new = validate_and_split_tfrecords(
            tfrecord_paths=tfrecord_paths, out_dir=out_dir, df=subset_df,
            pbar=pbar)
        log = pd.concat([log, log_new], verify_integrity=True)
        log.to_csv(log_path)


DEFAULT = np.nan * np.ones(255**2)

def parse_record(ex, index_col: str):
    keys_to_features = {
        band: tf.io.FixedLenFeature(shape=[255**2], dtype=tf.float32,
                                    default_value=DEFAULT)
        for band in REQUIRED_BANDS
    }
    keys_to_features.update({
        'cluster_id': tf.io.FixedLenFeature([], tf.float32),
        'lat': tf.io.FixedLenFeature([], tf.float32),
        'lon': tf.io.FixedLenFeature([], tf.float32),
        index_col: tf.io.FixedLenFeature([], tf.string),
    })
    ex = tf.io.parse_single_example(ex, keys_to_features)
    ex['img'] = tf.stack([
        tf.reshape(ex[band], [255, 255])
        for band in REQUIRED_BANDS
    ])
    for band in REQUIRED_BANDS:
        del ex[band]
    return ex


def validate_and_split_tfrecords(
        tfrecord_paths: Iterable[str],
        index_col: str,
        out_dir: str,
        df: pd.DataFrame,
        pbar: tqdm | None = None
        ) -> None:
    '''Validates and splits a list of exported TFRecord files (for a
    given country-year survey) into individual TFrecords, one per cluster.

    "Validating" a TFRecord comprises of 2 parts
    1) verifying that it contains the required bands
    2) verifying that its other features match the values from the dataset CSV

    Args
    - tfrecord_paths: list of str, paths to exported TFRecords files
    - index_col: str, name of column in CSV to use as unique index
    - out_dir: str, path to dir to save processed individual TFRecords
    - df: pd.DataFrame, index is DHSID_EA
    - pbar: tqdm, optional progress bar
    '''
    processed = []  # GEE exported, all good to go!
    missing_bands = []  # GEE exported, but missing some bands
    missing_labels = []  # GEE exported, but missing labels in CSV
    no_record = []  # no TFRecord found

    # create a progress bar if not given one already
    should_close_pbar = False
    if pbar is None:
        pbar = tqdm()
        should_close_pbar = True
    pbar.reset(total=len(df))

    # flag for whether to create output directory
    should_make_out_dir = not os.path.exists(out_dir)

    ds = tf.data.TFRecordDataset(tfrecord_paths, compression_type='GZIP')
    ds = ds.map(lambda ex: parse_record(ex, index_col))
    for record in ds.as_numpy_iterator():
        uniq_id = record[index_col].decode()
        if record_id not in df.index:
            missing_labels.append(uniq_id)
        elif np.isnan(record['img']).any():
            missing_bands.append(uniq_id)
        else:
            # optional: compare feature map values against CSV values
            csv_feats = df.loc[uniq_id, :].to_dict()
            for col, val in csv_feats.items():
                if col in record and record[col] != val:
                    tqdm.write(f'- {uniq_id}: record[{col}] = {record[col]}, '
                               f'CSV val = {val}')

            if should_make_out_dir:
                os.makedirs(out_dir, exist_ok=True)
                should_make_out_dir = False

            save_path = os.path.join(out_dir, uniq_id)
            np.savez_compressed(save_path, x=record['img'])
            processed.append(uniq_id)
        pbar.update(1)

    if should_close_pbar:
        pbar.close()

    seen = missing_bands + missing_labels + processed
    expected = df.index.to_numpy()
    no_record = np.setdiff1d(expected, seen)

    log = pd.concat([
        pd.DataFrame(index=pd.Index(arr, name=index_col),
                     data={'status': status})
        for status, arr in [
            ('processed', processed),
            ('missing_bands', missing_bands),
            ('missing_labels', missing_labels),
            ('no_record', no_record)
        ]
    ], verify_integrity=True)
    return log

def check_log(csv_path: str, index_col: str, log_path: str) -> None:
    '''Validates and splits a list of exported TFRecord files (for a
    given country-year survey) into individual TFrecords, one per cluster.

    "Validating" a TFRecord comprises of 2 parts
    1) verifying that it contains the required bands
    2) verifying that its other features match the values from the dataset CSV

    Args
    - csv_path: str, path to labels CSV, columns include [index_col]
    - index_col: str, name of column in CSV to use as unique index
    - log_path: str, path to log CSV, columns are [index_col, 'status']
    '''
    df = pd.read_csv(csv_path, index_col=False)
    df.set_index(index_col, inplace=True, verify_integrity=True)
    df['survey'] = df.index.str[:10]  # TODO: check if this works with LSMS
    labeled_surveys = df['survey'].unique()

    log = pd.read_csv(log_path, index_col=False)
    log.set_index(index_col, inplace=True, verify_integrity=True)
    log['survey'] = log.index.str[:10]  # TODO: check if this works with LSMS
    logged_surveys = log['survey'].unique()
    print('logged surveys not in labels:',
          sorted(np.setdiff1d(logged_surveys, labeled_surveys)))
    print('labeled surveys not in log:',
          sorted(np.setdiff1d(labeled_surveys, logged_surveys)))

    # get list of processed npzs which aren't in the labels CSV
    all_labeled_clusters = df.index
    all_processed_clusters = log[log['status'] == 'processed'].index
    unlabeled_npzs = sorted(set(all_processed_clusters) - set(all_labeled_clusters))
    print('num npzs missing labels:', len(unlabeled_npzs))
    print('npzs missing labels:', unlabeled_npzs)

    # for each survey in the log, check that the uniq_id's from the labels CSV
    # are all in the log
    for survey in logged_surveys:
        label_ids = df.loc[df['survey'] == survey].index
        log_ids = log.loc[log['survey'] == survey].index
        assert label_ids.isin(log_ids).all()

    # use some jupyter magic to get a list of empty directories
    # only surveys where no images were properly processed should be empty
    # TODO: update for LSMS
    empty_dirs = !find dhs_npzs -type d -empty
    for empty_dir in empty_dirs:
        survey = empty_dir.split('/')[1]
        if survey in logged_surveys:
            assert ((log['survey'] == survey) &
                    (log['status'] == 'processed')).sum() == 0
            print(f'survey {survey} has nothing processed')
        elif survey not in labeled_surveys:
            print(f'survey {survey} should not exist')

    # check that every processed image was actually in the labels CSV

    print('=== breakdown by status ===')
    display(log.groupby('status').size())

    incomplete_surveys = log.loc[log['status'] != 'processed', 'survey'].unique()
    not_processed_sizes = (
        log.loc[log['survey'].isin(incomplete_surveys)]
        .groupby(['survey', 'status'])
        .size()
    )
    display(not_processed_sizes.unstack().astype(pd.Int64Dtype()))

    empty_surveys = log.groupby('survey').filter(lambda df: (df['status'] != 'processed').all())
    print('surveys without any processed:', empty_surveys['survey'].unique())

In [None]:
ds = process_dataset(
    csv_path=DHS_INPUT_CSV_PATH,
    index_col='DHSID_EA',
    input_dir=DHS_EXPORT_FOLDER,
    processed_dir=DHS_PROCESSED_FOLDER,
    log_path='dhs_tfrecords_export_log.csv')

In [None]:
check_log(csv_path=DHS_INPUT_CSV_PATH,
          index_col='DHSID_EA',
          log_path='dhs_tfrecords_export_log.csv')

## Verify images

Randomly sample 20 `.npz` files and plot them vs. expected images from Google Earth Engine.

In [None]:
import ee  # earthengine
from IPython.display import Image
import matplotlib.pyplot as plt
import PIL

try:
    # if already authenticated, can directly intiialize the Earth Engine API
    ee.Initialize()
except:
    # otherwise, authenticate first, then initialize
    ee.Authenticate()
    ee.Initialize()

In [None]:
def ee_viz(lat: float, lon: float, year: int) -> None:
    """Displays 255x255px Landsat 5/7/8 surface reflectance image tiles
    (3-year median composite tile) centered on the given lat/lon coordinates.

    This image will not be as "clean" as the GEE-exported composites because
    here we do not do any fancy cloud masking / QA control.
    """
    # get 255x255px box around (lat, lon)
    res = 30  # meters per pixel
    radius = 255 / 2.0  # radius of image in pixels
    pt = ee.Geometry.Point([lon, lat])
    roi = pt.buffer(radius * res).bounds()

    SatParam = namedtuple(
        'SatParam', ['col_name', 'min_year', 'max_year', 'rgb_bands', 'scale'])
    params = {
        'Landsat 5': SatParam(
            col_name='LANDSAT/LT05/C01/T1_SR', min_year=1984, max_year=2012,
            rgb_bands=['B3', 'B2', 'B1'], scale=0.0001),
        'Landsat 7': SatParam(
            col_name='LANDSAT/LE07/C01/T1_SR', min_year=1999, max_year=2021,
            rgb_bands=['B3', 'B2', 'B1'], scale=0.0001),
        'Landsat 8': SatParam(
            col_name='LANDSAT/LC08/C01/T1_SR', min_year=2013, max_year=2021,
            rgb_bands=['B4', 'B3', 'B2'], scale= 0.0001)
    }

    # these values empirically seem to work well for L7 and L8 images
    vis_params = {
        'min': 0,     # becomes 0 in RGB
        'max': 0.35,  # becomes 255 in RGB
        # 'gamma': 2.5  # set between [1, 2.5] to match your own aesthetic
    }

    img_urls = {}
    for name, sat in params.items():
        if (year + 1 < sat.min_year) or (year - 1 > sat.max_year):
            continue

        # get Landsat surface reflectance image
        start_year = max(sat.min_year, year - 1)
        end_year = min(sat.max_year, year + 1)
        start = f'{start_year}-01-01'
        end = f'{end_year}-12-31'
        img = (
            ee.ImageCollection(sat.col_name)
            .filterDate(start, end)
            .select(sat.rgb_bands)
            .median()
            .multiply(sat.scale)
        )

        # Create a URL to the image, and display it
        url = img.getThumbUrl(
            {**vis_params, 'dimensions': 255, 'region': roi})
        print(name)
        display(Image(url=url))


def npz_viz(npz_path: str) -> None:
    '''Visualizes the RGB and NL bands of a satellite image stored in a
    .npz file.

    Note: GEE images are exported with (0,0) = lower-left corner. By default,
    matplotlib's plt.imshow() and PIL.Image assume (0,0) = upper-left corner.
    '''
    with open(npz_path, 'rb') as f:
        img = np.load(f)['x']
    rgb = np.stack([img[2], img[1], img[0]], axis=2)
    nl = img[-1]

    # rescale to (0, 1)
    # cutoff = rgb.max()  # using max img intensity
    # cutoff = 0.35  # using hard cutoff
    cutoff = min(0.35, (rgb.max() + 0.35) / 2)  # hybrid cutoff
    rgb = np.clip(rgb / cutoff, a_min=0, a_max=1)

    cutoff = min(100, (100 + nl.max()) / 4)  # hybrid cutoff
    nl = np.clip(nl / cutoff, a_min=0, a_max=1)

    # Option 1: matplotlib imshow()
    # plt.imshow(rgb, origin='lower')
    # plt.show()

    # Option 2: display(PIL.Image)
    im255 = np.uint8(rgb * 255)
    im = PIL.Image.fromarray(im255[::-1], mode='RGB')
    display(im)

    nl255 = np.uint8(nl * 255)
    im = PIL.Image.fromarray(nl255[::-1], mode='L')
    display(im)

In [None]:
rng = np.random.default_rng(seed=123)
num_samples = 20

surveys = os.listdir(DHS_PROCESSED_FOLDER)
dhs_csv = pd.read_csv(DHS_INPUT_CSV_PATH, index_col='DHSID_EA')

for i in range(num_samples):
    # sample a survey
    survey = rng.choice(surveys)

    # sample a npz file
    npz_filenames = os.listdir(os.path.join(DHS_PROCESSED_FOLDER, survey))
    npz_filename = rng.choice(npz_filenames)
    npz_path = os.path.join(DHS_PROCESSED_FOLDER, survey, npz_filename)
    print(f'===== {npz_path} =====')

    dhsid_ea = os.path.splitext(os.path.basename(npz_path))[0]
    lat, lon, year = dhs_csv.loc[dhsid_ea, ['lat', 'lon', 'year']]
    print(lat, lon, year)

    # display the RGB bands from the NPZ
    npz_viz(npz_path)

    # compare against the expected GEE image
    ee_viz(lat, lon, year)

    if i != num_samples - 1:
        print()

In [None]:
# this is the example cluster we use in the figures
dhs_csv.loc['PE-2004-5#-00000969']

## Create final labels CSV file

Some clusters in the input CSV file might not have a downloaded image. This section removes the clusters without images and outputs a final labels CSV file.

In [None]:
def create_final_labels(labels_path: str,
                        export_log_path: str,
                        index_col: str,
                        final_csv_path: str) -> None:
    '''
    Args
    - csv_path: str, path to CSV of DHS or LSMS cluster labels
    - export_log_path: str, path to CSV log of processing TFRecords
    - index_col: str, name of column in CSV to use as unique index
    - final_csv_path: str, path to save final labels CSV
    '''
    export_log = pd.read_csv(export_log_path)
    export_log.set_index(index_col, verify_integrity=True, inplace=True)

    labels = pd.read_csv(labels_path)
    labels.set_index(index_col, verify_integrity=True, inplace=True)

    assert labels.index.isin(export_log.index).all()

    failed_exports = export_log[export_log['status'] != 'processed'].index
    num_failed_labels = labels.index.isin(failed_exports).sum()
    if num_failed_labels > 0:
        print(f'Failed to download images for {num_failed_labels} clusters.')
        print('Removing those clusters to create final labels CSV.')

        success_exports = export_log[export_log['status'] == 'processed']
        final_labels = success_exports.merge(
            labels, how='inner', left_index=True, right_index=True)
        del final_labels['status']
        final_labels.to_csv(final_csv_path)
    else:
        print('Images were exported for all labels!')
        shutil.copy2(labels_path, final_csv_path)

In [None]:
create_final_labels(labels_path=DHS_INPUT_CSV_PATH,
                    export_log_path='dhs_tfrecords_export_log.csv',
                    index_col='DHSID_EA',
                    final_csv_path=DHS_FINAL_CSV_PATH)

## Tar and gzip the npz files

The `.tar.gz` files are sharded to take up ~20GiB (20 * 2^30 bytes). The files are sharded such that all `.npz` files from a country are in the same shard.

In [None]:
from io import StringIO
from pprint import pprint

# use `du` to estimate size of npzs from each survey
# - this overcounts, because there are some processed npzs that might not get
#   get included in the final verison
folder_sizes = !du -h -m -d 1 dhs_npzs/* | sed 's/\t/,/g'
sizes = pd.read_csv(StringIO('\n'.join(folder_sizes)), names=['MiB', 'folder'])
sizes['country'] = sizes['folder'].str[9:11]
sizes_by_country = sizes.groupby('country')['MiB'].sum().sort_index().to_frame()
display(sizes_by_country.head())

In [None]:
# determine country grouping
max_size = 20 * 2**10   # in MiB
cum_size = 0
groups = []  # list of ([list of country codes], size in MiB)
group = []
for country in sorted(sizes_by_country.index):
    size = sizes_by_country.loc[country, 'MiB']
    if cum_size > 0 and cum_size + size > max_size:
        groups.append((group, cum_size))
        cum_size = 0
        group = []

    cum_size += size
    group.append(country)

groups.append((group, cum_size))
print(groups)

In [None]:
# for each group, create a list of files to tar-gzip
dhs_final = pd.read_csv(DHS_FINAL_CSV_PATH)
dhs_final['survey'] = dhs_final['DHSID_EA'].str[:10]
for i, group in enumerate(groups):
    cnames = group[0]
    files_list = dhs_final.loc[dhs_final['cname'].isin(cnames), ['survey', 'DHSID_EA']]
    files_list = files_list['survey'] + '/' + files_list['DHSID_EA'] + '.npz'
    files_list.sort_values(inplace=True)
    display(files_list)
    files_list.to_csv(f'dhs_tar_list_{cnames[0]}_{cnames[-1]}.txt', index=False, header=False)

In [None]:
# !tar -czvf dhs_AL_DR.tar.gz -C dhs_npzs -T dhs_tar_list_AL_DR.txt
# !tar -czvf dhs_EG_HT.tar.gz -C dhs_npzs -T dhs_tar_list_AL_DR.txt
# !tar -czvf dhs_IA_IA.tar.gz -C dhs_npzs -T dhs_tar_list_AL_DR.txt
# !tar -czvf dhs_ID_MZ.tar.gz -C dhs_npzs -T dhs_tar_list_AL_DR.txt
# !tar -czvf dhs_NG_SZ.tar.gz -C dhs_npzs -T dhs_tar_list_AL_DR.txt
# !tar -czvf dhs_TD_ZW.tar.gz -C dhs_npzs -T dhs_tar_list_AL_DR.txt

In [None]:
# to extract the tar.gz files
# !tar -xzvf dhs_AL_DR.tar.gz -C <output_dir>

## Calculate Mean and Std-Dev for Each Band

The means and standard deviations calculated here are saved as constants in `sustainbench/datasets/poverty_dataset.py` for `_MEANS_DHS`, `_STD_DEVS_DHS`, `_MEANS_LSMS`, and `_STD_DEVS_LSMS`.

In [None]:
from concurrent.futures import ThreadPoolExecutor

def calculate_band_means(path_and_year) -> tuple[np.ndarray, np.ndarray, int]:
    '''
    Args
    - path_year: tuple (path, year)
      - path: str, path to npz file containing single entry 'x'
        representing a (C, H, W) image
      - year: int

    Returns: (means, year)
    - sums: np.ndarray, shape [C], sum of values for each band
    - sum_sqs: np.ndarray, shape [C], sum of squares of values for each band
    - year: int
    '''
    npz_path, year = path_and_year
    img = np.load(npz_path)['x']
    sums = np.sum(img, axis=(1, 2), dtype=np.float64)
    sum_sqs = np.sum(img ** 2, axis=(1, 2), dtype=np.float64)
    return sums, sum_sqs, year

In [None]:
dhs_final = pd.read_csv(DHS_FINAL_CSV_PATH, index_col='DHSID_EA')
dhs_final['path'] = (
    DHS_PROCESSED_FOLDER + '/' +
    dhs_final.index.str[:10] + '/' +
    dhs_final.index + '.npz'
)
path_years = dhs_final[['path', 'year']].apply(tuple, axis=1)

sums_dmsp = []
sum_sqs_dmsp = []
sums_viirs = []
sum_sqs_viirs = []

with ThreadPoolExecutor(max_workers=30) as pool:
    inputs = path_years
    futures = pool.map(calculate_band_means, inputs)
    for sums, sum_sqs, year in tqdm(futures, total=len(inputs)):
        if year < 2012:
            sums_dmsp.append(sums)
            sum_sqs_dmsp.append(sum_sqs)
        else:
            sums_viirs.append(sums)
            sum_sqs_viirs.append(sum_sqs)

In [None]:
sums_all = np.stack(sums_dmsp + sums_viirs)
sum_sqs_all = np.stack(sum_sqs_dmsp + sum_sqs_viirs)

sums_dmsp = np.stack(sums_dmsp)
sum_sqs_dmsp = np.stack(sum_sqs_dmsp)
sums_viirs = np.stack(sums_viirs)
sum_sqs_viirs = np.stack(sum_sqs_viirs)

In [None]:
# calculate means
num_pixels_per_img = 255 * 255
band_means = sums_all.mean(axis=0) / num_pixels_per_img
dmsp_mean = sums_dmsp[:, -1].mean() / num_pixels_per_img
viirs_mean = sums_viirs[:, -1].mean() / num_pixels_per_img

MEANS = {
    band: np.float32(band_means[i])
    for i, band in enumerate(REQUIRED_BANDS)
}
MEANS['DMSP'] = dmsp_mean
MEANS['VIIRS'] = viirs_mean
display(MEANS)

In [None]:
# calculate standard deviations
# population std-dev
# = sqrt( E[X**2] - E[X]**2 )
# = sqrt( sum_sqs / N - mean**2 )
band_sd = np.sqrt( sum_sqs_all.mean(axis=0) / num_pixels_per_img - band_means**2 )
dmsp_sd = np.sqrt( sum_sqs_dmsp[:, -1].mean() / num_pixels_per_img - dmsp_mean**2 )
viirs_sd = np.sqrt( sum_sqs_viirs[:, -1].mean() / num_pixels_per_img - viirs_mean**2 )

STD_DEVS = {
    band: np.float32(band_sd[i])
    for i, band in enumerate(REQUIRED_BANDS)
}
STD_DEVS['DMSP'] = dmsp_sd
STD_DEVS['VIIRS'] = viirs_sd
display(STD_DEVS)