In [17]:
import numpy as np
import tensorflow as tf
import os

In [18]:
TFRECORDS_DIR = '../data/lsms_tfrecords/'

In [19]:
BANDS = ['BLUE', 'GREEN', 'RED', 'NIR', 'SW_IR1', 'SW_IR2', 'TEMP', 'avg_rad']

In [20]:
def create_single_feature_set(filename):
    record = tf.data.TFRecordDataset(filenames=[filename])
    feature_set = parse_features(record=record)
    
    return feature_set

def parse_features(record):
    raw_example = next(iter(record)) 
    example = tf.train.Example.FromString(raw_example.numpy())
    
    return example.features.feature

In [64]:
tfrecords = sorted([f for f in os.listdir(TFRECORDS_DIR) if not f.startswith('.')])
for i, file in enumerate(tfrecords):
    tfrecords[i] = os.path.join('../data/lsms_tfrecords', file)
num_records = len(tfrecords)

def calculate_band_means():
    band_means = {'BLUE': 0, 'GREEN': 0, 'RED': 0, 'NIR': 0, 'SW_IR1': 0, 'SW_IR2': 0, 'TEMP': 0, 'avg_rad': 0}

    for tf_r in tfrecords:
        feature_set = create_single_feature_set(filename=tf_r)

        for band_name in BANDS:
            band = np.array(feature_set[band_name].float_list.value, dtype=np.float32).reshape(255, 255)
            band_means[band_name] += np.mean(band[:,:])
            
    band_means = {key: value / num_records for key, value in band_means.items()}
    
    return band_means        

def calculate_band_stdevs(band_means):
    band_stdevs = {'BLUE': 0, 'GREEN': 0, 'RED': 0, 'NIR': 0, 'SW_IR1': 0, 'SW_IR2': 0, 'TEMP': 0, 'avg_rad': 0}
    
    for tf_r in tfrecords:
        feature_set = create_single_feature_set(filename=tf_r)

        for band_name in BANDS:
            band = np.array(feature_set[band_name].float_list.value, dtype=np.float32).reshape(255, 255)
            band_stdevs[band_name] += ((band - band_means[band_name])**2).sum()/(band.shape[0]*band.shape[1])
            
    band_stdevs = {key: np.sqrt(value / num_records) for key, value in band_stdevs.items()}
    
    return band_stdevs
    

In [59]:
band_means = calculate_band_means()

In [65]:
band_stdevs = calculate_band_stdevs(band_means)

In [66]:
print(f'Band Means: {band_means}\n')
print(f'Band Means: {band_stdevs}\n')

Band Means: {'BLUE': 0.05720699718743952, 'GREEN': 0.09490949383988444, 'RED': 0.11647556706520566, 'NIR': 0.25043694995276194, 'SW_IR1': 0.2392968657712096, 'SW_IR2': 0.17881930908670116, 'TEMP': 309.4823962960872, 'avg_rad': 1.8277193893627437}

Band Means: {'BLUE': 0.02379879403788589, 'GREEN': 0.03264212296594092, 'RED': 0.050468921297598834, 'NIR': 0.04951648377311826, 'SW_IR1': 0.07332469136800321, 'SW_IR2': 0.07090649886221509, 'TEMP': 6.000001012494749, 'avg_rad': 4.328436715534132}

