In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
import IPython.display as ipd
import torch
from torch import nn
from scipy import signal
import glob
import librosa
import librosa.display
import time
import subprocess
import os
from scipy.io import wavfile
from tqdm import tqdm

import skimage.io
import sys
sys.path.append("../src")
from audioutils import get_batch_stft, mss_loss
from decoders import raw_avg_decode

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
print("Device: ", device)
n_bits = 1000*36
np.random.seed(n_bits)
B = np.random.randint(0, 2, n_bits) ## Payload

sr = 44100
n_iters = 200
win_length = 2048
avg_win = 25
lam = 1e-1

## Let xavg be my downsampled average signal
Gamma = 15
fwin = 16 # Between a window of [i-fwin, i+fwin] take the standard deviation of xavg


decoder_fn = lambda u: raw_avg_decode(u, avg_win, Gamma, fwin)


N = (B.size+1)*avg_win
print(B.size/(N/sr), "bps")

x, _ = librosa.load("../data/tunes/smoothcriminal.mp3", sr=sr)
x = x[sr*15::]

x = x[0:N]
x_orig = torch.from_numpy(x).to(device) ## Cover signal
x = torch.from_numpy(x).to(device) ## Watermarked signal
x = torch.atanh(x)
x = x.requires_grad_()

BTarget = torch.from_numpy(B).to(device)

In [None]:
bce_loss = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam([x], lr=1e-3)
losses = []
sigmoid_scale = 3

## Spectrogram of the "cover signal"
SpecOrig = torch.abs(get_batch_stft(x_orig.unsqueeze(0), win_length)[0, :, :])

for i in tqdm(range(n_iters)):
    optimizer.zero_grad()
    
    xtan = torch.tanh(x)
    BEst = decoder_fn(xtan)
    
    # Amount of noise added in dB
    Speci = torch.abs(get_batch_stft(xtan.unsqueeze(0), win_length)[0, :, :])
    #dS = torch.abs(20*torch.log10(Speci+1e-7) - 20*torch.log10(SpecOrig+1e-7))[:, 1:]
    #perceptual_loss = torch.mean(nn.functional.leaky_relu(dS-m))
    
    BEst = BEst.flatten()[0:B.size]
    BEst = sigmoid_scale*BEst
    loss1 = bce_loss(BEst, 1.0*BTarget)

    ## TODO: This needs to be a perceptual loss function
    loss2 = lam*torch.mean(torch.abs(SpecOrig-Speci)) 
    
    loss = loss1 + loss2
    losses.append(loss.item())
    loss1 = loss1.item()
    loss2 = loss2.item()
    
    loss.backward()
    optimizer.step()
    
    if i%100 == 0:
        ipd.clear_output()
        print(i, loss1, loss2)
    
plt.plot(losses)

In [None]:
bitrates = [192, 128, 96, 64]
xtan = torch.tanh(x).detach().cpu().flatten()

ref = np.array(x_orig.detach().cpu().flatten()*32768, dtype=np.int16)
wavfile.write("ref.wav", sr, ref)

xmp3 = np.array(xtan*32768, dtype=np.int16)
wavfile.write("temp.wav", sr, xmp3)
#subprocess.call(["peaq", "--advanced", "ref.wav", "temp.wav"])
compressed_filename = "temp.mp3"

for bitrate in bitrates:
    if os.path.exists(compressed_filename):
        os.remove(compressed_filename)
    subprocess.call(["ffmpeg", "-i", "temp.wav", "-b:a", "{}k".format(bitrate), compressed_filename], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
    xmp3, _ = librosa.load(compressed_filename, sr=sr)
    xmp3 = torch.from_numpy(xmp3).to(device)
    
    BEst = decoder_fn(xmp3)
    BEst = BEst.flatten()[0:B.size]
    BEst *= sigmoid_scale
    BEst = 0.5*(np.sign(BEst.detach().cpu().numpy())+1)
    berr = np.sum(np.abs(B-BEst))/B.size
    print("{}kbps {:.3f}".format(bitrate, berr))

ipd.Audio(xtan, rate=sr)

In [None]:
plt.plot(np.cumsum(np.abs(B-BEst)))