In [None]:
import os
from scipy.io import wavfile
import pyworld
import pysptk
import IPython

import torch
import numpy as np
from IPython.display import Audio

from sklearn.linear_model import LinearRegression

import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn import svm
from sklearn.multioutput import MultiOutputRegressor

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F

from torch.utils.data import TensorDataset, DataLoader

In [None]:
n_mfcc = 40

inputs = []
outputs = []

fs = 22050
fftlen = pyworld.get_cheaptrick_fft_size(fs)
alpha = pysptk.util.mcepalpha(fs)

# Load Preprocessed Data

In [None]:
X, Y = torch.load('X_IVANKA'),torch.load('Y_IVANKA')
X, Y = torch.FloatTensor(X.float()), torch.FloatTensor(Y.float())
X, Y = X[:, 1:], Y[:, 1:]

X = X[:10000]
Y = Y[:10000]

In [None]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
dataset = TensorDataset(X,Y)
loader = DataLoader(dataset, batch_size=5000, pin_memory=True)

# Training

In [None]:
class MLP(nn.Module):
    def __init__(self, n_mfcc=40):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(n_mfcc, 128),
            nn.BatchNorm1d(128),
        )
        
        self.rnn = nn.LSTM(128, 32, 2, batch_first = True, bidirectional = True)
        
        self.lin = torch.nn.Linear(64, n_mfcc)
        
        self.h0 = torch.randn(4, 1, 32).to(device)
        self.c0 = torch.randn(4, 1, 32).to(device)
        
    def forward(self, x):
        x = self.layers(x)
        x = x.view(1, -1, *x.shape[1:])
        x, _ = self.rnn(x, (self.h0, self.c0))
        x = x.squeeze()
        x = self.lin(x)
        return x
    
mlp = MLP(40).to(device)

In [None]:
loss = nn.MSELoss()
optimizer = optim.Adam(mlp.parameters(), lr = 0.01)

EPOCHS = 100

losses = []
validation_loss = []

t = tqdm(range(EPOCHS))
for i in t:
    for batch in loader:
        batchX, batchY = batch
        batchX, batchY = batchX.to(device), batchY.to(device)
        mlp.zero_grad()
        out = mlp(batchX)
        l = loss(out, batchY)
        l.backward()
        losses.append(l)
        optimizer.step()
        t.postfix = {'LOSS':float(l)}

In [None]:
plt.plot(losses)

In [None]:
pred = mlp(X[:2000].to(device))
pred = pred.cpu().detach().numpy()
plt.plot(Y[1000:1500,1])
plt.plot(pred[1000:1500,1])
plt.legend(['target', 'prediction'])

In [None]:
def predict(mcc):
    mcc = torch.Tensor(mcc)
    pred = mlp(mcc.to(device))
    pred = pred.cpu().detach().numpy()
    return pred

# Test

In [None]:
DATASET_PATH = "../../../../datasets/inflection/ivakna"
test_inp = DATASET_PATH + "/joanna/ivanka_3.wav"
test_out = DATASET_PATH + "/ivanka/ivanka_3.wav"

get_from_codebook = lambda row: mcc_output[np.argmin(np.sum((mcc_output - row) ** 2, axis = 1))]

def preprocess(data):
    data = data.astype(np.float64)
    f0, sp, ap = pyworld.wav2world(data, fs)
    mcc = pysptk.sp2mc(sp, order=n_mfcc, alpha = alpha)
    return mcc, f0, ap

def decode(mcc, f0, ap):
    sp = pysptk.mc2sp(
            mcc.astype(np.float64), alpha=alpha, fftlen=fftlen)
    waveform = pyworld.synthesize(
            f0, sp, ap, fs)
    return waveform

fs, data_inp = wavfile.read(test_inp)
fs, data_out = wavfile.read(test_out)

mcc, f0, ap = preprocess(data_inp)
mcc_output, f0_output, ap_output = preprocess(data_out)

predicted_mcc = mcc.copy()
predicted_mcc[:,1:] = predict(mcc[:,1:])

pred2 = predicted_mcc.copy()

for i in tqdm(range(len(predicted_mcc))):
    pred2[i] = get_from_codebook(pred2[i])

original = decode(mcc, f0, ap)
prediction = decode(predicted_mcc, f0 * 1.15, ap)
prediction2 = decode(pred2, f0 * 1.15, ap)
target = decode(mcc_output, f0_output, ap_output)

IPython.display.display(Audio(original, rate = fs))
IPython.display.display(Audio(target, rate = fs))
IPython.display.display(Audio(prediction, rate = fs))
IPython.display.display(Audio(prediction2, rate = fs))

In [None]:
mfcc = 1
plt.plot(predicted_mcc[:500,mfcc]) 
plt.plot(pred2[:500, mfcc])
plt.legend(['pred', 'pred with codebook'])

In [None]:
plt.plot(mcc_output[:500, mfcc])
plt.legend(['target'])