In [1]:
!pip install torch==1.9.0
!pip install librosa==0.8.1
!pip install soundfile==0.10.2
!pip install bokeh==2.3.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch==1.9.0
  Downloading torch-1.9.0-cp37-cp37m-manylinux1_x86_64.whl (831.4 MB)
[K     |████████████████████████████████| 831.4 MB 2.2 kB/s 
Installing collected packages: torch
  Attempting uninstall: torch
    Found existing installation: torch 1.12.1+cu113
    Uninstalling torch-1.12.1+cu113:
      Successfully uninstalled torch-1.12.1+cu113
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchvision 0.13.1+cu113 requires torch==1.12.1, but you have torch 1.9.0 which is incompatible.
torchtext 0.13.1 requires torch==1.12.1, but you have torch 1.9.0 which is incompatible.
torchaudio 0.12.1+cu113 requires torch==1.12.1, but you have torch 1.9.0 which is incompatible.[0m
Successfully installed torch-1.9.0
Looking in indexes: https://pypi.org/s

In [2]:
import torch
import librosa
import soundfile as sf
import torch.nn as nn
import numpy as np
from torch.nn.utils import weight_norm
from torch import optim
from math import ceil
import glob
import time
import random
import os

In [3]:
#Connect colab to your google drive
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
#prepare input folder
input_folder='inputs'
if not os.path.exists(input_folder):
    os.mkdir(input_folder)
#copy file from drive to colab --remember it has to be from mydrive---for whatever reason it does not go any deeper!
!cp /content/drive/MyDrive/S11_ucm_1_mono.wav /content/inputs

In [6]:
#paramaters
inpainting_indices= [0, 1]
is_cuda = torch.cuda.is_available()
gpu_num = 0
manual_random_seed = -1
input_file = 'S11_ucm_1_mono.wav'
segments_to_train = []
start_time = 0
init_sample_rate =  16000
fs_list = [320, 400, 500, 640, 800, 1000, 1280, 1600, 2000, 2500, 4000, 8000, 10000, 12000, 14400, 16000]
max_length = 25
run_mode = 'normal' #['normal', 'inpainting', 'denoising']
num_epochs = 3500
learning_rate = 0.0015
scheduler_lr_decay = 0.1
beta1 = 0.5
speech = False
num_layers = 8
output_folder = 'outputs'
filter_size = 9
set_first_scale_by_energy = True
min_energy_th = 0.0025
hidden_channels_init = 16
growing_hidden_channels_factor = 6
plot_losses = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
initial_noise_amp = 1
noise_amp_factor = 0.01

In [7]:
#functions
def get_input_signal(input_file, max_length):
    file_name = input_file.split('.')
    if len(file_name) < 2:
        input_file = '.'.join([input_file, 'wav'])
    output_folder = file_name[0].replace(' ', '_')
    if len(segments_to_train) == 0:
        samples, Fs = librosa.load(os.path.join('inputs', input_file), sr=None,
                                   offset=start_time, duration=2 * max_length)

    if samples.shape[0] / Fs > max_length:
        n_samples = int(max_length * Fs)
        samples = samples[:n_samples]

    output_folder = output_folder
    output_folder = os.path.join('outputs', output_folder)
    Fs = Fs
    if init_sample_rate < Fs:
        hr_samples = samples.copy()
        samples = librosa.resample(hr_samples, Fs, init_sample_rate)
        Fs = init_sample_rate
    norm_factor = max(abs(samples.reshape(-1)))
    samples = samples / norm_factor
    return samples, Fs

def create_input_signals(scales, set_first_scale_by_energy, min_energy_th,  filter_size, input_signal, Fs):
    # Performs downscaling for desired scales and outputs list of signals
    signals_list = []
    fs_list = []
    n_scales = len(scales)
    set_first_scale = False
    rf = calc_receptive_field(filter_size, dilation_factors)
    for k in range(n_scales):
        downsample = scales[k]
        fs = int(Fs / downsample)
        if downsample == 1:
            coarse_sig = input_signal
        else:
            coarse_sig = torch.Tensor(librosa.resample(input_signal.squeeze().numpy(), Fs, fs))
        if speech and fs < 500:
            continue
        if set_first_scale_by_energy and not speech:
            e = (coarse_sig ** 2).mean()
            if e < min_energy_th and not set_first_scale:
                continue
        set_first_scale = True
        signals_list.append(coarse_sig)
        assert np.mod(fs, 1) == 0, 'Sampling rate is not integer'
        fs_list.append(int(fs))

        # Write downsampled real sound
        filename = 'real@%dHz.wav' % fs
        write_signal(os.path.join(output_folder, filename), coarse_sig.cpu(), fs)

    return signals_list, fs_list

def calc_receptive_field(filter_size, dilation_factors, Fs=None):
    if Fs is None:
        # in samples
        return (filter_size * dilation_factors[0] + sum(dilation_factors[1:]) * (filter_size - 1))
    else:
        # in [ms]
        return (filter_size * dilation_factors[0] + sum(dilation_factors[1:]) * (filter_size - 1)) / Fs * 1e3

def write_signal(path, signal, fs, overwrite=False, subtype='PCM_16'):
    if signal is None:
        return
    if torch.is_tensor(signal):
        signal = signal.squeeze().detach().cpu().numpy()
    if not path.endswith('.wav'):
        path = path + '.wav'
    if not overwrite:
        if os.path.exists(path):
            files = glob.glob(path[:-4].replace('[Hz]', '[[]Hz[]]') + '*')
            path = path[:-4] + '_' + str(len(files)) + path[-4:]
    maxAmp = max(abs(signal.reshape(-1)))
    if maxAmp > 1:
        signal = signal / maxAmp  # normalize to avoid clipping
    sf.write(path, signal, fs, subtype=subtype)

def calc_pad_size(dilation_factors, filter_size):
    return int(np.ceil(sum(dilation_factors) * (filter_size - 1) / 2))

def get_noise(device, shape):
    return torch.randn(shape, device=device)

def draw_signal(generators_list, signals_lengths_list, fs_list, noise_amp_list, filter_size, dilation_factors, device, reconstruction_noise_list=None,
                condition=None, output_all_scales=False):
    # Draws a signal up to current scale, using learned generators
    pad_size = calc_pad_size(dilation_factors, filter_size)
    if output_all_scales:
        signals_all_scales = []
    for scale_idx, (netG, noise_amp) in enumerate(zip(generators_list, noise_amp_list)):
        signal_padder = nn.ConstantPad1d(pad_size, 0)
        if condition is None:
            n_samples = signals_lengths_list[scale_idx]
            if reconstruction_noise_list is not None:
                noise_signal = reconstruction_noise_list[scale_idx]
            else:
                noise_signal = get_noise(device, (1, 1, n_samples))
                noise_signal = noise_signal * noise_amp

            if scale_idx == 0:
                prev_sig = torch.full(noise_signal.shape, 0, device=device, dtype=noise_signal.dtype)
            else:
                prev_sig = signal_padder(prev_sig)

            # pad noise with zeros, to match signal after filtering
            if reconstruction_noise_list is None:
                # reconstruction_noise is already padded
                noise_signal = signal_padder(noise_signal)
                if scale_idx == 0:
                    prev_sig = signal_padder(prev_sig)
        else:
            if scale_idx < condition["condition_scale_idx"]:
                continue
            elif scale_idx == condition["condition_scale_idx"]:
                prev_sig = resample_sig(device, condition["condition_signal"], condition['condition_fs'],
                                        fs_list[scale_idx]).expand(1, 1, -1)
            noise_signal = get_noise(device, prev_sig.shape[2]).expand(1, 1, -1)
            noise_signal = signal_padder(noise_signal)
            noise_signal = noise_signal * noise_amp
            prev_sig = signal_padder(prev_sig)

        # Generate this scale signal
        cur_sig = netG((noise_signal + prev_sig).detach(), prev_sig)

        if output_all_scales:
            signals_all_scales.append(torch.squeeze(cur_sig).detach().cpu().numpy())

        # Upsample for next scale
        if scale_idx < len(fs_list) - 1:
            up_sig = resample_sig( device, cur_sig, orig_fs=fs_list[scale_idx], target_fs=fs_list[scale_idx + 1])
            if up_sig.shape[2] > signals_lengths_list[scale_idx + 1]:
                assert abs(
                    up_sig.shape[2] > signals_lengths_list[scale_idx + 1]) < 20, 'Should not happen, check this!'
                up_sig = up_sig[:, :, :signals_lengths_list[scale_idx + 1]]
            elif up_sig.shape[2] < signals_lengths_list[scale_idx + 1]:
                assert abs(
                    up_sig.shape[2] < signals_lengths_list[scale_idx + 1]) < 20, 'Should not happen, check this!'
                up_sig = torch.cat(
                    (up_sig, up_sig.new_zeros(1, 1, signals_lengths_list[scale_idx + 1] - up_sig.shape[2])),
                    dim=2)
        else:
            up_sig = cur_sig
        prev_sig = up_sig
        prev_sig = prev_sig.detach()

        del up_sig, cur_sig, noise_signal, netG

    if output_all_scales:
        return signals_all_scales
    else:
        return prev_sig

def resample_sig(device,input_signal, orig_fs=None, target_fs=None, resamplers=None):
    if resamplers == None:
        resamplers = {}
    if (orig_fs, target_fs) in resamplers.keys() and resamplers[(orig_fs, target_fs)].in_shape[2] == \
            input_signal.shape[2]:
        resampler = resamplers[(orig_fs, target_fs)]
    else:
        in_shape = input_signal.shape
        scale_factors = (1, 1, target_fs / orig_fs)
        resampler = ResizeLayer(in_shape, scale_factors=scale_factors, device=device)
        resamplers[(orig_fs, target_fs)] = resampler
    new_sig = resampler(input_signal)

    return new_sig

def support_sz(sz):
    def wrapper(f):
        f.support_sz = sz
        return f
    return wrapper

@support_sz(4)
def cubic(x):
    fw, to_dtype, eps = set_framework_dependencies(x)
    absx = fw.abs(x)
    absx2 = absx ** 2
    absx3 = absx ** 3
    return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +
            (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) *
            to_dtype((1. < absx) & (absx <= 2.)))

class ResizeLayer(nn.Module):
    def __init__(self, in_shape, scale_factors=None, out_shape=None,
                 interp_method=cubic, support_sz=None,
                 antialiasing=True, device=None):
        super(ResizeLayer, self).__init__()

        # fw stands for framework, that can be either numpy or torch. since
        # this is a torch layer, only one option in this case.
        fw = torch
        eps = fw.finfo(fw.float32).eps

        # set missing scale factors or output shapem one according to another,
        # scream if both missing
        scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape,
                                                        scale_factors, fw)
        
        # unless support size is specified by the user, it is an attribute
        # of the interpolation method
        if support_sz is None:
            support_sz = interp_method.support_sz
        
        self.n_dims = len(in_shape)       

        # sort indices of dimensions according to scale of each dimension.
        # since we are going dim by dim this is efficient
        self.sorted_filtered_dims_and_scales = [(dim, scale_factors[dim])
                                                for dim in
                                                sorted(range(self.n_dims),
                                                key=lambda ind:
                                                scale_factors[ind])
                                                if scale_factors[dim] != 1.]

        # iterate over dims
        field_of_view_list = []
        weights_list = []
        for dim, scale_factor in self.sorted_filtered_dims_and_scales:

            # get 1d set of weights and fields of view for each output
            # location along this dim
            field_of_view, weights = prepare_weights_and_field_of_view_1d(
                dim, scale_factor, in_shape[dim], out_shape[dim],
                interp_method, support_sz, antialiasing, fw, eps, device)

            # keep weights and fields of views for all dims
            weights_list.append(nn.Parameter(weights, requires_grad=False))
            field_of_view_list.append(nn.Parameter(field_of_view,
                                      requires_grad=False))

        self.field_of_view = nn.ParameterList(field_of_view_list)
        self.weights = nn.ParameterList(weights_list)
        self.in_shape = in_shape

    def forward(self, input):
        # output begins identical to input and changes with each iteration
        output = input

        for (dim, scale_factor), field_of_view, weights in zip(
                self.sorted_filtered_dims_and_scales,
                self.field_of_view,
                self.weights):
            # multiply the weights by the values in the field of view and
            # aggreagate
            output = apply_weights(output, field_of_view, weights, dim,
                                   self.n_dims, torch)
        return output

