In [None]:
# Imports
from matplotlib import pyplot as plt
import librosa
import torch
import torch.nn as nn
from torchsummary import summary
import torchaudio
import soundfile as sf
import numpy as np
import sys
sys.path.insert(0, '/home/ubuntu/joanna/AudioMNIST-AE/src/')
from scipy.io import wavfile
from IPython.display import Audio
# my modules:
import helpers
import importlib
from datetime import datetime
importlib.reload(helpers)

In [None]:
# Pre-processing (wav -> power spectrogram):
sig_orig,S,minmax=helpers.wav2powspec("test.wav")

# plot original and reconstructed signal
helpers.plot_spectrogram(S, title=None, ylabel="freq_bin")

# Post-processing (power spectrogram -> wav):
sig_recon=helpers.powspec2wave(S,orig_min=minmax["min"],orig_max=minmax["max"])

# plot original and reconstructed signal
plt.figure()
plt.plot(sig_orig,label='original')
plt.plot(sig_recon,label='Griffin-Lim reconstruction')
plt.legend()
plt.show()

# check how the reconstruction sounds
sf.write('test_reconstructed.wav', sig_recon.numpy().T, 22050, subtype='PCM_24')
Audio('test_reconstructed.wav')


In [None]:
# Import my module with data loader which uses these preprocessing steps 
import torchdataset_prep as dsprep
importlib.reload(dsprep)

# Parameters of data loader
AUDIO_PATH = "/home/ubuntu/Data/AudioMNIST/data"
SAMPLE_RATE = 22050
SIG_LEN=1
SNR=100
N_SPK=60

# Parameters of data loader
dataset = dsprep.AudioMnistPowSpec(AUDIO_PATH, SAMPLE_RATE, SIG_LEN,N_SPK,SNR)
random_idx=np.random.randint(1,len(dataset))
data, label = dataset[random_idx]
# plot a random data point
helpers.plot_datapoint(data,label)
sig_recon=helpers.powspec2wave(data,orig_min=0,orig_max=5)
sf.write('reconstructed.wav', sig_recon.numpy().T, 22050, subtype='PCM_24')
Audio('reconstructed.wav')

In [None]:
# Import my module with model definition and training procedure
import training as TR
importlib.reload(TR)

# choose computing device
if torch.backends.mps.is_available():
    print("Using M1")
    device = torch.device("mps")
else:
    print("Using Cuda")
    device = torch.device('cuda:0')

# split dataset into training set, test set and validation set
N_train = round(len(dataset) * 0.8)
N_rest = len(dataset) - N_train
trainset, restset = torch.utils.data.random_split(dataset, [N_train, N_rest])
N_test = round(len(restset) * 0.5)
N_val = len(restset) - N_test
testset, valset = torch.utils.data.random_split(restset, [N_test, N_val])

# create dataloaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=6)
valloader = torch.utils.data.DataLoader(valset, batch_size=256, shuffle=True, num_workers=6)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=True, num_workers=6)
    
# instantiate a model
model=TR.AutoencoderConv()
model.to(device)

# training
N_EPOCHS=30
outputs=TR.training(model, trainloader, N_EPOCHS, device,store_outputs=True)

# save model
now=datetime.now(); dt_string = now.strftime("%d-%m-%Y--%H-%M")
torch.save(model.state_dict(), "../models/trained_model_"+dt_string+".pth")


In [None]:
# test = print the reconstruction
n_examples=5
for i in range(0,9):
    fig, axs = plt.subplots(2,n_examples)
    x_orig=outputs[i][1]
    x_recon=outputs[i][2]
    x_label=outputs[i][3]
    for k in range(0,n_examples):
        x_orig_plot=torch.squeeze(x_orig[k].reshape(-1,513,44).detach()).cpu()
        x_recon_plot=torch.squeeze(x_recon[k].reshape(-1,513,44).detach()).cpu()
        x_label_plot=torch.squeeze(x_label[k].detach()).cpu()
        
        axs[0,k].set_title("orig "+ str(x_label_plot))
        axs[0,k].imshow(librosa.power_to_db(x_orig_plot), origin="lower", aspect="auto")
        
        axs[1,k].set_title("recon "+ str(x_label_plot))
        axs[1,k].imshow(librosa.power_to_db(x_recon_plot), origin="lower", aspect="auto")

        
    plt.show(block=False)

In [None]:
# Check how good is the reconstruction
FINAL_EPOCH=9
x_recon=outputs[FINAL_EPOCH][2]
x_label=outputs[FINAL_EPOCH][3]
rand_idx=np.random.randint(1,48)
x_recon_play=torch.squeeze(x_recon[rand_idx].reshape(-1,513,44).detach()).cpu()
x_label_play=torch.squeeze(x_label[rand_idx].detach()).cpu()
sig_recon=helpers.powspec2wave(x_recon_play,orig_min=0,orig_max=5)
sf.write('reconstructed.wav', sig_recon.numpy().T, 22050, subtype='PCM_24')


print("Original label: "+ str(x_label_play.numpy()))
Audio('reconstructed.wav')

In [None]:
#1,16,3,2,1
O1=helpers.compute_cnn_out([1,513,44,1],[16,3,3,1],[1,1],[2,2])
print(O1)
O2=helpers.compute_cnn_out([1, 257.0, 22.0, 16],[32,3,3,16],[1,1],[2,2])
print(O2)
O3=helpers.compute_cnn_out([1, 129.0, 11.0, 32],[64,129,11,32],[0,0],[1,1])
print(O3)