In [1]:
import os
import numpy as np
from PIL import Image
from torch.utils import data
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.autograd import Variable
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import time
import gc
import matplotlib.pyplot as plt
import librosa
import IPython.display as ipd
import scipy
import ot
import util

In [2]:
class CustomDataset(Dataset):
    def __init__(self, clean_data_s, noise_data_s, clean_data_t, noise_data_t):
        self.clean_data_s = clean_data_s
        self.noise_data_s = noise_data_s
        self.clean_data_t = clean_data_t
        self.noise_data_t = noise_data_t

    def __len__(self):
        return len(self.clean_data_s)

    def __getitem__(self, index):
        clean_sample_s = self.clean_data_s[index]
        noise_sample_s = self.noise_data_s[index]
        clean_sample_t = self.clean_data_t[index]
        noise_sample_t = self.noise_data_t[index]
        return clean_sample_s, noise_sample_s, clean_sample_t, noise_sample_t

In [3]:
ds_train = torch.load('data/DATASET/TRAIN_ST_FINAL_01.pt')
ds_test = torch.load('data/DATASET/TEST_ST_FINAL_01.pt')
ds_val = torch.load('data/DATASET/VALIDATION_ST_FINAL_01.pt')

In [4]:
batch_size = 128
torch.manual_seed(0)

<torch._C.Generator at 0x7f8ef0219d50>

In [5]:
#Train
custom_dataloader_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
#Test
custom_dataloader_test = DataLoader(ds_test, batch_size=batch_size, shuffle=True)
#Val
custom_dataloader_val = DataLoader(ds_val, batch_size=batch_size, shuffle=True)

In [6]:
for i, (clean_spec_s, noise_spec_s, clean_spec_t, noise_spec_t) in enumerate(custom_dataloader_train):
    print(i)
    print(clean_spec_s.size())
    print(noise_spec_s.size())
    print(clean_spec_t.size())
    print(noise_spec_t.size())
    #break

0
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
1
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
2
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
3
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
4
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
5
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
6
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
7
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
8
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
torch.Size([128, 64, 257])
9
torch.Si

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [8]:
def OT_loss(X_s, X_t, y_s, y_t_pred):
    N = X_s.shape[0]

    C0 = torch.cdist(X_s.reshape((N, -1)), X_t.reshape((N, -1)), p=2).cpu()
    C1 = torch.cdist(y_s.reshape((N, -1)), y_t_pred.reshape((N, -1)), p=2).cpu()

    alpha = 1  # OT source weight in loss
    beta = 1   # OT target weight in loss
    C = alpha * C0 + beta * C1

    γ = ot.emd(ot.unif(N), ot.unif(N), C.detach().numpy())
    γ = torch.from_numpy(γ).float()

    loss = torch.sum(γ * C)

    return loss

In [9]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.blstm = nn.LSTM(257, 1024, dropout=0.0, num_layers=2, bidirectional=True, batch_first=True)
        self.LReLU = nn.LeakyReLU(0.3)
        self.ReLU = nn.ReLU()
        self.Dropout = nn.Dropout(p=0.0)
        self.fc1 = nn.Linear(1024 * 2, 1024)
        self.fc2 = nn.Linear(1024, 257)

    def forward(self, x):
        #  x: clean mag, y: noise mag
        output, _ = self.blstm(x)
        output = self.fc1(output)
        output = self.LReLU(output)
        output = self.Dropout(output)
        output = self.fc2(output)
        output = self.ReLU(output)
        return output

In [10]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.conv1 = nn.Sequential(  # input shape (batch_size, 1, 64, 257)
            nn.Conv2d(
                in_channels=1,  # input height
                out_channels=8,  # n_filters
                kernel_size=5,  # filter size
                stride=1,  # filter movement/step
                padding=2,
                # if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1
            ),  # output shape (16, 28, 28)
            nn.ReLU(),  # activation
            nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (batch_size, 16, 32, 128)
        )
        self.conv2 = nn.Sequential(  # input shape (batch_size, 16, 32, 128)
            nn.Conv2d(8, 16, 5, 1, 2),
            nn.ReLU(),  # activation
            nn.MaxPool2d(2),  # output shape (batch_size, 32, 16, 64)
        )
        self.out1 = nn.Sequential(
            nn.Linear(16 * 16 * 64, 16 * 16),  # fully connected layer, output 10 classes
            nn.ReLU()  # activation
        )

        self.out2 = nn.Sequential(
            nn.Linear(16 * 16, 1),  # fully connected layer, output 10 classes
            nn.Sigmoid()  # activation
        )

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)  # flatten the output of conv2 to (batch_size, 32 * 16 * 64)
        x = self.out1(x)
        output = self.out2(x)
        return output

In [11]:
G = Generator().to(device)

In [12]:
D = Discriminator().to(device)