def prepare_weights_and_field_of_view_1d(dim, scale_factor, in_sz, out_sz,
                                         interp_method, support_sz, 
                                         antialiasing, fw, eps, device=None):
    # If antialiasing is taking place, we modify the window size and the
    # interpolation method (see inside function)
    interp_method, cur_support_sz = apply_antialiasing_if_needed(
                                                             interp_method,
                                                             support_sz,
                                                             scale_factor,
                                                             antialiasing)

    # STEP 1- PROJECTED GRID: The non-integer locations of the projection of
    # output pixel locations to the input tensor
    projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw, device)

    # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels
    # that influence it
    field_of_view = get_field_of_view(projected_grid, cur_support_sz, in_sz,
                                      fw, eps)

    # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in the
    # field of view for each output pixel
    weights = get_weights(interp_method, projected_grid, field_of_view)

    return field_of_view, weights

def apply_weights(input, field_of_view, weights, dim, n_dims, fw):
    # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying
    # its set of weights with the pixel values in its field of view.
    # We now multiply the fields of view with their matching weights.
    # We do this by tensor multiplication and broadcasting.
    # this step is separated to a different function, so that it can be
    # repeated with the same calculated weights and fields.

    # for this operations we assume the resized dim is the first one.
    # so we transpose and will transpose back after multiplying
    tmp_input = fw_swapaxes(input, dim, 0, fw)

    # field_of_view is a tensor of order 2: for each output (1d location
    # along cur dim)- a list of 1d neighbors locations.
    # note that this whole operations is applied to each dim separately,
    # this is why it is all in 1d.
    # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1:
    # for each output pixel (this time indicated in all dims), these are the
    # values of the neighbors in the 1d field of view. note that we only
    # consider neighbors along the current dim, but such set exists for every
    # multi-dim location, hence the final tensor order is image_dims+1.
    neighbors = tmp_input[field_of_view]

    # weights is an order 2 tensor: for each output location along 1d- a list
    # of weighs matching the field of view. we augment it with ones, for
    # broadcasting, so that when multiplies some tensor the weights affect
    # only its first dim.
    tmp_weights = fw.reshape(weights, (*weights.shape, * [1] * (n_dims - 1)))

    # now we simply multiply the weights with the neighbors, and then sum
    # along the field of view, to get a single value per out pixel
    tmp_output = (neighbors * tmp_weights).sum(1)

    # we transpose back the resized dim to its original position
    return fw_swapaxes(tmp_output, 0, dim, fw)

