In [1]:
import os
import glob
import pytorch_lightning as pl
from dataset import MusicDataModule
from model import MusicAutoEncoder

import utils
import librosa
from tqdm import tqdm
import torch
import numpy as np

pl.seed_everything(69)
%load_ext autoreload
%autoreload 2

In [2]:
AUDIO_DIR = os.environ.get("AUDIO_DIR_SMALL")
tracks = utils.load("data/fma_metadata/tracks.csv")

tracks = tracks[tracks['set', 'subset'] <= 'small']
# remove genres with no tracks (should be 8 genres resulting)
tracks[("track", "genre_top")] = tracks[("track", "genre_top")].cat.remove_unused_categories()

In [3]:
dataset = MusicDataModule(batch_size=32, num_workers=12, rebuild_existing=False)
dataset.prepare_data()
dataset.setup()
val_split = dataset.val
len(val_split)

NA values per feature:
track  genre_top    0
dtype: int64
Total clean tracks: 8000
Genre counts:
(track, genre_top)
Electronic       1000
Experimental     1000
Folk             1000
Hip-Hop          1000
Instrumental     1000
International    1000
Pop              1000
Rock             1000
dtype: int64
Total 8 genre features


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=8000.0), HTML(value='')))

Track 99134 broke with error 
Track 108925 broke with error 
Track 98567 broke with error Track ./data/fma_large/098/098567.mp3 has duration 0.5104761904761905, not 29.5. Rejecting.
Track 98565 broke with error Track ./data/fma_large/098/098565.mp3 has duration 1.6076190476190477, not 29.5. Rejecting.
Track 98569 broke with error Track ./data/fma_large/098/098569.mp3 has duration 1.5292517006802722, not 29.5. Rejecting.
Track 133297 broke with error 



900

In [4]:
load_version = 1

checkpoint_path = os.path.join("lightning_logs", f"version_{load_version}", "checkpoints", "*.ckpt")
checkpoint_path = sorted(glob.glob(checkpoint_path))[-1] # latest checkpoint

model = MusicAutoEncoder.load_from_checkpoint(checkpoint_path)
model

MusicAutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (9): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))
    (10): ReLU()
    (11): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
    (13): ReLU()
    (14): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (15): Conv2d(64, 32, kernel_size=(1, 5), stride=(1, 1))
    (16): ReLU()
    (17): AdaptiveMaxPool

In [5]:
val_ids, val_encodings = [], []
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
for i, (song_x, _) in enumerate(tqdm(val_split)):
    path = val_split.X[i]
    song_id = int(path.split("/")[4])
    encoding = model(song_x.unsqueeze(0).to(device))
    val_ids.append(song_id)
    val_encodings.append(encoding.data.cpu())

val_encodings = torch.stack(val_encodings, dim=0).squeeze()
val_ids = np.array(val_ids, dtype=int)

100%|██████████| 900/900 [00:17<00:00, 51.27it/s]


In [6]:
def sim_matrix(a, b, eps=1e-8):
    """
    added eps for numerical stability
    """
    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
    a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
    b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
    sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
    return sim_mt

encoding_sims = sim_matrix(val_encodings, val_encodings)
print(encoding_sims.shape)

torch.Size([900, 900])


In [21]:
def get_top_k_similar_ids(track_id, sim_top_k):
    x = dataset.build_features_for_track_id(track_id)
    x = torch.FloatTensor(x).to(device).unsqueeze(0)
    encoding = model(x).detach().cpu()
    encoding = encoding.squeeze().unsqueeze(0)
    track_sims = sim_matrix(encoding, val_encodings)[0]
    sorted_by_sim = torch.argsort(track_sims, descending=True)
    k_most_similar = sorted_by_sim[1:sim_top_k+1] # skip the first because that's track_id
    k_most_similar_tracks = val_ids[k_most_similar]
    return k_most_similar_tracks

sim_top_k = 3
for track_id in val_ids[:10]:
    genre = tracks.loc[track_id][("track", "genre_top")]
    k_most_similar_ids = get_top_k_similar_ids(track_id, sim_top_k)
    similar_track_genres = tracks.loc[k_most_similar_ids][[("track", "genre_top")]]
    similar_track_genres.columns = ["genre_top"]
    print(f"Top {sim_top_k} most similar tracks to track {track_id} (genre {genre}):")
    display(similar_track_genres)

Top 3 most similar tracks to track 133970 (genre Instrumental):


Unnamed: 0_level_0,genre_top
track_id,Unnamed: 1_level_1
72612,Folk
130667,Folk
73658,Folk


Top 3 most similar tracks to track 45393 (genre Instrumental):


Unnamed: 0_level_0,genre_top
track_id,Unnamed: 1_level_1
110208,Experimental
120778,Instrumental
85438,Folk


Top 3 most similar tracks to track 143532 (genre Electronic):


Unnamed: 0_level_0,genre_top
track_id,Unnamed: 1_level_1
84605,Pop
92125,Pop
139226,Electronic


Top 3 most similar tracks to track 130758 (genre Folk):


Unnamed: 0_level_0,genre_top
track_id,Unnamed: 1_level_1
121474,Instrumental
52642,Electronic
126717,Instrumental


Top 3 most similar tracks to track 62751 (genre Pop):


Unnamed: 0_level_0,genre_top
track_id,Unnamed: 1_level_1
91084,Electronic
141568,Hip-Hop
110106,Electronic


Top 3 most similar tracks to track 56028 (genre Electronic):


Unnamed: 0_level_0,genre_top
track_id,Unnamed: 1_level_1
148513,Pop
134794,Instrumental
67360,Pop


Top 3 most similar tracks to track 104357 (genre Folk):


Unnamed: 0_level_0,genre_top
track_id,Unnamed: 1_level_1
46930,Pop
134922,International
69567,Folk


Top 3 most similar tracks to track 121922 (genre International):


Unnamed: 0_level_0,genre_top
track_id,Unnamed: 1_level_1
108745,Hip-Hop
57177,Folk
38961,Pop


Top 3 most similar tracks to track 9560 (genre Rock):


Unnamed: 0_level_0,genre_top
track_id,Unnamed: 1_level_1
47032,Rock
138413,Folk
140260,Electronic


Top 3 most similar tracks to track 58174 (genre Pop):


Unnamed: 0_level_0,genre_top
track_id,Unnamed: 1_level_1
114268,Hip-Hop
141568,Hip-Hop
53937,Pop


In [40]:
import IPython

def get_audio_from_id(track_id):
    track_filename = utils.get_audio_path(AUDIO_DIR, int(track_id))
    audio_data, sample_rate = librosa.load(track_filename, mono=True, duration=29.5)
    return track_filename, audio_data, sample_rate

track_filename, _, _ = get_audio_from_id(list(val_songs.keys())[0])
IPython.display.Audio(track_filename)
