In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
import IPython.display as ipd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from scipy import signal
import glob
import librosa
import time
import pickle
from scipy.io import wavfile

import skimage.io
import sys
sys.path.append("../src")
from imtools import get_voronoi_image, splat_voronoi_image_1nn
from tsp import *

## Utility Functions

In [None]:
def upsample_time(X, hop_length, mode='nearest'):
    """
    Upsample a tensor by a factor of hop_length along the time axis
    
    Parameters
    ----------
    X: torch.tensor(M, T, N)
        A tensor in which the time axis is axis 1
    hop_length: int
        Upsample factor
    mode: string
        Mode of interpolation.  'nearest' by default to avoid artifacts
        where notes in the violin jump by large intervals
    
    Returns
    -------
    torch.tensor(M, T*hop_length, N)
        Upsampled tensor
    """
    X = X.permute(0, 2, 1)
    X = nn.functional.interpolate(X, size=hop_length*X.shape[-1], mode=mode)
    return X.permute(0, 2, 1)

################################################
# Loudness code modified from original Google Magenta DDSP implementation in tensorflow
# https://github.com/magenta/ddsp/blob/86c7a35f4f2ecf2e9bb45ee7094732b1afcebecd/ddsp/spectral_ops.py#L253
# which, like this repository, is licensed under Apache2 by Google Magenta Group, 2020
# Modifications by Chris Tralie, 2023

def power_to_db(power, ref_db=0.0, range_db=80.0, use_tf=True):
    """Converts power from linear scale to decibels."""
    # Convert to decibels.
    db = 10.0*np.log10(np.maximum(power, 10**(-range_db/10)))
    # Set dynamic range.
    db -= ref_db
    db = np.maximum(db, -range_db)
    return db

def extract_loudness(x, sr, hop_length, n_fft=512):
    """
    Extract the loudness in dB by using an A-weighting of the power spectrum
    (section B.1 of the paper)

    Parameters
    ----------
    x: ndarray(N)
        Audio samples
    sr: int
        Sample rate (used to figure out frequencies for A-weighting)
    hop_length: int
        Hop length between loudness estimates
    n_fft: int
        Number of samples to use in each window
    """
    # Computed centered STFT
    S = librosa.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=n_fft, center=True)
    
    # Compute power spectrogram
    amplitude = np.abs(S)
    power = amplitude**2

    # Perceptual weighting.
    freqs = np.arange(S.shape[0])*sr/n_fft
    a_weighting = librosa.A_weighting(freqs)[:, None]

    # Perform weighting in linear scale, a_weighting given in decibels.
    weighting = 10**(a_weighting/10)
    power = power * weighting

    # Average over frequencies (weighted power per a bin).
    avg_power = np.mean(power, axis=0)
    loudness = power_to_db(avg_power)
    return np.array(loudness, dtype=np.float32)

################################################

## Datasets

In [None]:
class CurveData(Dataset):
    def __init__(self, rg, voronoi_samples, T, samples_per_batch):
        """
        Parameters
        ----------
        rg: list(int)
            List of indices to take in each image class (used for test/train split)
        voronoi_samples: int
            Number of samples in the Voronoi image
        T: int
            Number of samples to take in each chunk
        samples_per_batch: int
            Number of samples per batch
        """
        self.files = []
        for c in glob.glob("../data/imagenet/*"): # Go through each class
            files = glob.glob("{}/*.pkl".format(c))
            files = sorted(files)
            self.files += [files[i] for i in rg]
        # Load in all curves
        self.Ys = []
        for file in self.files:
            res = pickle.load(open(file, "rb"))
            Y = res[voronoi_samples]["Y"]
            if Y.shape[0] >= T:
                self.Ys.append(np.array(Y, dtype=np.float32))
        self.T = T
        self.samples_per_batch = samples_per_batch
    
    def __len__(self):
        return self.samples_per_batch
    
    def __getitem__(self, idx):
        """
        Pull out a random chunk of the appropriate length from a random file
        """
        idx = np.random.randint(len(self.Ys))
        Y = self.Ys[idx]
        Y = np.roll(Y, np.random.randint(Y.shape[0]), axis=0)
        Y = Y[0:self.T, :]
        return torch.from_numpy(Y)
        
