In [0]:
import tensorflow as tf
import numpy as np
import time

tf.enable_eager_execution()

In [18]:
# For sequences

def online_statistics_sequences(data):
  n_samples = 0
  n_all, mean_all, var_all, m2_all = 0, 0, 0, 0
  n_channel, mean_channel, var_channel, m2_channel = 0, 0, 0, 0
  min_all, max_all = np.inf, -np.inf
  min_seq_len, max_seq_len = np.inf, -np.inf

  for x in data: 
    n_samples += 1
    seq_len, feature_size = x.shape

    # Global mean&var
    n_all += seq_len*feature_size
    delta_all = x - mean_all
    mean_all = mean_all + delta_all.sum()/n_all
    m2_all = m2_all + (delta_all*(x - mean_all)).sum()

    # Channelwise mean&var
    n_channel += seq_len
    delta_channel = x - mean_channel
    mean_channel = mean_channel + delta_channel.sum(axis=0)/n_channel
    m2_channel = m2_channel + (delta_channel*(x - mean_channel)).sum(axis=0)

    # Global min&max values.
    min_all = np.min(x) if np.min(x) < min_all else min_all
    max_all = np.max(x) if np.max(x) > max_all else max_all
    
    # Min&max sequence length.
    min_seq_len = seq_len if seq_len < min_seq_len else min_seq_len
    max_seq_len = seq_len if seq_len > max_seq_len else max_seq_len

  var_all = m2_all/(n_all-1)
  var_channel = m2_channel/(n_channel-1)
  stats = dict(mean_all=mean_all, var_all=var_all, mean_channel=mean_channel, var_channel=var_channel, min_all=min_all, max_all=max_all, min_seq_len=min_seq_len, max_seq_len=max_seq_len, n_samples=n_samples)
  return stats

def offline_statistics_sequences(data):
  """
  data can be list of data samples (in variable length) or a numpy tensor with
  samples listed in the first axis.
  """
  all_samples = np.vstack(data)

  mean_all = all_samples.mean()
  var_all = all_samples.var(ddof=1)
  mean_channel = all_samples.mean(axis=0)
  var_channel = all_samples.var(axis=0, ddof=1)

  min_all = all_samples.min()
  max_all = all_samples.max()

  seq_lens = np.array([x.shape[0] for x in data])
  min_seq_len = seq_lens.min()
  max_seq_len = seq_lens.max()
  n_samples = len(data)

  stats = dict(mean_all=mean_all, var_all=var_all, mean_channel=mean_channel, var_channel=var_channel, min_all=min_all, max_all=max_all, min_seq_len=min_seq_len, max_seq_len=max_seq_len, n_samples=n_samples)
  return stats

# Quick Test
def log_stats(start_time, stats, tag="Online"):
  print("--- %s seconds ---" % (time.time() - start_time))
  print("[{2}] mean: {0}, std: {1}".format(stats["mean_all"], stats["var_all"], tag))
  print("[{2}] mean channel: {0}, std channel: {1}".format(stats["mean_channel"], stats["var_channel"], tag))
  print("[{1}] # samples: {0}".format(stats["n_samples"], tag))
  print("[{2}] min value: {0}, max value: {1}".format(stats["min_all"], stats["max_all"], tag))
  print("[{2}] min length: {0}, max length: {1}".format(stats["min_seq_len"], stats["max_seq_len"], tag))
  print("============")

num_samples = 10000
seq_len = 100
feature_size = 3
batch_size = 16
num_epochs = 2
eval_frequency = 2  # in number of training steps.

# samples = np.concatenate([np.random.normal(0, 10, (num_samples, seq_len, 1)), np.random.normal(30, 5, (num_samples, seq_len, 2))], axis=-1)
samples = [np.random.normal(0, 10, (192, 3)), np.random.normal(0, 30, (19, 3)), np.random.normal(10, 10, (92, 3))]
labels = np.random.randint(0, 10, (num_samples, seq_len, feature_size))

