In [1]:
import pickle
import torch
import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
from audiocraft.data.audio_utils import convert_audio
from audiocraft.modules.conditioners import ConditioningAttributes

try:
    with open('embeddings.pickle', 'rb') as f:
        embs = pickle.load(f)
        print('cached embeddings', len(embs))
except FileNotFoundError:
    print('empty embeddings')
    embs = {}

objc[69559]: Class AVFFrameReceiver is implemented in both /Users/maxj/miniconda3/envs/audiocraft/lib/libavdevice.58.8.100.dylib (0x1183e4798) and /Users/maxj/miniconda3/envs/audiocraft/lib/python3.9/site-packages/av/.dylibs/libavdevice.59.7.100.dylib (0x13f388778). One of the two will be used. Which one is undefined.
objc[69559]: Class AVFAudioReceiver is implemented in both /Users/maxj/miniconda3/envs/audiocraft/lib/libavdevice.58.8.100.dylib (0x1183e47e8) and /Users/maxj/miniconda3/envs/audiocraft/lib/python3.9/site-packages/av/.dylibs/libavdevice.59.7.100.dylib (0x13f3887c8). One of the two will be used. Which one is undefined.


cached embeddings 1249


In [2]:
list(embs.values())[0].shape

torch.Size([1536])

In [3]:
import pandas as pd
import sqlite3

con = sqlite3.connect("plaintext.db")
df_tracks = pd.read_sql_query("SELECT `ID` as track_id, `FolderPath` as path FROM djmdContent WHERE `rb_local_deleted` = 0;", con)

df_tracks.head()

Unnamed: 0,track_id,path
0,178582527,/Users/maxj/Music/PioneerDJ/Sampler/OSC_SAMPLE...
1,35341928,/Users/maxj/Music/PioneerDJ/Sampler/OSC_SAMPLE...
2,196758694,/Users/maxj/Music/PioneerDJ/Sampler/OSC_SAMPLE...
3,96986769,/Users/maxj/Music/PioneerDJ/Sampler/OSC_SAMPLE...
4,78147589,/Users/maxj/Music/PioneerDJ/Demo Tracks/Demo T...


In [4]:
df_tags = pd.read_sql_query("SELECT `ID` as tag_id, `Name` as name FROM djmdMyTag WHERE `rb_local_deleted` = 0", con)
df_tag_tracks = pd.read_sql_query("SELECT DISTINCT `MyTagID` as tag_id, `ContentID` as track_id FROM djmdSongMyTag WHERE `rb_local_deleted` = 0", con)

In [5]:
df_tags

Unnamed: 0,tag_id,name
0,1,Genre
1,866765180,Bass
2,2,Components
3,2729535446,Vocal
4,3,Situation
...,...,...
56,726849173,Twerk
57,457161321,Glitch
58,3882634457,90s
59,1259368789,Techno


In [6]:
df_tag_tracks

Unnamed: 0,tag_id,track_id
0,3300050053,234114820
1,3626785568,234114820
2,860746464,234114820
3,1394931706,234955773
4,1084846050,234955773
...,...,...
4792,3740743506,83506799
4793,2747842722,144034496
4794,866765180,221658915
4795,969379338,191343646


In [7]:
# Merge to get tags associated with each track
merged_df = pd.merge(df_tag_tracks, df_tags, on='tag_id')
merged_df = pd.merge(merged_df, df_tracks, on='track_id')

# Pivot the table
pivot_df = merged_df.pivot(index='path', columns='name', values='tag_id')

# Replace non-NA values with True and NA values with False
pivot_df = pivot_df.notna()

# Reset index for final DataFrame
pivot_df.reset_index(inplace=True)

pivot_df.columns.name = None
pivot_df

Unnamed: 0,path,90s,Aftercare,Aggro,Alt,Ambient,Apex,Bass,Bassline,Beats,...,Sunrise,Sunset,Techno,Transition,Trap,Tribal,Trip,Twerk,Vocal,Wave
0,/Users/maxj/Dropbox/MusicBackup/Live Sets/2022...,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
1,/Users/maxj/Dropbox/MusicBackup/Live Sets/2022...,False,False,False,False,False,False,True,False,False,...,False,True,False,False,True,False,False,False,False,False
2,/Users/maxj/Music/Downloads/Bandcamp/00000oooo...,False,False,False,False,False,False,False,False,False,...,False,True,False,False,False,False,False,False,False,False
3,/Users/maxj/Music/Downloads/Bandcamp/Digital A...,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
4,/Users/maxj/Music/Downloads/Bandcamp/Falcons -...,False,False,False,False,False,False,True,False,False,...,False,False,False,False,False,False,False,False,True,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1112,/Users/maxj/Music/Rips/Spotify/AIFF/voljum - c...,False,False,False,False,False,False,False,False,False,...,False,True,False,False,False,False,True,False,False,False
1113,/Users/maxj/Music/Rips/Spotify/AIFF/warner cas...,False,False,False,False,False,False,False,False,False,...,False,True,False,False,False,False,False,False,False,False
1114,/Users/maxj/Music/Rips/Spotify/AIFF/wex - END ...,False,False,False,False,False,False,False,False,False,...,False,False,False,False,True,False,False,False,False,False
1115,/Users/maxj/Music/Rips/Spotify/AIFF/yune pinku...,False,False,False,False,False,False,False,True,False,...,False,True,False,False,False,False,False,False,True,False