class AudioData(Dataset):
    def __init__(self, file_pattern, T, samples_per_batch, win_length, sr):
        """
        Parameters
        ----------
        file_pattern: string
            File pattern to match for audio files
        T: int
            Number of windows to take in each chunk
        samples_per_batch: int
            Number of samples per batch
        win_length: int
            Window length of STFT; hop_length assumed to be half of this
        sr: int
            Sample rate
        """
        self.samples_per_batch = samples_per_batch
        hop_length = win_length//2
        self.hop_length = hop_length
        self.T = T
        self.n_samples = hop_length*(T-1)+win_length
        self.samples = []
        self.loudnesses = []
        for filename in glob.glob(file_pattern):
            x, _ = librosa.load(filename, sr=sr)
            loudness = extract_loudness(x, sr, hop_length, n_fft=win_length)
            x = np.array(x, dtype=np.float32)
            loudness = np.array(loudness, dtype=np.float32)
            self.samples.append(x)
            self.loudnesses.append(loudness)

    def __len__(self):
        return self.samples_per_batch
    
    def __getitem__(self, idx):
        """
        Return a random audio clip, along with its loudness
        
        Returns
        -------
        torch.tensor(n_samples)
            Audio clip
        torch.tensor(T, 1)
            Loudness
        """
        idx = np.random.randint(len(self.samples))
        x = self.samples[idx]
        loudness = self.loudnesses[idx]
        i1 = np.random.randint(len(loudness)-T-1)
        loudness = loudness[i1:i1+self.T]
        loudness = loudness[:, None]
        i1 = i1*self.hop_length
        x = x[i1:i1+self.n_samples]
        return torch.from_numpy(x), torch.from_numpy(loudness)
        
        
n_samples = 3000
T = 300
samples_per_batch = 10000
win_length = 1024
sr = 44100

curve_train = CurveData(np.arange(10), n_samples, T, samples_per_batch)
curve_test  = CurveData(np.arange(10, 15), n_samples, T, samples_per_batch)

audio_train = AudioData("../data/musdb18hq/train/*/mixture.wav", T, samples_per_batch, win_length, sr)
audio_test  = AudioData("../data/musdb18hq/test/*/mixture.wav",  T, samples_per_batch, win_length, sr)

In [None]:
loader = DataLoader(curve_train, batch_size=16, shuffle=True)
Y = next(iter(loader)).numpy()
print(Y.shape)
plt.scatter(Y[0, :, 0], Y[0, :, 1], c=Y[0, :, 2::])
plt.plot(Y[0, :, 0], Y[0, :, 1], c='k', linewidth=1)

loader = DataLoader(audio_train, batch_size=16, shuffle=True)
X, L = next(iter(loader))
ipd.Audio(X[0, :], rate=sr)

# Encoder / Decoder Architectures

In [None]:
class MLP(nn.Module):
    def __init__(self, depth=3, n_input=1, n_units=512):
        super(MLP, self).__init__()
        layers = []
        for i in range(depth):
            if i == 0:
                layers.append(nn.Linear(n_input, n_units))
            else:
                layers.append(nn.Linear(n_units, n_units))
            layers.append(nn.LayerNorm(normalized_shape=n_units))
            layers.append(nn.LeakyReLU())
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.layers(x)
    
    def get_num_parameters(self):
        total = 0
        for p in self.parameters():
            total += np.prod(p.shape)
        return total
            
def modified_sigmoid(x):
    return 2*torch.sigmoid(x)**np.log(10) + 1e-7
        

