In [1]:
import numpy as np
import os
from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn as nn
import torch
import pandas as pd
from torch.autograd import Variable

In [2]:
dir_path ="train/clean"

In [3]:
def padding_left(numpy_image, frame_size):
    result_image = np.zeros((frame_size,numpy_image.shape[1])).astype(np.float16)
    result_image[:numpy_image.shape[0],:] =numpy_image
    return result_image    

In [4]:
def normalize_data(numpy_array):
        numpy_array =(numpy_array -numpy_array.min())/ (numpy_array.max()- numpy_array.min()) *255
        return numpy_array.astype(np.uint8)

In [5]:
def cut_image_into_frames(numpy_image, frame_size):
    if (numpy_image.shape[0]< frame_size):
        result_image =padding_left(numpy_image, frame_size)
        return [result_image]
    elif(numpy_image.shape[0]== frame_size):
        return [numpy_image]
    else:
        results =[]
        frame_number =int(np.ceil(numpy_image.shape[0]/frame_size))
        for frame_id in range(0,frame_number):
            if(numpy_image.shape[0]>=(frame_id+1)*frame_size):
                result_image=numpy_image[frame_id*frame_size:(frame_id+1)*frame_size, :]
                results.append(result_image)
            else:
                result_image=padding_left(numpy_image[frame_id*frame_size:, :],frame_size)
                results.append(result_image)
        return results       

In [6]:
def combine_frames(results, size):
    result_frame =np.concatenate(results, axis=0)
    return result_frame[:size, :]

In [7]:
class DenoisingDatasetVal(Dataset):    
    """Sound denoising dataset."""

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        clean_path ="clean"
        noisy_path ="noisy"
        self.train_mels =[]
        train_packages =os.listdir(os.path.join(self.root_dir, clean_path))
        for train_package in train_packages:
            clean_package =os.path.join(self.root_dir,clean_path,train_package)
            mel_files =os.listdir(clean_package)
            for mel_file in mel_files:
                mel_clean =os.path.join(clean_package,mel_file)
                mel_noise =os.path.join(self.root_dir,noisy_path,train_package, mel_file)
                self.train_mels.append((mel_clean,mel_noise))

    def __len__(self):
        return len(self.train_mels)

    def __getitem__(self, idx):
        path_clean =self.train_mels[idx][0]
        path_noise =self.train_mels[idx][1]
        return path_clean, path_noise

In [8]:
class DenoisingDatasetTrain(Dataset):    
    """Sound denoising dataset."""

    def __init__(self, train_path):
        self.train_path = train_path
        self.data =pd.read_csv(self.train_path)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        path_clean =self.data.iloc[idx]["mel_clean_frame_path"]
        path_noise =self.data.iloc[idx]["mel_noise_frame_path"]
        study_id= self.data.iloc[idx]["study_id"]
        clean_image_frame =np.load(path_clean)
        noisy_image_frame =np.load(path_noise)
        clean_image_frame =clean_image_frame[np.newaxis, :, :]
        noisy_image_frame =noisy_image_frame[np.newaxis, :, :]
        return study_id,clean_image_frame, noisy_image_frame

In [9]:
def plot_frames(*args):
    args = [x.squeeze() for x in args]
    n = min([x.shape[0] for x in args])
    
    plt.figure(figsize=(2*n, 2*len(args)))
    for j in range(n):
        for i in range(len(args)):
            ax = plt.subplot(len(args), n, i*n + j + 1)
            plt.imshow(args[i][j])
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

    plt.show()

In [10]:
train_data=DenoisingDatasetTrain("denoising_dataset.csv")
val_data=DenoisingDatasetVal("val")

In [11]:
train_loader =DataLoader(train_data, batch_size=128)
val_loader =DataLoader(val_data, batch_size=1)

In [13]:
class DenoisingAutoEncoder(nn.Module):
 
    def __init__(self):
        super(DenoisingAutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=1, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2, stride=2), 
            nn.Conv2d(16, 8, 3, stride=1, padding=1), 
            nn.ReLU(True),
            nn.BatchNorm2d(8),
            nn.MaxPool2d(2, stride=2)  
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, stride=2),  
            nn.ReLU(True),
            nn.BatchNorm2d(16),
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm2d(8),
            nn.ConvTranspose2d(8, 1, 2, stride=1, padding=1),  
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [14]:
num_epochs =50
learning_rate =0.01
frame_size= 80

