In [None]:
import os
import numpy as np
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from audiomentations import (
    Compose, AddBackgroundNoise
)
import matplotlib.pyplot as plt
import requests
import tarfile
import io
import soundfile as sf
from sklearn.model_selection import train_test_split
from scipy.signal import convolve
from tqdm import tqdm
import random
from sklearn.metrics import f1_score
from audiomentations import Compose, AddBackgroundNoise

import wandb

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
torch.set_printoptions(threshold=torch.inf)

In [None]:
wandb.init(
    project = "etri",

    config = {

        "DENOISER_LR": 3e-4,
        "CLASSIFIER_LR": 3e-4,

        "EPOCH": 200,
        "BATCH_SIZE": 8,
    }
)

In [None]:
class CONFIG:
    SEED = 42
    
    DENOISER_LR = 3e-4
    CLASSIFIER_LR = 3e-4
    
    EPOCH = 200
    BATCH_SIZE = 8
    TARGET_SIZE = (128, 128)  

    SR = 25600
    N_MEL = 128
    DURATION = 0.1
    #NUM_AUGMENTATIONS = 10
    N_FFT = 2048
    HOP_LENGTH = 32

    NOISE_DIR = 'background_noises'
    DATA_DIR = 'data'

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CONFIG.SEED) # Seed 고정

In [None]:
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')

device

In [None]:
def add_reverb(signal, sr, reverb_amount=0.75, start_time=0.05):
    impulse_response = np.concatenate([np.zeros(int(sr * 0.01)), np.ones(int(sr * 0.1))])
    reverb_signal = convolve(signal, impulse_response, mode='full')
    reverb_signal = reverb_signal[:len(signal)]
    start_sample = int(sr * start_time)
    padded_reverb_signal = np.concatenate([np.zeros(start_sample), reverb_signal])
    padded_reverb_signal = padded_reverb_signal[:len(signal)]
    reverb_signal = signal + reverb_amount * padded_reverb_signal
    return reverb_signal

