# Audio SRGAN

In [1]:
import os, re
import numpy as np
import IPython.display as ipd
SPEED = '0.5x'

audio_files = []
for root, _, files in os.walk(os.path.expanduser('~/FMA/fma_small/fma_small/000/')):
    for f in files:
        audio_files.append(root+f)
audio_files[:5]

['/home/b073040018/FMA/fma_small/fma_small/000/000002.mp3',
 '/home/b073040018/FMA/fma_small/fma_small/000/000005.mp3',
 '/home/b073040018/FMA/fma_small/fma_small/000/000010.mp3',
 '/home/b073040018/FMA/fma_small/fma_small/000/000140.mp3',
 '/home/b073040018/FMA/fma_small/fma_small/000/000141.mp3']

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader

class FMA(Dataset):
    def __init__(self, audio_files):
        self.audio_files = audio_files

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        x, sr = torchaudio.load(audio_files[idx])
        x_stretched, _ = torchaudio.load(re.sub(
            'fma_small/fma_small',
            'fma_small/fma_small_' + SPEED,
            audio_files[idx]
        ))
        # might be mono
        if x.shape[0] != 2:
            x = x.repeat(2, 1)
            x_stretched = x_stretched.repeat(2, 1)
            
        x = x[:, :1*sr] # extract first 5 sec
        x_stretched = x_stretched[:, :2*sr]
        
        return x, x_stretched, sr

  '"sox" backend is being deprecated. '


In [3]:
dataset_train = FMA(audio_files)
print(f'Number of samples: {len(dataset_train)}')
x, x_stretched, sr = dataset_train[0]
print('First item:')
print('Original speed:')
print(x)
ipd.Audio(x[:, :sr*4], rate=sr)

Number of samples: 62
First item:
Original speed:
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0903, 0.0953, 0.1142],
        [0.0000, 0.0000, 0.0000,  ..., 0.1133, 0.1215, 0.1434]])


In [4]:
print('Stretched speed:')
print(x_stretched)
ipd.Audio(x_stretched[:, :sr*4], rate=sr)

Stretched speed:
tensor([[ 0.0000,  0.0000,  0.0000,  ..., -0.2078, -0.2285, -0.2436],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.2473, -0.2711, -0.2868]])


In [5]:
train_loader = DataLoader(dataset_train, batch_size=4, shuffle=True)

In [6]:
def pixel_shuffle_1d(x, upscale_factor=2):
    batch_size = x.shape[0]
    short_channel_len = x.shape[1]
    short_width = x.shape[2]

    long_channel_len = short_channel_len // upscale_factor
    long_width = upscale_factor * short_width

    x = x.contiguous().view([batch_size, upscale_factor, long_channel_len, short_width])
    x = x.permute(0, 2, 3, 1).contiguous()
    x = x.view(batch_size, long_channel_len, long_width)

    return x
    
class Bottleneck(nn.Module):
    '''The residual block with bottleneck'''
    expansion = 4
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int = 1,
        groups: int = 1,
        base_width: int = 64,
        norm_layer = nn.BatchNorm1d,
    ):
        super(Bottleneck, self).__init__()
        neck_width = int(out_channels*(base_width/64.))*groups
        self.seq = nn.Sequential(
            nn.Conv1d(in_channels, neck_width, 1, bias=False),
            norm_layer(neck_width),
            nn.ReLU(inplace=True),
            nn.Conv1d(neck_width, neck_width, 3, padding=1, stride=stride, groups=groups, bias=False),
            norm_layer(neck_width),
            nn.ReLU(inplace=True),
            nn.Conv1d(neck_width, out_channels*self.expansion, 1, bias=False),
            norm_layer(out_channels*self.expansion),
        )
        self.bypass = nn.Sequential(
            nn.Conv1d(in_channels, out_channels*self.expansion, 1, stride=stride, bias=False),
            norm_layer(out_channels*self.expansion)
        )
    def forward(self, x):
        out = self.seq(x)
        x = self.bypass(x)
        #residual
        out += x
        F.relu_(out)
        return out

class ResNet50(nn.Module):
    def __init__(
        self, 
        blocks_list = [3, 4, 6, 3],
        num_classes: int = 5,
        groups: int = 1, 
        base_width: int = 64, 
        norm_layer = nn.BatchNorm1d
    ):
        super(ResNet50, self).__init__()
        self.in_channels = base_width
        self.groups = groups
        self.base_width = base_width
        self.pre_layer = nn.Sequential(
            nn.Conv1d(2, self.in_channels, 7, stride=2, padding=3, bias=False),
            norm_layer(self.in_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(3, stride=2, padding=1),
        )
        self.layer1 = self._make_layer(self.base_width, blocks_list[0])
        self.layer2 = self._make_layer(self.base_width*2, blocks_list[1], stride=1)
        self.layer3 = self._make_layer(self.base_width*4, blocks_list[2], stride=1)
        self.layer4 = self._make_layer(self.base_width*8, blocks_list[3], stride=1)
        
        #initialize parameters
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x): 
        if not isinstance(x, torch.Tensor):
            x = torch.Tensor(x)
        out = self.pre_layer(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        return out


    def _make_layer(self, out_channels, blocks, stride=1):
        '''helper function to create multiple blocks'''
        layers = []
        layers.append(Bottleneck(self.in_channels, out_channels, stride, self.groups, self.base_width))
        self.in_channels = out_channels*Bottleneck.expansion
        for _ in range(blocks-1):
            layers.append(Bottleneck(self.in_channels, out_channels, groups=self.groups, base_width=self.base_width))
        return nn.Sequential(*layers)
    
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.backbone = ResNet50()
        self.final = nn.Conv1d(256, 2, 1, bias=False)
    def forward(self, x):
        out = self.backbone(x)
        out = pixel_shuffle_1d(out, 4)
        out = pixel_shuffle_1d(out, 2)
        out = self.final(out)
        return out

In [7]:
# Num epochs
start_epoch, max_epoch = 1, 5
log_step = 5

# Loss functions
criterion = torch.nn.MSELoss()

# Initialize generator
generator = Generator()

if torch.cuda.is_available():
    generator.cuda()
    criterion.cuda()

# Optimizers
optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [8]:
from tqdm import tqdm
from IPython.display import clear_output
import time

t = time.time()
loader = train_loader

for epoch in tqdm(range(start_epoch, max_epoch + 1)):
    for data in loader:
        x = data[0].cuda()
        y = data[1].cuda()
        # ------------------
        #  Train Generators
        # ------------------
        optimizer.zero_grad()
        gen = generator(x)
        loss = criterion(gen, y)
        loss.backward()
        optimizer.step()
        
    print(epoch)
    print(loss.cpu().detach().numpy())

 20%|██        | 1/5 [00:38<02:33, 38.29s/it]

1
0.08020061


 40%|████      | 2/5 [01:16<01:54, 38.15s/it]

2
0.06979436


 60%|██████    | 3/5 [01:53<01:15, 37.91s/it]

3
0.0174601


 80%|████████  | 4/5 [02:30<00:37, 37.77s/it]

4
0.04413254


100%|██████████| 5/5 [03:08<00:00, 37.67s/it]

5
0.063230194





In [12]:
x, streched, sr = dataset_train[1]
ipd.Audio(x[:, :sr*4], rate=sr)

In [13]:
ipd.Audio(streched[:, :sr*4], rate=sr)

In [14]:
gen = generator(torch.unsqueeze(x, 0).cuda()).cpu().detach().numpy()[0]
ipd.Audio(gen[:, :sr*4], rate=sr)