def get_weights(interp_method, projected_grid, field_of_view):
    # the set of weights per each output pixels is the result of the chosen
    # interpolation method applied to the distances between projected grid
    # locations and the pixel-centers in the field of view (distances are
    # directed, can be positive or negative)
    weights = interp_method(projected_grid[:, None] - field_of_view)

    # we now carefully normalize the weights to sum to 1 per each output pixel
    sum_weights = weights.sum(1, keepdims=True)
    sum_weights[sum_weights == 0] = 1
    return weights / sum_weights

def fw_ceil(x, fw):
    return x.ceil().long()


def fw_cat(x, fw):
    return fw.cat(x)


def fw_swapaxes(x, ax_1, ax_2, fw):
    return x.transpose(ax_1, ax_2)
    
def fw_set_device(x, device, fw):
    return x.to(device)

def set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw):
    # eventually we must have both scale-factors and out-sizes for all in/out
    # dims. however, we support many possible partial arguments
    if scale_factors is None and out_shape is None:
        raise ValueError("either scale_factors or out_shape should be "
                         "provided")
    if out_shape is not None:
        # if out_shape has less dims than in_shape, we defaultly resize the
        # first dims for numpy and last dims for torch
        out_shape = list(out_shape) + list(in_shape[:-len(out_shape)])
        if scale_factors is None:
            # if no scale given, we calculate it as the out to in ratio
            # (not recomended)
            scale_factors = [out_sz / in_sz for out_sz, in_sz
                             in zip(out_shape, in_shape)]
    if scale_factors is not None:
        # by default, if a single number is given as scale, we assume resizing
        # two dims (most common are images with 2 spatial dims)
        scale_factors = (scale_factors
                         if isinstance(scale_factors, (list, tuple))
                         else [scale_factors, scale_factors])
        # if less scale_factors than in_shape dims, we defaultly resize the
        # first dims for numpy and last dims for torch
        scale_factors = list(scale_factors) + [1] * (len(in_shape) - len(scale_factors)) 
        if out_shape is None:
            # when no out_shape given, it is calculated by multiplying the
            # scale by the in_shape (not recomended)
            out_shape = [ceil(scale_factor * in_sz)
                         for scale_factor, in_sz in
                         zip(scale_factors, in_shape)]
        # next line intentionally after out_shape determined for stability
        scale_factors = [float(sf) for sf in scale_factors]
    return scale_factors, out_shape

def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor,
                                 antialiasing):
    # antialiasing is "stretching" the field of view according to the scale
    # factor (only for downscaling). this is low-pass filtering. this
    # requires modifying both the interpolation (stretching the 1d
    # function and multiplying by the scale-factor) and the window size.
    if scale_factor >= 1.0 or not antialiasing:
        return interp_method, support_sz
    cur_interp_method = (lambda arg: scale_factor *
                         interp_method(scale_factor * arg))
    cur_support_sz = support_sz / scale_factor
    return cur_interp_method, cur_support_sz

def get_projected_grid(in_sz, out_sz, scale_factor, fw, device=None):
    # we start by having the ouput coordinates which are just integer locations
    out_coordinates = fw.arange(out_sz)
    
    # if using torch we need to match the grid tensor device to the input device
    out_coordinates = fw_set_device(out_coordinates, device, fw)
        
    # This is projecting the ouput pixel locations in 1d to the input tensor,
    # as non-integer locations.
    # the following fomrula is derived in the paper
    # "From Discrete to Continuous Convolutions" by Shocher et al.
    return (out_coordinates / scale_factor +
            (in_sz - 1) / 2 - (out_sz - 1) / (2 * scale_factor))


def get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps):
    # for each output pixel, map which input pixels influence it, in 1d.
    # we start by calculating the leftmost neighbor, using half of the window
    # size (eps is for when boundary is exact int)
    left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw)

    # then we simply take all the pixel centers in the field by counting
    # window size pixels from the left boundary
    ordinal_numbers = fw.arange(ceil(cur_support_sz - eps))
    # in case using torch we need to match the device
    ordinal_numbers = fw_set_device(ordinal_numbers, projected_grid.device, fw)
    field_of_view = left_boundaries[:, None] + ordinal_numbers

    # next we do a trick instead of padding, we map the field of view so that
    # it would be like mirror padding, without actually padding
    # (which would require enlarging the input tensor)
    mirror = fw_cat((fw.arange(in_sz), fw.arange(in_sz - 1, -1, step=-1)), fw)
    field_of_view = mirror[fw.remainder(field_of_view, mirror.shape[0])]
    field_of_view = fw_set_device(field_of_view,projected_grid.device, fw)
    return field_of_view

def set_framework_dependencies(x):
    if type(x) is np.ndarray:
        to_dtype = lambda a: a
        fw = np
    else:
        to_dtype = lambda a: a.to(x.dtype)
        fw = torch
    eps = fw.finfo(fw.float32).eps
    return fw, to_dtype, eps