In [None]:
def pad_to_target_size(mel_spec, target_size):
    c, h, w = mel_spec.shape
    target_h, target_w = target_size


    pad_h = max(0, target_h - h)
    pad_w = max(0, target_w - w)

    assert pad_h == 0, "height padding will be occured (N_MEL = 128)"

    pad_top = pad_h // 2
    pad_bottom = pad_h - pad_top
    pad_left = pad_w // 2
    pad_right = pad_w - pad_left

    mel_spec_padded = F.pad(mel_spec, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
    
    return mel_spec_padded

In [None]:
def min_max_scaling(tensor, min_value=0.0, max_value=1.0):
    tensor_min = tensor.min()
    tensor_max = tensor.max()
    
    scaled_tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
    scaled_tensor = scaled_tensor * (max_value - min_value) + min_value
    
    return scaled_tensor, tensor_min, tensor_max

In [None]:
class TrainDataset(Dataset):
    def __init__(self):
        self.data_dir = './train1'
        self.noise_dir = './background_noises'
        self.transform = Compose([AddBackgroundNoise(sounds_path=self.noise_dir, min_snr_in_db=10, max_snr_in_db=10, p=1)])
        self.file_paths, self.labels = self._load_file_paths_and_labels()


    def _load_file_paths_and_labels(self):
        file_paths = []
        labels = []

        label_map = {'Caution': 0, 'Fault': 1, 'Normal': 2}
        for label, idx in label_map.items():
            label_dir = os.path.join(self.data_dir, label)
            if not os.path.exists(label_dir):
                continue
            wav_files = [os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith('.wav')]
            if wav_files:
                file_paths.extend(wav_files)
                labels.extend([idx] * len(wav_files))

        return file_paths, labels


    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        #print("G")
        file_idx = idx
        file_path = self.file_paths[file_idx]
        label = self.labels[file_idx]
        
        y, sr = librosa.load(file_path, sr=CONFIG.SR)

        mel = librosa.feature.melspectrogram(
            y=y, 
            sr=CONFIG.SR, 
            n_mels=CONFIG.N_MEL,
            n_fft=CONFIG.N_FFT,
            hop_length=CONFIG.HOP_LENGTH
        )

        #print("F")

        log_mel = librosa.power_to_db(mel, ref=np.max)
        log_mel = torch.tensor(log_mel).unsqueeze(0).float()
        
        #print("0")
        # Pad the Mel spectrogram to the target size
        log_mel_pad = pad_to_target_size(log_mel, target_size=CONFIG.TARGET_SIZE)

        #print("1")

        y_reverb = add_reverb(y, sr)
        y_reverb = np.array(y_reverb, dtype=np.float32)
        #print("2")

        y_reverb_noise = self.transform(samples=y_reverb, sample_rate=CONFIG.SR)
        #print("3")

        mel_reverb_noise = librosa.feature.melspectrogram(
            y=y_reverb_noise, 
            sr=CONFIG.SR, 
            n_mels=CONFIG.N_MEL,
            n_fft=CONFIG.N_FFT,
            hop_length=CONFIG.HOP_LENGTH
        )
        #print("4")


        log_mel_reverb_noise = librosa.power_to_db(mel_reverb_noise, ref=np.max)
        log_mel_reverb_noise = torch.tensor(log_mel_reverb_noise).unsqueeze(0).float()
        #print("5")
        # Pad the Mel spectrogram to the target size
        log_mel_reverb_noise_pad = pad_to_target_size(log_mel_reverb_noise, target_size=CONFIG.TARGET_SIZE)
        #print("E")


        scaled_log_mel_pad, tensor_min, tensor_max = min_max_scaling(log_mel_pad)
        scaled_log_mel_reverb_noise_pad, _, _ = min_max_scaling(log_mel_reverb_noise_pad)

        #print("B")

        return scaled_log_mel_pad, scaled_log_mel_reverb_noise_pad, label, tensor_min, tensor_max


In [None]:
class ValidDataset(Dataset):
    def __init__(self):
        self.data_dir = './val'
        self.base_dir = './val1'
        self.file_paths, self.labels = self._load_file_paths_and_labels()


    def _load_file_paths_and_labels(self):
        file_paths = []
        labels = []

        label_map = {'Caution': 0, 'Fault': 1, 'Normal': 2}
        for label, idx in label_map.items():
            label_dir = os.path.join(self.data_dir, label)
            if not os.path.exists(label_dir):
                continue
            wav_files = [os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith('.wav')]
            if wav_files:
                file_paths.extend(wav_files)
                labels.extend([idx] * len(wav_files))

        return file_paths, labels


    def __len__(self):
        return len(self.file_paths)


    def __getitem__(self, idx):
        file_idx = idx
        file_path = self.file_paths[file_idx]
        label = self.labels[file_idx]

        file_path_split = file_path.split('/')

        base_path = self.base_dir + '/' + file_path_split[2] + '/' + file_path_split[3].split('_')[0] + '.wav'
        
        base_y, sr = librosa.load(base_path, sr=CONFIG.SR)

        base_mel = librosa.feature.melspectrogram(
            y=base_y,
            sr=CONFIG.SR, 
            n_mels=CONFIG.N_MEL,
            n_fft=CONFIG.N_FFT,
            hop_length=CONFIG.HOP_LENGTH
        )

        base_log_mel = librosa.power_to_db(base_mel, ref=np.max)
        base_log_mel = torch.tensor(base_log_mel).unsqueeze(0).float()
        
        # Pad the Mel spectrogram to the target size
        base_log_mel_pad = pad_to_target_size(base_log_mel, target_size=CONFIG.TARGET_SIZE)

        y_reverb_noise, sr = librosa.load(file_path, sr=CONFIG.SR)

        
        mel_reverb_noise = librosa.feature.melspectrogram(
            y=y_reverb_noise, 
            sr=CONFIG.SR, 
            n_mels=CONFIG.N_MEL,
            n_fft=CONFIG.N_FFT,
            hop_length=CONFIG.HOP_LENGTH
        )

        log_mel_reverb_noise = librosa.power_to_db(mel_reverb_noise, ref=np.max)
        log_mel_reverb_noise = torch.tensor(log_mel_reverb_noise).unsqueeze(0).float()
        
        # Pad the Mel spectrogram to the target size
        log_mel_reverb_noise_pad = pad_to_target_size(log_mel_reverb_noise, target_size=CONFIG.TARGET_SIZE)

        scaled_log_mel_pad, tensor_min, tensor_max = min_max_scaling(base_log_mel_pad)
        scaled_log_mel_reverb_noise_pad, _, _ = min_max_scaling(log_mel_reverb_noise_pad)


        return scaled_log_mel_pad, scaled_log_mel_reverb_noise_pad, label, tensor_min, tensor_max


In [None]:
class TestDataset(Dataset):
    def __init__(self):
        self.data_dir = './test_50'
        self.base_dir = './test1'

        self.file_paths, self.labels = self._load_file_paths_and_labels()


    def _load_file_paths_and_labels(self):
        file_paths = []
        labels = []

        label_map = {'Caution': 0, 'Fault': 1, 'Normal': 2}
        for label, idx in label_map.items():
            label_dir = os.path.join(self.data_dir, label)
            if not os.path.exists(label_dir):
                continue
            wav_files = [os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith('.wav')]
            if wav_files:
                file_paths.extend(wav_files)
                labels.extend([idx] * len(wav_files))

        return file_paths, labels


    def __len__(self):
        return len(self.file_paths)


    def __getitem__(self, idx):
        file_idx = idx
        file_path = self.file_paths[file_idx]
        label = self.labels[file_idx]

        file_path_split = file_path.split('/')

        base_path = self.base_dir + '/' + file_path_split[2] + '/' + file_path_split[3].split('_')[0] + '.wav'

        denoise_output_path = "./dunet_denoised/" + file_path_split[2] + '/' + file_path_split[3]

        base_y, sr = librosa.load(base_path, sr=CONFIG.SR)

        base_mel = librosa.feature.melspectrogram(
            y=base_y,
            sr=CONFIG.SR, 
            n_mels=CONFIG.N_MEL,
            n_fft=CONFIG.N_FFT,
            hop_length=CONFIG.HOP_LENGTH
        )
        
        
        base_log_mel = librosa.power_to_db(base_mel, ref=np.max)
        base_log_mel = torch.tensor(base_log_mel).unsqueeze(0).float()
        
        # Pad the Mel spectrogram to the target size
        base_log_mel_pad = pad_to_target_size(base_log_mel, target_size=CONFIG.TARGET_SIZE)

        y_reverb_noise, sr = librosa.load(file_path, sr=CONFIG.SR)
        

        mel_reverb_noise = librosa.feature.melspectrogram(
            y=y_reverb_noise, 
            sr=CONFIG.SR, 
            n_mels=CONFIG.N_MEL,
            n_fft=CONFIG.N_FFT,
            hop_length=CONFIG.HOP_LENGTH
        )

        log_mel_reverb_noise = librosa.power_to_db(mel_reverb_noise, ref=np.max)
        log_mel_reverb_noise = torch.tensor(log_mel_reverb_noise).unsqueeze(0).float()
        
        # Pad the Mel spectrogram to the target size
        log_mel_reverb_noise_pad = pad_to_target_size(log_mel_reverb_noise, target_size=CONFIG.TARGET_SIZE)

        # print(base_log_mel_pad)

        scaled_log_mel_pad, _, _ = min_max_scaling(base_log_mel_pad)
        scaled_log_mel_reverb_noise_pad, tensor_min, tensor_max = min_max_scaling(log_mel_reverb_noise_pad)


        return scaled_log_mel_pad, scaled_log_mel_reverb_noise_pad, label, tensor_min, tensor_max, denoise_output_path


In [None]:
train_dataset = TrainDataset()
val_dataset = ValidDataset()
test_dataset = TestDataset()

train_loader = DataLoader(train_dataset, batch_size=CONFIG.BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG.BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=CONFIG.BATCH_SIZE, shuffle=False)

In [None]:
class DUNET(nn.Module):
    def __init__(self,in_channels=1, out_channels=1, base_channels=64):
        super(DUNET, self).__init__()
        base_width = base_channels
        self.encoder = UNetEncoder(in_channels, base_width)
        self.bottleneck = LinearBottleneck()
        self.decoder = UNetDecoder(base_width, out_channels=out_channels)


    def forward(self, x):
        # eb1,eb2,eb3,eb4,eb5,eb6,eb7,eb8,eb9,encoder_output = self.encoder(x)
        eb1,eb2,eb3,eb4,eb5,eb6,eb7,encoder_output = self.encoder(x)

        bottleneck_output = self.bottleneck(encoder_output)
        
        # decoder_output = self.decoder(bottleneck_output,eb1,eb2,eb3,eb4,eb5,eb6,eb7,eb8,eb9)
        decoder_output = self.decoder(bottleneck_output,eb1,eb2,eb3,eb4,eb5,eb6,eb7)

        return decoder_output
        

class UNetEncoder(nn.Module):
    def __init__(self, in_channels, base_width):
        super(UNetEncoder, self).__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, base_width, kernel_size=1),
            nn.Conv2d(base_width, base_width, kernel_size=3, dilation=1, padding=1),
            nn.PReLU(),
            nn.BatchNorm2d(base_width),
            nn.Conv2d(base_width, base_width, kernel_size=1))
        self.mp1 = nn.Sequential(nn.MaxPool2d(2))

        self.block2 = nn.Sequential(
            nn.Conv2d(base_width, base_width*2, kernel_size=1),
            nn.Conv2d(base_width*2, base_width*2, kernel_size=3, dilation=2, padding=2),
            nn.PReLU(),
            nn.BatchNorm2d(base_width*2),
            nn.Conv2d(base_width*2, base_width*2, kernel_size=1))
        self.mp2 = nn.Sequential(nn.MaxPool2d(2))

        self.block3 = nn.Sequential(
            nn.Conv2d(base_width*2, base_width*4, kernel_size=1),
            nn.Conv2d(base_width*4, base_width*4, kernel_size=3, dilation=3, padding=3),
            nn.PReLU(),
            nn.BatchNorm2d(base_width*4),
            nn.Conv2d(base_width*4, base_width*4, kernel_size=1))
        self.mp3 = nn.Sequential(nn.MaxPool2d(2))

        self.block4 = nn.Sequential(
            nn.Conv2d(base_width*4, base_width*8, kernel_size=1),
            nn.Conv2d(base_width*8, base_width*8, kernel_size=3, dilation=4, padding=4),
            nn.PReLU(),
            nn.BatchNorm2d(base_width*8),
            nn.Conv2d(base_width*8, base_width*8, kernel_size=1))
        self.mp4 = nn.Sequential(nn.MaxPool2d(2))

        self.block5 = nn.Sequential(
            nn.Conv2d(base_width*8, base_width*16, kernel_size=1),
            nn.Conv2d(base_width*16, base_width*16, kernel_size=3, dilation=5, padding=5),
            nn.PReLU(),
            nn.BatchNorm2d(base_width*16),
            nn.Conv2d(base_width*16, base_width*16, kernel_size=1))
        self.mp5 = nn.Sequential(nn.MaxPool2d(2))

        self.block6 = nn.Sequential(
            nn.Conv2d(base_width*16, base_width*16, kernel_size=1),
            nn.Conv2d(base_width*16, base_width*16, kernel_size=3, dilation=6, padding=6),
            nn.PReLU(),
            nn.BatchNorm2d(base_width*16),
            nn.Conv2d(base_width*16, base_width*16, kernel_size=1))
        self.mp6 = nn.Sequential(nn.MaxPool2d(2))

        self.block7 = nn.Sequential(
            nn.Conv2d(base_width*16, base_width*16, kernel_size=1),
            nn.Conv2d(base_width*16, base_width*16, kernel_size=3, dilation=7, padding=7),
            nn.PReLU(),
            nn.BatchNorm2d(base_width*16),
            nn.Conv2d(base_width*16, base_width*16, kernel_size=1))
        self.mp7 = nn.Sequential(nn.MaxPool2d(2))
        
        # self.block8 = nn.Sequential(
        #     nn.Conv2d(base_width*16, base_width*16, kernel_size=1),
        #     nn.Conv2d(base_width*16, base_width*16, kernel_size=3, dilation=8, padding=8),
        #     nn.PReLU(),
        #     nn.BatchNorm2d(base_width*16),
        #     nn.Conv2d(base_width*16, base_width*16, kernel_size=1))
        # self.mp8 = nn.Sequential(nn.MaxPool2d(2))

        # self.block9 = nn.Sequential(
        #     nn.Conv2d(base_width*16, base_width*16, kernel_size=1),
        #     nn.Conv2d(base_width*16, base_width*16, kernel_size=3, dilation=9, padding=9),
        #     nn.PReLU(),
        #     nn.BatchNorm2d(base_width*16),
        #     nn.Conv2d(base_width*16, base_width*16, kernel_size=1))
        # self.mp9 = nn.Sequential(nn.MaxPool2d(2))


    def forward(self, x):
        eb1 = self.block1(x)
        mp1 = self.mp1(eb1)
        eb2 = self.block2(mp1)
        mp2 = self.mp2(eb2)
        eb3 = self.block3(mp2)
        mp3 = self.mp3(eb3)
        eb4 = self.block4(mp3)
        mp4 = self.mp4(eb4)
        eb5 = self.block5(mp4)
        mp5 = self.mp5(eb5)
        eb6 = self.block6(mp5)
        mp6 = self.mp6(eb6)
        eb7 = self.block7(mp6)
        encoder_output = self.mp7(eb7)

        # mp7 = self.mp7(eb7)
        # eb8 = self.block8(mp7)
        # mp8 = self.mp8(eb8)
        # eb9 = self.block9(mp8)
        # encoder_output = self.mp9(eb9)
        
        # return eb1,eb2,eb3,eb4,eb5,eb6,eb7,eb8,eb9,encoder_output
        return eb1,eb2,eb3,eb4,eb5,eb6,eb7,encoder_output



