In [15]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns  

import wandb

In [16]:
api = wandb.Api()

In [17]:
import json
from pathlib import Path

results = []

for run in api.runs("kwatcharasupat-gatech/banda"):
    if run.group != "test runs - completed":
        continue

    print(run.name, run.state)
    config = run.config

    ckpt = config['ckpt_path']
    test_model_id = ckpt.split("/")[-3]
    test_set = config['data']['test']['datasource'][0]['cls']
    
    print(test_model_id, test_set)

    save_path = Path(f"../results/{test_model_id}/{test_set}/detailed_results.csv")
    if save_path.exists():
        results.append(str(save_path))
        print(f"  {save_path} exists, skipping...")
        continue
    
    metrics = run.summary['test/metrics']
    table = run.logged_artifacts()[0]
    table_dir = table.download("../_artifacts")
    table_name = "test/metrics"
    table_path = f"{table_dir}/{table_name}.table.json"

    with open(table_path) as file:
        json_dict = json.load(file)

    df = pd.DataFrame(json_dict["data"], columns=json_dict["columns"])
    print(df.head())

    df['model'] = test_model_id
    df['test_set'] = test_set

    save_path.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(save_path, index=False)

    results.append(str(save_path))

bandit-mus64-l1snr-multi-adam-musdb18hq-vdbo-test-drums-1359 finished
9rg6syhm MUSDB18HQDatasource
  ../results/9rg6syhm/MUSDB18HQDatasource/detailed_results.csv exists, skipping...
bandit-mus64-l1snr-multi-adam-musdb18hq-vdbo-test-other-1359 finished
qtyn4z9v MUSDB18HQDatasource
  ../results/qtyn4z9v/MUSDB18HQDatasource/detailed_results.csv exists, skipping...
bandit-mus64-l1snr-multi-adam-moisesdb-vdbo-test-vdbo-1359 finished
xzslzg26 MoisesDBDatasource
  ../results/xzslzg26/MoisesDBDatasource/detailed_results.csv exists, skipping...
bandit-mus64-l1snr-multi-adam-moisesdb-vdbo-test-vocals-1359 finished
4vxw0l1g MoisesDBDatasource
  ../results/4vxw0l1g/MoisesDBDatasource/detailed_results.csv exists, skipping...
bandit-mus64-l1snr-multi-adam-musdb18hq-vdbo-test-bass-1359 finished
lhtpigg8 MUSDB18HQDatasource
  ../results/lhtpigg8/MUSDB18HQDatasource/detailed_results.csv exists, skipping...
bandit-mus64-l1snr-multi-adam-musdb18hq-vdbo-test-vocals-1359 finished
4vxw0l1g MUSDB18HQDatasour

