In [344]:
import torch.nn as nn
import librosa
import numpy as np
import torch
from torch.utils import data
import matplotlib.pyplot as plt
import IPython.display as ipd

In [345]:
#load data 

c1, cr = librosa.load('Data/Chrysanth-instrument.m4a')
c2, cr = librosa.load('Data/Chrysanth-vocal.m4a')
len_c = min(len(c1), len(c2))

h1, hr = librosa.load('Data/HBD-instrument.m4a')
h2, hr = librosa.load('Data/HBD-vocal.m4a')
len_h = min(len(h1), len(h2))

# mix sound

C_mix = librosa.stft(c1[:len_c] + c2[:len_c], n_fft=1024, hop_length=512)  # mixed_1
H_mix = librosa.stft(h1[:len_h] + h2[:len_h], n_fft=1024, hop_length=512)  # mixed_1

C_instr = librosa.stft(c1[:len_c] , n_fft=1024, hop_length=512) 
H_instr = librosa.stft(h1[:len_h] , n_fft=1024, hop_length=512) 

C_vocal = librosa.stft(c2[:len_c] , n_fft=1024, hop_length=512) 
H_vocal = librosa.stft(h2[:len_h] , n_fft=1024, hop_length=512) 



In [346]:
ipd.Audio(c1[:len_c] + c2[:len_c], rate = cr)

In [347]:
ipd.Audio(h1[:len_h] + h2[:len_h], rate = hr)

In [348]:
# Build Data

class MixData(data.Dataset):
    def __init__(self, mix, instrument, vocal):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.mix = mix.T
        self.instrument = instrument.T
        self.vocal = vocal.T
#         self.transform = transform

    def __len__(self):
        return len(self.mix)

    def __getitem__(self, idx):
        
        beta=2
        input_mix = np.abs(self.mix[idx])
        mask = np.abs(self.instrument[idx])**beta/(np.abs(self.instrument[idx])**beta + \
                                               np.abs(self.vocal[idx])**beta)

#         if self.transform:
#             sample = self.transform(sample)

        return input_mix, mask

bs = 20

train_loader = torch.utils.data.DataLoader(MixData(C_mix, C_instr, C_vocal ), batch_size = bs, shuffle \
                                           = True, num_workers = 4)
# test_loader = torch.utils.data.DataLoader(MixData(H_mix, H_instr, H_vocal), batch_size = bs, shuffle \
#                                            = True, num_workers = 4)    


In [349]:
# Build Network

class Remixing(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(1, 8, 3, 1),
            nn.ReLU(),
            nn.Conv1d(8, 16, 3, 1),
            nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(8144, 2048),
            nn.ReLU(),
            nn.Linear(2048, 513),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = x.reshape((x.shape[0], 1, x.shape[1]))
        x = self.conv(x)
        
        x = x.view(x.shape[0],-1)
#         print(x.shape)
        x = self.fc(x)
        
        return x
         

In [350]:
maxEpoch = 100
model = Remixing().cuda()
optimizer= torch.optim.Adam(model.parameters(), lr=0.001)
loss_function= nn.MSELoss()

In [351]:
# Train Model
errt = np.zeros(maxEpoch)
for i in range(maxEpoch):
    model.train()
    for mix, mask in train_loader:
        
        mix = mix.cuda()
#         print(mix.shape)
        mask = mask.cuda()
        optimizer.zero_grad()
        mask_h = model(mix)
#         print(mask_h.shape)
#         print(mask.shape)
        loss = loss_function(mask_h, mask)
        errt[i] += loss.data.cpu().numpy()
        loss.backward()
        optimizer.step()
    print('Epoch-{}'.format(i), np.mean(errt[i]))
    

Epoch-0 10.066270679235458
Epoch-1 7.797088012099266
Epoch-2 6.877188324928284
Epoch-3 6.268712058663368
Epoch-4 5.779728874564171
Epoch-5 5.4055228643119335
Epoch-6 5.080741323530674
Epoch-7 4.7780790366232395
Epoch-8 4.523646574467421
Epoch-9 4.299571972340345
Epoch-10 4.078583549708128
Epoch-11 3.8986913599073887
Epoch-12 3.709589798003435
Epoch-13 3.529097802937031
Epoch-14 3.3769357092678547
Epoch-15 3.221481256186962
Epoch-16 3.058282356709242
Epoch-17 2.9156608954072
Epoch-18 2.786071017384529
Epoch-19 2.6444577164947987
Epoch-20 2.525842972099781
Epoch-21 2.4116259180009365
Epoch-22 2.3017156049609184
Epoch-23 2.177434654906392
Epoch-24 2.082939837127924
Epoch-25 1.9894858002662659
Epoch-26 1.8943636361509562
Epoch-27 1.814569890499115
Epoch-28 1.7286887932568789
Epoch-29 1.650138009339571
Epoch-30 1.5907513136044145
Epoch-31 1.5142084518447518
Epoch-32 1.445036861114204
Epoch-33 1.3799724699929357
Epoch-34 1.3263532174751163
Epoch-35 1.275683332234621
Epoch-36 1.23511373344808

In [355]:
# Test on data

model.eval()
test_abs = torch.tensor(np.transpose(np.abs(H_mix))).cuda()
# test_abs.shape
mask_h = model(test_abs)
h_instr_p = H_mix * mask_h.data.cpu().numpy().T
h_instr_p = librosa.istft(h_instr_p, hop_length=512)
ipd.Audio(h_instr_p, rate = hr)

In [353]:
model.eval()
test_abs = torch.tensor(np.transpose(np.abs(C_mix))).cuda()
# test_abs.shape
mask_c = model(test_abs)
C_instr_p = C_mix * mask_c.data.cpu().numpy().T
C_instr_p = librosa.istft(C_instr, hop_length=512)
ipd.Audio(C_instr_p, rate = cr)