class LinearBottleneck(nn.Module):
    def __init__(self):
        super(LinearBottleneck, self).__init__()

        self.block1 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),
            nn.PReLU(),
            nn.Unflatten(1, (1024, 1, 1)))



    def forward(self, x):
        x = self.block1(x)

        return x


class UNetDecoder(nn.Module):
    def __init__(self, base_width, out_channels=1):
        super(UNetDecoder, self).__init__()

        self.tp9 = nn.Sequential(
            nn.ConvTranspose2d(base_width*16, base_width*16, stride=2, kernel_size=2))
        self.block9 = nn.Sequential(
            nn.Conv2d(base_width*(16+16), base_width*16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*16),
            nn.Conv2d(base_width*16, base_width*16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*16))

        self.tp8 = nn.Sequential(
            nn.ConvTranspose2d(base_width*16, base_width*16, stride=2, kernel_size=2))
        self.block8 = nn.Sequential(
            nn.Conv2d(base_width*(16+16), base_width*16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*16),
            nn.Conv2d(base_width*16, base_width*16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*16))
        
        self.tp7 = nn.Sequential(
            nn.ConvTranspose2d(base_width*16, base_width*16, stride=2, kernel_size=2))
        self.block7 = nn.Sequential(
            nn.Conv2d(base_width*(16+16), base_width*16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*16),
            nn.Conv2d(base_width*16, base_width*16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*16))
        
        self.tp6 = nn.Sequential(
            nn.ConvTranspose2d(base_width*16, base_width*16, stride=2, kernel_size=2))
        self.block6 = nn.Sequential(
            nn.Conv2d(base_width*(16+16), base_width*16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*16),
            nn.Conv2d(base_width*16, base_width*16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*16))
        
        self.tp5 = nn.Sequential(
            nn.ConvTranspose2d(base_width*16, base_width*16, stride=2, kernel_size=2))
        self.block5 = nn.Sequential(
            nn.Conv2d(base_width*(16+16), base_width*16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*16),
            nn.Conv2d(base_width*16, base_width*16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*16))
        
        self.tp4 = nn.Sequential(
            nn.ConvTranspose2d(base_width*16, base_width*8, stride=2, kernel_size=2))
        self.block4 = nn.Sequential(
            nn.Conv2d(base_width*(8+8), base_width*8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*8),
            nn.Conv2d(base_width*8, base_width*8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*8))
        
        self.tp3 = nn.Sequential(
            nn.ConvTranspose2d(base_width*8, base_width*4, stride=2, kernel_size=2))
        self.block3 = nn.Sequential(
            nn.Conv2d(base_width*(4+4), base_width*4, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*4),
            nn.Conv2d(base_width*4, base_width*4, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*4))
        
        self.tp2 = nn.Sequential(
            nn.ConvTranspose2d(base_width*4, base_width*2, stride=2, kernel_size=2))
        self.block2 = nn.Sequential(
            nn.Conv2d(base_width*(2+2), base_width*2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*2),
            nn.Conv2d(base_width*2, base_width*2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width*2))
        
        self.tp1 = nn.Sequential(
            nn.ConvTranspose2d(base_width*2, base_width, stride=2, kernel_size=2))
        self.block1 = nn.Sequential(
            nn.Conv2d(base_width*(1+1), base_width, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width),
            nn.Conv2d(base_width, base_width, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base_width))
        
        self.block10 = nn.Sequential(
            nn.Conv2d(base_width, out_channels, kernel_size=1),
            nn.Sigmoid()
        )
        

    # def forward(self,bottleneck_output,eb1,eb2,eb3,eb4,eb5,eb6,eb7,eb8,eb9):
    def forward(self,bottleneck_output,eb1,eb2,eb3,eb4,eb5,eb6,eb7):
        # tp9 = self.tp9(bottleneck_output)
        # cat9 = torch.cat((tp9, eb9), dim=1)
        # db9 = self.block9(cat9)

        # tp8 = self.tp8(db9)
        # cat8 = torch.cat((tp8, eb8), dim=1)
        # db8 = self.block8(cat8)

        # tp7 = self.tp7(db8)
        tp7 = self.tp7(bottleneck_output)
        cat7 = torch.cat((tp7, eb7), dim=1)
        db7 = self.block7(cat7)

        tp6 = self.tp6(db7)
        cat6 = torch.cat((tp6, eb6), dim=1)
        db6 = self.block6(cat6)

        tp5 = self.tp5(db6)
        cat5 = torch.cat((tp5, eb5), dim=1)
        db5 = self.block5(cat5)

        tp4 = self.tp4(db5)
        cat4 = torch.cat((tp4, eb4), dim=1)
        db4 = self.block4(cat4)

        tp3 = self.tp3(db4)
        cat3 = torch.cat((tp3, eb3), dim=1)
        db3 = self.block3(cat3)

        tp2 = self.tp2(db3)
        cat2 = torch.cat((tp2, eb2), dim=1)
        db2 = self.block2(cat2)

        tp1 = self.tp1(db2)
        cat9 = torch.cat((tp1, eb1), dim=1)
        db1 = self.block1(cat9)

        decoder_output = self.block10(db1)

        return decoder_output