def get_filtered_noise(H, A, win_length):
    """
    Perform subtractive synthesis by applying FIR filters to windows
    and summing overlap-added versions of them together
    
    Parameters
    ----------
    H: torch.tensor(n_batches x time x n_coeffs)
        FIR filters for each window for each batch
    A: torch.tensor(n_batches x time x 1)
        Amplitudes for each window for each batch
    win_length: int
        Window length of each chunk to which to apply FIR filter.
        Hop length is assumed to be half of this
        
    Returns
    -------
    torch.tensor(n_batches, hop_length*(time-1)+win_length)
        Filtered noise for each batch
    """
    n_batches = H.shape[0]
    T = H.shape[1]
    n_coeffs = H.shape[2]
    hop_length = win_length//2
    n_samples = hop_length*(T-1)+win_length

    ## Pad impulse responses and generate noise
    H = nn.functional.pad(H, (0, win_length*2-n_coeffs))
    noise = torch.randn(n_batches, n_samples).to(H)

    ## Take out each overlapping window of noise
    N = torch.zeros(n_batches, T, win_length*2).to(H)
    n_even = n_samples//win_length
    N[:, 0::2, 0:win_length] = noise[:, 0:n_even*win_length].view(n_batches, n_even, win_length)
    n_odd = T - n_even
    N[:, 1::2, 0:win_length] = noise[:, hop_length:hop_length+n_odd*win_length].view(n_batches, n_odd, win_length)
    
    # Apply amplitude to each window
    N = N*A
    
    ## Perform a zero-phase version of each filter and window
    FH = torch.fft.rfft(H)
    FH = torch.real(FH)**2 + torch.imag(FH)**2 # Make it zero-phase
    FN = torch.fft.rfft(N)
    y = torch.fft.irfft(FH*FN)[..., 0:win_length]
    y = y*torch.hann_window(win_length).to(y)

    ## Overlap-add everything
    ola = torch.zeros(n_batches, n_samples).to(y)
    ola[:, 0:n_even*win_length] += y[:, 0::2, :].reshape(n_batches, n_even*win_length)
    ola[:, hop_length:hop_length+n_odd*win_length] += y[:, 1::2, :].reshape(n_batches, n_odd*win_length)
    
    return ola
    
    
        
class CurveEncoder(nn.Module):
    def __init__(self, mlp_depth, n_units, n_taps, win_length, pre_scale=0.01):
        """
        Parameters
        ----------
        mlp_depth: int
            Depth of each multilayer perceptron
        n_units: int
            Number of units in each multilayer perceptron
        n_taps: int
            Number of taps in each FIR filter
        win_length: int
            Length of window for each windowed audio chunk
        pre_scale: float
            Initial ampitude of noise (try to start off much lower than audio)
        """
        super(CurveEncoder, self).__init__()
        self.win_length = win_length
        self.hop_length = win_length//2
        self.pre_scale = pre_scale
        
        self.YMLP = MLP(mlp_depth, 5, n_units) # Curve MLP
        self.LMLP = MLP(mlp_depth, 1, n_units) # Loudness MLP
        
        self.gru = nn.GRU(input_size=n_units*2, hidden_size=n_units, num_layers=1, bias=True, batch_first=True)
        self.FinalMLP = MLP(mlp_depth, n_units*3, n_units)
        self.TapsDecoder = nn.Linear(n_units, n_taps)
        self.AmplitudeDecoder = nn.Linear(n_units, 1)
        
    
    def forward(self, Y, L):
        """
        Parameters
        ----------
        Y: torch.tensor(n_batches, T, 5)
            xyrgb samples
        L: torch.tensor(n_batches, T, 1)
            
        """
        YOut = self.YMLP(Y)
        LOut = self.LMLP(L)
        YL = torch.concatenate((YOut, LOut), axis=2)
        G = self.gru(YL)[0]
        G = torch.concatenate((YOut, LOut, G), axis=2)
        final = self.FinalMLP(G)
        H = nn.functional.tanh(self.TapsDecoder(final))
        A = nn.functional.leaky_relu(self.AmplitudeDecoder(final))
        N = get_filtered_noise(H, A, self.win_length)
        return self.pre_scale*N
    
    def get_num_parameters(self):
        total = 0
        for p in self.parameters():
            total += np.prod(p.shape)
        return total
        

        
