In [1]:
!pip install torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
[0m

In [2]:
import torch
import torch.cuda
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
from torch.autograd import Variable 
import h5py
from torch.utils.data import DataLoader, Dataset, BatchSampler, Sampler
import soundfile as sf
import librosa
import librosa.display
import math
import soundfile as sf
import pandas as pd
import os
import numpy as np
from scipy.io.wavfile import write
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import matplotlib.pylab as plt
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary


In [3]:
os.listdir('/kaggle/input/download-dataset/')

['__results__.html',
 'dataset_babble_v2.hdf5',
 '__notebook__.ipynb',
 '__output__.json',
 'custom.css']

In [4]:
num_workers = 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
device

device(type='cuda')

In [6]:
def evaluate(model, loader, criterion=None, epoch=None, log=None):
    model.eval()
    loss = 0
    for data, target in loader:
        # scaler = MinMaxScaler(feature_range=(0, 1))
        # data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape)
        # target = scaler.transform(target.reshape(-1, target.shape[-1])).reshape(target.shape)
        data = data.to(device)
        target = target.to(device)
        
        output = torch.squeeze(model(data))

        if criterion is not None:
            loss += criterion(output, target.to(torch.float32)).item()

    if criterion is not None:
        loss /= len(loader.dataset)

    print('Average loss: {:.4f}\n'.format(loss))


class RandomBatchSampler(Sampler):
    """Sampling class to create random sequential batches from a given dataset
    E.g. if data is [1,2,3,4] with bs=2. Then first batch, [[1,2], [3,4]] then shuffle batches -> [[3,4],[1,2]]
    This is useful for cases when you are interested in 'weak shuffling'
    :param dataset: dataset you want to batch
    :type dataset: torch.utils.data.Dataset
    :param batch_size: batch size
    :type batch_size: int
    :returns: generator object of shuffled batch indices
    """
    def __init__(self, dataset, batch_size):
        self.batch_size = batch_size
        self.dataset_length = len(dataset)
        self.n_batches = self.dataset_length / self.batch_size
        self.batch_ids = torch.randperm(int(self.n_batches))

    def __len__(self):
        return self.batch_size

    def __iter__(self):
        for id in self.batch_ids:
            idx = torch.arange(id * self.batch_size, (id + 1) * self.batch_size)
            for index in idx:
                yield int(index)
        if int(self.n_batches) < self.n_batches:
            idx = torch.arange(int(self.n_batches) * self.batch_size, self.dataset_length)
            for index in idx:
                yield int(index)


def fast_loader(dataset, batch_size=32, drop_last=False, transforms=None):
    """Implements fast loading by taking advantage of .h5 dataset
    The .h5 dataset has a speed bottleneck that scales (roughly) linearly with the number
    of calls made to it. This is because when queries are made to it, a search is made to find
    the data item at that index. However, once the start index has been found, taking the next items
    does not require any more significant computation. So indexing data[start_index: start_index+batch_size]
    is almost the same as just data[start_index]. The fast loading scheme takes advantage of this. However,
    because the goal is NOT to load the entirety of the data in memory at once, weak shuffling is used instead of
    strong shuffling.
    :param dataset: a dataset that loads data from .h5 files
    :type dataset: torch.utils.data.Dataset
    :param batch_size: size of data to batch
    :type batch_size: int
    :param drop_last: flag to indicate if last batch will be dropped (if size < batch_size)
    :type drop_last: bool
    :returns: dataloading that queries from data using shuffled batches
    :rtype: torch.utils.data.DataLoader
    """
    return DataLoader(
        dataset, batch_size=None,  # must be disabled when using samplers
        sampler=BatchSampler(RandomBatchSampler(dataset, batch_size), batch_size=batch_size, drop_last=drop_last), 
        num_workers=num_workers
    )


class HDF5Dataset(Dataset):
    def __init__(self, file_path, operation = 'train'):
        self.file_path = file_path
        self.x_dataset_name = 'x_' + operation
        self.y_dataset_name = 'y_' + operation
        self.length = None

        with h5py.File(self.file_path, 'r') as hf:
            
            self.length = len(hf.get(self.x_dataset_name))

    def __len__(self):
        return self.length

    def _open_hdf5(self):
        self._hf = h5py.File(self.file_path, 'r')

    def __getitem__(self, index):
        if not hasattr(self, '_hf'):
            self._open_hdf5()

        x = self._hf[self.x_dataset_name][index]
        y = self._hf[self.y_dataset_name][index]

        x = (torch.from_numpy(x)).to(torch.float32)
        y = (torch.from_numpy(y)).to(torch.float32)
        return (x, y)

def get_train_loader_hdf5(batch_size):
    print('Train: ', end="")
    train_dataset = HDF5Dataset('/kaggle/input/download-dataset/dataset_babble_v2.hdf5', operation = 'train')
    # train_loader = DataLoader(train_dataset, batch_size=batch_size,
    #                           shuffle=True, num_workers=num_workers, drop_last= False)
    train_loader = fast_loader(train_dataset, batch_size=batch_size, drop_last=False)
    print('Found', len(train_dataset), ' train samples')
    return train_loader


def get_test_loader_hdf5(batch_size):
    print('Test: ', end="")
    test_dataset = HDF5Dataset('/kaggle/input/download-dataset/dataset_babble_v2.hdf5', operation = 'test')
    # test_loader = DataLoader(test_dataset, batch_size=batch_size,
    #                           shuffle=True, num_workers=num_workers, drop_last= False)
    test_loader = fast_loader(test_dataset, batch_size=batch_size, drop_last=False)

    print('Found', len(test_dataset), ' test samples')
    return test_loader


In [7]:
save_as = 'kaggle_CNN_V2'
max_epochs = 60
batch_size = 64


model_file = 'models/' + save_as + '.pt'

layout = {
    "ABCDE": {
        "loss": ["Multiline", ["loss/train", "loss/validation"]]    
        },
}

writer = SummaryWriter('experiments/' + save_as)
writer.add_custom_scalars(layout)




def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [8]:

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2)

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(2)

        self.conv7 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv8 = nn.Conv2d(512, 512, kernel_size=3, padding=1)

        self.up1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv9 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.conv10 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv11 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.conv12 = nn.Conv2d(128, 128, kernel_size=3, padding=1)

        self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv13 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.conv14 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        self.conv15 = nn.Conv2d(64, 1, kernel_size=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = (x.to(torch.float32)).view(-1, 1, 25, 257)
    
        x1 = self.pool1(self.relu(self.conv2(self.relu(self.conv1(x)))))
        x2 = self.pool2(self.relu(self.conv4(self.relu(self.conv3(x1)))))
        x3 = self.pool3(self.relu(self.conv6(self.relu(self.conv5(x2)))))
        x4 = self.relu(self.conv8(self.relu(self.conv7(x3))))

        x5 = self.up1(x4)
        print(x3.shape)
        print("sssss \n")
        print(x5.shape)
        x5 = torch.cat((x5, x3), dim=1)
        x5 = self.relu(self.conv10(self.relu(self.conv9(x5))))

        x6 = self.up2(x5)
        x6 = torch.cat((x6, x2), dim=1)
        x6 = self.relu(self.conv12(self.relu(self.conv11(x6))))

        x7 = self.up3(x6)
        x7 = torch.cat((x7, x1), dim=1)
        x7 = self.relu(self.conv14(self.relu(self.conv13(x7))))

        x8 = self.conv15(x7)
        
        out = x8.view(-1, 125, 257)

        return out



class LSTM(nn.Module):
    def __init__(self):

        super(LSTM, self).__init__()

        self.fc_1 =  nn.Linear(257, 257) #fully connected 1

        self.lstm = nn.LSTM(input_size=257, hidden_size=257,
                          num_layers=4, batch_first=True) #lstm

        self.fc_2 =  nn.Linear(257, 257) #fully connected 2

        self.relu = nn.ReLU()
    
    def forward(self,x):
        out = (x.to(torch.float32)).view(-1, 25, 257)
        out = self.relu(self.fc_1(out))

        h_0 = Variable(torch.zeros(4, out.size(0), 257)).to(device) #hidden state
        c_0 = Variable(torch.zeros(4, out.size(0), 257)).to(device) #internal state
        # Propagate input through LSTM
        output, (hn, cn) = self.lstm(out, (h_0, c_0)) #lstm with input, hidden, and internal state
        out = self.fc_2(output)
        out = self.relu(out)
        out = out.view(-1, 125, 257) #reshaping the data for Dense layer next


        return out

class CNN(nn.Module):
    def __init__(self):

        super(CNN, self).__init__()

        self.c1 = nn.Conv2d(in_channels = 1, out_channels = 8, kernel_size = 7) #CNN
        self.b1 = nn.BatchNorm2d(num_features = 8)
        self.c2 = nn.Conv2d(in_channels = 8, out_channels = 16, kernel_size = 5) #CNN
        self.b2 = nn.BatchNorm2d(num_features = 16)
        self.c3 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 3) #CNN
        self.b3 = nn.BatchNorm2d(num_features = 32)
        self.t1 = nn.ConvTranspose2d(in_channels= 32, out_channels= 16, kernel_size=3)
        self.b4 = nn.BatchNorm2d(num_features = 16)
        self.t2 = nn.ConvTranspose2d(in_channels= 16, out_channels= 8, kernel_size=5)
        self.b5 = nn.BatchNorm2d(num_features = 8)
        self.t3 = nn.ConvTranspose2d(in_channels= 8, out_channels= 1, kernel_size=7)
        self.b6 = nn.BatchNorm2d(num_features = 1)


        self.relu = nn.ReLU()
    
    def forward(self,x):
        
        # Propagate input through LSTM
        x = (x.to(torch.float32)).view(-1, 1, 25, 257)
        out = self.c1(x) #lstm with input, hidden, and internal state
        out = self.relu(self.b1(out))
        out = self.c2(out)
        out = self.relu(self.b2(out))
        out = self.c3(out)
        out = self.relu(self.b3(out))
        out = self.t1(out)
        out = self.relu(self.b4(out))
        out = self.t2(out)
        out = self.relu(self.b5(out))
        out = self.t3(out)
        out = self.relu(self.b6(out))
        out = out.view(-1, 125, 257)

        return out

