In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
import IPython.display as ipd
from scipy.ndimage import maximum_filter
import scipy.io as sio
import librosa
import torch
from torch import nn

In [None]:
sr = 44100
x_carrier, sr = librosa.load("barry.mp3", sr=sr)
x_signal, sr = librosa.load("chocolaterain.mp3", sr=sr)
x_signal = x_signal[-sr*20::]
x_carrier = x_carrier[-sr*20::]

In [None]:
def get_windowed_softmax(X, time_win, freq_win, temperature=20):
    """
    Compute softmax in a window around each element in a tensor, 
    using summed area tables to avoid loops
    https://en.wikipedia.org/wiki/Summed-area_table
    
    Parameters
    ----------
    X: torch.tensor(M, N)
        2D tensor on which to compute the windowed softmax
    time_win: int
        Half-width of window
    freq_win: int
        Half-height of window
    temperature: float
        Put data into the range [0, temperature] before taking the softmax
    """
    E = torch.exp(temperature*X/torch.max(X))
    C = torch.nn.functional.pad(E, (time_win+1, time_win, freq_win+1, freq_win))
    C = torch.cumsum(C, axis=1)
    C = torch.cumsum(C, axis=0)

    fw = freq_win*2+1
    tw = time_win*2+1

    IA = C[0:-fw, 0:-tw]
    IB = C[0:-fw, tw::]
    IC = C[fw::, 0:-tw]
    ID = C[fw::, tw::]

    denom = ID + IA - IB - IC
    return E/denom
    
def make_beepy_tune(S, win, hop, thresh, min_freq=0):
    """
    Make a beepy tune out of the softmax values
    
    S: ndarray(M, N)
        Softmax of spectrogram
    win: int
        Window length used in spectrogram
    hop: int
        Hop length used in spectrogram
    thresh: float
        Threshold above which to include tones
    min_freq: float
        Frequency below which to ignore peaks
    
    Returns
    -------
    """
    x, y = np.meshgrid(np.arange(S.shape[1]), np.arange(S.shape[0]))
    hann = 0.5*(1-np.cos(np.linspace(0, 2*np.pi, win)))
    yout = np.zeros(S.shape[1]*hop+win)
    t = np.arange(win)/sr
    for time, freq, mag in zip(x[S > thresh], y[S > thresh], S[S > thresh]):
        freq = freq*sr/win
        if freq > min_freq and np.isfinite(mag) and yout[time*hop:time*hop+win].size == win:
            yout[time*hop:time*hop+win] += mag*hann*np.cos(2*np.pi*freq*t)
    return yout



## Parameters
win = 2048
hop = 1024
time_win = 3
freq_win = 5
max_freq = 256
temperature = 10
thresh = 0.05 # Softmax threshold above which to clamp to 1
lam = 0.5 # Weight of maxes term in the loss
lam_mask = 30 # Max loss of mask term


hann = torch.hann_window(win)
XSignal = torch.stft(torch.from_numpy(x_signal), win, hop, win, hann, return_complex=True)
XSignal = torch.abs(XSignal)
SSignal = get_windowed_softmax(XSignal[0:max_freq+freq_win, :], time_win, freq_win, temperature)[0:max_freq, :]

XCarrier = torch.stft(torch.from_numpy(x_carrier), win, hop, win, hann, return_complex=True)
XCarrier = torch.abs(XCarrier)
SCarrier = get_windowed_softmax(XCarrier[0:max_freq+freq_win, :], time_win, freq_win, temperature)[0:max_freq, :]

print(torch.sum(torch.isinf(SSignal)), torch.max(SSignal))
print(torch.mean(SSignal[SSignal > thresh]))

plt.figure(figsize=(10, 6))
plt.imshow(SSignal[0:max_freq, :], aspect='auto', cmap='magma', interpolation='none')
plt.ylim([0, max_freq])
plt.colorbar()

ipd.Audio(make_beepy_tune(SSignal[0:max_freq, :].numpy(), win, hop, thresh), rate=sr)

In [None]:
## Spectral fit loss
HANN_TABLE = {}
def mss_loss(x, y, eps=1e-7):
    loss = 0
    win = 64
    while win <= 2048:
        hop = win//4
        if not win in HANN_TABLE:
            HANN_TABLE[win] = torch.hann_window(win).to(x)
        hann = HANN_TABLE[win]
        SX = torch.abs(torch.stft(x, win, hop, win, hann, return_complex=True))
        SY = torch.abs(torch.stft(y, win, hop, win, hann, return_complex=True))
        loss_win = torch.sum(torch.abs(SX-SY)) + torch.sum(torch.abs(torch.log(SX+eps)-torch.log(SY+eps)))
        loss += loss_win/torch.numel(SX)
        win *= 2
    return loss