class CurveDecoder(nn.Module):
    def __init__(self, mlp_depth, n_units, win_length):
        """
        Parameters
        ----------
        mlp_depth: int
            Depth of each multilayer perceptron
        n_units: int
            Number of units in each multilayer perceptron
        win_length: int
            Length of window for each windowed audio chunk
        """
        super(CurveDecoder, self).__init__()
        self.win_length = win_length
        
        self.SMLP = MLP(mlp_depth, win_length//2+1, n_units) # STFT MLP
        
        self.gru = nn.GRU(input_size=n_units, hidden_size=n_units, num_layers=1, bias=True, batch_first=True)
        self.FinalMLP = MLP(mlp_depth, n_units*2, n_units)
        self.YDecoder = nn.Linear(n_units, 5)
        
    
    def forward(self, X):
        """
        Parameters
        ----------
        Y: torch.tensor(n_batches, n_samples)
            Audio samples
        """
        win = self.win_length
        hop = win//2
        hann = torch.hann_window(win).to(X)
        S = torch.abs(torch.stft(X, win, hop, win, hann, return_complex=True, center=False))
        S = torch.swapaxes(S, 1, 2)
        SOut = self.SMLP(S)
        G = self.gru(SOut)[0]
        G = torch.concatenate((SOut, G), axis=2)
        final = self.FinalMLP(G)
        return modified_sigmoid(self.YDecoder(final))
    
    def get_num_parameters(self):
        total = 0
        for p in self.parameters():
            total += np.prod(p.shape)
        return total


curve_loader = DataLoader(curve_train, batch_size=16, shuffle=True)
audio_loader = DataLoader(audio_train, batch_size=16, shuffle=True)
Y = next(iter(curve_loader))
X, L = next(iter(audio_loader))

mlp_depth = 3
n_units = 256
n_taps = 50
encoder = CurveEncoder(mlp_depth, n_units, n_taps, win_length)
decoder = CurveDecoder(mlp_depth, n_units, win_length)
print("Encoder params", encoder.get_num_parameters())
print("Decoder params", decoder.get_num_parameters())
N = encoder(Y, L)
ipd.Audio(N.detach().numpy()[1, :], rate=sr)
XN = X + N
YOut = decoder(XN)

# Loss Functions

In [None]:
HANN_TABLE = {}

def mss_loss(X, Y, eps=1e-7):
    loss = 0
    win = 64
    while win <= 2048:
        hop = win//4
        if not win in HANN_TABLE:
            HANN_TABLE[win] = torch.hann_window(win).to(X)
        hann = HANN_TABLE[win]
        SX = torch.abs(torch.stft(X.squeeze(), win, hop, win, hann, return_complex=True))
        SY = torch.abs(torch.stft(Y.squeeze(), win, hop, win, hann, return_complex=True))
        loss_win = torch.sum(torch.abs(SX-SY)) + torch.sum(torch.abs(torch.log(SX+eps)-torch.log(SY+eps)))
        loss += loss_win/torch.numel(SX)
        win *= 2
    return loss

# Test Example - Layla And Smooth Criminal

In [None]:
device = 'cpu'
n_points = 3000
n_iters = 100
I = skimage.io.imread("../data/images/layla.png")
N = min(I.shape[0], I.shape[1])
J, YLayla, final_cost = get_voronoi_image(I, device, n_points, n_neighbs=2, n_iters=n_iters, verbose=False, plot_iter_interval=0, use_lsqr=False)
YLayla = get_tsp_tour(YLayla)
#YLayla = curve_train.Ys[100]
plt.imshow(splat_voronoi_image_1nn(YLayla, 200, 200))

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
YLayla = np.array(YLayla, dtype=np.float32)
hop_length = win_length//2
TLayla = YLayla.shape[0]
YLayla = torch.from_numpy(YLayla).to(device)
YLayla = YLayla.view(1, TLayla, 5)
xsmooth, _ = librosa.load("../data/tunes/smoothcriminal.mp3", sr=sr)
#xsmooth, _ = librosa.load("../data/musdb18hq/test/Arise - Run Run Run/mixture.wav", sr=sr)
xsmooth = xsmooth[sr*15::] # Cut off quieter beginning
lsmooth = extract_loudness(xsmooth, sr, hop_length, win_length)
lsmooth = np.array(lsmooth[0:TLayla], dtype=np.float32)
xsmooth = np.array(xsmooth[0:hop_length*(TLayla-1)+win_length], dtype=np.float32)
xsmooth = torch.from_numpy(xsmooth[None, :]).to(device)
lsmooth = torch.from_numpy(lsmooth[None, :, None]).to(device)
print(xsmooth.shape, lsmooth.shape)

# Train Loop

In [None]:
# Try to use the GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
print("Device: ", device)

## Step 2: Create model with a test batch
noise_only = False
mlp_depth = 3
n_units = 512
n_taps = 200
pre_scale = 1
lam_xy = 1000 # Weight for geometry curve fit
lam_rgb = 200 # Weight for rgb curve fit
encoder = CurveEncoder(mlp_depth, n_units, n_taps, win_length, pre_scale)
encoder = encoder.to(device)
decoder = CurveDecoder(mlp_depth, n_units, win_length)
decoder = decoder.to(device)
print("Encoder params", encoder.get_num_parameters())
print("Decoder params", decoder.get_num_parameters())

## Step 3: Setup the loss function
batch_size=16
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)
# Start at learning rate of 0.001, rate decay of 0.98 factor every 10,000 steps
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10000/len(curve_train), gamma=0.98)