class FCNN(nn.Module):
    def __init__(self):

        super(FCNN, self).__init__()

        self.fc_1 = nn.Linear(257, 512) #fully connected 1
        self.fc_2 = nn.Linear((512), 512) #fully connected 2
        self.fc_3 = nn.Linear(512, 512) #fully connected 2
        self.fc_4 = nn.Linear(512, 512) #fully connected 2
        self.fc_5 = nn.Linear(512, 257) #fully connected last layer
        
        self.relu = nn.ReLU()
    
    def forward(self,x):
        out = x.to(torch.float32)

        out = self.fc_1(out)
        out = self.relu(out)

        out = self.fc_2(out)
        out = self.relu(out)

        out = self.fc_3(out)
        out = self.relu(out)

        out = self.fc_4(out)
        out = self.relu(out)

        out = self.fc_5(out)
        # out = self.relu(out)
        return out


In [9]:
def train_main(use_model):

    model_file = save_as + '.pt'

    if (use_model == 'lstm'):
        model = LSTM().to(device)
#         summary(model, (2, 125, 257))

    elif(use_model == 'fcn'):
        model = FCNN().to(device)
        summary(model, (2, 125, 257))

    elif(use_model == 'cnn'):
        model = CNN().to(device)
        summary(model, (2,1, 125, 257))

    else:
        model = Autoencoder().to(device)
