In [1]:
import jax.numpy as jnp
from functools import partial
import jax
from jax import jit, vmap

take a sample of the training set

In [19]:
import tensorflow_datasets as tfds
dataset = tfds.load('eurosat/all', split='train[:10%]', as_supervised=True)
for tfe_imgs, tfe_labels in dataset.batch(100_000):
    break
imgs = jnp.array(tfe_imgs)
del tfe_imgs
imgs.shape

(2700, 64, 64, 13)

construct a flattened version of the array, we only care about per channel stats

In [5]:
flat_imgs = imgs.reshape(-1, 13)

what is the max and 99.9th percentile per channel?

In [6]:
jnp.max(flat_imgs, axis=0)

DeviceArray([10388., 27661., 28000., 28000., 20748., 23288., 25600.,
             28002.,  5209.,   183., 16716., 16337., 27534.],            dtype=float32)

In [7]:
channel_p999s = jnp.percentile(flat_imgs, 99.9, axis=0)
channel_p999s

DeviceArray([2612., 3492., 3582., 3970., 3709., 4407., 5678., 5604.,
             2264.,   52., 4848., 3876., 6124.], dtype=float32)

definitely looks like the ~28000 values are sensor noise (?)
let's clip at the 99.9th percentile

In [8]:
@partial(vmap, in_axes=(1, 0), out_axes=1)
def clip_per_channel(x, a_max):    
    return jnp.clip(x, a_min=0, a_max=a_max)

In [9]:
clipped_flat_imgs = clip_per_channel(flat_imgs, channel_p999s)

In [10]:
jnp.max(clipped_flat_imgs, axis=0)

DeviceArray([2612., 3492., 3582., 3970., 3709., 4407., 5678., 5604.,
             2264.,   52., 4848., 3876., 6124.], dtype=float32)

and now calc channel mean/std for standardisation

In [11]:
@jit
def channel_means_stds(x):    
    means = jnp.mean(x, axis=0)
    stds = jnp.std(x, axis=0)
    return means, stds

channel_means, channel_stds = channel_means_stds(clipped_flat_imgs)
channel_means, channel_stds

(DeviceArray([1354.7904  , 1117.1971  , 1041.1869  ,  946.0516  ,
              1198.119   , 2003.2725  , 2375.2615  , 2302.0972  ,
               730.6099  ,   12.077799, 1821.995   , 1119.2013  ,
              2602.028   ], dtype=float32),
 DeviceArray([ 242.14961 ,  324.46646 ,  386.99976 ,  587.74664 ,
               565.23846 ,  859.7307  , 1086.1215  , 1116.8077  ,
               404.7259  ,    4.397278,  998.8627  ,  756.0413  ,
              1231.3727  ], dtype=float32))

can now roll the entire thing into a preprocess; clip -> standardise

In [15]:
@partial(vmap, in_axes=(1, 0, 0), out_axes=1)
def standardise_per_channel(x, mean, std):        
    return (x - mean) / std      

@jit
def preprocess(x): #, channel_p999s, channel_means, channel_stds):
    orig_shape = x.shape
    x = x.reshape(-1, 13)
    x = clip_per_channel(x, channel_p999s)
    x = standardise_per_channel(x, channel_means, channel_stds)    
    return x.reshape(orig_shape)

In [16]:
std_clipped_imgs = preprocess(imgs) #, channel_p999s, channel_means, channel_stds)
std_clipped_imgs.shape

(5400, 64, 64, 13)

In [17]:
channel_means_stds(std_clipped_imgs.reshape(-1, 13))

(DeviceArray([-5.15690566e-08,  1.16206984e-07,  5.12158449e-09,
               1.40137146e-07,  1.01107140e-07,  3.64692134e-08,
               5.69776262e-08, -1.21019511e-07,  1.09849154e-07,
               6.64039916e-08,  1.87202733e-08,  1.57091350e-07,
               8.04883484e-08], dtype=float32),
 DeviceArray([1.        , 1.0000001 , 1.        , 0.99999994, 1.        ,
              1.        , 1.        , 0.99999994, 1.        , 1.0000001 ,
              0.99999994, 0.9999999 , 1.        ], dtype=float32))