def calc_gradient_penalty(run_mode, current_holes, netD, real_data, fake_data, LAMBDA, alpha=None, _grad_outputs=None, mask_ratio=None, not_valid_idx_start=None, not_valid_idx_end=None):
    # Gradient penalty method for WGAN
    if alpha is None:
        alpha = torch.rand(1, 1)
        alpha = alpha.expand(real_data.size())
        if torch.cuda.is_available():
            alpha = alpha.cuda(real_data.get_device())  # gpu) #if use_cuda else alpha
    interpolates = alpha * real_data + ((1 - alpha) * fake_data)
    interpolates = torch.autograd.Variable(interpolates, requires_grad=True)
    use_mask = False
    mask_ratio = 1
    disc_interpolates = netD(interpolates, use_mask)
    if _grad_outputs is None:
        _grad_outputs = torch.ones(disc_interpolates.size())
        if torch.cuda.is_available():
            _grad_outputs = _grad_outputs.cuda(real_data.get_device())
    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                    grad_outputs=_grad_outputs,
                                    create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradient_penalty = ((mask_ratio * gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    del gradients, interpolates, _grad_outputs, disc_interpolates
    return gradient_penalty

def stft(sig, n_fft, hop_length, window_size):
    s = torch.stft(sig, n_fft, hop_length, win_length=window_size,
                   window=torch.hann_window(window_size, device=sig.device), return_complex=False)
    return s

def spec(x, n_fft, hop_length, window_size):
    s = stft(x, n_fft, hop_length, window_size)
    n = torch.norm(s, p=2, dim=-1)
    return n

def norm(x):
    return (x.view(x.shape[0], -1) ** 2).sum(dim=-1).sqrt()


def squeeze(x):
    if len(x.shape) == 3:
        assert x.shape[-1] in [1, 2]
        x = torch.mean(x, -1)
    if len(x.shape) != 2:
        raise ValueError(f'Unknown input shape {x.shape}')
    return x

def multi_scale_spectrogram_loss(multispec_loss_n_fft, multispec_loss_hop_length, multispec_loss_window_size, current_holes, x_in, x_out):
    losses = []
    args = [multispec_loss_n_fft,
            multispec_loss_hop_length,
            multispec_loss_window_size]
    for n_fft, hop_length, window_size in zip(*args):
        if window_size == -1:
            window_size = x_in.shape[1]
            hop_length = window_size + 1
            n_fft = int(2 ** np.ceil(np.log2(window_size)))
        spec_in = spec(squeeze(x_in.float()), n_fft, hop_length, window_size)
        spec_out = spec(squeeze(x_out.float()), n_fft, hop_length, window_size)
        losses.append(norm(spec_in - spec_out))
    return sum(losses) / len(losses)

def reset_grads(model, require_grad):
    for p in model.parameters():
        p.requires_grad_(require_grad)
    return model

In [8]:
#model 
class Generator(nn.Module):
    def __init__(self, filter_size, hidden_channels, current_fs ):
        super(Generator, self).__init__()
        self.head = ConvBlock(filter_size, 1, hidden_channels, dilation_factors[0])
        self.body = nn.Sequential()
        self.Fs = current_fs
        for i in range(num_layers - 2):
            block = ConvBlock(filter_size, hidden_channels, hidden_channels, dilation_factors[i + 1])
            self.body.add_module('block%d' % (i + 1), block)
        self.tail = nn.Sequential()
        self.tail.add_module('tail0',
                             NormConv1d(in_channels=hidden_channels, out_channels=hidden_channels,
                                        kernel_size=filter_size,
                                        dilation=dilation_factors[-1]))
        self.filter = nn.Sequential(
            NormConv1d(in_channels=hidden_channels, out_channels=hidden_channels,
                       kernel_size=filter_size, padding=int((filter_size - 1) / 2)),
            nn.Tanh()
        )
        self.gate = nn.Sequential(
            NormConv1d(in_channels=hidden_channels, out_channels=hidden_channels,
                       kernel_size=filter_size, padding=int((filter_size - 1) / 2)),
            nn.Sigmoid()
        )
        self.out_conv = NormConv1d(hidden_channels, 1, kernel_size=1)
        self.pe_filter = PreEmphasisFilter(device)

    def forward(self, noise_plus_sig, prev_sig):
        out_head = self.head(noise_plus_sig)
        out_body = self.body(out_head)
        out_tail = self.tail(out_body)
        filter = self.filter(out_tail)
        gate = self.gate(out_tail)
        out_tail = filter * gate
        out_tail = self.out_conv(out_tail)
        out_filt = self.pe_filter(out_tail)
        ind = int((prev_sig.shape[2] - out_filt.shape[2]) / 2)
        prev_sig = prev_sig[:, :, ind:(prev_sig.shape[2] - ind)]
        output = out_filt + prev_sig
        return output


class Discriminator(nn.Module):
    def __init__(self, run_mode, current_holes, hidden_channels, dilation_factors, num_layers, device,filter_size ):
        super(Discriminator, self).__init__()
        if run_mode == 'inpainting':
            mask = current_holes
        else:
            mask = None
        self.head = ConvBlock(filter_size, 1, hidden_channels, dilation_factors[0], mask=mask)
        mask = self.head.mask_out
        self.body = nn.ModuleList()
        for i in range(num_layers - 2):
            block = ConvBlock(filter_size, hidden_channels, hidden_channels,
                              dilation_factors[i + 1], mask=mask)
            mask = block.mask_out
            self.body.add_module('block%d' % (i + 1), block)
        self.mask_out = mask
        self.tail = NormConv1d(hidden_channels, 1, kernel_size=filter_size,
                               dilation=dilation_factors[-1])
        self.pe_filter = PreEmphasisFilter(device)

    def forward(self, sig, use_mask=False):
        out_head = self.head(sig, use_mask)
        out_body = out_head
        for b in self.body:
            out_body = b(out_body, use_mask)
        out_tail = self.tail(out_body)
        output = self.pe_filter(out_tail)
        return output


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 and classname.find('ConvBlock') == -1 and hasattr(m, 'weight'):
        if m.weight.numel() > 1 and m.weight.requires_grad:  # scalar blocks are initiailized upon creation
            m.weight.data.normal_(0.0, 0.02)

    elif classname.find('Norm') != -1 and hasattr(m, 'weight'):
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class PreEmphasisFilter(nn.Module):
    def __init__(self, device):
        super(PreEmphasisFilter, self).__init__()
        self.alpha = torch.Tensor([0.97]).to(device)
        self.alpha.requires_grad = False

    def forward(self, x):
        output = torch.cat((x[:, :, 0].view(x.shape[0], x.shape[1], 1), x[:, :, 1:] - self.alpha * x[:, :, :-1]), dim=2)
        return output


class NormConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True):
        super(NormConv1d, self).__init__()
        self.conv = weight_norm(nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                          stride=stride, padding=padding, dilation=dilation, bias=bias))

    def forward(self, x):
        output = self.conv(x)
        return output


