In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import tensorflow as tf
import numpy as np
import argparse
import glob
import sys
import os

import matplotlib.pyplot as plt

from helpers.dataset import get_mv_analysis_users, load_data_set, filter_by_gender
from helpers.datapipeline import data_pipeline_generator_gan, data_pipeline_gan
from helpers import plotting
from models.gan.wavegan import WaveGAN
from models.gan.specgan import SpecGAN

In [None]:
from collections import namedtuple

Args = namedtuple('Args', 'net,gender,latent_dim,slice_len,audio_dir,audio_meta,mv_data_path,n_epochs,batch,prefetch,n_seconds,sample_rate')
args = Args('specgan', 'female', 100, 16384, '/beegfs/mm10572/voxceleb1/dev', './data/ad_voxceleb12/vox12_meta_data.csv', 
            './data/ad_voxceleb12/vox2_mv_data.npz', 2000, 64, 0, 3, 16000)

In [None]:
print('Parameters summary')
print('>', 'Net GAN: {}'.format(args.net))
print('>', 'Gender GAN: {}'.format(args.gender))
print('>', 'Latent dim: {}'.format(args.latent_dim))
print('>', 'Slice len: {}'.format(args.slice_len))

print('>', 'Audio dirs: {}'.format(args.audio_dir))
print('>', 'Audio meta: {}'.format(args.audio_meta))
print('>', 'Master voice data: {}'.format(args.mv_data_path))
print('>', 'Number of epochs: {}'.format(args.n_epochs))
print('>', 'Batch size: {}'.format(args.batch))
print('>', 'Prefetch: {}'.format(args.prefetch))
print('>', 'Max number of seconds: {}'.format(args.n_seconds))

print('>', 'Sample rate: {}'.format(args.sample_rate))

In [None]:
# Load data set
print('Loading data')
audio_dir = map(str, args.audio_dir.split(','))
mv_user_ids = get_mv_analysis_users(args.mv_data_path)
x_train, y_train = load_data_set(audio_dir, mv_user_ids)
x_train, y_train = filter_by_gender(x_train, y_train, args.audio_meta, args.gender)

x_train, y_train = x_train[:8192], y_train[:8192]

classes = len(np.unique(y_train))

print(f'X_train {len(x_train)}: {x_train[:3]}')
print(f'Y_train {len(y_train)}: {y_train[:3]}')

# Generator output test
print('Checking generator output')
for index, x in enumerate(data_pipeline_generator_gan(x_train[:10], slice_len=args.slice_len, sample_rate=args.sample_rate)):
    print('>', index, x.shape)

# Data pipeline output test
print('Checking data pipeline output')
train_data = data_pipeline_gan(x_train, slice_len=args.slice_len, sample_rate=args.sample_rate, batch=args.batch, prefetch=args.prefetch, output_type='spectrum' if args.net == 'specgan' else 'raw')

for index, x in enumerate(train_data):
    print('>', index, x.shape)
    if index == 10:
        break

# Create and train model
train_data = data_pipeline_gan(x_train, slice_len=args.slice_len, sample_rate=args.sample_rate, batch=args.batch, prefetch=args.prefetch, output_type='spectrum' if args.net == 'specgan' else 'raw')

In [None]:
f = plotting.imsc(x.numpy(), '')

In [None]:
plt.figure()
plt.imshow(x[2].numpy().squeeze())
plt.figure()
_ = plt.hist(x[2].numpy().ravel())

In [None]:
# Create GAN
model_id = int(args.net.split('/')[1].replace('v','')) if '/v' in args.net else -1
print('Creating GAN with id={}'.format(model_id))
available_nets = {'wavegan': WaveGAN, 'specgan': SpecGAN}
gan_model = available_nets[args.net.split('/')[0]](id=model_id, gender=args.gender, latent_dim=args.latent_dim, slice_len=args.slice_len, lr=1e-5)

In [None]:
!rm -rf data/pt_models/specgan/female/

In [None]:
gan_model.build()

In [None]:
gan_model.train_step(batch_x)

In [None]:
print(f'X_train {len(x_train)}: {x_train[:3]}')
print(f'Y_train {len(y_train)}: {y_train[:3]}')

In [None]:
gan_model.build()
gan_model.train(train_data, epochs=args.n_epochs, batch=args.batch, dsteps=10, gradient_penalty=True, preview_interval=20)

In [None]:
gan_model.show_progress()

In [None]:
fig = gan_model.preview(1)

In [None]:
for index, batch_x in enumerate(train_data):
    print('>', index, x.shape)
    if index == 0:
        break

In [None]:

batch_y = gan_model.sample().numpy()

plt.hist(batch_x[0].numpy().ravel(), alpha=0.5)
plt.hist(batch_y.ravel(), alpha=0.5)

print('D_x', gan_model.discriminator(batch_x[0:1]).numpy())
print('D_G_z', gan_model.discriminator(batch_y).numpy())

# plt.imshow(batch_y.squeeze())