In [1]:
import tensorflow as tf

In [2]:
import tensorflow_probability as tfp

In [3]:
import numpy as np
import os

In [4]:
spectrogram_dir = "/Users/llewyn/Documents/data/stft/codes"
augmented_dir = "/Users/llewyn/Documents/data/stft/codes"

def get_dataset(ds_dir=spectrogram_dir):
    files = [f for f in os.listdir(ds_dir) if os.path.isfile(os.path.join(ds_dir, f))]
    files = [os.path.join(ds_dir, f) for f in files]
    
    # get augmented filenames
    # aug_files = [f for f in os.listdir(augmented_dir) if os.path.isfile(os.path.join(augmented_dir, f))]
    # aug_files = [os.path.join(augmented_dir, f) for f in aug_files]
    
    files = tf.constant(files) # [path_to_file1... path to filen]
    # aug_files = tf.constant(aug_files) # [path_to_badfile1 ... path to bad filen]
    
    # dataset = tf.data.Dataset.from_tensor_slices((files, aug_files))    # => [[path to good file1 , path to badfile1], [], []]

    dataset = tf.data.Dataset.from_tensor_slices(files)
    return dataset

In [5]:
tfd = tfp.distributions
tfk = tf.keras
tfkl = tf.keras.layers

AUTOTUNE = tf.data.experimental.AUTOTUNE

def read_codes(item):
    codes = np.load(item.decode(), allow_pickle=True)
    codes = codes[0].reshape(88, 16, 1) / 512.
    return codes.astype(np.float32)

def load_file(filepath):
    codes = tf.numpy_function(read_codes, [filepath], [tf.float32])
    return codes

def split_data(ds, shuffle_buffer_size=1024, batch_size=64):
    test_ds = ds.take(200) 
    train_ds = ds.skip(200)
        
    train_ds = train_ds.shuffle(buffer_size=shuffle_buffer_size)
    train_ds = train_ds.map(load_file, num_parallel_calls=AUTOTUNE)
    train_ds = train_ds.batch(batch_size, drop_remainder=True)
    
    test_ds = test_ds.shuffle(buffer_size=shuffle_buffer_size)
    test_ds = test_ds.map(load_file, num_parallel_calls=AUTOTUNE)
    test_ds = test_ds.batch(batch_size, drop_remainder=True)
   
    train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
    test_ds = test_ds.prefetch(tf.data.AUTOTUNE)
    
    return train_ds, test_ds


In [6]:
dataset = get_dataset()
train_ds, test_ds = split_data(dataset)

In [125]:
top_input_size, bot_input_size = (88, 16, 1), (88, 64, 1)

# Define a Pixel CNN network
dist = tfd.PixelCNN(
    image_shape=top_input_size,
    num_resnet=1,
    num_hierarchies=2,
    num_filters=32,
    num_logistic_mix=5,
    dropout_p=.3,
)

# Define the model input
image_input = tfkl.Input(shape=top_input_size)

# Define the log likelihood for the loss fn
log_prob = dist.log_prob(image_input)

# Define the model
model = tfk.Model(inputs=image_input, outputs=log_prob)
model.add_loss(-tf.reduce_mean(log_prob))

# Compile and train the model
model.compile(
    optimizer=tfk.optimizers.Adam(.001),
    metrics=[])

In [8]:
model.fit(train_ds, epochs=10, verbose=True)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x16193a100>

In [11]:
model.save_weights("./saved_models/top_encoder")

In [116]:
def read_codes(item):
    codes = np.load(item.decode(), allow_pickle=True)
    top_codes = codes[0]
    bot_codes = codes[1].reshape(88, 64, 1)
    return np.mean(top_codes).astype(np.float32), bot_codes.astype(np.float32)

def load_file(filepath):
    top, bot = tf.numpy_function(read_codes, [filepath], [tf.float32, tf.float32])
    return ((bot, top), )

def split_data(ds, shuffle_buffer_size=1024, batch_size=64):
    test_ds = ds.take(200) 
    train_ds = ds.skip(200)
        
    train_ds = train_ds.shuffle(buffer_size=shuffle_buffer_size)
    train_ds = train_ds.map(load_file, num_parallel_calls=AUTOTUNE)
    train_ds = train_ds.batch(batch_size, drop_remainder=True)
    
    test_ds = test_ds.shuffle(buffer_size=shuffle_buffer_size)
    test_ds = test_ds.map(load_file, num_parallel_calls=AUTOTUNE)
    test_ds = test_ds.batch(batch_size, drop_remainder=True)
   
    train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
    test_ds = test_ds.prefetch(tf.data.AUTOTUNE)
    
    return train_ds, test_ds

In [117]:
train_ds, test_ds = split_data(dataset)

In [127]:
# Define a Pixel CNN network
dist = tfd.PixelCNN(
    image_shape=bot_input_size,
    conditional_shape=(),
    num_resnet=1,
    num_hierarchies=2,
    num_filters=32,
    num_logistic_mix=5,
    dropout_p=.3,
)

bot_input = tfkl.Input(shape=bot_input_size)
top_input = tfkl.Input(shape=())

log_prob = dist.log_prob(bot_input, conditional_input=top_input)

class_cond_model = tfk.Model(
    inputs=[bot_input, top_input], outputs=log_prob)
class_cond_model.add_loss(-tf.reduce_mean(log_prob))
class_cond_model.compile(
    optimizer=tfk.optimizers.Adam(learning_rate=1e-7),
    metrics=[])

In [128]:
class_cond_model.fit(train_ds, epochs=10)

Epoch 1/10
   9/5491 [..............................] - ETA: 6:49:52 - loss: nan

KeyboardInterrupt: 