In [1]:
from src.autoencoder.conv_autoencoder import ConvAutoencoder
from src.database import Database
from src.song import Song
import os
import torch

In [2]:
GENRES = ['classical', 'jazz', 'rock', 'hiphop', 'reggae', 'country', 'metal']
autoencoder = ConvAutoencoder()
autoencoder.load_state_dict(torch.load('./autoencoder.pth', map_location=torch.device('cpu')))
db = Database('./data/', GENRES, autoencoder=autoencoder, songs_per_genre=40)
db.calculate_index()

In [3]:
MFCC = 'MFCC'
ENCODER = 'ENCODER'
class Tester:
    def __init__(self, root_dir: str, genres, autoencoder: ConvAutoencoder, songs_from: int, songs_to: int):
        self.autoencoder = autoencoder
        self.songs_to = songs_to
        self.songs_from = songs_from
        self.root_dir = root_dir
        self.genres = genres
        self.kas = [1,2,3]
        self.res = {MFCC: {1: 0, 2: 0, 3: 0},
                    ENCODER: {1: 0, 2: 0, 3: 0}}
        self.db_len = None

    def test_each_song(self, db: Database):
        for song_path, genre in self.__iterate_songs():
            song = Song(song_path, genre)
            song.precalculate_embedding(self.autoencoder)
            self.__handle_song(db, song, MFCC)
            self.__handle_song(db, song, ENCODER)
        for strategy in [MFCC, ENCODER]:
            print(strategy)
            for k in [1,2,3]:
                print(f"\tP@{k}: {'{0:.2f}%'.format(self.res[strategy][k] / self.db_len * 100)}")
            print('-'*20)

    def __handle_song(self, db: Database, song: Song, strategy: str):
        predicted_genres = self.__get_matching_genres(db, song, strategy)
        for k in [1, 2, 3]:
            if song.genre in predicted_genres[:k]:
                self.res[strategy][k] += 1

    @staticmethod
    def __get_matching_genres(db: Database, song, strategy):
        strategy_name = {MFCC: 'mfcc', ENCODER: 'autoencoder'}[strategy]
        distances = db.calculate_distances(song, strategy=strategy_name)
        sorted_distances = {k: v for k, v in sorted(distances.items(), key=lambda item: item[1])}
        return list(sorted_distances.keys())

    def __iterate_songs(self):
        songs = {}
        for genre in self.genres:
            partial_arr = [self.__path_for_song(genre, idx) for idx in range(self.songs_from, self.songs_to)]
            songs[genre] = partial_arr
        arr = [item for items in songs.values() for item in items]
        self.db_len = len(arr)
        for genre, items in songs.items():
          for item in items:
              yield item, genre

    def __path_for_song(self, genre: str, idx: int):
      return os.path.join(self.root_dir, genre, f"{genre}.{'{:05d}'.format(idx)}.wav")

In [4]:
tester = Tester('./data/', GENRES, autoencoder, 50, 60)

In [5]:
tester.test_each_song(db)

MFCC
	P@1: 35.71%
	P@2: 48.57%
	P@3: 65.71%
--------------------
ENCODER
	P@1: 27.14%
	P@2: 51.43%
	P@3: 68.57%
--------------------