class ConvBlock(nn.Sequential):
    def __init__(self, filter_size, in_channels, out_channels, dilation=1, mask=None):
        super(ConvBlock, self).__init__()
        if filter_size is None:
            filter_size = filter_size
        if mask is not None:
            self.mask_in = mask
            self.mask_out = []
            self.rf = int((filter_size - 1) * dilation)
            for hole in self.mask_in:
                self.mask_out.append([hole[0] - self.rf, hole[1]])
            # ???
            # for idx in range(len(self.mask_out) - 1):
            #     if self.mask_out[idx+1][0] < self.mask_out[idx][1]:
            #         self.mask_out[idx+1][0] = self.mask_out[idx][1] + 1

        else:
            self.mask_out = None
        self.conv = NormConv1d(in_channels, out_channels, filter_size, dilation=dilation)
        self.norm = nn.BatchNorm1d(out_channels)
        self.activation = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x, use_mask=False):
        out_conv = self.conv(x)
        if use_mask:
            #tmp = torch.cat((out_conv[:, :, :int(self.mask_out[0][0])], out_conv[:, :, int(self.mask_out[0][1] + 1):]), dim=2)
            tmp = out_conv[:, :, :int(self.mask_out[0][0])].clone()
            cut_idx = []
            cut_idx.append(tmp.shape[2])
            for idx in range(len(self.mask_out)-1):
                tmp = torch.cat((tmp, out_conv[:, :, int(self.mask_out[idx][1] + 1):int(self.mask_out[idx+1][0])]), dim=2)
                cut_idx.append(tmp.shape[2])
            tmp = torch.cat((tmp, out_conv[:, :, int(self.mask_out[-1][1] + 1):]), dim=2)

            tmp_norm = self.norm(tmp)
            out_norm = out_conv
            out_norm[:, :, :int(self.mask_out[0][0])] = tmp_norm[:, :, :int(cut_idx[0])]
            for idx in range(len(self.mask_out) - 1):
                out_norm[:, :, int(self.mask_out[idx][1] + 1):int(self.mask_out[idx+1][0])] = tmp_norm[:, :, int(cut_idx[idx]):int(cut_idx[idx+1])] #tmp_norm[:, :, int(self.mask_out[idx][0]):int(self.mask_out[idx+1][0])]
                #out_norm[:, :, :int(self.mask_out[idx+1][0])] = tmp_norm[:, :, :int(self.mask_out[idx+1][0])]
            out_norm[:, :, int(self.mask_out[-1][1] + 1):] = tmp_norm[:, :, int(cut_idx[-1]):]

        else:
            out_norm = self.norm(out_conv)
        return self.activation(out_norm)

