In [1]:
cd ..

/Users/bghorvath/git/acoustic-anomaly-detection


In [2]:
import os
from tqdm import tqdm
import librosa
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import seaborn as sns

from acoustic_anomaly_detection.utils import get_attributes

In [3]:
FRAME_SIZE = 1024
N_MELS = 128
N_MFCC = 40
N_FTT = 1024
HOP_LENGTH = 512
POWER = 2.0
SR = 16000
DURATION = 10
FRAMES = int(np.ceil(SR * 10 / HOP_LENGTH))

### Calculate metrics for each machine ID

In [4]:
def get_amplitude_envelope(signal, _):
    ae = np.array([max(signal[i:i+FRAME_SIZE]) for i in range(0, len(signal), HOP_LENGTH)])
    return ae

def get_rms(signal):
    rms = librosa.feature.rms(y=signal, frame_length=FRAME_SIZE, hop_length=HOP_LENGTH)[0]
    return rms

def get_spectral_centroid(signal, sr):
    sc = librosa.feature.spectral_centroid(y=signal, sr=sr, n_fft=N_FTT, hop_length=HOP_LENGTH).squeeze()
    return sc

def get_spectral_bandwidth(signal, sr):
    sb = librosa.feature.spectral_bandwidth(y=signal, sr=sr, n_fft=N_FTT, hop_length=HOP_LENGTH).squeeze()
    return sb

def get_spectral_envelope(signal, sr):
    S = librosa.feature.melspectrogram(y=signal, sr=sr, n_fft=N_FTT, hop_length=HOP_LENGTH, n_mels=N_MELS, power=POWER)
    se = S.mean(axis=1)
    return se

In [5]:
metrics_dict = {
    "time" : {
        "amplitude_envelope": get_amplitude_envelope,
    },
    "time_frequency" : {
        "spectral_centroid": get_spectral_centroid,
        "spectral_bandwidth": get_spectral_bandwidth,
    },
    "frequency" : {
        "spectral_envelope": get_spectral_envelope,
    }
}

In [6]:
def calculate_stats(file_list):
    metrics = {}
    for file in file_list:
        signal, sr = librosa.load(file, sr=SR)
        signal = librosa.util.fix_length(signal, size=DURATION * SR) 
        for domain, metric in metrics_dict.items():
            y_axis = N_MELS if domain == "frequency" else FRAMES
            for metric_name, metric_func in metric.items():
                if metric_name not in metrics:
                    metrics[metric_name] = np.empty((0, y_axis))
                metric_value = metric_func(signal, sr)
                metrics[metric_name] = np.vstack((metrics[metric_name], metric_value))
    
    return metrics

## Compare metrics between source and target domain (train only)

In [7]:
data_dir = os.path.join("data", "dcase2023t2")
plot_dir = os.path.join("results", "plots", "feature_exploration")
os.makedirs(plot_dir, exist_ok=True)

train_test = "train"
machine_stats = {}
for dev_eval in ["dev", "eval"]:
    machine_types = os.listdir(os.path.join(data_dir, dev_eval))
    for machine_type in tqdm(machine_types):
        data_path = os.path.join(data_dir, dev_eval, machine_type, train_test)
        for domain in ["source", "target"]:
            file_list = []
            for file in os.listdir(data_path):
                file_path = os.path.join(data_dir, dev_eval, machine_type, train_test, file)
                attributes = get_attributes(file_path)
                attr_domain = attributes["domain"]
                if attr_domain == domain:
                    file_list.append(file_path)
            stats = calculate_stats(file_list)
            machine_stats[dev_eval + "_" + machine_type + "_" + domain] = stats

100%|██████████| 7/7 [02:24<00:00, 20.65s/it]
100%|██████████| 7/7 [02:22<00:00, 20.30s/it]


### Aggregate stats by calculating mean

In [18]:
agg_stats = defaultdict(lambda: defaultdict(lambda: dict()))
for machine, stats in machine_stats.items():
    dev_eval, machine_type, domain = machine.split("_")
    for metric, value in stats.items():
        agg_stats[metric][machine_type][domain] = value.mean(axis=0)

### Plot metrics for each machine ID

In [24]:
for metric, metric_machine_stats in agg_stats.items():
    xlabel = "Mel bins" if metric in metrics_dict["frequency"] else "Frames"
    ylabel = "Frequency" if metric in metrics_dict["time_frequency"] else "Power"
    for machine_type, machine_type_stats in metric_machine_stats.items():
        plt.figure(figsize=(5, 5))
        plt.scatter(x=range(machine_type_stats["source"].shape[0]), y=machine_type_stats["source"], label="source", s=10)
        plt.scatter(x=range(machine_type_stats["target"].shape[0]), y=machine_type_stats["target"], label="target", s=10)
        plt.title(metric.replace("_", " ").title())
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(plot_dir, "domain", "domain_" + metric + "_" + machine_type + ".png"))
        plt.close()

## Compare metrics between normal and anomalous (test, source domain only)

In [26]:
data_dir = os.path.join("data", "dcase2023t2")
plot_dir = os.path.join("results", "plots", "feature_exploration")
os.makedirs(plot_dir, exist_ok=True)

train_test = "test"
domain = "source"
machine_stats = {}
for dev_eval in ["dev", "eval"]:
    machine_types = os.listdir(os.path.join(data_dir, dev_eval))
    for machine_type in tqdm(machine_types):
        data_path = os.path.join(data_dir, dev_eval, machine_type, train_test)
        for label in ["normal", "anomaly"]:
            file_list = []
            for file in os.listdir(data_path):
                file_path = os.path.join(data_dir, dev_eval, machine_type, train_test, file)
                attributes = get_attributes(file_path)
                attr_domain = attributes["domain"]
                attr_label = attributes["label"]
                if attr_domain == domain and attr_label == label:
                    file_list.append(file_path)
            stats = calculate_stats(file_list)
            machine_stats[dev_eval + "_" + machine_type + "_" + label] = stats

100%|██████████| 7/7 [00:14<00:00,  2.12s/it]
100%|██████████| 7/7 [00:19<00:00,  2.72s/it]


In [27]:
agg_stats = defaultdict(lambda: defaultdict(lambda: dict()))
for machine, stats in machine_stats.items():
    dev_eval, machine_type, label = machine.split("_")
    for metric, value in stats.items():
        agg_stats[metric][machine_type][label] = value.mean(axis=0)

In [28]:
for metric, metric_machine_stats in agg_stats.items():
    xlabel = "Mel bins" if metric in metrics_dict["frequency"] else "Frames"
    ylabel = "Frequency" if metric in metrics_dict["time_frequency"] else "Power"
    for machine_type, machine_type_stats in metric_machine_stats.items():
        plt.figure(figsize=(5, 5))
        plt.scatter(x=range(machine_type_stats["normal"].shape[0]), y=machine_type_stats["normal"], label="normal", s=10)
        plt.scatter(x=range(machine_type_stats["anomaly"].shape[0]), y=machine_type_stats["anomaly"], label="anomaly", s=10)
        plt.title(metric.replace("_", " ").title())
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(plot_dir, "label", "label_" + metric + "_" + machine_type + ".png"))
        plt.close()