In [1]:
import os, json, librosa, torch
from utils import complexToPolar, saveSpectrogram, reconstructAudioFromBatches
from flask import Flask, render_template, request, redirect
from inference import inference
import numpy as np

In [2]:
with open('config.json') as f:
    data=f.read()
stftParams=json.loads(data)['stftParams']
data=0
file="LJ050-0084.wav"
audio, _=librosa.load(file, sr=16000)

In [3]:
checkpoint="/Users/zombie/Downloads/colabBatchSize4_5000.pt"

In [4]:
def transform(X):
    return (X - X.min()) / (X.max() - X.min())


def split(array):
    """Accepts numpy array of shape [x, y] returns squared arrays split along y-axis"""
    xdim = array.shape[0] // 2 * 2
    ydim = array.shape[1]
    ydimHead = ydim // xdim * xdim  # y-dim of the initial batches                                                                                                                   
    ydimTail = ydim - ydim // xdim * xdim  # y-dim of the last batch                                                                                                                 
    batches = np.asarray(np.hsplit(array[:xdim, :ydimHead], ydim // xdim))

    # pad the last batch with minimum value of spectrogram                                                                                                                           
    tailBatch = array[:xdim, ydimHead:]
    tailB = np.full((xdim, xdim), array.min())
    tailB[:, :ydimTail] = tailBatch
    batches = np.concatenate((batches, np.expand_dims(tailB, axis=0)), axis=0)
    batches = transform(batches)
    return np.expand_dims(batches, axis=1)

In [5]:
from unet import UNet
net=UNet(1,1)
ckp = torch.load(checkpoint, map_location='cpu')
net.load_state_dict(ckp['modelStateDict'])
net.eval()
ckp=0

istftParams={key:val for key, val in stftParams.items() if key!="n_fft"}
stft=librosa.stft(audio, **stftParams)
stft=stft[:, :4000]
mag, phase=complexToPolar(stft)
magBatch=split(mag)
phaseBatch=split(phase)

In [6]:
newMag=net(torch.from_numpy(magBatch))

In [7]:
spec=saveSpectrogram(magBatch[0][0], newMag[0][0].detach())
genAudio=reconstructAudioFromBatches(newMag.detach().numpy(), phaseBatch, istftParams)

In [10]:
genAudioFile="generated.wav"