#### Imports and setup

In [104]:
from trams.utils import change_cwd_for_jupyter

change_cwd_for_jupyter()

from trams.datamodule import TramsDataModule
from trams.dataset import print_labels_statistics, print_metadata_statistics

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Load wav files into dataset and apply processing steps

In [105]:
dm = TramsDataModule(batch_size=16, validation_split=0.1)
dataset = dm.prepare_data()
dataset

DatasetDict({
    train: Dataset({
        features: ['audio', 'label', 'sample_rate', 'num_frames', 'num_channels', 'bits_per_sample', 'label_name', 'path', 'spectrogram'],
        num_rows: 2696
    })
    validation: Dataset({
        features: ['audio', 'label', 'sample_rate', 'num_frames', 'num_channels', 'bits_per_sample', 'label_name', 'path', 'spectrogram'],
        num_rows: 300
    })
})

#### Check training data labels and metadata statistics

We can see that labels are not distributed evenly, which may potentially cause accuracy underperformance in several classes.

In [106]:
print_labels_statistics(dataset["train"])

Unnamed: 0_level_0,Unnamed: 1_level_0,count
label_name,label,Unnamed: 2_level_1
Accelerating_1_New,0,444
Accelerating_2_CKD_Long,1,152
Accelerating_3_CKD_Short,2,66
Accelerating_4_Old,3,369
Braking_1_New,4,343
Braking_2_CKD_Long,5,129
Braking_3_CKD_Short,6,56
Braking_4_Old,7,392
Negative,8,745


As expected, the number of channels, bits per sample and sample rate are all constant throughout the dataset.

In [107]:
print_metadata_statistics(dataset["train"])


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,count
sample_rate,bits_per_sample,num_channels,Unnamed: 3_level_1
22050,32,1,2696


#### Plot histogram of audio lengths in seconds in order to decide on the truncation threshold

We can see a bimodal distribution of audio lengths due to the double size of negative samples. We will take 4000 ms as the truncation limit and pad zeros randomly before and after signal.

In [108]:
dataset["train"].set_format("pandas")
sample_rate = int(dataset["train"][0]["sample_rate"])
(dataset["train"]["num_frames"] / sample_rate).plot.hist(backend="plotly")


#### Check Mel Spectrogram shape

In [109]:
dataset = dataset.with_format("torch")
dataset["train"][0]["spectrogram"].shape

torch.Size([64, 173])

#### Predict trams on test wav file

In [110]:
from pathlib import Path

import torchaudio
import torch
import pandas as pd
from torchaudio.backend.common import AudioMetaData
from torchaudio import transforms


from trams.config import (
    RAW_DATA_DIR_TEST,
    NUM_FFT,
    NUM_MELS,
    MAX_DB,
    MAX_LENGTH_SECS,
    ONE_TENTH_SEC,
    TRAINDED_MODEL_PATH,
)
from trams.model import TramsAudioClassifier, ModelConfig


def predict_trams_from_wav(input_wav: Path, output_csv: Path):
    audio, sample_rate = torchaudio.load(input_wav)
    audio_metadata: AudioMetaData = torchaudio.info(input_wav)
    if sample_rate != 22050 or audio_metadata.num_channels != 1:
        raise ValueError("Current model only supports sample rate of 22050 and mono channel.")

    num_frames = audio_metadata.num_frames
    num_seconds = num_frames / sample_rate
    print(f"Loaded {input_wav.parts[-1]} file spanning {num_seconds} seconds.")

    mel_spectrogram = transforms.MelSpectrogram(sample_rate, n_fft=NUM_FFT, n_mels=NUM_MELS)
    amplitude_transformer = transforms.AmplitudeToDB(top_db=MAX_DB)
    
    slice_frames = int(MAX_LENGTH_SECS * sample_rate)
    offset_frames = int(ONE_TENTH_SEC * sample_rate)
    slices = int((num_frames - slice_frames) / offset_frames)

    checkpoint  = torch.load(TRAINDED_MODEL_PATH)
    model = TramsAudioClassifier(ModelConfig())
    model.load_state_dict(checkpoint["state_dict"])
    
    predictions = []
    with torch.no_grad():
        for i in range(slices):
            slice = audio[:, i * offset_frames : i * offset_frames + slice_frames]
            spectrogram = amplitude_transformer(mel_spectrogram(slice))
            output: torch.Tensor = model(spectrogram)
            prediction = torch.argmax(output.softmax(dim=1), dim=1).item()
            predictions.append(prediction)
    
    print(pd.Series(predictions).value_counts())
            
            

wav_file = RAW_DATA_DIR_TEST / "tram-2018-12-07-15-32-08.wav"
predict_trams_from_wav(wav_file, Path.cwd())


Loaded tram-2018-12-07-15-32-08.wav file spanning 300.0 seconds.
8    2960
dtype: int64
