#### Imports and setup

In [6]:
from pathlib import Path

from trams.utils import change_cwd_for_jupyter

change_cwd_for_jupyter()

from trams.datamodule import TramsDataModule
from trams.dataset import print_labels_statistics, print_metadata_statistics
from trams.predict import predict_trams_from_wav, validate
from trams.config import RAW_DATA_DIR_TEST

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Load wav files into dataset and apply processing steps

In [7]:
dm = TramsDataModule(batch_size=16, validation_split=0.1, max_length_secs=4, snr=5)
dataset = dm.prepare_data()
dataset

DatasetDict({
    train: Dataset({
        features: ['audio', 'label', 'sample_rate', 'num_frames', 'num_channels', 'bits_per_sample', 'label_name', 'path', 'spectrogram'],
        num_rows: 2696
    })
    validation: Dataset({
        features: ['audio', 'label', 'sample_rate', 'num_frames', 'num_channels', 'bits_per_sample', 'label_name', 'path', 'spectrogram'],
        num_rows: 300
    })
})

#### Check training data labels and metadata statistics

We can see that labels are not distributed evenly, which may potentially cause accuracy underperformance in several classes.

In [13]:
print_labels_statistics(dataset["train"])

Unnamed: 0_level_0,Unnamed: 1_level_0,count
label_name,label,Unnamed: 2_level_1
Accelerating_1_New,0,444
Accelerating_2_CKD_Long,1,152
Accelerating_3_CKD_Short,2,66
Accelerating_4_Old,3,369
Braking_1_New,4,343
Braking_2_CKD_Long,5,129
Braking_3_CKD_Short,6,56
Braking_4_Old,7,392
Negative,8,745


As expected, the number of channels, bits per sample and sample rate are all constant throughout the dataset.

In [12]:
print_metadata_statistics(dataset["train"])


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,count
sample_rate,bits_per_sample,num_channels,Unnamed: 3_level_1
22050,32,1,2696


#### Plot histogram of audio lengths in seconds in order to decide on the truncation threshold

We can see a bimodal distribution of audio lengths due to the double size of negative samples. We will take 4000 ms as the truncation limit and pad zeros randomly before and after signal.

In [8]:
dataset["train"].set_format("pandas")
sample_rate = int(dataset["train"][0]["sample_rate"])
(dataset["train"]["num_frames"] / sample_rate).plot.hist(backend="plotly")


#### Check Mel Spectrogram shape

In [9]:
dataset = dataset.with_format("torch")
dataset["train"][0]["spectrogram"].shape

torch.Size([64, 173])

#### Predict trams on test wav file

In [11]:
wav_file = RAW_DATA_DIR_TEST / "tram-2018-12-07-15-32-08.wav"
output_csv = Path.cwd() / "output.csv"
predict_trams_from_wav(wav_file, output_csv)


Loaded tram-2018-12-07-15-32-08.wav file spanning 300.0 seconds.
8    2960
dtype: int64


In [20]:
validate(validation_split=0.1, use_cache=True, max_length_secs=4, snr=5)


Accuracy: 0.917, Total items: 300
