Rejection resampling test. See https://www.tensorflow.org/guide/data#resampling

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
zip_path = tf.keras.utils.get_file(
    origin='https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip',
    fname='creditcard.zip',
    extract=True)

csv_path = zip_path.replace('.zip', '.csv')

In [38]:
creditcard_ds = tf.data.experimental.make_csv_dataset(
    csv_path, batch_size=1024, label_name="Class",
    # Set the column types: 30 floats and an int.
    column_defaults=[float()]*30+[int()])

In [37]:
for features, labels in creditcard_ds.unbatch().batch(10).take(10):
    print (labels.numpy())

[0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0]


In [39]:
creditcard_ds?

[0;31mType:[0m           PrefetchDataset
[0;31mString form:[0m    <PrefetchDataset shapes: (OrderedDict([(Time, (1024,)), (V1, (1024,)), (V2, (1024,)), (V3, (1024, <...> t32), (V26, tf.float32), (V27, tf.float32), (V28, tf.float32), (Amount, tf.float32)]), tf.int32)>
[0;31mFile:[0m           ~/anaconda3/envs/TF2/lib/python3.7/site-packages/tensorflow_core/python/data/ops/dataset_ops.py
[0;31mDocstring:[0m      A `Dataset` that asynchronously prefetches its input.
[0;31mInit docstring:[0m
See `Dataset.prefetch()` for details.

Args:
  input_dataset: The input dataset.
  buffer_size: See `Dataset.prefetch()` for details.
  slack_period: (Optional.) An integer. If non-zero, determines the number
    of GetNext calls before injecting slack into the execution. This may
    reduce CPU contention at the start of a step. Note that a tensorflow
    user should not have to set this manually; enable this behavior
    automatically via `tf.data.Options.experimental_slack` instead. Defaul

In [5]:
def count(counts, batch):
  features, labels = batch
  class_1 = labels == 1
  class_1 = tf.cast(class_1, tf.int32)

  class_0 = labels == 0
  class_0 = tf.cast(class_0, tf.int32)

  counts['class_0'] += tf.reduce_sum(class_0)
  counts['class_1'] += tf.reduce_sum(class_1)

  return counts

In [6]:
counts = creditcard_ds.take(10).reduce(
    initial_state={'class_0': 0, 'class_1': 0},
    reduce_func = count)

counts = np.array([counts['class_0'].numpy(),
                   counts['class_1'].numpy()]).astype(np.float32)

fractions = counts/counts.sum()
print(fractions)

[0.9961914  0.00380859]


### Resampling - dataset as two different tf.data datasets

In [None]:
negative_ds = (
  creditcard_ds
    .unbatch()
    .filter(lambda features, label: label==0)
    .repeat())
positive_ds = (
  creditcard_ds
    .unbatch()
    .filter(lambda features, label: label==1)
    .repeat())

In [None]:
for features, label in positive_ds.batch(10).take(1):
  print(label.numpy())

In [None]:
balanced_ds = tf.data.experimental.sample_from_datasets(
    [negative_ds, positive_ds], [0.5, 0.5]).batch(10)

In [None]:
for features, labels in balanced_ds.take(10):
  print(labels.numpy())

In [None]:
counts = balanced_ds.take(10).reduce(
    initial_state={'class_0': 0, 'class_1': 0},
    reduce_func = count)

counts = np.array([counts['class_0'].numpy(),
                   counts['class_1'].numpy()]).astype(np.float32)

print(counts/counts.sum())

### Rejection resampling

In [30]:
def class_func(features, label):
    return label

In [31]:
resampler = tf.data.experimental.rejection_resample(
            class_func, #=lambda features, label: label, 
            target_dist=[0.5, 0.5], 
            initial_dist=fractions)

In [32]:
resample_ds = creditcard_ds.unbatch().apply(resampler).batch(10)

In [None]:
balanced_ds = resample_ds.map(lambda extra_label, features_and_label: features_and_label)

In [23]:
# Testing cell
for features, labels in resample_ds.take(3):
    print(labels[1])

tf.Tensor([0 0 0 0 1 1 1 1 0 0], shape=(10,), dtype=int32)
tf.Tensor([1 1 0 1 0 0 0 0 1 0], shape=(10,), dtype=int32)
tf.Tensor([1 1 1 0 0 0 0 1 0 0], shape=(10,), dtype=int32)


In [10]:
for features, labels in balanced_ds.take(10):
    print(labels.numpy())

[0 0 0 0 1 1 1 1 0 0]
[1 1 0 1 0 0 0 0 1 0]
[1 1 1 0 0 0 0 1 0 0]
[0 0 0 0 0 1 0 1 0 0]
[1 0 0 1 0 0 0 1 0 1]
[0 0 0 1 1 0 0 1 0 1]
[0 0 1 1 0 1 0 0 0 0]
[0 1 0 1 0 0 1 0 0 1]
[0 1 0 0 0 0 0 1 0 0]
[0 1 0 0 1 0 0 1 0 0]


In [12]:
counts = balanced_ds.take(10).reduce(
    initial_state={'class_0': 0, 'class_1': 0},
    reduce_func = count)

counts = np.array([counts['class_0'].numpy(),
                   counts['class_1'].numpy()]).astype(np.float32)

print(counts/counts.sum())

[0.57666665 0.42333335]
