In [1]:
import numpy as np
import pandas as pd
import IPython
import os
import glob
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm, tqdm_notebook

from pyAudioAnalysis import audioSegmentation as aS
from pydub import AudioSegment
import pydub

import torchaudio
import torchaudio.transforms as transforms
import librosa
import librosa.display

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
torch.__version__

'1.5.0'

In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib inline  

In [4]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [5]:
class PollyDataset(Dataset):
    def __init__(self):
        uk_words = set(map(lambda x: x.split('/')[-1][:-7], glob.glob("data/uk/*")))
        us_words = set(map(lambda x: x.split('/')[-1][:-7], glob.glob("data/us/*")))
        common_words = list(uk_words.intersection(us_words))
        self.uk_words_common = list(map(lambda x: "data/uk/" + x + "_uk.mp3", common_words))
        self.us_words_common = list(map(lambda x: "data/us/" + x + "_us.mp3", common_words))
        
        self.sg_transform = torchaudio.transforms.Spectrogram(n_fft=255)
        
    def __getitem__(self, idx):
        # load and compute spectrogram (sg) for us accent word
        us_audio = torchaudio.load(self.us_words_common[idx])[0]
        eps = 1e-5
        us_sg = F.interpolate(
            self.sg_transform(
                us_audio
            ), size = 128
        )
        # load and compute spectrogram (sg) for uk accent word
        uk_audio = torchaudio.load(self.uk_words_common[idx])[0]
        uk_sg = F.interpolate(
            self.sg_transform(
                uk_audio
            ), size = 128
        )
        return (us_sg, uk_sg)
    
    def __len__(self):
        return len(self.us_words_common)

In [6]:
polly = PollyDataset()
batch_size = 64

loader_train = torch.utils.data.DataLoader(polly, batch_size=batch_size,
                                          shuffle=False, num_workers=8, pin_memory=True, drop_last=True)
iter_train = iter(loader_train)
print(len(loader_train))

480


In [None]:
sample = iter_train.next()
sample[0].shape

In [None]:
uk_audio, rate = torchaudio.load("data/uk/treaty_uk.mp3")
print(rate)
IPython.display.Audio(uk_audio.numpy(), rate=rate)

In [None]:
us_audio, rate = torchaudio.load("data/us/treaty_us.mp3")
IPython.display.Audio(us_audio.numpy(), rate=rate)

In [None]:
us_sg = F.interpolate(
            polly.sg_transform(
                us_audio
            ).log2(), size = 128
        )
uk_sg = F.interpolate(
            polly.sg_transform(
                uk_audio
            ).log2(), size = 128
        )
fig = plt.figure(figsize=(8, 8))
fig.add_subplot(1, 2, 1)
plt.imshow(us_sg[0][:128, :])
fig.add_subplot(1, 2, 2)
plt.imshow(uk_sg[0][:128, :])
plt.imsave('polly_accent_transfer.png')

In [None]:
IPython.display.Audio(polly[1][2].numpy(), rate=polly[1][3])

In [7]:
class ATModel(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, batch_size):
        super(ATModel, self).__init__()
        self.batch_size = batch_size
        self.embedding_size = 1024
        self.seq_len = 8 
        self.features = int(self.embedding_size/self.seq_len)
        self.hidden_size = int(self.features)
        
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=stride, padding=1
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=stride, padding=1
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.maxpool = nn.MaxPool2d(kernel_size=2)
        
        self.conv_fc = nn.Linear(768, self.embedding_size)
        
        
        self.lstm = nn.LSTM(
            input_size=self.features,
            hidden_size=self.hidden_size,
            num_layers=3,
            batch_first=True,
            bidirectional=False
        )
        
        self.lstm_fc1 = nn.Linear(self.embedding_size * 2, self.embedding_size)
        
        self.lstm_fc2 = nn.Linear(self.embedding_size, 128 * 128)
        
        
        self.deconv1 = nn.ConvTranspose2d(1, 3, 4, stride=2, padding=1)
        self.bn_dc1 = nn.BatchNorm2d(3)
        
        self.deconv2 = nn.ConvTranspose2d(3, 1, 4, stride=2, padding=1)
        self.bn_dc2 = nn.BatchNorm2d(1)
        
    def forward(self, x):
#         print('input:', x.shape)
        x = self.conv1(x)
        x = F.relu(x)
        x = self.bn1(x)
        