In [None]:
resnet50 = models.resnet50(pretrained=True)
num_ftrs = resnet50.fc.in_features
resnet50.fc = nn.Linear(num_ftrs, 3)
first_conv_layer = resnet50.conv1
original_weights = first_conv_layer.weight.data

new_first_conv = nn.Conv2d(1, first_conv_layer.out_channels,
                           kernel_size=first_conv_layer.kernel_size,
                           stride=first_conv_layer.stride,
                           padding=first_conv_layer.padding,
                           bias=first_conv_layer.bias)

with torch.no_grad():
    new_first_conv.weight = nn.Parameter(torch.mean(original_weights, dim=1, keepdim=True))

resnet50.conv1 = new_first_conv

In [None]:
class CustomClassifier(nn.Module):
    def __init__(self, backbone):
        super(CustomClassifier, self).__init__()
        self.backbone = backbone
        
    def forward(self, x):
        x = self.backbone(x)
        x = torch.log_softmax(x, dim=1)

        return x

In [None]:
class EarlyStopping:
    def __init__(self, patience=10, delta=0.0, verbose=True):
        """
        patience (int): loss or score가 개선된 후 기다리는 기간. default: 3
        delta  (float): 개선시 인정되는 최소 변화 수치. default: 0.0
        mode     (str): 개선시 최소/최대값 기준 선정('min' or 'max'). default: 'min'.
        verbose (bool): 메시지 출력. default: True
        """
        self.early_stop = False
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        
        self.best_mse = np.Inf
        self.best_f1 = 0
        self.delta = delta
        

    def __call__(self, mse, f1, denoiser_model, classifier_model):

        if f1 > (self.best_f1 - self.delta): 
            self.counter = 0
            self.best_f1 = f1
            self.best_mse = mse
            if self.verbose:
                print(f'(Update) Best MSE: {self.best_mse:.5f} Best F1: {self.best_f1:.5f}')
            self.save_checkpoint(denoiser_model, classifier_model)

        else:
            self.counter += 1
            if self.verbose:
                print(f'(Patience) {self.counter}/{self.patience} Current MSE: {mse:.5f} Current F1: {f1:.5f} Best MSE: {self.best_mse:.5f} Best F1: {self.best_f1:.5f}')


        if self.counter >= self.patience:
            if self.verbose:
                print(f'[EarlyStopping] (Patience) {self.counter}/{self.patience} Best MSE: {self.best_mse:.5f} Best F1: {self.best_f1:.5f}')
            self.early_stop = True
                
        else:
            self.early_stop = False

    def save_checkpoint(self, denoiser_model, classifier_model):
        if self.verbose:
            print(f'Saving model ...')
        torch.save(denoiser_model.state_dict(), 'dunet_denoiser_75_es10_N.pt')
        torch.save(classifier_model.state_dict(), 'dunet_classifier_75_es10_N.pt')


