<a href="https://colab.research.google.com/github/jhChoi1997/EE488_AI_Convergence_Capstone_Design_Anomaly_Detection_2022spring/blob/main/EE488_DCASE2020_WaveNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!gdown https://drive.google.com/uc?id=1p0aANQlQRKqM9FGhkV3j2h55PJUOgEXg
!unzip valve.zip -d ./valve/

!gdown https://drive.google.com/uc?id=1hKmdy5bySo5JrZ9FjxVrrBp_5XhgWe2Y
!unzip valve.zip -d ./valve_test/

In [None]:
import os
import sys
import librosa
import librosa.core
import librosa.feature
import yaml
import glob
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from sklearn import metrics

In [None]:
dataset_dir = './valve'
test_dir = './valve_test'
model_dir = './model'

n_fft = 2048
hop_length = 512
n_mels = 128
power = 2
n_block = 5
n_mul = 6
kernel_size = 3

EPOCHS = 10
BATCH = 32

In [None]:
def file_load(wav_name):
  try:
    return librosa.load(wav_name, sr=None, mono=False)
  except:
    print('file_broken or not exists!! : {}'.format(wav_name))
    

def file_list_generator(target_dir):
  training_list_path = os.path.abspath('{dir}/*.wav'.format(dir=target_dir))
  files = sorted(glob.glob(training_list_path))
  if len(files) == 0:
    print('no_wav_file!!')
  return files


def file_to_log_mel(file_name, n_mels, n_fft, hop_length, power):
  y, sr = file_load(file_name)
  mel_spectrogram = librosa.feature.melspectrogram(y=y,
                                                   sr=sr,
                                                   n_fft=n_fft,
                                                   hop_length=hop_length,
                                                   n_mels=n_mels,
                                                   power=power)
  
  log_mel_spectrogram = 20.0 / power * np.log10(mel_spectrogram + sys.float_info.epsilon)

  return log_mel_spectrogram


def list_to_dataset(file_list, n_mels, n_fft, hop_length, power):
  for idx in tqdm(range(len(file_list))):
    log_mel = file_to_log_mel(file_list[idx],
                              n_mels=n_mels,
                              n_fft=n_fft,
                              hop_length=hop_length,
                              power=power)
    if idx == 0:
      dataset = np.zeros((len(file_list), len(log_mel[:,0]), len(log_mel[0,:])), float)
    dataset[idx, :, :] = log_mel
  
  return dataset

In [None]:
os.makedirs(model_dir, exist_ok=True)

dataset_dir = os.path.abspath(dataset_dir)
machine_type = os.path.split(dataset_dir)[1]
model_file_path = '{model}/model_{machine_type}'.format(model=model_dir, machine_type=machine_type)

files = file_list_generator(dataset_dir)
train_data = list_to_dataset(files,
                             n_mels=n_mels,
                             n_fft=n_fft,
                             hop_length=hop_length,
                             power=power)

