from src.autoencoder.conv_autoencoder import ConvAutoencoder
from src.database import Database
from src.song import Song
import os
from soundfile import LibsndfileError
from audioread import NoBackendError
import torch

In [2]:
GENRES = ['classical', 'jazz', 'rock', 'hiphop', 'reggae', 'country', 'metal']
autoencoder = ConvAutoencoder()
autoencoder.load_state_dict(torch.load('./autoencoder.pth', map_location=torch.device('cpu')))
db = Database('./data/', GENRES, autoencoder=autoencoder, songs_per_genre=70)
db.calculate_index()

  return f(*args, **kwargs)


Song at path ./data/jazz/jazz.00054.wav contains invalid data


In [3]:
other_song = Song('./data/country/country.00000.wav')
other_song.precalculate_embedding(autoencoder)

In [4]:
db.calculate_distances(other_song, strategy='mfcc')

{'classical': 14005171.380368233,
 'jazz': 9615758.985298634,
 'rock': 5485655.035300732,
 'hiphop': 5116294.510719299,
 'reggae': 7877165.333296776,
 'country': 6791867.8783323765,
 'metal': 4442366.683151245}

In [5]:
db.calculate_distances(other_song, strategy='autoencoder')

{'classical': 17754.477492943293,
 'jazz': 13970.155619497838,
 'rock': 11017.145569801816,
 'hiphop': 11634.990981658339,
 'reggae': 12976.183367585641,
 'country': 10690.055497432393,
 'metal': 13225.27153517073}

In [6]:
class Tester:
    def __init__(self, root_dir: str, genres, autoencoder, songs_from, songs_to):
        self.autoencoder = autoencoder
        self.songs_to = songs_to
        self.songs_from = songs_from
        self.root_dir = root_dir
        self.genres = genres
        self.db_len = None

    def test_each_song(self, db: Database, verbose: bool = False):
        mfcc_succ = 0
        enc_succ = 0
        for song_path, genre in self.__iterate_songs():
            if verbose:
                print(song_path)
            try:
                song = Song(song_path, genre)
                song.precalculate_embedding(self.autoencoder)
                mfcc_genre = self.__get_matching_genre(db, song, 'mfcc')
                if mfcc_genre == genre:
                    if verbose:
                        print('\t[MFCC] match')
                    mfcc_succ += 1
                else:
                    if verbose:
                        print('\t[MFCC] fail')
                enc_genre = self.__get_matching_genre(db, song, 'autoencoder')
                if enc_genre == genre:
                    if verbose:
                        print('\t[ENC] match')
                    enc_succ += 1
                else:
                    if verbose:
                        print('\t[ENC] fail')
            except LibsndfileError:
                print('\tSKIP - contains invalid data')
            except NoBackendError:
                print('\tSKIP - contains invalid data')
        return mfcc_succ / self.db_len, enc_succ / self.db_len

    @staticmethod
    def __get_matching_genre(db, song, strategy):
        distances = db.calculate_distances(song, strategy=strategy)
        sorted_distances = {k: v for k, v in sorted(distances.items(), key=lambda item: item[1])}
        match = list(sorted_distances.keys())[0]
        return match

    def __iterate_songs(self):
        songs = {}
        for genre in self.genres:
            partial_arr = [self.__path_for_song(genre, idx) for idx in range(self.songs_from, self.songs_to)]
            songs[genre] = partial_arr
        arr = [item for items in songs.values() for item in items]
        self.db_len = len(arr)
        for genre, items in songs.items():
          for item in items:
              yield item, genre

    def __path_for_song(self, genre: str, idx: int):
      return os.path.join(self.root_dir, genre, f"{genre}.{'{:05d}'.format(idx)}.wav")

In [7]:
tester = Tester('./data/', GENRES, autoencoder, 70, 100)

In [8]:
mfcc_rate, enc_rate = tester.test_each_song(db)
print('MFCC: {0:.2f}%'.format(mfcc_rate * 100))
print('ENC: {0:.2f}%'.format(enc_rate * 100))

MFCC: 38.57%
ENC: 33.81%
