In [1]:
# %matplotlib notebook
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import torch, torchvision
import torch.nn as nn
import pickle
import os

if torch.cuda.is_available():
    computing_device = torch.device("cuda")
else:
    computing_device = torch.device("cpu")

In [2]:
# get paths to files

# train_files = ['data/MT8_8K.wav']
# train_files = ['data/02_8K.wav', 'data/03_8K.wav', 'data/04_8K.wav', 'data/05_8K.wav']
# val_files = ['data/01_8K.wav']

train_ratio = 1.0

in_files = []
for root, dirs, files in os.walk("data"):
    for file in files:
        if file.endswith(".wav"):
            in_files += [root + '/' + file]
np.random.shuffle(in_files)
train_files = in_files[:int(train_ratio*len(in_files))]
val_files = in_files[int(train_ratio*len(in_files)):]
print(in_files[:3])

['data/03-8K.wav', 'data/08-8K.wav', 'data/09-8K.wav']


In [3]:
# dataloader init

from util import *

chunk_size = 20
window_size = 2046
window_overlap = 1023
batch_size = 1

train_gen = DataGenerator(train_files, chunk_size, window_size, window_overlap, batch_size)
val_gen = DataGenerator(val_files, chunk_size, window_size, window_overlap, batch_size)

File 14/14, sr=8000, zxx_c.shape=(1024, 1044), X.shape=torch.Size([1060, 1, 2048])



In [4]:
# model setup

from lstm import *

input_dim = train_gen.X_list[0].shape[2] #TODO
hidden_dim = 2048
num_layers = 1

# model = LSTMBasic(input_dim, hidden_dim, num_layers=num_layers, batch_size=batch_size)
# model = LSTMFC(input_dim, hidden_dim, hidden_dim, num_layers=num_layers, batch_size=batch_size, dropout_p=0.2)
model = LSTMCNN(input_dim, hidden_dim, num_layers=num_layers, batch_size=batch_size, decoder="2fc")

model = model.to(computing_device)
criterion = nn.MSELoss().to(computing_device)
dp = nn.DataParallel(model, dim=1).to(computing_device)
m = dp.module
optimizer = torch.optim.Adam(dp.parameters(), lr=0.001)

In [5]:
# trainer setup

from lstm_trainer import *
trainer = LSTMTrainer(dp, criterion, optimizer)

Using cpu


In [6]:
# load trained model? 

load_model = False

if load_model:
    epochs_trained = 1000
#     model_file = "models/cs{}_h{}_e{}.ckpt".format(chunk_size, hidden_dim, epochs_trained)
    model_file = "models/cs20_h2048_e4000.ckpt"
    print("Loading model: {}".format(model_file))
    trainer.load_model(model_file, epochs_trained)

In [7]:
%%time

# training

train_model = True
iter_epochs = 1
iters = 1
dump_epochs = 500

if train_model:
    
#     fig = plt.figure(figsize=(6,3))
#     ax = fig.add_subplot(1,1,1)
#     fig.show(); fig.canvas.draw()
    
    # train a series of models at different numbers of epochs
    curr_train_losses, curr_val_losses = [], []
    for i in range(iters):

        train_loss, val_loss = trainer.train(train_gen, val_gen, iter_epochs, 1,
                                             dump_model=True, dump_epochs=dump_epochs, dump_loss=True)
        curr_train_losses += train_loss  # train_loss is a 2D python list/
        curr_val_losses += val_loss
        
        # plot loss curve
#         ax.clear()
#         ax.plot(np.array(curr_train_losses).mean(axis=1))
#         fig.canvas.draw()
        
print()

RuntimeError: input.size(-1) must be equal to input_size. Expected 4096, got 2048

In [8]:
# import pickle
# import numpy as np

# train_loss = []
# val_loss = []
# for i in range(80):
#     t,v = pickle.load(open("models/model_h150_e{}.ckpt.loss.pkl".format((i+1)*10), 'rb'))
#     train_loss += [t]
#     val_loss += [v]
# plt.plot(np.average(np.array(train_loss).reshape((800,3799)), axis=1))
# plt.plot(np.average(np.array(val_loss).reshape((800,674)), axis=1))

# misc. tests below

In [9]:
raise Exception("STOP") # dirty way to stop the notebook

Exception: STOP

In [None]:
chunk_size = 20
window_size = 2046
window_overlap = 1023
batch_size = 2

# test_files = ['data/05_8K.wav']#, 'data/02_8K.wav', 'data/03_8K.wav', 'data/04_8K.wav', 'data/05_8K.wav']
test_files = in_files[-1:]

test_gen = DataGenerator(test_files, chunk_size, window_size, window_overlap, batch_size)

In [None]:
fname, X, T = test_gen[0]
X.shape

In [None]:
eval_output, hidden_states, cell_states = trainer.eval_model(test_gen, prime_len=100, gen_len=200)

In [None]:
eo = torch.FloatTensor(eval_output)[:, 0]
eo.shape

In [None]:
t,x = test_gen.reassemble_istft(eo[:, :1])

In [None]:
import matplotlib.pyplot as plt

fs = 8000

plt.specgram(x, Fs=fs, NFFT=window_size, noverlap=window_overlap)
plt.show()

In [None]:
wavfile.write("2.wav", fs, x)

In [None]:
cell_states = np.array(cell_states)[:, 0, 0]
hidden_states = np.array(hidden_states)[:, 0, 0]

In [None]:
plt.figure(figsize=(15,15))
plt.imshow(eo[:,0].transpose(0,1), cmap='gray')
plt.show()
plt.figure(figsize=(15,15))
plt.imshow(cell_states.transpose(), cmap='gray')
plt.show()
X.max()