start_time = time.time()
online_stats = online_statistics_sequences(samples)
log_stats(start_time, online_stats, "Online")


start_time = time.time()
offline_stats = offline_statistics_sequences(samples)
log_stats(start_time, online_stats, "Offline")

--- 0.00043201446533203125 seconds ---
[Online] mean: 1.788093508690388, std: 209.0499481567896
[Online] mean channel: [1.91043825 0.9617036  2.49213867], std channel: [197.54878069 239.00104677 190.78693344]
[Online] # samples: 3
[Online] min value: -106.05878061698726, max value: 71.94902591014984
[Online] min length: 19, max length: 192
--- 0.0010149478912353516 seconds ---
[Offline] mean: 1.788093508690388, std: 209.0499481567896
[Offline] mean channel: [1.91043825 0.9617036  2.49213867], std channel: [197.54878069 239.00104677 190.78693344]
[Offline] # samples: 3
[Offline] min value: -106.05878061698726, max value: 71.94902591014984
[Offline] min length: 19, max length: 192


In [13]:
def tf_online_statistics_sequences(iterable_data, key=None):
    """
    Given a da data iterator, gathers data statistics online. The whole data isn't required to be loaded.
    It is eager compatible, and hence it is okay to pass numpy array or python list.
    Args:
        iterable_data: where each sample is assumed to be a dictionary with <key,value> pairs. 
        key: data of interest.
    Returns:
        (dict) or data statistics with mean, var, min, max calculated across the sequences or all values, and
        seq_len, number of samples.
    """
    n_samples = 0.0
    n_all, mean_all, var_all, m2_all = 0.0, 0.0, 0.0, 0.0
    n_channel, mean_channel, var_channel, m2_channel = 0.0, 0.0, 0.0, 0.0
    min_all, max_all = np.inf, -np.inf
    min_seq_len, max_seq_len = np.inf, -np.inf

    for sample_dict in iterable_data:
        if key is not None:
            sample = sample_dict[key]
        else:
            sample = sample_dict
        n_samples += 1
        seq_len, feature_size = tf.to_float(sample.shape)

        # Global mean&var
        n_all += seq_len * feature_size
        delta_all = sample - mean_all
        mean_all = mean_all + tf.reduce_sum(delta_all) / n_all
        m2_all = m2_all + tf.reduce_sum(delta_all * (sample - mean_all))

        # Channel-wise mean&var
        n_channel += seq_len
        delta_channel = sample - mean_channel
        mean_channel = mean_channel + tf.reduce_sum(delta_channel, axis=0) / n_channel
        m2_channel = m2_channel + tf.reduce_sum(delta_channel * (sample - mean_channel), axis=0)

        # Global min&max values.
        min_all = np.min(sample) if np.min(sample) < min_all else min_all
        max_all = np.max(sample) if np.max(sample) > max_all else max_all

        # Min&max sequence length.
        min_seq_len = seq_len if seq_len < min_seq_len else min_seq_len
        max_seq_len = seq_len if seq_len > max_seq_len else max_seq_len

    var_all = m2_all / (n_all - 1)
    var_channel = m2_channel / (n_channel - 1)
    stats = dict(mean_all=mean_all, var_all=var_all, mean_channel=mean_channel, var_channel=var_channel,
                  min_all=min_all, max_all=max_all, min_seq_len=min_seq_len, max_seq_len=max_seq_len,
                  n_samples=n_samples)
    return stats

training_dataset = tf.data.Dataset.from_tensor_slices({"dummy_samples":np.float32(samples), "labels": samples})
start_time = time.time()
stats = tf_online_statistics_sequences(training_dataset, key="dummy_samples")
log_stats(start_time, stats, "TF-Online")

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use tf.cast instead.
--- 8.227863788604736 seconds ---
[TF-Online] mean: 19.989553451538086, std: 250.1465606689453
[TF-Online] mean channel: [-1.8944379e-02  2.9992947e+01  2.9994984e+01], std channel: [99.924255 25.002115 24.99739 ]
[TF-Online] # samples: 10000.0
[TF-Online] min value: -49.19842529296875, max value: 53.97816848754883
[TF-Online] min length: 100.0, max length: 100.0


