<a href="https://colab.research.google.com/github/dragonsl-dev/randomCNN-voice-transfer/blob/master/Style_Transfer_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Recommended gpu: k80, P100

In [None]:
!nvidia-smi

# Install

In [1]:
!git clone https://github.com/dragonsl-dev/randomCNN-voice-transfer

Cloning into 'randomCNN-voice-transfer'...
remote: Enumerating objects: 74, done.[K
remote: Counting objects: 100% (74/74), done.[K
remote: Compressing objects: 100% (57/57), done.[K
remote: Total 125 (delta 38), reused 46 (delta 17), pack-reused 51[K
Receiving objects: 100% (125/125), 7.55 MiB | 41.58 MiB/s, done.
Resolving deltas: 100% (58/58), done.


In [2]:
!sudo apt install libav-tools

Reading package lists... Done
Building dependency tree       
Reading state information... Done
Package libav-tools is not available, but is referred to by another package.
This may mean that the package is missing, has been obsoleted, or
is only available from another source
However the following packages replace it:
  ffmpeg

E: Package 'libav-tools' has no installation candidate


In [3]:
%cd randomCNN-voice-transfer/
!pip install -r requirements.txt

/content/randomCNN-voice-transfer


fix librosa issue

In [4]:
#@title
%%writefile utils.py
import librosa
import numpy as np
import torch
from model import *
import soundfile as sf

def wav2spectrum(filename):
    x, sr = librosa.load(filename)
    S = librosa.stft(x, N_FFT)
    p = np.angle(S)

    S = np.log1p(np.abs(S))
    return S, sr


def spectrum2wav(spectrum, sr, outfile):
    # Return the all-zero vector with the same shape of `a_content`
    a = np.exp(spectrum) - 1
    p = 2 * np.pi * np.random.random_sample(spectrum.shape) - np.pi
    for i in range(50):
        S = a * np.exp(1j * p)
        x = librosa.istft(S)
        p = np.angle(librosa.stft(x, N_FFT))
    
    sf.write(outfile, x, sr, 'PCM_24')
    #librosa.output.write_wav(outfile, x, sr)


def wav2spectrum_keep_phase(filename):
    x, sr = librosa.load(filename)
    S = librosa.stft(x, N_FFT)
    p = np.angle(S)

    S = np.log1p(np.abs(S))
    return S, p, sr


def spectrum2wav_keep_phase(spectrum, p, sr, outfile):
    # Return the all-zero vector with the same shape of `a_content`
    a = np.exp(spectrum) - 1
    for i in range(50):
        S = a * np.exp(1j * p)
        x = librosa.istft(S)
        p = np.angle(librosa.stft(x, N_FFT))
    #librosa.output.write_wav(outfile, x, sr)
    sf.write(outfile, x, sr, 'PCM_24')

def compute_content_loss(a_C, a_G):
    """
    Compute the content cost
    Arguments:
    a_C -- tensor of dimension (1, n_C, n_H, n_W)
    a_G -- tensor of dimension (1, n_C, n_H, n_W)
    Returns:
    J_content -- scalar that you compute using equation 1 above
    """
    m, n_C, n_H, n_W = a_G.shape

    # Reshape a_C and a_G to the (m * n_C, n_H * n_W)
    a_C_unrolled = a_C.view(m * n_C, n_H * n_W)
    a_G_unrolled = a_G.view(m * n_C, n_H * n_W)

    # Compute the cost
    J_content = 1.0 / (4 * m * n_C * n_H * n_W) * torch.sum((a_C_unrolled - a_G_unrolled) ** 2)

    return J_content


def gram(A):
    """
    Argument:
    A -- matrix of shape (n_C, n_L)
    Returns:
    GA -- Gram matrix of shape (n_C, n_C)
    """
    GA = torch.matmul(A, A.t())

    return GA


def gram_over_time_axis(A):
    """
    Argument:
    A -- matrix of shape (1, n_C, n_H, n_W)
    Returns:
    GA -- Gram matrix of A along time axis, of shape (n_C, n_C)
    """
    m, n_C, n_H, n_W = A.shape

    # Reshape the matrix to the shape of (n_C, n_L)
    # Reshape a_C and a_G to the (m * n_C, n_H * n_W)
    A_unrolled = A.view(m * n_C * n_H, n_W)
    GA = torch.matmul(A_unrolled, A_unrolled.t())

    return GA


def compute_layer_style_loss(a_S, a_G):
    """
    Arguments:
    a_S -- tensor of dimension (1, n_C, n_H, n_W)
    a_G -- tensor of dimension (1, n_C, n_H, n_W)
    Returns:
    J_style_layer -- tensor representing a scalar style cost.
    """
    m, n_C, n_H, n_W = a_G.shape

    # Reshape the matrix to the shape of (n_C, n_L)
    # Reshape a_C and a_G to the (m * n_C, n_H * n_W)

    # Calculate the gram
    # !!!!!! IMPORTANT !!!!! Here we compute the Gram along n_C,
    # not along n_H * n_W. But is the result the same? No.
    GS = gram_over_time_axis(a_S)
    GG = gram_over_time_axis(a_G)

    # Computing the loss
    J_style_layer = 1.0 / (4 * (n_C ** 2) * (n_H * n_W)) * torch.sum((GS - GG) ** 2)

    return J_style_layer


Overwriting utils.py


In [None]:
#@title
%%writefile train.py
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from torch.autograd import Variable
from utils import *
from model import *
import time
import math
import argparse
cuda = True if torch.cuda.is_available() else False

parser = argparse.ArgumentParser()
parser.add_argument('-content', help='Content input')
parser.add_argument('-content_weight', help='Content weight. Default is 1e2', default = 1e2)
parser.add_argument('-style', help='Style input')
parser.add_argument('-style_weight', help='Style weight. Default is 1', default = 1)
parser.add_argument('-epochs', type=int, help='Number of epoch iterations. Default is 20000', default = 20000)
parser.add_argument('-print_interval', type=int, help='Number of epoch iterations between printing losses', default = 1000)
parser.add_argument('-plot_interval', type=int, help='Number of epoch iterations between plot points', default = 1000)
parser.add_argument('-learning_rate', type=float, default = 0.002)
parser.add_argument('-output', help='Output file name. Default is "output"', default = 'output')
args = parser.parse_args()


CONTENT_FILENAME = args.content
STYLE_FILENAME = args.style

a_content, sr = wav2spectrum(CONTENT_FILENAME)
a_style, sr = wav2spectrum(STYLE_FILENAME)

a_content_torch = torch.from_numpy(a_content)[None, None, :, :]
if cuda:
    a_content_torch = a_content_torch.cuda()
print(a_content_torch.shape)
a_style_torch = torch.from_numpy(a_style)[None, None, :, :]
if cuda:
    a_style_torch = a_style_torch.cuda()
print(a_style_torch.shape)

model = RandomCNN()
model.eval()

a_C_var = Variable(a_content_torch, requires_grad=False).float()
a_S_var = Variable(a_style_torch, requires_grad=False).float()
if cuda:
    model = model.cuda()
    a_C_var = a_C_var.cuda()
    a_S_var = a_S_var.cuda()

a_C = model(a_C_var)
a_S = model(a_S_var)


# Optimizer
learning_rate = args.learning_rate
a_G_var = Variable(torch.randn(a_content_torch.shape) * 1e-3)
if cuda:
    a_G_var = a_G_var.cuda()
a_G_var.requires_grad = True
optimizer = torch.optim.Adam([a_G_var])

# coefficient of content and style
style_param = args.style_weight
content_param = args.content_weight

num_epochs = args.epochs
print_every = args.print_interval
plot_every = args.plot_interval

# Keep track of losses for plotting
current_loss = 0
all_losses = []


def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


start = time.time()
# Train the Model

try:
    for epoch in range(1, num_epochs + 1):
        optimizer.zero_grad()
        a_G = model(a_G_var)

        content_loss = content_param * compute_content_loss(a_C, a_G)
        style_loss = style_param * compute_layer_style_loss(a_S, a_G)
        loss = content_loss + style_loss
        loss.backward()
        optimizer.step()

        # print
        if epoch % print_every == 0:
            print("{} {}% {} content_loss:{:4f} style_loss:{:4f} total_loss:{:4f}".format(epoch,
                                                                                          epoch / num_epochs * 100,
                                                                                          timeSince(start),
                                                                                          content_loss.item(),
                                                                                          style_loss.item(), loss.item()))
            current_loss += loss.item()

        # Add current loss avg to list of losses
        if epoch % plot_every == 0:
            all_losses.append(current_loss / plot_every)
            current_loss = 0
except KeyboardInterrupt:
    print("User interrupted training")

gen_spectrum = a_G_var.cpu().data.numpy().squeeze()
gen_audio_C = args.output + ".wav"
spectrum2wav(gen_spectrum, sr, gen_audio_C)

plt.figure()
plt.plot(all_losses)
plt.savefig('loss_curve.png')

plt.figure(figsize=(5, 5))
# we then use the 2nd column.
plt.subplot(1, 1, 1)
plt.title("Content Spectrum")
plt.imsave('Content_Spectrum.png', a_content[:400, :])

plt.figure(figsize=(5, 5))
# we then use the 2nd column.
plt.subplot(1, 1, 1)
plt.title("Style Spectrum")
plt.imsave('Style_Spectrum.png', a_style[:400, :])

plt.figure(figsize=(5, 5))
# we then use the 2nd column.
plt.subplot(1, 1, 1)
plt.title("CNN Voice Transfer Result")
plt.imsave('Gen_Spectrum.png', gen_spectrum[:400, :])


# Infer
Upload content and style to input directory



```
-content input/orig.wav    # voice to modify
-style input/twi.wav       # song / source
```



In [6]:
# download samples
!wget https://github.com/dragonsl-dev/wav-samples/raw/main/orig.wav -P input/
!wget https://github.com/dragonsl-dev/wav-samples/raw/main/twi2.wav -P input/

--2021-02-05 15:07:16--  https://github.com/dragonsl-dev/wav-samples/raw/main/orig.wav
Resolving github.com (github.com)... 140.82.114.4
Connecting to github.com (github.com)|140.82.114.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/dragonsl-dev/wav-samples/main/orig.wav [following]
--2021-02-05 15:07:16--  https://raw.githubusercontent.com/dragonsl-dev/wav-samples/main/orig.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2428750 (2.3M) [audio/wav]
Saving to: ‘input/orig.wav’


2021-02-05 15:07:16 (49.8 MB/s) - ‘input/orig.wav’ saved [2428750/2428750]

--2021-02-05 15:07:16--  https://github.com/dragonsl-dev/wav-samples/raw/main/twi2.wav
Resolving github.com (github.com)... 140.82.113.3
Conne

In [None]:
# upload your own to input directory in the sidebar

In [8]:

!python train.py -epochs 200000 -print_interval 50 -content input/orig.wav -style input/twi2.wav

torch.Size([1, 1, 257, 1090])
torch.Size([1, 1, 257, 705])
50 0.025% 0m 4s content_loss:3.837719 style_loss:1594.863403 total_loss:1598.701172
100 0.05% 0m 13s content_loss:3.256724 style_loss:1333.706055 total_loss:1336.962769
150 0.075% 0m 21s content_loss:3.004975 style_loss:1009.639526 total_loss:1012.644531
200 0.1% 0m 29s content_loss:2.869258 style_loss:748.779358 total_loss:751.648621
250 0.125% 0m 38s content_loss:2.746639 style_loss:545.045471 total_loss:547.792114
300 0.15% 0m 47s content_loss:2.671658 style_loss:392.923981 total_loss:395.595642
350 0.17500000000000002% 0m 57s content_loss:2.639799 style_loss:286.733704 total_loss:289.373505
400 0.2% 1m 6s content_loss:2.634223 style_loss:215.552063 total_loss:218.186279
450 0.22499999999999998% 1m 16s content_loss:2.637390 style_loss:168.807190 total_loss:171.444580
500 0.25% 1m 26s content_loss:2.639517 style_loss:138.481857 total_loss:141.121368
550 0.27499999999999997% 1m 36s content_loss:2.633878 style_loss:118.981483 t