In [9]:
#training functions
def train(manual_random_seed, fs_list, scales, growing_hidden_channels_factor,learning_rate, beta1, scheduler_lr_decay, plot_losses,
          initial_noise_amp, noise_amp_factor, signals_list, dilation_factors, output_folder, inputs_lengths):
    if manual_random_seed != -1:
        random.seed(manual_random_seed)
        torch.manual_seed(manual_random_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    fs_list = fs_list
    n_scales = len(scales)
    generators_list = []
    noise_amp_list = []
    if run_mode == 'inpainting':
        energy_list = [(sig[mask] ** 2).mean().item() for sig, mask in zip(signals_list, masks)]
    else:
        energy_list = [(sig ** 2).mean().item() for sig in signals_list]
    reconstruction_noise_list = []
    output_signals = []
    loss_vectors = []

    for scale_idx in range(n_scales):
        output_signals_single_scale, loss_vectors_single_scale, netG, reconstruction_noise_list, noise_amp = train_single_scale(
                      scales, device, run_mode, hidden_channels_init, growing_hidden_channels_factor,  learning_rate, beta1, 
                      scheduler_lr_decay, plot_losses, initial_noise_amp, noise_amp_factor, signals_list, fs_list, 
                      generators_list, noise_amp_list, energy_list, reconstruction_noise_list, dilation_factors, output_folder, inputs_lengths)

        # Write fake sound
        fake_sound = output_signals_single_scale['fake_signal'].squeeze()
        filename = 'fake@%dHz.wav' % fs_list[scale_idx]
        write_signal(os.path.join(output_folder, filename), fake_sound,
                     fs_list[scale_idx], overwrite=False)

        # Write reconstructed sound
        reconstructed_sound = output_signals_single_scale['reconstructed_signal'].squeeze()
        filename = 'reconstructed@%dHz.wav' % fs_list[scale_idx]
        write_signal(os.path.join(output_folder, filename),
                     reconstructed_sound, fs_list[scale_idx], overwrite=False)
        torch.save(reconstruction_noise_list,
                   os.path.join(output_folder, 'reconstruction_noise_list.pt'))

        generators_list.append(netG)
        noise_amp_list.append(noise_amp)
        output_signals.append(output_signals_single_scale)
        loss_vectors.append(loss_vectors_single_scale)

    return output_signals, loss_vectors, generators_list, noise_amp_list, energy_list, reconstruction_noise_list


def train_single_scale(scales, device, run_mode, hidden_channels_init, growing_hidden_channels_factor,
                       learning_rate, beta1, scheduler_lr_decay, plot_losses, initial_noise_amp, noise_amp_factor, signals_list,
                        fs_list, generators_list, noise_amp_list, energy_list, reconstruction_noise_list, dilation_factors, output_folder, inputs_lengths):
    # Terminology: 0 is the higher scale (original signal, no downsampling). Higher scale means larger downsampling, e.g shorter signals
    n_scales = len(scales)
    current_scale = n_scales - len(generators_list) - 1
    scale_idx = n_scales - current_scale - 1
    input_signal = signals_list[scale_idx].to(device)
    current_fs = fs_list[scale_idx]
    N = len(input_signal)

    if run_mode == 'inpainting':
        current_mask = masks[scale_idx]
        current_mask = current_mask
        current_holes = torch.Tensor([(int(idx[0] / Fs * current_fs), int(idx[1] / Fs * current_fs)) for idx in inpainting_indices]).to(device)
    else:
        current_holes = None

    # Create inputs
    real_signal = input_signal.reshape(1, 1, N)

    hidden_channels = hidden_channels_init if scale_idx == 0 else int(
        hidden_channels_init * growing_hidden_channels_factor)

    scale_num = n_scales - scale_idx - 1
    pad_size = calc_pad_size(dilation_factors, filter_size)
    signal_padder = nn.ConstantPad1d(pad_size, 0)

    # Initialize models
    netD = Discriminator(run_mode, current_holes, hidden_channels, dilation_factors, num_layers, device, filter_size).to(device)
    netD.apply(weights_init)
    netG = Generator(filter_size, hidden_channels, current_fs).to(device)
    netG.apply(weights_init)
    receptive_field = calc_receptive_field(filter_size, dilation_factors, current_fs)
    receptive_field_percent = 100 * receptive_field / 1e3 / (N / current_fs)
    print('Signal in scale %d has %d samples, sample rate is %d[Hz].' % (
        scale_num, N, current_fs))
    print('Total receptive field is %d[msec] (%.1f%% of input).' % (receptive_field, receptive_field_percent))
    with open(os.path.join(output_folder, 'log.txt'), 'a') as f:
        f.write('*' * 30 + ' Scale ' + str(scale_num) + ' (' + str(current_fs) + ' [Hz]) ' + '*' * 30)
        f.write('\nreceptive_field = %d[msec] (%.1f%% of input)' % (receptive_field, receptive_field_percent))
        f.write('\nsignal_energy = %.4f' % energy_list[scale_idx])

    if scale_idx == 0:
        reconstruction_noise = get_noise(device, real_signal.shape)
    else:
        reconstruction_noise = torch.zeros(real_signal.shape, device=device)
        if run_mode == 'inpainting':
            reconstruction_noise[:, :, torch.logical_not(current_mask)] = get_noise(device, torch.nonzero(
                torch.logical_not(current_mask)).shape[0]).expand(1, 1, -1).to(device)

    reconstruction_noise = signal_padder(reconstruction_noise)

    if scale_idx > 1:
        netG.load_state_dict(
            torch.load('%s/netGScale%d.pth' % (output_folder, scale_idx - 1), map_location=device))
        netD.load_state_dict(
            torch.load('%s/netDScale%d.pth' % (output_folder, scale_idx - 1), map_location=device))

    output_folder = output_folder

    # Create optimizers
    optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=scheduler_milestones,
                                                      gamma=scheduler_lr_decay)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=scheduler_milestones,
                                                      gamma=scheduler_lr_decay)

    # Initialize error vectors
    v_err_real = np.zeros(num_epochs, )
    v_err_fake = np.zeros(num_epochs, )
    v_gp = np.zeros(num_epochs, )
    v_rec_loss = np.zeros(num_epochs, )

    epochs_start_time = time.time()
    # prepare inputs for gradient penalty
    if not run_mode == 'inpainting':
        D_out_shape = torch.Size((1, 1, N - 2 * pad_size))
        _grad_outputs = torch.ones(D_out_shape, device=device)
    grad_pen_alpha_vec = torch.rand(num_epochs).to(device)

    inputs_lengths = inputs_lengths
    for epoch_num in range(num_epochs):
        print_progress = epoch_num % 100 == 0
        # Create noise
        noise_signal = get_noise(device, real_signal.shape)
        noise_signal = signal_padder(noise_signal)
        #################################################################
        # Optimize D by maximizing D(realSignal)+(1-D(G(noise_signal))) #
        #################################################################
        netD.zero_grad()
        # Run on real signal
        not_valid_idx_start = []
        not_valid_idx_end = []
        if run_mode == 'inpainting':
            out_D_real = netD(real_signal, use_mask=True)
            tot_samples = out_D_real.shape[2]
            not_valid_idx_start = [int(idx[0] - receptive_field / 1e3 * current_fs + 1) for idx in current_holes]
            not_valid_idx_end = [int(idx[1] + 1) for idx in current_holes]  # +1 is because of pe filter
            out_D_real_cp = out_D_real.clone()
            out_D_real = out_D_real_cp[:, :, :not_valid_idx_start[0]]
            if len(current_holes) > 1:
                for i in range(len(current_holes) - 1):
                    out_D_real = torch.cat((out_D_real, out_D_real_cp[:, :, not_valid_idx_end[i] + 1:not_valid_idx_start[i+1]]), dim=2)
            out_D_real = torch.cat((out_D_real, out_D_real_cp[:, :, not_valid_idx_end[-1] + 1:]), dim=2)
            mask_ratio = tot_samples / out_D_real.shape[2]
        else:
            mask_ratio = 1
            out_D_real = netD(real_signal)
        err_real_D = -out_D_real.mean()
        err_real_D.backward(retain_graph=True)
        err_real_D = err_real_D.detach()
        if print_progress or plot_losses:
            err_real_D_val = err_real_D.item()

        if epoch_num == 0:
            if run_mode == 'inpainting':
                D_out_shape = out_D_real.shape
                _grad_outputs = torch.ones(D_out_shape, device=device)
            if scale_idx == 0:  # We are at coarsest scale
                prev_signal = torch.full(noise_signal.shape, 0, device=device, dtype=noise_signal.dtype)
                prev_reconstructed_signal = torch.zeros(reconstruction_noise.shape, device=device)
                noise_amp = initial_noise_amp
            else:
                prev_signal = draw_signal(generators_list, inputs_lengths, fs_list, noise_amp_list, filter_size, dilation_factors, device)
                prev_signal = signal_padder(prev_signal)
                prev_reconstructed_signal = draw_signal(generators_list, inputs_lengths,
                                                        fs_list,
                                                        noise_amp_list, filter_size, dilation_factors, device,
                                                        reconstruction_noise_list)
                prev_reconstructed_signal = signal_padder(prev_reconstructed_signal)
                innovation = energy_list[scale_idx] - energy_list[scale_idx - 1]
                energy_diff = torch.sqrt(torch.Tensor([innovation])).to(device)
                noise_amp = noise_amp_factor * max(torch.Tensor([0]).to(device),
                                                          energy_diff)

            if scale_idx == 1 and add_cond_noise:
                noise_amp = prev_reconstructed_signal.std()

            with open(os.path.join(output_folder, 'log.txt'), 'a') as f:
                f.write('\nnoise_amp: %.6f' % noise_amp)

            reconstruction_noise = reconstruction_noise * noise_amp
            reconstruction_noise_list.append(reconstruction_noise)
        else:
            if scale_idx > 0:
                prev_signal = draw_signal(generators_list, inputs_lengths, fs_list, noise_amp_list, filter_size, dilation_factors, device)
                prev_signal = signal_padder(prev_signal)

        input_noise = noise_signal * noise_amp

        # Run on fake signal
        fake_signal = netG((input_noise + prev_signal).detach(), prev_signal)
        out_D_fake = netD(fake_signal.detach())
        err_fake_D = out_D_fake.mean()
        del out_D_real, out_D_fake
        err_fake_D.backward(retain_graph=True)
        err_fake_D = err_fake_D.detach()
        if print_progress or plot_losses:
            err_fake_D_val = err_fake_D.item()

        lambda_grad=0.01
        gradient_penalty = calc_gradient_penalty(run_mode, current_holes, netD, real_signal, fake_signal, lambda_grad,
                                                 grad_pen_alpha_vec[epoch_num], _grad_outputs, mask_ratio)
        gradient_penalty.backward()
        if print_progress or plot_losses:
            gradient_penalty_val = gradient_penalty.item()
        del gradient_penalty

        optimizerD.step()

        if plot_losses:
            v_err_real[epoch_num] = err_real_D_val
            v_err_fake[epoch_num] = err_fake_D_val
            v_gp[epoch_num] = gradient_penalty_val

        #############################################
        # Update G by maximizing D(G(noise_signal)) #
        #############################################
        netG.zero_grad()
        output = netD(fake_signal)
        errG = -output.mean()
        del output
        errG.backward(retain_graph=True)
        errG = errG.detach()
        if print_progress or plot_losses:
            errG_val = errG.item()
        if scale_idx == 0:
            reconstructed_signal = netG((reconstruction_noise + prev_reconstructed_signal).detach(),
                                        prev_reconstructed_signal)
        else:
            reconstructed_signal = netG((reconstruction_noise + prev_reconstructed_signal).detach(),
                                        prev_reconstructed_signal)
        if alpha1 > 0:
            if run_mode == 'inpainting':
                rec_loss_t = alpha1 * torch.mean(
                    (real_signal[:, :, current_mask] - reconstructed_signal[:, :, current_mask]) ** 2)
            else:
                rec_loss_t = alpha1 * torch.mean((real_signal - reconstructed_signal) ** 2)
        else:
            rec_loss_t = 0
        if alpha2 > 0:
            multispec_loss_n_fft = (2048, 1024, 512)
            multispec_loss_hop_length = (240, 120, 50)
            multispec_loss_window_size = (1200, 600, 240)
            rec_loss_f = alpha2 * multi_scale_spectrogram_loss(multispec_loss_n_fft, multispec_loss_hop_length, multispec_loss_window_size,
                                                               current_holes, real_signal.permute(0, 2, 1),reconstructed_signal.permute(0, 2, 1))
        else:
            rec_loss_f = 0
        rec_loss = rec_loss_t + rec_loss_f
        rec_loss.backward(retain_graph=True)
        rec_loss = rec_loss.detach()
        if alpha1 > 0:
            rec_loss_t = rec_loss_t.detach()
        if alpha2 > 0:
            rec_loss_f = rec_loss_f.detach()
        if print_progress or plot_losses:
            rec_loss_val = rec_loss.item()

        optimizerG.step()

        if plot_losses:
            v_rec_loss[epoch_num] = rec_loss_val

        if print_progress:
            print('[%d/%d] D(real): %.2f. D(fake): %.2f. rec_loss: %.4f. gp: %.4f ' % (
                epoch_num, num_epochs, -err_real_D_val, err_fake_D_val, rec_loss_val, gradient_penalty_val))

        schedulerD.step()
        schedulerG.step()

        # Some memory cleanup
        fake_signal = fake_signal.detach()
        reconstructed_signal = reconstructed_signal.detach()
        if epoch_num < num_epochs - 1:
            del fake_signal, reconstructed_signal, rec_loss, rec_loss_t, rec_loss_f
        del noise_signal, input_noise
        if scale_idx > 0:
            del prev_signal

    epochs_stop_time = time.time()
    runtime_msg = 'Total time in scale %d: %d[sec] (%.2f[sec]/epoch on avg.). D(real): %f, D(fake): %f, rec_loss: %.4f. gp: %.4f' % (
        current_scale, epochs_stop_time - epochs_start_time,
        (epochs_stop_time - epochs_start_time) / num_epochs,
        -err_real_D_val, err_fake_D_val, rec_loss_val, gradient_penalty_val)
    print(runtime_msg)
    with open(os.path.join(output_folder, 'log.txt'), 'a') as f:
        f.write('\n%s\n' % runtime_msg)

    # Save this scale models
    torch.save(netG.state_dict(), '%s/netGScale%d.pth' % (output_folder, scale_idx))
    torch.save(netD.state_dict(), '%s/netDScale%d.pth' % (output_folder, scale_idx))
    # Pack outputs
    if plot_losses:
        loss_vectors = {'v_err_real': v_err_real,
                        'v_err_fake': v_err_fake,
                        'v_rec_loss': v_rec_loss,
                        'v_gp': v_gp}
    else:
        loss_vectors = []
    fake_signal = fake_signal.detach().cpu().numpy()[:, 0, :]
    reconstructed_signal = reconstructed_signal.detach().cpu().numpy()[:, 0, :]
    output_signals = {'fake_signal': fake_signal, 'reconstructed_signal': reconstructed_signal}
    del fake_signal, real_signal, netD, _grad_outputs, grad_pen_alpha_vec, input_signal, reconstructed_signal, prev_reconstructed_signal, reconstruction_noise
    netG = reset_grads(netG, False)
    netG.eval()
    if is_cuda:
        torch.cuda.empty_cache()
    print('*' * 30 + ' Finished working on scale ' + str(current_scale) + ' ' + '*' * 30)
    return output_signals, loss_vectors, netG, reconstruction_noise_list, noise_amp