In [0]:
# Get eager tensor, run in numpy.
def online_statistics_sequences(data, key):
  n_samples = 0
  n_all, mean_all, var_all, m2_all = 0, 0, 0, 0
  n_channel, mean_channel, var_channel, m2_channel = 0, 0, 0, 0
  min_all, max_all = np.inf, -np.inf
  min_seq_len, max_seq_len = np.inf, -np.inf

  for x in data: 
    if key is not None:
      x = x[key]
    if isinstance(x, tf.Tensor):
      x = x.numpy()

    n_samples += 1
    seq_len, feature_size = x.shape

    # Global mean&var
    n_all += seq_len*feature_size
    delta_all = x - mean_all
    mean_all = mean_all + delta_all.sum()/n_all
    m2_all = m2_all + (delta_all*(x - mean_all)).sum()

    # Channelwise mean&var
    n_channel += seq_len
    delta_channel = x - mean_channel
    mean_channel = mean_channel + delta_channel.sum(axis=0)/n_channel
    m2_channel = m2_channel + (delta_channel*(x - mean_channel)).sum(axis=0)

    # Global min&max values.
    min_all = np.min(x) if np.min(x) < min_all else min_all
    max_all = np.max(x) if np.max(x) > max_all else max_all
    
    # Min&max sequence length.
    min_seq_len = seq_len if seq_len < min_seq_len else min_seq_len
    max_seq_len = seq_len if seq_len > max_seq_len else max_seq_len

  var_all = m2_all/(n_all-1)
  var_channel = m2_channel/(n_channel-1)
  stats = dict(mean_all=mean_all, var_all=var_all, mean_channel=mean_channel, var_channel=var_channel, min_all=min_all, max_all=max_all, min_seq_len=min_seq_len, max_seq_len=max_seq_len, n_samples=n_samples)
  return stats

training_dataset = tf.data.Dataset.from_tensor_slices({"dummy_samples":np.float32(samples), "labels": samples})
start_time = time.time()
stats = tf_online_statistics_sequences(training_dataset, key="dummy_samples")
print("--- %s seconds ---" % (time.time() - start_time))

print("[TF Online] mean: {0}, std: {1}".format(stats["mean_all"], stats["var_all"]))
print("[TF Online] mean channel: {0}, std channel: {1}".format(stats["mean_channel"], stats["var_channel"]))
print("[TF Online] # samples: {0}".format(stats["n_samples"]))
print("[TF Online] min value: {0}, max value: {1}".format(stats["min_all"], stats["max_all"]))
print("[TF Online] min length: {0}, max length: {1}".format(stats["min_seq_len"], stats["max_seq_len"]))
print("============")

In [0]:
def online_statistics_sequences(data):
  n_samples = 0
  n, mean, var, m2 = 0, 0, 0, 0

  for x in data: 
    n_samples += 1

    # Global mean&var
    n += 1
    delta = x - mean
    mean = mean + delta/n
    m2 = m2 + delta*(x - mean)

  var = m2/(n-1)
  return mean, var

In [17]:
mean, var = online_statistics_sequences([10, 2, 3, 9])
print("[Online] mean: {0}, std: {1}".format(mean, var))

mean, var = online_statistics_sequences([10, 2, 9, 3])
print("[Online] mean: {0}, std: {1}".format(mean, var))

data = np.array([10, 2, 3, 9])
mean = data.mean()
var = data.var(ddof=1)
print("[Offline] mean: {0}, std: {1}".format(mean, var))

[Online] mean: 6.0, std: 16.666666666666668
[Online] mean: 6.0, std: 16.666666666666668
[Offline] mean: 6.0, std: 16.666666666666668