In [None]:
device = 'cuda'
x_orig = torch.from_numpy(x_carrier).to(device)
# x = torch.autograd.Variable(torch.atanh(torch.from_numpy(x_carrier)), requires_grad=True).to(device)
#x = nn.Parameter(torch.atanh(torch.from_numpy(x_carrier))).to(device)
x = torch.from_numpy(x_carrier).to(device)
x = torch.atanh(x)
x = x.requires_grad_()
hann = torch.hann_window(win).to(x)

## Step 1: Compute maxes to hide
S = torch.abs(torch.stft(torch.from_numpy(x_signal).to(x), win, hop, win, hann, return_complex=True))
SMaxSignal = get_windowed_softmax(S[0:max_freq+freq_win, :], time_win, freq_win, temperature)[0:max_freq, :]

# Clamp all of the targetmaxes to the 99%th highest max above the threshold
# This makes them sound less like the hidden signal
q = torch.quantile(SMaxSignal[SMaxSignal > thresh], 0.99)
print("q", q)
SMaxSignal[SMaxSignal > thresh] = q
SMaxSignal[SMaxSignal < thresh] = 0

# Create a mask to choose which maxes to include
n_mask = torch.sum(SMaxSignal > thresh)
mask = 5*torch.ones(n_mask).to(device)
mask = mask.requires_grad_()

## Step 2: Perform optimization
optimizer = torch.optim.Adam([x, mask], lr=1e-2)
n_iters = 2000
mss_losses = []
max_losses = []
mask_sums = []

plt.figure(figsize=(12, 8))

for i in range(n_iters):
    optimizer.zero_grad()
    
    xtan = torch.tanh(x)
    loss1 = mss_loss(xtan, x_orig)
    mss_losses.append(loss1.item())
    
    S = torch.abs(torch.stft(xtan, win, hop, win, hann, return_complex=True))
    S = get_windowed_softmax(S[0:max_freq+freq_win, :], time_win, freq_win, temperature)[0:max_freq, :]
    diff = (S-SMaxSignal)**2
    diff[SMaxSignal>thresh] *= torch.sigmoid(mask)*diff[SMaxSignal>thresh]
    loss2 = lam*torch.sum(diff)
    max_losses.append(loss2.item())
    
    mask_sum = torch.sum(torch.sigmoid(mask))
    mask_sums.append(mask_sum.item())
    
    loss = loss1 + loss2 - mask_sum*lam_mask/n_mask
    loss.backward()
    optimizer.step()
    
    
    ## Plot progress every 100 iterations
    if i%100 == 0 or i == n_iters-1:  
        print(mss_losses[-1], max_losses[-1], mask_sums[-1])
        res = torch.tanh(x).detach().cpu().numpy()
        
        #X = torch.stft(torch.from_numpy(x_signal), win, hop, return_complex=True)
        #X = torch.abs(X)
        #MHide = maximum_filter(X, (freq_win*2+1, time_win*2+1))
        #MHide = (X.numpy() == MHide)*1.0
        #MHide = MHide[0:max_freq, :]
        MHide = (SMaxSignal.detach().cpu().numpy() == torch.max(SMaxSignal).detach().cpu().numpy())
        MHide = MHide[0:max_freq, :]*1.0
        
        X = torch.stft(torch.from_numpy(res), win, hop, return_complex=True)
        X = torch.abs(X)
        MCarry = maximum_filter(X, (freq_win*2+1, time_win*2+1))
        MCarry = (X.numpy() == MCarry)*1.0
        MCarry = MCarry[0:max_freq, :]
        
        agreed = int(np.sum(MHide*MCarry))
        # Maxes that are supposed to be hidden but that are not
        false_neg = np.sum((MHide==1)*(MCarry==0))
        # Maxes that are not supposed to be hidden but that exist anyway
        false_pos = np.sum((MHide==0)*(MCarry==1))
        
        plt.clf()
        
        plt.subplot2grid((2, 3), (0, 0), colspan=3)
        plt.imshow(MHide-MCarry, cmap='RdBu')
        #cbar = plt.colorbar(ticks=[-1, 1])
        #cbar.ax.set_yticklabels(['FP', 'FN'])
        plt.ylim([0, max_freq])
        plt.title("{} Agreed, {} False Neg (Blue), {} False Pos (Red), {:.3f} Max Loss".format(agreed, false_neg, false_pos, max_losses[-1]))
        
        
        plt.subplot(234)
        plt.plot(mss_losses)
        plt.title("MSS Losses {:.3f}".format(mss_losses[-1]))
        plt.xlabel("Iteration")
        
        plt.subplot(235)
        plt.plot(max_losses)
        plt.title("Max Losses {:.3f}".format(max_losses[-1]))
        plt.xlabel("Iteration")
        
        plt.subplot(236)
        plt.plot(mask_sums)
        plt.plot(np.ones(len(mask_sums))*torch.numel(mask))
        plt.title("Mask Sum {:.3f}".format(mask_sums[-1]))
        plt.xlabel("Iteration")
        
        plt.savefig("%i.png"%i)
    

In [None]:
res = torch.tanh(x).detach().cpu().numpy()
ipd.Audio(res, rate=sr)