In [1]:
import pandas as pd
import numpy as np
import tensorflow as tf
import scipy.io.wavfile
import soundfile as sf
import os
import warnings
import joblib
from tqdm import tqdm
warnings.filterwarnings("ignore")

In [2]:
# loading amplitude of frequency 0Hz - 11025Hz    
def load_fft_amplitude(path):
    y, signal = scipy.io.wavfile.read(path) 
        
    fft_spectrum = np.fft.rfft(signal)
    freq = np.fft.rfftfreq(signal.size, d=1./y)
    fft_spectrum_abs = np.abs(fft_spectrum)
    
    data = np.column_stack((freq, np.round(fft_spectrum_abs)))
    tmpdf = pd.DataFrame(data, columns=['freq', 'amp'])
    tmpdf.loc[:, 'freq'] = np.round(tmpdf['freq'])
    
    df = pd.DataFrame(
                np.array(tmpdf.groupby('freq').max('amp')).reshape(1, -1),
                columns=[str(i) for i in range(0, 11025+1)],
            )
    
    x = np.array(df, dtype=np.float32)
    
    loaded_model = joblib.load(STANDLISER_PATH)
    x = loaded_model.transform(x)
    x = np.reshape(x, (x.shape[0], x.shape[1], 1))
    return x

In [3]:
# predict music genre
def get_prediction(x, interpreter):
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_detais = interpreter.get_output_details()
    
    interpreter.set_tensor(input_details[0]['index'], x)
    interpreter.invoke()
    result = interpreter.get_tensor(output_detais[0]['index'])
    return result


In [4]:
TFLITE_MODEL_PATH = 'model2.tflite'
STANDLISER_PATH = 'scaler.pkl'
STORE_DIR = 'tmp'
DATADIR = 'sound_data/genres/'
GENRES = ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']

In [5]:
interpreter = tf.lite.Interpreter(model_path=TFLITE_MODEL_PATH)

In [6]:
not_match = 0
unmatch = []
for i in range(10):
    for file in tqdm(os.listdir(f"{DATADIR}/{GENRES[i]}")):
        path = f"{DATADIR}/{GENRES[i]}/{file}"
        x = load_fft_amplitude(path)
        if np.argmax(get_prediction(x, interpreter)) != i:
            not_match += 1
            unmatch.append(path)        
            
print(f"Number of unmatched: {not_match}") 
print(f"Overall accuracy: {(1000-not_match)/1000*100}%")
print(f"List of unmatched songs: {unmatch}")  

100%|██████████| 100/100 [00:37<00:00,  2.67it/s]
100%|██████████| 100/100 [00:38<00:00,  2.63it/s]
100%|██████████| 100/100 [00:42<00:00,  2.36it/s]
100%|██████████| 100/100 [00:27<00:00,  3.64it/s]
100%|██████████| 100/100 [00:31<00:00,  3.18it/s]
100%|██████████| 100/100 [00:38<00:00,  2.62it/s]
100%|██████████| 100/100 [00:25<00:00,  3.96it/s]
100%|██████████| 100/100 [00:18<00:00,  5.49it/s]
100%|██████████| 100/100 [00:28<00:00,  3.57it/s]
100%|██████████| 100/100 [00:40<00:00,  2.49it/s]

Number of unmatched: 1
Overall accuracy: 99.9%
List of unmatched songs: ['sound_data/genres//rock/rock.00016.wav']