In [None]:
def train(denoiser_model, classifier_model, denoiser_optimizer, classifier_optimizer, denoiser_scheduler, classifier_scheduler, train_loader, val_loader, device):
    denoiser_model.to(device)
    classifier_model.to(device)




    for epoch in range(CONFIG.EPOCH):
        denoiser_model.train()
        classifier_model.train()
        
        train_reconstruction_loss, train_regularization_loss, train_classifier_loss = [], [], []

        for log_mel_pad, log_mel_reverb_noise_pad, label, _, _ in tqdm(iter(train_loader)):
            log_mel_pad = log_mel_pad.float().to(device)
            log_mel_reverb_noise_pad = log_mel_reverb_noise_pad.float().to(device)
            label = label.to(device)

            denoiser_optimizer.zero_grad()
            denoiser_output = denoiser_model(log_mel_reverb_noise_pad)
            reconstruction_loss = F.mse_loss(denoiser_output, log_mel_pad)
            denoiser_loss = reconstruction_loss
            denoiser_loss.backward()
            denoiser_optimizer.step()
            train_reconstruction_loss.append(reconstruction_loss.item())

            wandb.log({
                "Clear": wandb.Image(log_mel_pad[0].cpu().detach().numpy(), caption="Clear"),
                "Denoiser Output": wandb.Image(denoiser_output[0].cpu().detach().numpy(), caption="Denoised Output"),
                "Noise": wandb.Image(log_mel_reverb_noise_pad[0].cpu().detach().numpy(), caption="Noise")
            })

            classifier_optimizer.zero_grad()
            classifier_output = classifier_model(denoiser_output.detach()) # 역전파 방지
            classifier_loss = F.cross_entropy(classifier_output, label)
            classifier_loss.backward()
            classifier_optimizer.step()
            train_classifier_loss.append(classifier_loss.item())

        denoiser_scheduler.step()
        classifier_scheduler.step()

        train_reconstruction_loss = np.mean(train_reconstruction_loss)
        train_classifier_loss = np.mean(train_classifier_loss)

        val_reconstruction_loss, val_classifier_loss, f1 = validation(denoiser_model, classifier_model, val_loader, device)


        wandb.log({
            "Train Reconstruction Loss": train_reconstruction_loss,
            "Train Classification Loss": train_classifier_loss,
            "Valid Reconstruction Loss": val_reconstruction_loss,
            "Valid Classification Loss": val_classifier_loss,
            "F1 Score": f1,
            "Epoch": epoch
        })

        es(val_reconstruction_loss, f1, denoiser_model, classifier_model)

        if es.early_stop:
            print("Early stopping")
            break






