# Spectrogram-based resnet digit classifier using fastai

In [1]:
%matplotlib inline
from pathlib import Path
import librosa
from librosa import display
import IPython.display
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import multiprocessing as mp
import matplotlib.pyplot as plt

In [2]:
# Path to data directory and recordings directory
DATA = Path("../../data/words/")
RECORDINGS = DATA/"audio-recordings"

In [3]:
def load_wav(filename):
    return librosa.core.load(filename, sr=None, mono=True)

In [4]:
labels_df = pd.read_csv(DATA/"labels.csv")

In [5]:
def save_spectrogram(wav_path, save_path):
    plt.clf()
    data, sr = load_wav(wav_path)
    trimmed, idx = librosa.effects.trim(data, top_db=30)
    spec = librosa.feature.melspectrogram(trimmed, sr, n_fft=2048, hop_length=256)
    librosa.display.specshow(librosa.core.power_to_db(spec))
    cur_axes = plt.gca().set_axis_off()
    plt.savefig(save_path, bbox_inches="tight", pad_inches=0)

In [6]:
def process_file(fname):
    save_spectrogram(DATA/"processed"/fname, DATA/"spectrograms"/fname.replace(".wav", ".png"))

In [7]:
!rm ../../data/words/spectrograms/*.png

if not (DATA/"spectrograms").exists():
    (DATA/"spectrograms").mkdir()

In [8]:
pool = mp.Pool(4)  # Use 4 processes

for fname in tqdm(labels_df["filename"]):
    pool.apply(process_file, args=(fname,))
    
pool.close()

HBox(children=(FloatProgress(value=0.0, max=3825.0), HTML(value='')))




---

# Train models

In [7]:
from fastai.vision import *
from fastai import metrics
import torchvision

In [8]:
# Create new dataframe which has filenames ending in .png instead of .wav
spec_labels_df = labels_df.copy()
spec_labels_df["filename"] = spec_labels_df["filename"].apply(lambda x: x.replace(".wav", ".png"))

In [9]:
def get_data(valid_split_no, bs=32):
    df = spec_labels_df[spec_labels_df[f"valid{valid_split_no}"] != -1]
    return (ImageList.from_df(df, DATA/"spectrograms")
            .split_from_df(col=f"valid{valid_split_no}")
            .label_from_df(cols="label")
            .databunch(bs=bs)
           )

In [12]:
VALID_SPLITS = [1, 2, 3]
MODELS = [
    #models.resnet18,
    models.resnet34,
    models.resnet50,
    #models.resnet101,
    models.densenet121,
    models.densenet161
]
PRETRAINED = [True, False]
USE_MIXUP = [True, False]

In [13]:
for valid_split in VALID_SPLITS:
    print(f"************************************\nSPLIT {valid_split}\n************************************")
    for model in MODELS:
        for pretrained in PRETRAINED:
            for use_mixup in USE_MIXUP:
                # Free up GPU memory
                print(f"model={model.__name__}, pretrained={pretrained}, mixup={use_mixup}")
                gc.collect()
                torch.cuda.empty_cache()

                learn = cnn_learner(get_data(valid_split, bs=32),
                                    model,
                                    pretrained=pretrained,
                                    metrics=[accuracy, metrics.Precision(average="macro"), metrics.Recall(average="macro")],
                                    # Using macro average precision and recall
                                    callback_fns=[ShowGraph,
                                                  partial(callbacks.CSVLogger, filename=f"{model.__name__},pretrained={int(pretrained)},mixup={int(use_mixup)}")
                                                 ]).to_fp16()
                if use_mixup:
                    learn = learn.mixup()

                learn.fit_one_cycle(100, max_lr=1e-2)
                learn.export(f"Split{valid_split}_{model.__name__},pretrained={int(pretrained)},mixup={int(use_mixup)}")
                del learn

************************************
SPLIT 1
************************************
model=resnet18, pretrained=True, mixup=True
epoch     train_loss  valid_loss  accuracy  precision  recall    time    
0         5.798199    4.659916    0.037500  nan        0.037500  00:08     
1         5.237535    4.463215    0.043056  nan        0.043056  00:07     
2         4.824464    4.332203    0.048611  nan        0.048611  00:07     
3         4.553841    4.287099    0.054167  nan        0.054167  00:07     
4         4.284824    4.237043    0.059722  0.054026   0.059722  00:07     
5         4.100189    4.131279    0.066667  nan        0.066667  00:07     
6         3.901414    4.159503    0.066667  nan        0.066667  00:07     
7         3.765960    4.139883    0.073611  nan        0.073611  00:07     
8         3.701023    4.146627    0.090278  nan        0.090278  00:07     
9         3.655941    4.090623    0.080556  nan        0.080556  00:07     
10        3.626409    4.118668    0.0763

4         3.286608    4.225522    0.066667  nan        0.066667  00:07     
5         2.992630    4.241886    0.063889  nan        0.063889  00:07     
6         2.736857    4.388347    0.050000  0.037545   0.050000  00:07     
7         2.562820    4.498033    0.054167  nan        0.054167  00:07     
8         2.530335    4.593337    0.080556  nan        0.080556  00:07     
9         2.589300    4.797667    0.061111  nan        0.061111  00:07     
10        2.598200    4.642245    0.061111  0.061270   0.061111  00:07     
11        2.657790    4.781180    0.068056  nan        0.068056  00:07     
12        2.821449    4.938893    0.075000  nan        0.075000  00:07     
13        2.890401    4.774429    0.076389  nan        0.076389  00:07     
14        2.841595    5.187504    0.070833  nan        0.070833  00:07     
15        2.943459    4.604561    0.084722  nan        0.084722  00:07     
16        2.922383    4.599244    0.094444  nan        0.094444  00:07     
17        2.

11        4.282626    3.988843    0.075000  nan        0.075000  00:09     
12        4.027787    4.850452    0.062500  nan        0.062500  00:09     
13        3.897066    4.012793    0.076389  nan        0.076389  00:09     
14        3.754148    3.630597    0.118056  nan        0.118056  00:09     
15        3.581100    5.621696    0.044444  nan        0.044444  00:09     
16        3.364202    3.669604    0.131944  nan        0.131944  00:09     
17        3.284334    3.496647    0.145833  nan        0.145833  00:09     
18        3.149924    3.475070    0.140278  nan        0.140278  00:09     
19        3.080962    6.390936    0.029167  nan        0.029167  00:09     
20        2.980977    3.633442    0.145833  nan        0.145833  00:09     
21        2.888712    3.327877    0.179167  nan        0.179167  00:09     
22        2.818577    3.650597    0.163889  nan        0.163889  00:09     
23        2.668182    3.098652    0.211111  nan        0.211111  00:09     
24        2.

18        1.934137    4.176290    0.179167  nan        0.179167  00:09     
19        1.862199    4.246413    0.144444  nan        0.144444  00:09     
20        1.644065    3.292498    0.245833  nan        0.245833  00:09     
21        1.578485    4.729488    0.175000  nan        0.175000  00:09     
22        1.417178    3.623028    0.261111  nan        0.261111  00:09     
23        1.332195    5.077959    0.145833  nan        0.145833  00:09     
24        1.179957    4.208917    0.194444  nan        0.194444  00:09     
25        1.106235    4.018519    0.233333  nan        0.233333  00:09     
26        1.015506    4.034355    0.251389  nan        0.251389  00:09     
27        0.894108    3.964525    0.277778  nan        0.277778  00:09     
28        0.867856    4.558999    0.226389  nan        0.226389  00:09     
29        0.851641    5.044768    0.194444  nan        0.194444  00:09     
30        0.755077    4.301697    0.259722  nan        0.259722  00:09     
31        0.

25        2.814758    3.353134    0.177778  nan        0.177778  00:11     
26        2.819920    3.182096    0.208333  nan        0.208333  00:11     
27        2.735967    3.132349    0.209722  nan        0.209722  00:11     
38        1.793312    2.864043    0.358333  nan        0.358333  00:45     
39        1.799658    2.626933    0.355556  nan        0.355556  00:45     
40        1.799417    2.369328    0.405556  0.458716   0.405556  00:45     
41        1.770010    2.586621    0.365278  nan        0.365278  00:45     
42        1.749387    2.384364    0.420833  0.535204   0.420833  00:45     
43        1.719476    2.462113    0.377778  0.488405   0.377778  00:45     
44        1.648453    2.399096    0.397222  0.456076   0.397222  00:45     
45        1.708221    2.412323    0.404167  nan        0.404167  00:46     
46        1.713530    2.504405    0.375000  0.454383   0.375000  00:46     
47        1.647959    2.506886    0.394444  nan        0.394444  00:45     
48        1.

42        0.521083    4.145474    0.327778  nan        0.327778  00:45     
43        0.463152    3.810530    0.343056  0.433478   0.343056  00:45     
44        0.457513    3.936342    0.362500  0.420911   0.362500  00:45     
45        0.424602    4.299943    0.352778  nan        0.352778  00:45     
46        0.412185    4.079800    0.370833  0.445607   0.370833  00:45     
47        0.416420    4.110639    0.384722  nan        0.384722  00:45     
48        0.378118    4.063767    0.379167  0.469577   0.379167  00:45     
49        0.365915    4.094425    0.373611  0.456728   0.373611  00:45     
50        0.317210    3.853395    0.391667  0.429020   0.391667  00:45     
51        0.317054    3.835014    0.370833  0.423635   0.370833  00:45     
52        0.301419    3.778551    0.383333  nan        0.383333  00:45     
53        0.268808    5.036678    0.329167  nan        0.329167  00:45     
54        0.263499    4.514032    0.362500  0.462826   0.362500  00:45     
55        0.

RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 7.79 GiB total capacity; 5.63 GiB already allocated; 5.69 MiB free; 5.76 GiB reserved in total by PyTorch)