### PyTorch implementation of Random Audio Style Transfer

In [None]:
import librosa
import torch

from IPython.display import display, Audio

from utils import plot_spectrum, read_audio_spectrum, spectrum_to_audio, read_audio_spectrum_pt, spectrum_to_audio_pt
from model import RandomCNN, run_style_transfer
from model import get_input_optimizer, get_style_model_and_losses

In [None]:
CONTENT_PATH = "wavs/songs/imperial.mp3"
STYLE_PATH = "wavs/songs/usa.mp3"

content_s, content_sr, content_p = read_audio_spectrum(CONTENT_PATH)
style_s, style_sr, style_p = read_audio_spectrum(STYLE_PATH)

In [None]:
content_wav = spectrum_to_audio(content_s)
display(Audio(content_wav, rate=content_sr))

style_wav = spectrum_to_audio(style_s)
display(Audio(style_wav, rate=style_sr))

In [None]:
plot_spectrum(content_s)
plot_spectrum(style_s)

mag = read_audio_spectrum_pt(CONTENT_PATH)
plot_spectrum(mag)

In [None]:
wav = spectrum_to_audio_pt(torch.from_numpy(content_s))
content_pt = read_audio_spectrum_pt(CONTENT_PATH)
wav_pt = spectrum_to_audio_pt(content_pt)

In [None]:
display(Audio(wav.cpu(), rate=22050))
display(Audio(wav_pt.cpu(), rate=22050))

In [None]:
mcnn = RandomCNN()

content = torch.from_numpy(content_s)[None, None, :, :]
style = torch.from_numpy(style_s)[None, None, :, :]
result = torch.randn(content.data.size())

result_s = (content_s + style_s) / 2
result = torch.from_numpy(result_s)[None, None, :, :]

In [None]:
model, style_losses, content_losses = get_style_model_and_losses(mcnn, style, content)

result.requires_grad_(True)
model.eval()
model.requires_grad_(False)

optimizer = get_input_optimizer(result)

optimizer.zero_grad()
model(result)
style_score = 0
content_score = 0

for sl in style_losses:
    style_score += sl.loss
for cl in content_losses:
    content_score += cl.loss

style_score *= 1e6
content_score *= 1

loss = style_score + content_score
loss.backward()


In [None]:
result_s = result.cpu().data.numpy().squeeze()
plot_spectrum(result_s)

In [None]:
result_wav = spectrum_to_audio(result_s, p=content_p, rounds=50)

In [None]:
result_wav_pt = spectrum_to_audio_pt(result.squeeze())

In [None]:
display(Audio(result_wav, rate=style_sr))
display(Audio(result_wav_pt.cpu(), rate=style_sr))