In [8]:
filtered_df = pivot_df[pivot_df['path'].isin(embs)]
filtered_df

Unnamed: 0,path,90s,Aftercare,Aggro,Alt,Ambient,Apex,Bass,Bassline,Beats,...,Sunrise,Sunset,Techno,Transition,Trap,Tribal,Trip,Twerk,Vocal,Wave
0,/Users/maxj/Dropbox/MusicBackup/Live Sets/2022...,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
1,/Users/maxj/Dropbox/MusicBackup/Live Sets/2022...,False,False,False,False,False,False,True,False,False,...,False,True,False,False,True,False,False,False,False,False
2,/Users/maxj/Music/Downloads/Bandcamp/00000oooo...,False,False,False,False,False,False,False,False,False,...,False,True,False,False,False,False,False,False,False,False
3,/Users/maxj/Music/Downloads/Bandcamp/Digital A...,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
4,/Users/maxj/Music/Downloads/Bandcamp/Falcons -...,False,False,False,False,False,False,True,False,False,...,False,False,False,False,False,False,False,False,True,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1112,/Users/maxj/Music/Rips/Spotify/AIFF/voljum - c...,False,False,False,False,False,False,False,False,False,...,False,True,False,False,False,False,True,False,False,False
1113,/Users/maxj/Music/Rips/Spotify/AIFF/warner cas...,False,False,False,False,False,False,False,False,False,...,False,True,False,False,False,False,False,False,False,False
1114,/Users/maxj/Music/Rips/Spotify/AIFF/wex - END ...,False,False,False,False,False,False,False,False,False,...,False,False,False,False,True,False,False,False,False,False
1115,/Users/maxj/Music/Rips/Spotify/AIFF/yune pinku...,False,False,False,False,False,False,False,True,False,...,False,True,False,False,False,False,False,False,True,False


In [9]:
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import v_measure_score
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC

def test_classifier(X, y):
    # np.random.shuffle(y)
    svm_scores = cross_val_score(SVC(class_weight='balanced'), X, y, scoring='balanced_accuracy')
    # svm_scores = svm_scores[~np.isnan(svm_scores)]
    return svm_scores.mean()

scores = {}
for cat in pivot_df.columns[1:]:
    X = np.array([embs[f].numpy() for f in filtered_df['path']])
    y = np.array(filtered_df[cat])
    if sum(y == True) < 10:
        print("category", cat, "only", sum(y == True), "true")
        continue
    print(cat, sum(y == True))
    scores[cat] = test_classifier(X, y)

90s 12
Aftercare 123
Aggro 49
Alt 26
Ambient 36
Apex 27
Bass 189
Bassline 70
Beats 144
Breaks 136
Cinema 84
Corny 34
Dance 61
Dark 58
Day 268
Dirtybird 28
category Disco only 8 true
DnB 87
Dub 85
Dubstep 12
Feels 29
Funk 83
Future 18
category Glitch only 8 true
Grime 10
Hip Hop 118
House 280
IDM 49
Island 18
Jazz 92
Latin 55
Liquid 23
Lounge 71
Moody 210
Mystery 75
Nostalgic 103
Offbeat 33
Peak 235
Phonk 29
Piano 28
Pop 32
Prog 22
Rap 203
Rave 18
Sexy 27
Sparkle 36
Special 38
Sunrise 69
Sunset 341
category Techno only 5 true
Transition 18
Trap 212
Tribal 44
Trip 298
Twerk 11
Vocal 294
Wave 25


In [10]:
scores = pd.DataFrame.from_records(list(scores.items()), columns=["tag", "score"])
scores

Unnamed: 0,tag,score
0,90s,0.55
1,Aftercare,0.669282
2,Aggro,0.630144
3,Alt,0.515419
4,Ambient,0.530425
5,Apex,0.669908
6,Bass,0.634016
7,Bassline,0.63324
8,Beats,0.725318
9,Breaks,0.648359