def validation(denoiser_model, classifier_model, val_loader, device):
    denoiser_model.eval()
    classifier_model.eval()
        
    val_reconstruction_loss, val_regularization_loss, val_classifier_loss = [], [], []


    preds, labels = [], []

    with torch.no_grad():
        for log_mel_pad, log_mel_reverb_noise_pad, label, _, _ in tqdm(iter(val_loader)):
            log_mel_pad = log_mel_pad.float().to(device)
            log_mel_reverb_noise_pad = log_mel_reverb_noise_pad.float().to(device)
            label = label.to(device)

            denoiser_output = denoiser_model(log_mel_reverb_noise_pad)
            reconstruction_loss = F.mse_loss(denoiser_output, log_mel_pad)
            denoiser_loss = reconstruction_loss 
            val_reconstruction_loss.append(reconstruction_loss.item())

            classifier_output = classifier_model(denoiser_output.detach()) # 역전파 방지
            classifier_loss = F.cross_entropy(classifier_output, label)
            val_classifier_loss.append(classifier_loss.item())

            pred = classifier_output.argmax(dim=1)

            preds.append(pred.cpu().numpy())
            labels.append(label.cpu().numpy())

        val_reconstruction_loss = np.mean(val_reconstruction_loss)
        val_classifier_loss = np.mean(val_classifier_loss)

        preds = np.concatenate(preds, axis=0)
        labels = np.concatenate(labels, axis=0)

        f1 = f1_score(labels, preds, average='weighted')

    return  val_reconstruction_loss, val_classifier_loss, f1