#         print('conv1:', x.shape)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.bn2(x)
        
#         print('conv2:', x.shape)
        
        x = self.maxpool(x)
        
#         print('maxpool:', x.shape)
        x = torch.flatten(x, 1)
        x = self.conv_fc(x)
#         print('conv_fc:', x.shape)
        x = x.view(self.batch_size, self.seq_len, self.features)
#         print('x_view:', x.shape)
        x, (h_n, c_n) = self.lstm(x)
#         print('lstm:', x.shape)
        x = x.reshape(self.batch_size, 1, 32, 32)
#         x = x.reshape(self.batch_size, self.embedding_size)

        x = self.deconv1(x)
        x = F.relu(x)
        x = self.bn_dc1(x)
#         print('deconv1:', x.shape)
        
        x = self.deconv2(x)
        x = F.relu(x)
        x = self.bn_dc2(x)
#         print('deconv2:', x.shape)
#         x = self.lstm_fc1(x)
#         x = F.relu(x)
        
#         x = self.lstm_fc2(x)
#         x = F.relu(x)

#         print('lstm_fc:', x.shape)
        return x    

In [None]:
torch.cuda.empty_cache()

In [8]:
model = ATModel(in_channels=1, out_channels=3, kernel_size=3, stride=2, padding=1, batch_size=batch_size)
count_parameters(model)

device = ("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)
model.train()
criterion = nn.L1Loss()

optimizer = optim.Adam(model.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, threshold=1e-5)

cuda:0


In [None]:
device = ("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)
model.train()
criterion = nn.L1Loss()

optimizer = optim.Adam(model.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, threshold=1e-5)

# overfit to minibatch
iter_train = iter(loader_train)

steps = 10000
img, labels = iter_train.next() # retrieve minibatch
img, labels = img.to(device), labels.to(device)

epoch_pbar = tqdm_notebook(range(steps))

for steps in epoch_pbar:
    output = model.forward(img)
    loss = criterion(output, labels)
    epoch_pbar.set_description("Loss: {}".format(str(loss.item())))
    loss.backward()
    optimizer.step()
    scheduler.step(loss.item())
    optimizer.zero_grad()

In [None]:
sample = iter_train.next()

In [None]:
audio_t = transforms.GriffinLim(n_fft=255)

In [None]:
preds = model.forward(img.to(device))

In [None]:
fig = plt.figure(figsize=(8, 8))
fig.add_subplot(1, 2, 1)
plt.imshow(preds[0].cpu().squeeze(0).detach().numpy())
fig.add_subplot(1, 2, 2)
plt.imshow(labels[0].cpu().squeeze(0).detach().numpy())

In [None]:
IPython.display.Audio(audio_t(preds[8].cpu()).detach().numpy(), rate=22050)

In [None]:
epochs = 90
model.train()

writer = SummaryWriter()

epoch_pbar = tqdm_notebook(range(epochs))

for epoch in epoch_pbar:
    # training
    iter_train = iter(loader_train)
    offset = epoch * len(loader_train) # training_iter offset
    data_pbar = tqdm_notebook(range(len(loader_train)))
    train_loss = 0
    bad_batches = 0
    for data in data_pbar:
        # hack to bypass dataloading error.
        # will result in lower loss than actual (since dividing by larger number)
        try:
            img, labels = iter_train.next()
        except:
            bad_batches += 1
        img, labels = img.to(device), labels.to(device)
        output = model.forward(img)
        
        loss = criterion(output, labels)
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        data_pbar.set_description("Training Loss: {}".format(str(loss.item())))
        global_batch_num = offset + data
        writer.add_scalar('Loss/train', loss.item(), global_batch_num) # plotting train loss over batch_num
        train_loss += loss.item()
    print('bad_batches:', bad_batches)
    train_loss /= len(loader_train)
    print('avg train loss:', train_loss)
    scheduler.step(train_loss)
    torch.cuda.empty_cache()

HBox(children=(IntProgress(value=0, max=90), HTML(value='')))

HBox(children=(IntProgress(value=0, max=480), HTML(value='')))

bad_batches: 4
avg train loss: 0.34676499298463265


HBox(children=(IntProgress(value=0, max=480), HTML(value='')))

bad_batches: 4
avg train loss: 0.3442872630432248


HBox(children=(IntProgress(value=0, max=480), HTML(value='')))