In [1]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns  

import wandb

In [2]:
api = wandb.Api()

[34m[1mwandb[0m: Currently logged in as: [33mkwatcharasupat[0m ([33mkwatcharasupat-gatech[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
import json
from pathlib import Path

results = []

for run in api.runs("kwatcharasupat-gatech/banda"):
    if run.group != "test runs - completed":
        continue

    print(run.name, run.state)
    config = run.config

    ckpt = config['ckpt_path']
    test_model_id = ckpt.split("/")[-3]

    if "test" not in config['data']:
        print("  No test set, skipping...")
        continue
    
    try:
        test_set = config['data']['test']['datasource'][0]['cls']
        print(test_model_id, test_set)
    except Exception as e:
        print(f"  Error getting test set: {e}, skipping...")
        continue

    save_path = Path(f"../results/{test_model_id}/{test_set}/detailed_results.csv")
    if save_path.exists():
        results.append(str(save_path))
        print(f"  {save_path} exists, skipping...")
        continue
    
    metrics = run.summary['test/metrics']
    table = run.logged_artifacts()[0]
    table_dir = table.download("../_artifacts")
    table_name = "test/metrics"
    table_path = f"{table_dir}/{table_name}.table.json"

    with open(table_path) as file:
        json_dict = json.load(file)

    df = pd.DataFrame(json_dict["data"], columns=json_dict["columns"])
    print(df.head())

    df['model'] = test_model_id
    df['test_set'] = test_set

    save_path.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(save_path, index=False)

    results.append(str(save_path))

bandit-mus64-l1snr-multi-adam-musdb18hq-vdbo-test-drums-1359 finished
9rg6syhm MUSDB18HQDatasource
  ../results/9rg6syhm/MUSDB18HQDatasource/detailed_results.csv exists, skipping...
bandit-mus64-l1snr-multi-adam-musdb18hq-vdbo-test-other-1359 finished
qtyn4z9v MUSDB18HQDatasource
  ../results/qtyn4z9v/MUSDB18HQDatasource/detailed_results.csv exists, skipping...
bandit-mus64-l1snr-multi-adam-moisesdb-vdbo-test-vdbo-1359 finished
xzslzg26 MoisesDBDatasource
  ../results/xzslzg26/MoisesDBDatasource/detailed_results.csv exists, skipping...
bandit-mus64-l1snr-multi-adam-moisesdb-vdbo-test-vocals-1359 finished
4vxw0l1g MoisesDBDatasource
  ../results/4vxw0l1g/MoisesDBDatasource/detailed_results.csv exists, skipping...
bandit-mus64-l1snr-multi-adam-musdb18hq-vdbo-test-bass-1359 finished
lhtpigg8 MUSDB18HQDatasource
  ../results/lhtpigg8/MUSDB18HQDatasource/detailed_results.csv exists, skipping...
bandit-mus64-l1snr-multi-adam-moisesdb-vdbo-test-other-1359 finished
qtyn4z9v MoisesDBDatasource


In [4]:
df = pd.concat([pd.read_csv(f) for f in results], ignore_index=True)

timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
# df.to_csv(f"../results/combined_detailed_results_{timestamp}.csv", index=False)

df = pd.melt(df, id_vars=['model', 'test_set', 'full_path'], var_name='metric', value_name='value')
df['test_stem'] = df['metric'].apply(lambda x: x.split('/')[0])
df['metric_name'] = df['metric'].apply(lambda x: x.split('/')[-1])


In [5]:
from pprint import pprint

In [6]:
unique_models = df['model'].unique()

model_df = []


for test_model_id in unique_models:
    run = api.run(f"kwatcharasupat-gatech/banda/{test_model_id}")
    config = run.config

    epoch = run.summary['epoch']

    # print(run.name)

    model_cls = config['model']['cls']
    tfmodel_config = config['model']['params']['tf_model']
    tfmodel_config_dict = eval(tfmodel_config)
    tfmodel_cls = tfmodel_config_dict['cls']

    # print(" Model class:", model_cls)
    # print(" TF Model class:", tfmodel_cls) 

    loss_use_dbm = "dbm" in run.name
    loss_use_zlimit = "l1snrz" in run.name

    loss = "l1snr"
    if loss_use_zlimit:
        loss += "z"
    if loss_use_dbm:
        loss += "+dbm"

    stems = config['model']['params']['stems']
    if len(stems) == 1:
        stems = stems[0]
        model_cls = "Single" + model_cls
    else:
        stems = "".join([s[0] for s in stems])

    training_set = config['data']['train']['datasource'][0]['cls']
    # training_set = "MoisesDB" if "MoisesDB" in training_set else training_set


    pretrained_encoder  = config['model']['params'].get('pretrained_encoder_ckpt_path', None) is not None

    # print(" Stems:", stems)

    model_df.append({
        'model': test_model_id,
        'model_cls': model_cls,
        'tfmodel_cls': tfmodel_cls,
        'training_set': training_set,
        'epoch': epoch,
        'training_stems': stems,
        'pretrained_encoder': pretrained_encoder,
        'loss': loss
    })

model_df = pd.DataFrame(model_df)

In [7]:
id_cols = ['model', 'model_cls', 'tfmodel_cls', 'training_set', 'training_stems', 'pretrained_encoder', 'loss', 'epoch']
sort_cols = ['model_cls', 'tfmodel_cls', 'training_set', 'training_stems', 'pretrained_encoder', 'loss', 'epoch']

In [8]:
model_df = model_df.sort_values(sort_cols)[id_cols]

drop_ids = ['8akgy5kl', '6932jblk']

model_df = model_df[~model_df['model'].isin(drop_ids)]

model_df

Unnamed: 0,model,model_cls,tfmodel_cls,training_set,training_stems,pretrained_encoder,loss,epoch
14,prl820re,FixedStemBandit,MambaTFModel,MUSDB18HQDatasource,vdbo,False,l1snr,99
17,23wth3uc,FixedStemBandit,MambaTFModel,MoisesDBDatasource,vdbo,False,l1snr,99
12,zune532k,FixedStemBandit,RNNSeqBandModellingModule,MUSDB18HQDatasource,vdbo,False,l1snr,99
18,9qspypvq,FixedStemBandit,RNNSeqBandModellingModule,MUSDB18HQDatasource,vdbo,False,l1snr,249
2,xzslzg26,FixedStemBandit,RNNSeqBandModellingModule,MoisesDBDatasource,vdbo,False,l1snr,99
19,3zkxq8iu,FixedStemBandit,RNNSeqBandModellingModule,MoisesDBDatasource,vdbo,False,l1snr,249
21,uhnr1app,FixedStemBandit,RNNSeqBandModellingModule,MoisesDBStemWiseDatasource,vdbgpwbpooo,False,l1snr,99
22,gpxb68eo,FixedStemBandit,RNNSeqBandModellingModule,MoisesDBStemWiseDatasource,vdbgpwbpooo,False,l1snr+dbm,99
23,h5fcofk0,FixedStemBandit,RNNSeqBandModellingModule,MoisesDBStemWiseDatasource,vdbgpwbpooo,False,l1snrz+dbm,99
29,1neek7gq,FixedStemBandit,RNNSeqBandModellingModule,MoisesDBStemWiseDatasource,vdbgpwbpooo,False,l1snrz+dbm,99


In [9]:
# model_df = model_df[model_df.training_set.str.contains("MoisesDB")]  

# model_df

In [10]:
df = df.merge(model_df, on='model', how='inner')

In [11]:
df.columns

Index(['model', 'test_set', 'full_path', 'metric', 'value', 'test_stem',
       'metric_name', 'model_cls', 'tfmodel_cls', 'training_set',
       'training_stems', 'pretrained_encoder', 'loss', 'epoch'],
      dtype='object')

In [12]:
dfg = df.groupby(id_cols
        + ['test_set', 'test_stem', 'metric_name']
).median(numeric_only=True).reset_index()

dfg = dfg[dfg.metric_name == "SignalNoiseRatio"].dropna(subset=['value'])

dfg

Unnamed: 0,model,model_cls,tfmodel_cls,training_set,training_stems,pretrained_encoder,loss,epoch,test_set,test_stem,metric_name,value
2,0l0r6gav,FixedStemBandit,RoFormerTFModel,MUSDB18HQDatasource,vdbo,False,l1snr,99,MUSDB18HQDatasource,bass,SignalNoiseRatio,5.516438
10,0l0r6gav,FixedStemBandit,RoFormerTFModel,MUSDB18HQDatasource,vdbo,False,l1snr,99,MUSDB18HQDatasource,drums,SignalNoiseRatio,7.731935
18,0l0r6gav,FixedStemBandit,RoFormerTFModel,MUSDB18HQDatasource,vdbo,False,l1snr,99,MUSDB18HQDatasource,other,SignalNoiseRatio,4.970463
38,0l0r6gav,FixedStemBandit,RoFormerTFModel,MUSDB18HQDatasource,vdbo,False,l1snr,99,MUSDB18HQDatasource,vocals,SignalNoiseRatio,7.193657
46,0l0r6gav,FixedStemBandit,RoFormerTFModel,MUSDB18HQDatasource,vdbo,False,l1snr,99,MoisesDBDatasource,bass,SignalNoiseRatio,8.854338
...,...,...,...,...,...,...,...,...,...,...,...,...
2018,zune532k,FixedStemBandit,RNNSeqBandModellingModule,MUSDB18HQDatasource,vdbo,False,l1snr,99,MUSDB18HQDatasource,vocals,SignalNoiseRatio,7.866929
2026,zune532k,FixedStemBandit,RNNSeqBandModellingModule,MUSDB18HQDatasource,vdbo,False,l1snr,99,MoisesDBDatasource,bass,SignalNoiseRatio,8.062817
2034,zune532k,FixedStemBandit,RNNSeqBandModellingModule,MUSDB18HQDatasource,vdbo,False,l1snr,99,MoisesDBDatasource,drums,SignalNoiseRatio,8.073868
2042,zune532k,FixedStemBandit,RNNSeqBandModellingModule,MUSDB18HQDatasource,vdbo,False,l1snr,99,MoisesDBDatasource,other,SignalNoiseRatio,5.141108


In [13]:
dfg.test_stem.unique()

array(['bass', 'drums', 'other', 'vocals', 'bowed_strings', 'guitar',
       'other_keys', 'other_plucked', 'percussion', 'piano', 'wind'],
      dtype=object)

In [14]:
stem_order = [
    'vocals', 'drums', 'bass', 'guitar', 'piano', 'wind', 'bowed_strings', 'percussion', 'other_keys', 'other_plucked', 'other'
]
dfg['test_stem'] = pd.Categorical(dfg['test_stem'], categories=stem_order, ordered=True)

dfg = dfg.dropna()

In [15]:
dfg.columns

Index(['model', 'model_cls', 'tfmodel_cls', 'training_set', 'training_stems',
       'pretrained_encoder', 'loss', 'epoch', 'test_set', 'test_stem',
       'metric_name', 'value'],
      dtype='object')

In [16]:
# dfg_vdbo = dfg[dfg.test_stem.isin(['vocals', 'drums', 'bass', 'other'])]
dfg_vdbo = dfg[dfg.test_set == "MUSDB18HQDatasource"]

In [17]:
dfgx = pd.pivot(
    dfg,
    index=['model_cls', 'tfmodel_cls', 'training_stems', 'pretrained_encoder', 'epoch', 'loss', 'training_set', 'model'],
    columns=['test_set', 'test_stem'],
    values='value'
).sort_index(
    axis=1,
    level=[0,1],
    ascending=[True, True]
).reset_index()

# print(dfgx.round(2).to_latex(index=False, float_format="%.2f"))

In [18]:
dfgx.to_csv(f"../results/summary_results_{timestamp}.csv", index=False)

In [19]:
dfgx.round(1)


test_set,model_cls,tfmodel_cls,training_stems,pretrained_encoder,epoch,loss,training_set,model,MUSDB18HQDatasource,MUSDB18HQDatasource,...,MoisesDBDatasource,MoisesDBDatasource,MoisesDBDatasource,MoisesDBDatasource,MoisesDBDatasource,MoisesDBDatasource,MoisesDBDatasource,MoisesDBDatasource,MoisesDBDatasource,MoisesDBDatasource
test_stem,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,vocals,drums,...,drums,bass,guitar,piano,wind,bowed_strings,percussion,other_keys,other_plucked,other
0,FixedStemBandit,MambaTFModel,vdbo,False,99,l1snr,MUSDB18HQDatasource,prl820re,5.8,4.8,...,5.9,6.4,,,,,,,,4.1
1,FixedStemBandit,MambaTFModel,vdbo,False,99,l1snr,MoisesDBDatasource,23wth3uc,5.3,4.3,...,7.6,7.7,,,,,,,,5.3
2,FixedStemBandit,RNNSeqBandModellingModule,vdbgpwbpooo,False,99,l1snr,MoisesDBStemWiseDatasource,uhnr1app,,,...,8.7,8.8,2.6,2.0,-61.3,-74.9,0.0,0.1,-73.2,0.0
3,FixedStemBandit,RNNSeqBandModellingModule,vdbgpwbpooo,False,99,l1snr+dbm,MoisesDBStemWiseDatasource,gpxb68eo,,,...,8.7,8.5,2.7,1.8,-78.4,-75.0,0.0,0.4,-77.1,0.0
4,FixedStemBandit,RNNSeqBandModellingModule,vdbgpwbpooo,False,99,l1snrz+dbm,MoisesDBStemWiseDatasource,1neek7gq,,,...,8.8,8.6,2.8,1.5,-63.0,-76.8,0.0,0.2,-74.8,0.0
5,FixedStemBandit,RNNSeqBandModellingModule,vdbgpwbpooo,False,99,l1snrz+dbm,MoisesDBStemWiseDatasource,h5fcofk0,,,...,8.7,8.6,2.7,1.7,-63.3,-76.2,0.1,0.2,-72.3,0.0
6,FixedStemBandit,RNNSeqBandModellingModule,vdbgpwbpooo,True,99,l1snr,MoisesDBStemWiseDatasource,69lnvw21,,,...,9.7,10.2,3.7,2.3,-55.5,-69.7,0.4,0.2,-62.3,0.0
7,FixedStemBandit,RNNSeqBandModellingModule,vdbgpwbpooo,True,99,l1snr+dbm,MoisesDBStemWiseDatasource,pui49sm7,,,...,9.6,9.8,3.7,2.4,-64.5,-74.5,0.2,0.6,-66.1,0.0
8,FixedStemBandit,RNNSeqBandModellingModule,vdbgpwbpooo,True,99,l1snrz+dbm,MoisesDBStemWiseDatasource,gtq2g6z9,,,...,9.7,10.3,3.6,2.2,-60.5,-72.8,0.1,0.5,-67.6,0.0
9,FixedStemBandit,RNNSeqBandModellingModule,vdbgpwbpooo,True,99,l1snrz+dbm,MoisesDBStemWiseDatasource,q601q181,,,...,9.8,10.0,3.7,2.3,-61.8,-76.3,0.3,0.5,-70.4,0.0