n_epochs = 10000
xy_losses = []
rgb_losses = []
audio_losses = []

test_xy_losses = []
test_rgb_losses = []
test_audio_losses = []

fac = 0.7
plt.figure(figsize=(fac*18, fac*18))

for epoch in range(n_epochs):
    #####################    STEP 1:  TRAIN     ##################### 
    curve_loader = DataLoader(curve_train, batch_size=16, shuffle=True)
    audio_loader = DataLoader(audio_train, batch_size=16, shuffle=True)
    
    train_xy_loss = 0
    train_rgb_loss = 0
    train_audio_loss = 0
    for batch_num, (Y, (X, L)) in enumerate(zip(curve_loader, audio_loader)): # Go through each mini batch
        # Move inputs/outputs to GPU
        Y = Y.to(device)
        X = X.to(device)
        L = L.to(device)
        # Reset the optimizer's gradients
        optimizer.zero_grad()
        
        # Encode the curve
        N = encoder(Y, L)
        # Add filtered noise to audio and decode
        if noise_only:
            XN = N
        else:
            XN = X + N
        YOut = decoder(XN)
        
        # Add a loss term for the MSS fit of XN to X, as well as the L1 fit of Y to YOut
        loss_audio = mss_loss(X, XN)
        loss_xy = torch.mean(torch.abs(Y[:, :, 0:2]-YOut[:, :, 0:2]))
        loss_rgb = torch.mean(torch.abs(Y[:, :, 2::]-YOut[:, :, 2::]))
        
        loss = lam_xy*loss_xy + lam_rgb*loss_rgb + loss_audio
        
        train_xy_loss += loss_xy.item()
        train_rgb_loss += loss_rgb.item()
        train_audio_loss += loss_audio.item()
        
        # Compute the gradients of the loss function with respect
        # to all of the parameters of the model
        loss.backward()
        optimizer.step()
        
    last_Y = Y
    last_YOut = YOut
        
    audio_losses.append(train_audio_loss/len(audio_loader))
    xy_losses.append(train_xy_loss/len(audio_loader))
    rgb_losses.append(train_rgb_loss/len(audio_loader))
    
    print("Epoch {}, audio loss {:.3f}, xy loss {:.3f}, rgb loss {:.3f}".format(epoch, audio_losses[-1], xy_losses[-1], rgb_losses[-1]))
    scheduler.step()
    
    #####################    STEP 2:  TEST     ##################### 
    curve_loader = DataLoader(curve_test, batch_size=16)
    audio_loader = DataLoader(audio_test, batch_size=16)
    test_xy_loss = 0
    test_rgb_loss = 0
    test_audio_loss = 0
    for batch_num, (Y, (X, L)) in enumerate(zip(curve_loader, audio_loader)): # Go through each mini batch
        # Move inputs/outputs to GPU
        Y = Y.to(device)
        X = X.to(device)
        L = L.to(device)
        # Encode the curve
        N = encoder(Y, L)
        # Add filtered noise to audio and decode
        if noise_only:
            XN = N
        else:
            XN = X + N
        YOut = decoder(XN)
        # Add a loss term for the MSS fit of XN to X, as well as the L1 fit of Y to YOut
        loss_audio = mss_loss(X, XN)
        loss_xy = torch.mean(torch.abs(Y[:, :, 0:2]-YOut[:, :, 0:2]))
        loss_rgb = torch.mean(torch.abs(Y[:, :, 2::]-YOut[:, :, 2::]))
        test_xy_loss += loss_xy.item()
        test_rgb_loss += loss_rgb.item()
        test_audio_loss += loss_audio.item()
        
    last_Y_test = Y
    last_YOut_test = YOut
        
    test_audio_losses.append(test_audio_loss/len(audio_loader))
    test_xy_losses.append(test_xy_loss/len(audio_loader))
    test_rgb_losses.append(test_rgb_loss/len(audio_loader))
    
    
    #####################    STEP 3:  LAYLA     ##################### 
    N = encoder(YLayla, lsmooth)
    if noise_only:
        XN = N
    else:
        XN = xsmooth + N
    YOut = decoder(XN)
    loss_xy_layla  = torch.mean(torch.abs(YLayla[:, :, 0:2]-YOut[:, :, 0:2]))
    loss_rgb_layla = torch.mean(torch.abs(YLayla[:, :, 2::]-YOut[:, :, 2::]))
    YOut = YOut.detach().cpu().numpy()
    YOut = YOut[0, :, :]
    YOut[:, 2::] = np.maximum(YOut[:, 2::], 0)
    YOut[:, 2::] = np.minimum(YOut[:, 2::], 1)
    
    
    #####################    STEP 4: PLOT    ##################### 
    plt.clf()
    plt.subplot(331)
    plt.scatter(YOut[:, 0], YOut[:, 1], c=YOut[:, 2::])
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.gca().invert_yaxis()
    plt.subplot(332)
    plt.imshow(splat_voronoi_image_1nn(YOut, 200, 200))
    
    plt.subplot(333)
    plt.plot(audio_losses)
    plt.plot(lam_xy*np.array(xy_losses))
    plt.plot(lam_rgb*np.array(rgb_losses))
    plt.legend(["Audio ({:.3f})".format(audio_losses[-1]), 
                "xy ({:.3f})".format(xy_losses[-1]),
                "rgb ({:.3f})".format(rgb_losses[-1])])
    plt.xlabel("Epoch")
    plt.ylabel("Scaled Loss")
    plt.title("Epoch {}, Train Losses".format(epoch))
    
    plt.subplot(334)
    x1 = last_Y.detach().cpu()[0, :, 0].numpy()
    x2 = last_YOut.detach().cpu()[0, :, 0].numpy()
    plt.scatter(x1, x2)
    plt.axis("equal")
    plt.title("X Coord Train (Mean {:.3f})".format(np.mean(np.abs(x1-x2))))
    plt.subplot(335)
    x1 = last_Y.detach().cpu()[0, :, 2].numpy()
    x2 = last_YOut.detach().cpu()[0, :, 2].numpy()
    plt.scatter(x1, x2)
    plt.axis("equal")
    plt.title("R Coord Train (Mean {:.3f})".format(np.mean(np.abs(x1-x2))))    
    
    plt.subplot(336)
    plt.plot(test_audio_losses)
    plt.plot(lam_xy*np.array(test_xy_losses))
    plt.plot(lam_rgb*np.array(test_rgb_losses))
    plt.legend(["Audio ({:.3f})".format(test_audio_losses[-1]), 
                "xy ({:.3f})".format(test_xy_losses[-1]),
                "rgb ({:.3f})".format(test_rgb_losses[-1])])
    plt.xlabel("Epoch")
    plt.ylabel("Scaled Loss")
    plt.title("Test Losses")
    
    
    plt.subplot(337)
    plt.scatter(YLayla.detach().cpu()[0, :, 0], YOut[:, 0])
    plt.axis("equal")
    plt.title("X Coord Layla (Loss {:.3f})".format(loss_xy_layla))
    plt.subplot(338)
    plt.scatter(YLayla.detach().cpu()[0, :, 2], YOut[:, 2])
    plt.axis("equal")
    plt.title("R Coord Layla (Loss {:.3f})".format(loss_rgb_layla))
    
    plt.subplot(339)
    textstr = "Epoch {}\n\nlam_xy = {}\nlam_rgb = {}\npre_scale={}\nn_units={}\nn_taps={}\npre_scale={}".format(epoch, lam_xy, lam_rgb, pre_scale, n_units, n_taps, pre_scale)
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax = plt.gca()
    ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=14,
            verticalalignment='top', bbox=props)
    plt.axis("off")
    
    plt.savefig("Epoch{}.png".format(epoch), bbox_inches='tight')
    
    x = XN.detach().cpu().numpy()[0, :]
    x = x/np.max(x)
    x = np.array(x*32768, dtype=np.int16)
    wavfile.write("Epoch{}.wav".format(epoch), sr, x)
    
    x = N.detach().cpu().numpy()[0, :]
    x = x/np.max(x)
    x = np.array(x*32768, dtype=np.int16)
    wavfile.write("Epoch{}Noise.wav".format(epoch), sr, x)