# Setup

## Importing libraries

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
pip install nesmdb



In [None]:
import pickle

## Importing our dataset

Dataset: https://github.com/chrisdonahue/nesmdb

In [None]:
import requests 
url = "http://deepyeti.ucsd.edu/cdonahue/nesmdb/nesmdb24_seprsco.tar.gz"
  
r = requests.get(url)

with open("dataset.tar.gz",'wb') as f: 
    f.write(r.content) 

In [None]:
import tarfile
my_tar = tarfile.open('dataset.tar.gz')
my_tar.extractall('./dataset')
my_tar.close()

## Setting up global variables

# Preparation

## Utility functions

In [None]:
def seprsco_to_arr(file,k=4,z=8,x=96,y=96):
  with open(file, 'rb') as f:
    rate, nsamps, seprsco = pickle.load(f)
  while(len(seprsco)<x*z): seprsco=np.concatenate((seprsco,seprsco),axis=0)

  arr = np.zeros([k,z,x,y],dtype='int8')
  for i in range(z*x):
    if(seprsco[i][0]!=0): arr[0,i//x,i%x,seprsco[i][0]-20]=1
    if(seprsco[i][1]!=0): arr[1,i//x,i%x,seprsco[i][1]-20]=1
    if(seprsco[i][2]!=0): arr[2,i//x,i%x,seprsco[i][2]-20]=1
    if(seprsco[i][3]!=0): arr[3,i//x,i%x,seprsco[i][3]]=1
  return arr

In [131]:
def arr_to_seprsco(arr,k=4,z=8,x=96,y=96):
  seprsco=np.zeros([z*x,k])
  for i in range(z):
    for j in range(x):
      for b in range(y):
        if(arr[0][i][j][b]==1): seprsco[i*x+j,0]=b+20
        if(arr[1][i][j][b]==1): seprsco[i*x+j,1]=b+20
        if(arr[2][i][j][b]==1): seprsco[i*x+j,2]=b+20
        if(arr[3][i][j][b]==1): seprsco[i*x+j,3]=b
  return seprsco

In [62]:
def fn_to_input(filename): return np.array([seprsco_to_arr(filename).reshape(96*96*8*4)])
def output_to_arr(out): return out.reshape(4,8,96,96)

## Preparing the dataset

# Creating the model

In [None]:
inputs = keras.Input(4*96*96*8)

In [None]:
encoded = layers.Dense(256,activation='relu')(inputs)
encoded = layers.Dense(128,activation='relu')(encoded)

decoded = layers.Dense(256,activation='relu')(encoded)
decoded = layers.Dense(4*96*96*8,activation='relu')(decoded)

In [None]:
autoencoder = keras.Model(inputs,decoded)
autoencoder.compile(optimizer='adadelta',loss='binary_crossentropy')

# Training the model

In [None]:
import glob;
from scipy.sparse import csr_matrix
from numpy import reshape

In [None]:
x_train = np.array([seprsco_to_arr(filename).reshape(96*96*8*4) for filename in glob.iglob('dataset/nesmdb24_seprsco/train/*.seprsco.pkl')],dtype='int8')

x_test = np.array([np.reshape(seprsco_to_arr(filename),96*96*8*4) for filename in glob.iglob('dataset/nesmdb24_seprsco/test/*.seprsco.pkl')],dtype='int8')
x_valid = np.array([np.reshape(seprsco_to_arr(filename),96*96*8*4) for filename in glob.iglob('dataset/nesmdb24_seprsco/valid/*.seprsco.pkl')],dtype='int8')

In [None]:
autoencoder.optimizer.learning_rate = 0.1

In [None]:
autoencoder.fit(x_train, x_train,
                epochs=100,
                batch_size=64,
                shuffle=True,
                validation_data=(x_test, x_test),)

In [None]:
autoencoder.save('autoencoder v1.0-100') # 100 epochs

In [49]:
autoencoder = keras.models.load_model('autoencoder v1.0-100')

In [203]:
path = 'dataset/nesmdb24_seprsco/valid/012_AlphaMission_01_02SYDsThemeArea.seprsco.pkl'

y = autoencoder.predict(fn_to_input(path))
y = np.round(y,4)
unique, counts = np.unique(y, return_counts=True)

In [204]:
x = fn_to_input(path)
unique, counts = np.unique(x, return_counts=True)
dict(zip(unique, counts))

{0: 293620, 1: 1292}

In [207]:
unique, counts = np.unique(y, return_counts=True)
dict(zip(unique, counts))

{0.0: 293278, 1.0: 1634}

In [206]:
bar = np.percentile(y[0],99)+0.002
for i in range(len(y[0])):
  if(y[0][i]>=bar): y[0][i]=1
  else: y[0][i]=0

In [208]:
xs = output_to_arr(y)
xs = arr_to_seprsco(xs)
xs

array([[ 0.,  0.,  0., 16.],
       [ 0.,  0., 48., 15.],
       [ 0.,  0., 48.,  0.],
       ...,
       [ 0.,  0., 48., 15.],
       [67.,  0., 48., 15.],
       [67.,  0., 48., 15.]])

In [201]:
from nesmdb.convert import seprsco_to_wav
with open(path, 'rb') as f:
  exprsco = pickle.load(f)
wav = seprsco_to_wav(exprsco)

In [223]:
wav = seprsco_to_wav((24,1466371,xs))

In [212]:
from scipy.io.wavfile import write

write('test_wav_00.wav', 44100 , wav)

# Creating and exporting the generator