In [13]:
hyperparameters = {
    'NUM_EPOCHS' : 150,
    'optimizer_g' : torch.optim.Adam(G.parameters(), lr=1e-5),
    'optimizer_d' : torch.optim.Adam(D.parameters(), lr=1e-3),
    'optimizer_s' : torch.optim.Adam(G.parameters(), lr=1e-4),
    'optimizer_ot' : torch.optim.Adam(G.parameters(), lr=1e-5),
    'criterion_g' : nn.MSELoss()
}

In [14]:
LOSS_TRAIN_MSE_S = []
LOSS_VAL_MSE_S = []
LOSS_TRAIN_MSE_T = []
LOSS_VAL_MSE_T = []
TIME = []

In [15]:
# Training with OT Loss
for epoch in range(hyperparameters['NUM_EPOCHS']):
    start_time = time.time()
    epoch_OT_loss = 0.0
    epoch_S_loss = 0.0
    epoch_D_loss = 0.0
    epoch_G_loss = 0.0
    #Train NN
    D.train()
    G.train()
    Count = 0
    for batch_idx, (clean_stft_mag_features_s, noise_stft_mag_features_s, _, noise_stft_mag_features_t) in enumerate(custom_dataloader_train):
        #load features in CUDA
        y_s = clean_stft_mag_features_s.to(device)
        X_s = noise_stft_mag_features_s.to(device)
        #y_t = clean_stft_mag_features_t.to(device)
        X_t = noise_stft_mag_features_t.to(device)
        if (batch_idx % 13 == 0):
            hyperparameters['optimizer_d'].zero_grad()
            loss_d = -torch.mean(D(y_s)) + torch.mean(D(G(X_t).detach()))
            loss_d.backward()
            hyperparameters['optimizer_d'].step()
            for p in D.parameters():
                p.data.clamp_(-0.001, 0.001)
            epoch_D_loss += loss_d.item()
        if (batch_idx % 6 == 0):
            hyperparameters['optimizer_g'].zero_grad()
            loss_g = -torch.mean(D(G(X_t)))
            loss_g.backward()
            hyperparameters['optimizer_g'].step()
            epoch_G_loss += loss_g.item()
        if (batch_idx % 2 == 0):
            loss_s = hyperparameters['criterion_g'](G(X_t), y_s)
            hyperparameters['optimizer_s'].zero_grad()
            loss_s.backward()
            hyperparameters['optimizer_s'].step()
            epoch_S_loss += loss_s.item()
        ot_loss = OT_loss(X_s, X_t, y_s, G(X_s))
        hyperparameters['optimizer_ot'].zero_grad()
        ot_loss.backward()
        hyperparameters['optimizer_ot'].step()
        epoch_OT_loss += ot_loss.item()
        # Remove features from CUDA
        del y_s, X_s, X_t
        torch.cuda.empty_cache()
        gc.collect()
    #Eval NN
    G.eval()
    with torch.no_grad():
        running_mse_loss_train_s = 0.0
        running_mse_loss_train_t = 0.0
        for batch_idx, (clean_stft_mag_features_s, noise_stft_mag_features_s, clean_stft_mag_features_t, noise_stft_mag_features_t) in enumerate(custom_dataloader_train):
            #load features in CUDA
            y_s = clean_stft_mag_features_s.to(device)
            X_s = noise_stft_mag_features_s.to(device)
            y_t = clean_stft_mag_features_t.to(device)
            X_t = noise_stft_mag_features_t.to(device)
            outputs_s = G(X_s)
            outputs_t = G(X_t)
            mse_train_s = hyperparameters['criterion_g'](outputs_s, y_s)
            mse_train_t = hyperparameters['criterion_g'](outputs_t, y_t)
            running_mse_loss_train_s += mse_train_s.item()
            running_mse_loss_train_t += mse_train_t.item()
            del y_s, X_s, y_t, X_t, outputs_s, outputs_t
            torch.cuda.empty_cache()
            gc.collect()    
    #Eval NN
    G.eval()
    with torch.no_grad():
        running_mse_loss_val_s = 0.0
        running_mse_loss_val_t = 0.0
        for batch_idx, (clean_stft_mag_features_s, noise_stft_mag_features_s, clean_stft_mag_features_t, noise_stft_mag_features_t) in enumerate(custom_dataloader_val):
            #load features in CUDA
            y_s = clean_stft_mag_features_s.to(device)
            X_s = noise_stft_mag_features_s.to(device)
            y_t = clean_stft_mag_features_t.to(device)
            X_t = noise_stft_mag_features_t.to(device)
            outputs_s = G(X_s)
            outputs_t = G(X_t)
            mse_val_s = hyperparameters['criterion_g'](outputs_s, y_s)
            mse_val_t = hyperparameters['criterion_g'](outputs_t, y_t)
            running_mse_loss_val_s += mse_val_s.item()
            running_mse_loss_val_t += mse_val_t.item()
            del y_s, X_s, y_t, X_t, outputs_s, outputs_t
            torch.cuda.empty_cache()
            gc.collect()
            
    # Scores for train and validation
    epoch_time = time.time() - start_time
    TIME.append(epoch_time)
    print(f"Epoch {epoch+1} took {epoch_time:.2f} seconds\n========================================================")
    print('Epoch %d S_MSE Train loss: %.5f T_MSE Train loss: %.5f' % 
          (epoch + 1, 
           running_mse_loss_train_s/len(custom_dataloader_train),
           running_mse_loss_train_t/len(custom_dataloader_train)))
    print('Epoch %d S_MSE Val loss: %.5f T_MSE Val loss: %.5f' % 
          (epoch + 1, 
           running_mse_loss_val_s/len(custom_dataloader_val),
           running_mse_loss_val_t/len(custom_dataloader_val)))
    
    LOSS_TRAIN_MSE_S.append(running_mse_loss_train_s/len(custom_dataloader_train))
    LOSS_VAL_MSE_S.append(running_mse_loss_val_s/len(custom_dataloader_val))
    LOSS_TRAIN_MSE_T.append(running_mse_loss_train_t/len(custom_dataloader_train))
    LOSS_VAL_MSE_T.append(running_mse_loss_val_t/len(custom_dataloader_val))
    print('\n')