In [None]:
train_dataset = torch.Tensor(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

In [None]:
print(train_dataset.shape)

In [None]:
class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1):
        super(CausalConv1d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.dilation = dilation

        self.conv1 = self.causal_conv(self.in_channels, self.out_channels, self.kernel_size, self.dilation)
        self.padding = self.conv1.padding[0]

    def causal_conv(self, in_channels, out_channels, kernel_size, dilation):
        pad = (kernel_size - 1) * dilation
        return nn.Conv1d(in_channels, out_channels, kernel_size, padding=pad, dilation=dilation)

    def forward(self, x):
        x = self.conv1(x)
        x = x[:, :, :-self.padding]
        return x


class ResidualBlock(nn.Module):
    def __init__(self, n_channel, n_mul, kernel_size, dilation_rate):
        super(ResidualBlock, self).__init__()
        self.n_channel = n_channel
        self.n_mul = n_mul
        self.kernel_size = kernel_size
        self.dilation_rate = dilation_rate
        self.n_filter = self.n_channel * self.n_mul

        self.sigmoid_group_norm = nn.GroupNorm(1, self.n_filter)
        self.sigmoid_conv = CausalConv1d(self.n_filter, self.n_filter, self.kernel_size, self.dilation_rate)
        self.tanh_group_norm = nn.GroupNorm(1, self.n_filter)
        self.tanh_conv = CausalConv1d(self.n_filter, self.n_filter, self.kernel_size, self.dilation_rate)

        self.skip_group_norm = nn.GroupNorm(1, self.n_filter).to(device)
        self.skip_conv = nn.Conv1d(self.n_filter, self.n_channel, 1)
        self.residual_group_norm = nn.GroupNorm(1, self.n_filter)
        self.residual_conv = nn.Conv1d(self.n_filter, self.n_filter, 1)

    def forward(self, x):
        x1 = self.sigmoid_group_norm(x)
        x1 = self.sigmoid_conv(x1)
        x2 = self.tanh_group_norm(x)
        x2 = self.tanh_conv(x2)
        x1 = nn.Sigmoid()(x1)
        x2 = nn.Tanh()(x2)
        x = x1 * x2

        x1 = self.skip_group_norm(x)
        skip = self.skip_conv(x1)
        x2 = self.residual_group_norm(x)
        residual = self.residual_conv(x2)

        return skip, residual


class WaveNet(nn.Module):
    def __init__(self, n_block, n_channel, n_mul, kernel_size):
        super(WaveNet, self).__init__()

        self.n_block = n_block
        self.n_channel = n_channel
        self.n_mul = n_mul
        self.kernel_size = kernel_size
        self.n_filter = self.n_channel * self.n_mul

        self.group_norm1 = nn.GroupNorm(1, self.n_channel)
        self.conv1 = nn.Conv1d(self.n_channel, self.n_filter, 1)

        self.residual_blocks = [ResidualBlock(self.n_channel, self.n_mul, self.kernel_size, 2 ** i) for i in
                                range(self.n_block)]

        self.relu1 = nn.ReLU()

        self.group_norm2 = nn.GroupNorm(1, self.n_channel)
        self.conv2 = nn.Conv1d(self.n_channel, self.n_channel, 1)
        self.group_norm3 = nn.GroupNorm(1, self.n_channel)
        self.conv3 = nn.Conv1d(self.n_channel, self.n_channel, 1)

    def forward(self, x):
        x = self.group_norm1(x)
        x = self.conv1(x)

        skip_connections = []
        for rb in self.residual_blocks:
            rb = rb.to(device)
            skip, x = rb(x)
            skip_connections.append(skip)
        skip_x = torch.stack(skip_connections).sum(dim=0)

        x = self.relu1(skip_x)
        x = self.group_norm2(x)
        x = self.conv2(x)
        x = self.group_norm3(x)
        x = self.conv3(x)
        output = x[:, :, self.get_receptive_field() - 1:-1]

        return output

    def get_receptive_field(self):
        receptive_field = 1
        for _ in range(self.n_block):
            receptive_field = receptive_field * 2 + self.kernel_size - 2
        return receptive_field


In [None]:
model = WaveNet(n_block, n_mels, n_mul, kernel_size).to(device)
print(model)

In [None]:
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
def train(dataloader, model, loss_fn, optimizer):
  size = len(dataloader.dataset)
  for batch, X in enumerate(dataloader):
    X = X.to(device)

    pred = model(X)
    receptive_field = model.get_receptive_field()

    loss = loss_fn(pred, X[:, :, receptive_field:])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if batch % 30 == 0:
      loss, current = loss.item(), batch * len(X)
      print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
for t in range(EPOCHS):
  print(f"Epoch {t + 1}\n-------------------------------")
  train(train_dataloader, model, loss_fn, optimizer)

In [None]:
def get_anomaly_score(true, pred):
  anomaly_score = nn.MSELoss()(true, pred)
  return anomaly_score


In [None]:
normal_files = sorted(glob.glob('{dir}/normal_*'.format(dir=test_dir)))
anomaly_files = sorted(glob.glob('{dir}/anomaly_*'.format(dir=test_dir)))

normal_labels = np.zeros(len(normal_files))
anomaly_labels = np.ones(len(anomaly_files))

test_files = np.concatenate((normal_files, anomaly_files), axis=0)
y_true = np.concatenate((normal_labels, anonmaly_labels), axis=0)
y_pred = [0. for k in test_files]

test_dataset = list_to_dataset(test_files, n_mels, n_fft, hop_length, power)

for file_idx in range(len(test_files)):
  data = test_dataset[file_idx].to(device)
  output = model(data)

  score = get_anomaly_score(data, output)
  y_pred[file_idx] = score

auc = metrics.roc_auc_score(y_true, y_pred)

print(auc)