# Define input data loading functions - TF

Based on https://www.tensorflow.org/tutorials/load_data/images#using_tfdata_for_finer_control

In [None]:
# class names get turned into 0, 1, 2, 3, ... to be used with sparse_categorical_crossentropy
def encode_class_label(class_name, available_classes):
  numerical_label = class_name == available_classes
  return tf.argmax(numerical_label)

In [None]:
# decode png
def decode_img_data(img_data, channels=3):
  img_data = tf.io.decode_png(img_data, channels=channels)
  return tf.image.resize(img_data, [IMAGE_SIZE, IMAGE_SIZE])

In [None]:
# return pair: decoded png and class name turned into numerical label
def process_sample(sample, available_classes, channels=3):
  img_path = sample[0]
  class_label = sample[1]
  
  img_data = tf.io.read_file(img_path)
  img_data = decode_img_data(img_data, channels)
  numerical_label = one_hot_encode_class_label(class_label, available_classes)
  return img_data, numerical_label

In [None]:
def configure_for_performance(ds, batch_size=8):
  ds = ds.cache()
  ds = ds.shuffle(buffer_size=1000)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
  return ds

In [None]:
# path&genre pairs from dataframe -> image data&numerically-encoded label pairs as dataset
def prepare_dataset_based_on_df(df, available_classes, batch_size=8, channels=3):
    df.loc[:, 'path'] = df['path'].apply(lambda x: str(x)) # copy-on-write warning fixed with the loc
    ds = tf.data.Dataset.from_tensor_slices((df['path'], df['genre']))
    
    print("Example data:")
    [print(d[0].numpy(), d[1].numpy()) for d in ds.take(3)]
        
    print(f'Data set size: {tf.data.experimental.cardinality(ds).numpy()}')    
    
    # Use Dataset.map to create a Dataset of (image data, numerically-encoded label) pairs:
    ds = ds.map(lambda *d: process_sample(d, available_classes, channels),
                num_parallel_calls=tf.data.AUTOTUNE)
    
    ds = configure_for_performance(ds, batch_size)

    return ds

In [None]:
def visualize_samples_from_dataset_batch(ds, available_classes, samples=6):
    image_batch, label_batch = next(iter(ds))
    samples = samples if samples <= len(label_batch) else len(label_batch)
    
    fig, axes = plt.subplots(ncols=3, nrows=(samples + 2) // 3)
    [axis.set_axis_off() for axis in axes.ravel()]
    
    for i in range(samples):
        ax = axes.flat[i]
        ax.imshow(image_batch[i].numpy().astype("uint8"))
        ax.set_title(available_classes[label_batch[i]])

In [None]:
# Automatic - but with no manual control over the split
#
# gtzan_image_dir = gtzan_dir / 'spectrograms'
# gtzan_train_ds = tf.keras.utils.image_dataset_from_directory(
#   gtzan_image_dir,
#   seed=42,
#   validation_split=0.2,
#   labels='inferred',
#   label_mode='categorical',
#   color_mode='rgb',
#   image_size=(IMAGE_SIZE, IMAGE_SIZE))

# CNN - TF

In [None]:
def create_CNN(img_size, channels, num_classes):
    return models.Sequential([
        layers.Input(shape=(img_size, img_size, channels)),
        
        layers.Conv2D(64, (5, 5), activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2), strides=(2, 2)),
        layers.Dropout(0.2),
        
        layers.Conv2D(64, (5, 5), activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2), strides=(2, 2)),
        layers.Dropout(0.2),
        
        layers.Conv2D(128, (5, 5), activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2), strides=(2, 2)),
        layers.Dropout(0.2),
        
        layers.Conv2D(128, (5, 5), activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2), strides=(2, 2)),
        layers.Dropout(0.2),
        
        layers.Flatten(),
        layers.Dense(num_classes, activation='softmax')
    ])

In [None]:
multiclass_metrics=[tf.keras.metrics.SparseCategoricalAccuracy(),
                    #...
                ]

# Training/Testing - TF

In [None]:
model_gtzan = create_CNN(img_size=IMAGE_SIZE, channels=CHANNELS, num_classes=len(gtzan_classes))

In [None]:
model_gtzan.compile(optimizer='adam',
                    loss='sparse_categorical_crossentropy',
                    metrics=multiclass_metrics)

In [None]:
model_gtzan.fit(
  gtzan_train_dl,
  validation_data=gtzan_val_dl,
  epochs=8
)

In [None]:
model_gtzan.evaluate(
  gtzan_test_dl
)