In [None]:
# pip install jupyter matplotlib

import itertools
import os

from matplotlib import pyplot as plt
import tensorflow as tf

In [None]:
from multispecies_whale_detection import front_end

In [None]:
from multispecies_whale_detection import dataset

In [None]:
base_dir = os.path.expanduser('~/tmp/whale_data')
batch_size = 128

def configured_window_dataset(
  input_subdirectory: str,
  windowing: dataset.Windowing,
) -> tf.data.Dataset:
    """Creates a Dataset, binding arguments shared by train and validation."""
    return dataset.new_window_dataset(
        tfrecord_filepattern=os.path.join(base_dir, 'input', input_subdirectory,
                                          'tfrecords-*'),
        windowing=windowing,
        duration=1.0,
        class_names=['Orca', 'SRKW', 'IBKW', 'Ej'],
        min_overlap=0.25,
    )

train_dataset = configured_window_dataset(
  'train',
  dataset.RandomWindowing(4),
).cache().repeat().shuffle(batch_size * 4).batch(batch_size).prefetch(1)

In [None]:
model = tf.keras.Sequential([
      front_end.Spectrogram(
          front_end.SpectrogramConfig(
              sample_rate=4000.0,
              frame_seconds=0.05,
              hop_seconds=0.025,
              frequency_scaling=front_end.MelScalingConfig(
                  lower_edge_hz=125.0,
                  num_mel_bins=64,
              ))),
      front_end.SpectrogramToImage(sgram_min=-323, sgram_max=-99),
  ])

In [None]:
iter_train = iter(train_dataset)

In [None]:
waveform, labels = next(iter_train)

In [None]:
(waveform.shape, labels.shape)

In [None]:
(tf.math.reduce_mean(images), tf.math.reduce_std(images), tf.math.reduce_min(images), tf.math.reduce_max(images))

In [None]:
images = model(waveform)

In [None]:
images.shape

In [None]:
offset = 20
limit = 10

for image, label_batch in itertools.islice(zip(images, labels), offset, offset +limit):
    plt.imshow(tf.cast(image, tf.int32))
    plt.title(label_batch.numpy())
    plt.show()

In [None]:
batch_counts = []
for _, label_batch in itertools.islice(iter(train_dataset), 5):
    batch_counts.append(tf.math.reduce_sum(label_batch, axis=0))
batch_counts = tf.stack(batch_counts)

In [None]:
batch_counts.shape

In [None]:
num_classes = batch_counts.shape[-1]
for class_index in range(num_classes):
    class_counts = batch_counts[:, class_index]
    plt.hist(class_counts)
plt.show()