In [None]:
denoiser_model = DUNET()
denoiser_optimizer = torch.optim.AdamW(params=denoiser_model.parameters(), lr=CONFIG.DENOISER_LR)
denoiser_scheduler = torch.optim.lr_scheduler.LambdaLR(denoiser_optimizer, lr_lambda = lambda epoch: 1.0 ** CONFIG.EPOCH)

In [None]:
classifer_model = CustomClassifier(backbone=resnet50)
classifier_optimizer = torch.optim.AdamW(params=classifer_model.parameters(), lr=CONFIG.CLASSIFIER_LR)
classifier_scheduler = torch.optim.lr_scheduler.LambdaLR(classifier_optimizer, lr_lambda = lambda epoch: 1.0 ** CONFIG.EPOCH)

In [None]:
es = EarlyStopping()

In [None]:
train(denoiser_model, classifer_model, denoiser_optimizer, classifier_optimizer, denoiser_scheduler, classifier_scheduler, train_loader, val_loader, device)

In [None]:
ckpt_denoiser_model = DUNET()
ckpt_denoiser_model.load_state_dict(torch.load('./dunet_denoiser_50_es10.pt'))
ckpt_denoiser_model.eval()

In [None]:
ckpt_classifier_model = CustomClassifier(backbone=resnet50)
ckpt_classifier_model.load_state_dict(torch.load('./dunet_classifier_50_es10.pt'))
ckpt_classifier_model.eval()