Epoch 1 took 186.83 seconds
Epoch 1 S_MSE Train loss: 0.04782 T_MSE Train loss: 0.02975
Epoch 1 S_MSE Val loss: 0.04847 T_MSE Val loss: 0.03097


Epoch 2 took 189.73 seconds
Epoch 2 S_MSE Train loss: 0.03660 T_MSE Train loss: 0.02342
Epoch 2 S_MSE Val loss: 0.03723 T_MSE Val loss: 0.02455


Epoch 3 took 197.30 seconds
Epoch 3 S_MSE Train loss: 0.03339 T_MSE Train loss: 0.02012
Epoch 3 S_MSE Val loss: 0.03406 T_MSE Val loss: 0.02156


Epoch 4 took 201.37 seconds
Epoch 4 S_MSE Train loss: 0.03198 T_MSE Train loss: 0.01784
Epoch 4 S_MSE Val loss: 0.03351 T_MSE Val loss: 0.01900


Epoch 5 took 202.24 seconds
Epoch 5 S_MSE Train loss: 0.03249 T_MSE Train loss: 0.01657
Epoch 5 S_MSE Val loss: 0.03350 T_MSE Val loss: 0.01759


Epoch 6 took 202.38 seconds
Epoch 6 S_MSE Train loss: 0.02983 T_MSE Train loss: 0.01587
Epoch 6 S_MSE Val loss: 0.03005 T_MSE Val loss: 0.01668


Epoch 7 took 203.95 seconds
Epoch 7 S_MSE Train loss: 0.02731 T_MSE Train loss: 0.01482
Epoch 7 S_MSE Val loss: 0.02843 T_MS

In [16]:
torch.save(G.state_dict(), 'models_regression/G_model_NOISE_DA_WITH_OT_LOSS_SourceTarget_ALL_NOISES_dB_Generator_FINAL.pt')
torch.save(D.state_dict(), 'models_regression/D_model_NOISE_DA_WITH_OT_LOSS_SourceTarget_ALL_NOISES_dB_Discriminador_FINAL.pt')

In [17]:
util.generate_pkl('stastics/', LOSS_TRAIN_MSE_S, 'LOSS_TRAIN_MSE_S')
util.generate_pkl('stastics/', LOSS_VAL_MSE_S, 'LOSS_VAL_MSE_S')
util.generate_pkl('stastics/', LOSS_TRAIN_MSE_T, 'LOSS_TRAIN_MSE_T')
util.generate_pkl('stastics/', LOSS_VAL_MSE_T, 'LOSS_VAL_MSE_T')
util.generate_pkl('stastics/', TIME, 'TIME_DAOT')

[0.047819566623378204, 0.03660100558381371, 0.0333938226573357, 0.03198343392524148, 0.03249464376571299, 0.029833060657965534, 0.02731453431691943, 0.026015028130544834, 0.026654032927726497, 0.026758673694767372, 0.02392238998362998, 0.024915155070293852, 0.021354595930980786, 0.020218397764598623, 0.020594897421718647, 0.021767329886367842, 0.019840442825954, 0.019234575151207568, 0.019657525132183266, 0.018160430874925953, 0.018415987526825748, 0.017651044016218736, 0.017275169002627877, 0.017611990878436744, 0.018747999814158978, 0.0162090191680218, 0.017016337437069966, 0.016835643909871578, 0.016834169872762525, 0.01682635845945162, 0.015627570407261617, 0.015136551514987936, 0.015074099527531061, 0.014845793002418109, 0.014840879317122597, 0.01422232947069682, 0.01454078739884646, 0.014516612741814441, 0.014085417073376539, 0.01372166202567956, 0.01345377296851087, 0.013366491612078263, 0.013948168499911782, 0.012995826733932524, 0.012670524207045552, 0.01351346360018649, 0.012