#         summary(model, (2,1, 125, 257))




    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()

    print(model)
    print(f'\n ========== Nb parametres du modele : {count_parameters(model)} ========== \n')
    # summary(model, (2, 125, 257))

    train_loader = get_train_loader_hdf5(batch_size)
    validation_loader = get_test_loader_hdf5(batch_size)
    best_val_loss = None
    best_val_epoch = 0
    max_stagnation = 5
    early_stop = False
    for epoch in range(1, max_epochs + 1):
        start_time = datetime.now()
        # Set model to training mode
        model.train(True)
        epoch_loss = 0.
        # Loop over each batch from the training set
        for data, target in train_loader:

            data = data.to(device)
            target = target.to(device)

            # Zero gradient buffers
            optimizer.zero_grad()

            # Pass data through the network
            output = model(data)
            output = torch.squeeze(output)

            # Calculate loss
            loss = criterion(output, target)
            epoch_loss += loss.item()

            # Backpropagate
            loss.backward()

            # Update weights
            optimizer.step()
        epoch_loss /= len(train_loader.dataset)
        writer.add_scalar("loss/train", epoch_loss, epoch)

        print('Train Epoch: {}, Loss: {:.4f}'.format(epoch, epoch_loss))

        # train(model, train_loader, criterion, optimizer, epoch, log)

        ## VALIDATION
        with torch.no_grad():
            print('\nValidation:')
            # evaluate(model, validation_loader, criterion, epoch, log)
            model.eval()
            loss = 0.
            snr = 0.
            for data, target in validation_loader:
                # scaler = MinMaxScaler(feature_range=(0, 1))
                # data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape)
                # target = scaler.transform(target.reshape(-1, target.shape[-1])).reshape(target.shape)
                data = data.to(device)
                target = target.to(device)
                
                output = torch.squeeze(model(data))
                # snr += mix.compute_SNR(target, output)
                loss += criterion(output, target).item()
            
            loss /= len(validation_loader.dataset)
            # snr /= len(validation_loader.dataset)
            writer.add_scalar("loss/validation", loss, epoch)
            # writer.add_scalar("loss/val_SNR", snr, epoch)
            print('Average loss: {:.4f}\n'.format(loss))
            if ((best_val_loss is None) or (best_val_loss > loss)):
                best_val_loss, best_val_epoch = loss, epoch
            if (best_val_epoch < (epoch - max_stagnation)):
                # nothing is improving for a while
                early_stop = True


        end_time = datetime.now()
        epoch_time = (end_time - start_time).total_seconds()
        txt = 'Epoch took {:.2f} seconds.'.format(epoch_time)
        print(txt)
        if(early_stop):
            print(f"Stagnation reached")
            break


    torch.save(model.state_dict(), model_file)

In [10]:
use_model = 'cnn'
torch.cuda.empty_cache()
import gc
gc.collect()
train_main(use_model)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 8, 19, 251]             400
       BatchNorm2d-2           [-1, 8, 19, 251]              16
              ReLU-3           [-1, 8, 19, 251]               0
            Conv2d-4          [-1, 16, 15, 247]           3,216
       BatchNorm2d-5          [-1, 16, 15, 247]              32
              ReLU-6          [-1, 16, 15, 247]               0
            Conv2d-7          [-1, 32, 13, 245]           4,640
       BatchNorm2d-8          [-1, 32, 13, 245]              64
              ReLU-9          [-1, 32, 13, 245]               0
  ConvTranspose2d-10          [-1, 16, 15, 247]           4,624
      BatchNorm2d-11          [-1, 16, 15, 247]              32
             ReLU-12          [-1, 16, 15, 247]               0
  ConvTranspose2d-13           [-1, 8, 19, 251]           3,208
      BatchNorm2d-14           [-1, 8, 