In [15]:
model = DenoisingAutoEncoder().cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                             weight_decay=1e-5)

In [16]:
model

DenoisingAutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(8, 16, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stat

In [17]:
def cross_validation(model,criterion, val_loader):
    overal_valid_loss =0
    val_count=0
    with torch.no_grad():
        for path_clean, path_noise in val_loader:
            val_count+=1
            model.eval()
            mel_clean =np.load(path_clean[0])
            mel_noise =np.load(path_noise[0])
            mel_clean_frames =cut_image_into_frames(mel_clean, frame_size)
            mel_noise_frames =cut_image_into_frames(mel_noise, frame_size)
            overal_image_loss =0
            for mel_clean_frame,mel_noise_frame in  zip(mel_clean_frames,mel_noise_frames):
                    mel_noise_frame= torch.tensor(mel_noise_frame, dtype =torch.float)
                    mel_noise_frame=mel_noise_frame[np.newaxis,np.newaxis, :, :]
                    mel_clean_frame =torch.tensor(mel_clean_frame,dtype =torch.float)
                    mel_clean_frame=mel_clean_frame[np.newaxis,np.newaxis, :, :]
                    img_noisy = Variable(mel_noise_frame).cuda()
                    img_clean = Variable(mel_clean_frame).cuda()
                    # ===================forward=====================
                    output = model(img_noisy)
                    loss = criterion(output, img_clean)
                    overal_image_loss+=loss.cpu().detach().numpy()
            overal_image_loss =overal_image_loss/ len(mel_clean_frames)
            overal_valid_loss+=overal_image_loss
        print('val loss:{:.4f}', overal_valid_loss/val_count)

In [19]:
for epoch in range(num_epochs):
    overal_train_loss =0
    count_train =0
    model.train()
    for _, mel_clean, mel_noise in train_loader:
        count_train+=1
        mel_noise= mel_noise.float()
        mel_clean =mel_clean.float()
        img_noisy = Variable(mel_noise).cuda()
        img_clean = Variable(mel_clean).cuda()
        # ===================forward=====================
        output = model(img_noisy)
        loss = criterion(output, img_clean)
        overal_train_loss+=loss.cpu().detach().numpy()
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
    print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, overal_train_loss/count_train))
#     plot_frames(img_clean[:10].cpu().detach().numpy(),img_noisy.cpu().detach().numpy(),output.cpu().detach().numpy())
    cross_validation(model,criterion, val_loader)
    if epoch % 10 == 0:
        torch.save(model.state_dict(), 'denoising_autoencoder_simple.pth')


epoch [1/50], loss:0.0731
val loss:{:.4f} 0.07022970320064316
epoch [2/50], loss:0.0693
val loss:{:.4f} 0.0678097390924095
epoch [3/50], loss:0.0678
val loss:{:.4f} 0.06923925768529139
epoch [4/50], loss:0.0668
val loss:{:.4f} 0.07002911025186967
epoch [5/50], loss:0.0665
val loss:{:.4f} 0.0676533434135941
epoch [6/50], loss:0.0658
val loss:{:.4f} 0.06679688660258354
epoch [7/50], loss:0.0655
val loss:{:.4f} 0.06360243008487781
epoch [8/50], loss:0.0655
val loss:{:.4f} 0.06585996852697754
epoch [9/50], loss:0.0650
val loss:{:.4f} 0.06527411333986496
epoch [10/50], loss:0.0653
val loss:{:.4f} 0.06299372130448146
epoch [11/50], loss:0.0650
val loss:{:.4f} 0.0696481352007448
epoch [12/50], loss:0.0649
val loss:{:.4f} 0.07118571667577005
epoch [13/50], loss:0.0647
val loss:{:.4f} 0.06707869607629295
epoch [14/50], loss:0.0653
val loss:{:.4f} 0.06405318503310774
epoch [15/50], loss:0.0646
val loss:{:.4f} 0.06336034280734369
epoch [16/50], loss:0.0645
val loss:{:.4f} 0.06378170728332595
epoc

KeyboardInterrupt: 