In [None]:
def inverse_min_max_scaling(scaled_tensor, tensor_min, tensor_max, min_value=0.0, max_value=1.0):
    tensor_min = tensor_min.numpy()
    tensor_max = tensor_max.numpy()
    original_tensor = (scaled_tensor - min_value) / (max_value - min_value)
    original_tensor = original_tensor * (tensor_max - tensor_min) + tensor_min
    
    return original_tensor

In [None]:
def test(best_denoiser_model, best_classifier_model, test_loader, device):

    best_denoiser_model.to(device)
    best_classifier_model.to(device)
    best_denoiser_model.eval()
    best_classifier_model.eval()

    preds, labels = [], []

    mses = []

    with torch.no_grad():
        for log_mel_pad, log_mel_reverb_noise_pad, label, tensor_min, tensor_max, denoise_output_path in tqdm(iter(test_loader)):

            log_mel_reverb_noise_pad = log_mel_reverb_noise_pad.float().to(device)
            label = label.to(device)

            denoiser_output = best_denoiser_model(log_mel_reverb_noise_pad)
            for i, log_mel_spec in enumerate(denoiser_output.detach().cpu().numpy()):
                log_mel_spec = np.squeeze(log_mel_spec, axis=0)
                lp = np.squeeze(log_mel_pad[i].numpy(), axis=0)
                #plt.imsave(bases_path[i]+'.png', lp, cmap="gray")
                mse = np.mean((lp - log_mel_spec) ** 2)
                mses.append(mse)
            #     log_mel_spec = inverse_min_max_scaling(log_mel_spec, tensor_min[i], tensor_max[i])
                
            #     mel_spec = librosa.db_to_power(log_mel_spec)

            #     #mel_spec_audio = librosa.feature.inverse.mel_to_audio(mel_spec, sr=CONFIG.SR, n_fft=CONFIG.N_FFT, hop_length=CONFIG.HOP_LENGTH)
            #     S = librosa.feature.inverse.mel_to_stft(mel_spec, sr=CONFIG.SR, n_fft=CONFIG.N_FFT)

            #     mel_spec_audio = librosa.griffinlim(S, n_fft=CONFIG.N_FFT, hop_length=CONFIG.HOP_LENGTH)
            #     mel_spec_audio = np.squeeze(mel_spec_audio)
            #     #print(mel_spec_audio.shape)
            #     sf.write(denoise_output_path[i], mel_spec_audio, CONFIG.SR, format='WAV')

            classifier_output = best_classifier_model(denoiser_output.detach()) # 역전파 방지


            pred = classifier_output.argmax(dim=1)

            preds.append(pred.cpu().numpy())
            labels.append(label.cpu().numpy())


        preds = np.concatenate(preds, axis=0)
        labels = np.concatenate(labels, axis=0)

        f1 = f1_score(labels, preds, average='weighted')
        print(f'Test F1 Score: {f1:.4f}')
        print(f'Test MSE: {np.mean(mses):.4f}')


In [None]:
test(ckpt_denoiser_model, ckpt_classifier_model, test_loader, device)