In [10]:
#training
startTime = time.time()

if len(inpainting_indices)%2 != 0:
    raise Exception('Provide START and END indices of each hole!')

if is_cuda:
    torch.cuda.set_device(gpu_num)
    device = torch.device("cuda:%d" % gpu_num)

if manual_random_seed != -1:
    random.seed(manual_random_seed)
    torch.manual_seed(manual_random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

samples, Fs = get_input_signal(input_file, max_length)

fs_list = [f for f in fs_list if f <= Fs]
if fs_list[-1] != Fs:
    fs_list.append(Fs)

scales = [Fs / f for f in fs_list]

print('Working on file: %s' % input_file)

scheduler_milestones = [int(num_epochs * 2 / 3)]

alpha1 = 0
alpha2 = 1e-4
add_cond_noise = True

dilation_factors = [2 ** i for i in range(num_layers)]

if not os.path.exists(output_folder):
    os.mkdir(output_folder)

if os.path.exists(output_folder):
    dirs = glob.glob(output_folder + '*')
    output_folder = output_folder + '_' + str(len(dirs) + 1)

os.mkdir(output_folder)
print('Writing results to %s\n' % output_folder)

signals_list, fs_list = create_input_signals(scales, set_first_scale_by_energy, min_energy_th,  filter_size, torch.tensor(samples), Fs)
if len(signals_list) == 0:
    set_first_scale_by_energy = False
    scales = scales[2:]  # Manually start from 500
    signals_list, fs_list = create_input_signals(scales, set_first_scale_by_energy, min_energy_th,  filter_size, torch.tensor(samples), Fs)
scales = [Fs / f for f in fs_list]

fs_list = fs_list
inputs_lengths = [len(s) for s in signals_list]

print('Running on ' + str(device))

output_signals, loss_vectors, generators_list, noise_amp_list, energy_list, reconstruction_noise_list = train(
                          manual_random_seed, fs_list, scales, growing_hidden_channels_factor,learning_rate, beta1, scheduler_lr_decay,
                          plot_losses, initial_noise_amp, noise_amp_factor, signals_list, dilation_factors, output_folder, inputs_lengths)

Working on file: S11_ucm_1_mono.wav
Writing results to outputs_2

Running on cuda:0
Signal in scale 15 has 7680 samples, sample rate is 320[Hz].
Total receptive field is 6378[msec] (26.6% of input).
[0/3500] D(real): 0.00. D(fake): 0.00. rec_loss: 0.0278. gp: 0.0083 
[100/3500] D(real): 0.01. D(fake): -0.02. rec_loss: 0.0181. gp: 0.0053 
[200/3500] D(real): 0.00. D(fake): -0.06. rec_loss: 0.0173. gp: 0.0061 
[300/3500] D(real): -0.09. D(fake): -0.09. rec_loss: 0.0181. gp: 0.0087 
[400/3500] D(real): 0.03. D(fake): -0.02. rec_loss: 0.0166. gp: 0.0429 
[500/3500] D(real): -0.00. D(fake): -0.05. rec_loss: 0.0169. gp: 0.0090 
[600/3500] D(real): -0.03. D(fake): -0.09. rec_loss: 0.0164. gp: 0.0306 
[700/3500] D(real): -0.06. D(fake): -0.12. rec_loss: 0.0167. gp: 0.0069 
[800/3500] D(real): -0.11. D(fake): -0.16. rec_loss: 0.0171. gp: 0.0159 
[900/3500] D(real): -0.07. D(fake): -0.18. rec_loss: 0.0163. gp: 0.0130 
[1000/3500] D(real): -0.05. D(fake): -0.18. rec_loss: 0.0174. gp: 0.0077 
[110

In [11]:
#!zip -r outputs.zip outputs_2

In [19]:
class AudioGenerator(object):
    def __init__(self, output_folder, fs_list, dilation_factors, filter_size, device, generators_list=None, noise_amp_list=None, reconstruction_noise_list=None):
        super(AudioGenerator, self).__init__()
        self.generators_list = generators_list
        self.noise_amp_list = noise_amp_list
        self.reconstruction_noise_list = reconstruction_noise_list
        self.output_folder = output_folder
        self.fs_list= fs_list
        self.device = device
        self.dilation_factors = dilation_factors
        self.filter_size = filter_size
        if not os.path.exists(os.path.join(output_folder, 'GeneratedSignals')):
            os.mkdir(os.path.join(output_folder, 'GeneratedSignals'))

    def generate(self, nSignals=1, length=20, generate_all_scales=False):
        for sig_idx in range(nSignals):
            # Draws a signal up to current scale, using learned generators
            output_signals_list = draw_signal(self.generators_list,
                                              [round(f * length) for f in self.fs_list], self.fs_list,
                                              self.noise_amp_list,  self.filter_size, self.dilation_factors, self.device, 
                                              output_all_scales=generate_all_scales)
            # Write signals
            if generate_all_scales:
                for scale_idx, sig in enumerate(output_signals_list):
                    write_signal(
                        os.path.join(self.output_folder, 'GeneratedSignals',
                                     'generated@%dHz.wav' % self.fs_list[scale_idx]),
                        sig, self.fs_list[scale_idx], overwrite=False)
            else:
                write_signal(
                    os.path.join(self.output_folder, 'GeneratedSignals',
                                 'generated@%dHz.wav' % self.fs_list[-1]),
                    output_signals_list, self.fs_list[-1], overwrite=False)

    def condition(self, condition, write=True):
        condition["condition_scale_idx"] = np.where(np.array(self.fs_list) <= condition["condition_fs"])[0][
                                               -1] + 1
        condition["condition_signal"] = torch.Tensor(condition["condition_signal"]).expand(1, 1, -1).to(
            self.device)
        lengths = [int(condition["condition_signal"].shape[2] / condition["condition_fs"] * fs) for fs in
                   self.fs_list]
        conditioned_signal = draw_signal(self.generators_list, lengths, self.fs_list, self.noise_amp_list, 
                                         self.filter_size, self.dilation_factors, self.device,
                                         condition=condition)
        if write:
            output_file = os.path.join(self.output_folder, 'GeneratedSignals',
                                       'conditioned_on_' + condition['name'])
            write_signal(output_file, conditioned_signal, self.params.Fs)
        else:
            return conditioned_signal

In [32]:
nSignals=1
length=25
generate_all_scales=False

audio_generator = AudioGenerator(output_folder, fs_list, dilation_factors, filter_size, device, generators_list, noise_amp_list,
                                 reconstruction_noise_list=reconstruction_noise_list)

audio_generator.generate(nSignals=nSignals, length=length,generate_all_scales=generate_all_scales)

In [30]:
path = "/content/outputs_2/GeneratedSignals"

In [33]:
paths = []
size = 0
for root, dirs, files in os.walk(path):
    for file in files:
        if (file.endswith(".wav") and  (not (file.startswith(".") or file.startswith("noise")))):
             paths.append(os.path.join(root, file))
             size += os.path.getsize(os.path.join(root, file))
             


print(f'We have {len(paths)} .Wav Files with {size/1024**2:.2f} Mb in size')

We have 5 .Wav Files with 3.81 Mb in size


In [34]:
!!zip -r /content/outputs_2/GeneratedSignals.zip /content/outputs_2/GeneratedSignals

['updating: content/outputs_2/GeneratedSignals/ (stored 0%)',
 'updating: content/outputs_2/GeneratedSignals/generated@16000Hz.wav (deflated 14%)',
 'updating: content/outputs_2/GeneratedSignals/generated@16000Hz_1.wav (deflated 16%)',
 '  adding: content/outputs_2/GeneratedSignals/generated@16000Hz_2.wav (deflated 13%)',
 '  adding: content/outputs_2/GeneratedSignals/generated@16000Hz_3.wav (deflated 15%)',
 '  adding: content/outputs_2/GeneratedSignals/generated@16000Hz_4.wav (deflated 11%)']

In [35]:
from google.colab import files
files.download('/content/outputs_2/GeneratedSignals.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [36]:
!cp /content/outputs_2/GeneratedSignals.zip  /content/drive/MyDrive/FinalProject/CAW_outputs

# Remember to change this path everything you do generation ok