[34m[1mwandb[0m:   1 of 1 files downloaded.  


                                           full_path  vocals/PredDecibel  \
0  /home/hice1/kwatchar3/Documents/data/musdb18hq...          -25.449699   
1  /home/hice1/kwatchar3/Documents/data/musdb18hq...          -24.294737   
2  /home/hice1/kwatchar3/Documents/data/musdb18hq...          -23.462811   
3  /home/hice1/kwatchar3/Documents/data/musdb18hq...          -25.938585   
4  /home/hice1/kwatchar3/Documents/data/musdb18hq...          -29.722309   

   vocals/SNR  vocals/SignalNoiseRatio  vocals/TargetDecibel  \
0    7.896120                 7.896120            -24.617992   
1    5.011571                 5.011571            -22.769804   
2    7.731605                 7.731605            -22.894981   
3    6.334326                 6.334326            -25.338280   
4    1.516020                 1.516020            -27.117050   

   drums/PredDecibel  drums/SNR  drums/SignalNoiseRatio  drums/TargetDecibel  \
0         -24.503357   4.066603                4.066603           -22.046003  

[34m[1mwandb[0m:   1 of 1 files downloaded.  


                                           full_path  vocals/PredDecibel  \
0  /home/hice1/kwatchar3/Documents/data/moisesdb/...          -33.811985   
1  /home/hice1/kwatchar3/Documents/data/moisesdb/...          -32.660110   
2  /home/hice1/kwatchar3/Documents/data/moisesdb/...          -27.798347   
3  /home/hice1/kwatchar3/Documents/data/moisesdb/...          -36.665211   
4  /home/hice1/kwatchar3/Documents/data/moisesdb/...          -25.120926   

   vocals/SNR  vocals/SignalNoiseRatio  vocals/TargetDecibel  \
0    4.060677                 4.060677            -33.488358   
1    6.353225                 6.353225            -32.111732   
2   11.459699                11.459699            -27.513449   
3    6.001208                 6.001208            -35.463284   
4    8.051785                 8.051785            -24.466866   

   drums/PredDecibel  drums/SNR  drums/SignalNoiseRatio  drums/TargetDecibel  \
0         -30.635454   5.421061                5.421061           -30.736063  

In [18]:
df = pd.concat([pd.read_csv(f) for f in results], ignore_index=True)

timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
# df.to_csv(f"../results/combined_detailed_results_{timestamp}.csv", index=False)

df = pd.melt(df, id_vars=['model', 'test_set', 'full_path'], var_name='metric', value_name='value')
df['test_stem'] = df['metric'].apply(lambda x: x.split('/')[0])
df['metric_name'] = df['metric'].apply(lambda x: x.split('/')[-1])


In [19]:
from pprint import pprint

In [20]:
unique_models = df['model'].unique()

model_df = []


for test_model_id in unique_models:
    run = api.run(f"kwatcharasupat-gatech/banda/{test_model_id}")
    config = run.config

    assert run.summary['epoch'] == 99

    # print(run.name)

    model_cls = config['model']['cls']
    tfmodel_config = config['model']['params']['tf_model']
    tfmodel_config_dict = eval(tfmodel_config)
    tfmodel_cls = tfmodel_config_dict['cls']


    # print(" Model class:", model_cls)
    # print(" TF Model class:", tfmodel_cls) 

    stems = config['model']['params']['stems']
    if len(stems) == 1:
        stems = stems[0]
        model_cls = "Single" + model_cls
    else:
        stems = "/".join(stems)

    training_set = config['data']['train']['datasource'][0]['cls']

    # print(" Stems:", stems)

    model_df.append({
        'model': test_model_id,
        'model_cls': model_cls,
        'tfmodel_cls': tfmodel_cls,
        'training_set': training_set,
        # 'stems': stems
    })

model_df = pd.DataFrame(model_df)

In [21]:
model_df = model_df.sort_values(['model_cls', 'tfmodel_cls', 'training_set'])[['model', 'model_cls', 'tfmodel_cls', 'training_set']]

drop_ids = ['8akgy5kl', '6932jblk']

model_df = model_df[~model_df['model'].isin(drop_ids)]

model_df

Unnamed: 0,model,model_cls,tfmodel_cls,training_set
14,prl820re,FixedStemBandit,MambaTFModel,MUSDB18HQDatasource
17,23wth3uc,FixedStemBandit,MambaTFModel,MoisesDBDatasource
12,zune532k,FixedStemBandit,RNNSeqBandModellingModule,MUSDB18HQDatasource
2,xzslzg26,FixedStemBandit,RNNSeqBandModellingModule,MoisesDBDatasource
11,0l0r6gav,FixedStemBandit,RoFormerTFModel,MUSDB18HQDatasource
9,3i6i6v60,FixedStemBandit,RoFormerTFModel,MoisesDBDatasource
0,9rg6syhm,SingleFixedStemBandit,RNNSeqBandModellingModule,MUSDB18HQDatasource
4,lhtpigg8,SingleFixedStemBandit,RNNSeqBandModellingModule,MUSDB18HQDatasource
5,2tx9mc4f,SingleFixedStemBandit,RNNSeqBandModellingModule,MUSDB18HQDatasource
13,gbsvyktp,SingleFixedStemBandit,RNNSeqBandModellingModule,MUSDB18HQDatasource


In [22]:
df = df.merge(model_df, on='model', how='inner')

In [23]:
df.columns

Index(['model', 'test_set', 'full_path', 'metric', 'value', 'test_stem',
       'metric_name', 'model_cls', 'tfmodel_cls', 'training_set'],
      dtype='object')

In [24]:
dfg = df.groupby(
    [
        'model', 'model_cls', 'tfmodel_cls', 'training_set',
        'test_set', 'test_stem', 'metric_name'
    ]
).median(numeric_only=True).reset_index()

dfg = dfg[dfg.metric_name == "SignalNoiseRatio"]

dfg

Unnamed: 0,model,model_cls,tfmodel_cls,training_set,test_set,test_stem,metric_name,value
2,0l0r6gav,FixedStemBandit,RoFormerTFModel,MUSDB18HQDatasource,MUSDB18HQDatasource,bass,SignalNoiseRatio,5.516438
6,0l0r6gav,FixedStemBandit,RoFormerTFModel,MUSDB18HQDatasource,MUSDB18HQDatasource,drums,SignalNoiseRatio,7.731935
10,0l0r6gav,FixedStemBandit,RoFormerTFModel,MUSDB18HQDatasource,MUSDB18HQDatasource,other,SignalNoiseRatio,4.970463
14,0l0r6gav,FixedStemBandit,RoFormerTFModel,MUSDB18HQDatasource,MUSDB18HQDatasource,vocals,SignalNoiseRatio,7.193657
18,0l0r6gav,FixedStemBandit,RoFormerTFModel,MUSDB18HQDatasource,MoisesDBDatasource,bass,SignalNoiseRatio,8.854338
...,...,...,...,...,...,...,...,...
494,zune532k,FixedStemBandit,RNNSeqBandModellingModule,MUSDB18HQDatasource,MUSDB18HQDatasource,vocals,SignalNoiseRatio,7.866929
498,zune532k,FixedStemBandit,RNNSeqBandModellingModule,MUSDB18HQDatasource,MoisesDBDatasource,bass,SignalNoiseRatio,8.062817
502,zune532k,FixedStemBandit,RNNSeqBandModellingModule,MUSDB18HQDatasource,MoisesDBDatasource,drums,SignalNoiseRatio,8.073868
506,zune532k,FixedStemBandit,RNNSeqBandModellingModule,MUSDB18HQDatasource,MoisesDBDatasource,other,SignalNoiseRatio,5.141108


In [25]:
stem_order = [
    'vocals', 'drums', 'bass', 'other',
]
dfg['test_stem'] = pd.Categorical(dfg['test_stem'], categories=stem_order, ordered=True)

dfg = dfg.dropna()

In [26]:
dfgx = pd.pivot(
    dfg,
    index=['model_cls', 'tfmodel_cls'],
    columns=['training_set', 'test_set', 'test_stem'],
    values='value'
).sort_index(
    axis=1,
    level=[0,1,2],
    ascending=[True, True, True]
).reset_index()

print(dfgx.round(2).to_latex(index=False, float_format="%.2f"))

\begin{tabular}{llrrrrrrrrrrrrrrrr}
\toprule
model_cls & tfmodel_cls & \multicolumn{8}{r}{MUSDB18HQDatasource} & \multicolumn{8}{r}{MoisesDBDatasource} \\
 &  & \multicolumn{4}{r}{MUSDB18HQDatasource} & \multicolumn{4}{r}{MoisesDBDatasource} & \multicolumn{4}{r}{MUSDB18HQDatasource} & \multicolumn{4}{r}{MoisesDBDatasource} \\
 &  & vocals & drums & bass & other & vocals & drums & bass & other & vocals & drums & bass & other & vocals & drums & bass & other \\
\midrule
FixedStemBandit & MambaTFModel & 5.81 & 4.83 & 3.02 & 3.86 & 6.20 & 5.94 & 6.38 & 4.11 & 5.32 & 4.31 & 2.91 & 3.74 & 6.97 & 7.59 & 7.72 & 5.29 \\
FixedStemBandit & RNNSeqBandModellingModule & 7.87 & 6.78 & 4.57 & 4.66 & 7.68 & 8.07 & 8.06 & 5.14 & 7.71 & 6.42 & 4.88 & 4.57 & 8.49 & 9.21 & 9.17 & 6.10 \\
FixedStemBandit & RoFormerTFModel & 7.19 & 7.73 & 5.52 & 4.97 & 7.54 & 8.68 & 8.85 & 5.29 & 7.48 & 7.10 & 5.20 & 4.76 & 8.40 & 9.33 & 9.60 & 6.29 \\
SingleFixedStemBandit & RNNSeqBandModellingModule & 8.03 & 6.68 & 5.24 & 4