In [1]:
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

from deem import deem # proposed module, deem is in short for 'de'-biasing 'em'beddings

In [2]:
classf_param = {'C': [10**(-6), 10**(-4), 10**(-2), 10**(0), 10**(2)], 
                'metric': 'roc_auc', 'cv': 3,
                'gamma': [10**(-6), 10**(-4), 10**(-2), 10**(0), 10**(2)]}  

In [3]:
def debias(irmas_train, irmas_test, openmic_train, openmic_test, openmic_mask):
    
    X_train, X_test = deem.projection(irmas_train, irmas_test)
    train, test = (X_train, irmas_train[1]), (X_test, irmas_test[1])
    irmas_irmas = deem.irmas_irmas(train, test)

    X_train, X_test = deem.projection(openmic_train, openmic_test)
    train, test = (X_train, openmic_train[1]), (X_test, openmic_test[1])
    openmic_openmic = deem.openmic_openmic(train, test, openmic_mask)
    
    X_train, X_test = deem.projection(irmas_train, openmic_test)
    train, test = (X_train, irmas_train[1]), (X_test, openmic_test[1])
    irmas_openmic = deem.irmas_openmic(train, test, openmic_mask)
    
    X_train, X_test = deem.projection(openmic_train, irmas_test)
    train, test = (X_train, openmic_train[1]), (X_test, irmas_test[1])
    openmic_irmas = deem.openmic_irmas(train, test)
    
    result = pd.concat([irmas_irmas, openmic_openmic, irmas_openmic, openmic_irmas], ignore_index=True)
    result['embedding'] = [deem.embedding + deem.debias_method] * len(result)

    print(deem.debias_method)
    return result

# VGGish

In [4]:
embedding = 'vggish'
deem = deem(embedding, classf_param)

irmas_train, irmas_test = deem.data_loader(dataset='irmas', data_root='')
openmic_train, openmic_test, openmic_mask = deem.data_loader(dataset='openmic', data_root='openmic-2018/')

Loading irmas data:


  0%|          | 0/6705 [00:00<?, ?it/s]

Loading openmic data:


  0%|          | 0/20000 [00:00<?, ?it/s]

In [5]:
deem.debias_method = ''  # '' = orginal, no debiasing
result_all = debias(irmas_train, irmas_test, openmic_train, openmic_test, openmic_mask)  # original




In [6]:
debias_methods = ['', '-lda', '-lda-genre', '-k', '-klda', '-klda-genre']

for method in debias_methods[1:]:
    deem.debias_method = method 
    result_all = result_all.append(debias(irmas_train, irmas_test, openmic_train, openmic_test, openmic_mask))

-lda
-lda-genre
-k
-klda
-klda-genre


# OpenL3

In [7]:
embedding = 'openl3'
deem.embedding = embedding

irmas_train, irmas_test = deem.data_loader(dataset='irmas', data_root='')
openmic_train, openmic_test, openmic_mask = deem.data_loader(dataset='openmic', data_root='openmic-2018/')

Loading irmas data:


  0%|          | 0/6705 [00:00<?, ?it/s]

Loading openmic data:


  0%|          | 0/20000 [00:00<?, ?it/s]

In [8]:
for method in debias_methods:
    deem.debias_method = method 
    result_all = result_all.append(debias(irmas_train, irmas_test, openmic_train, openmic_test, openmic_mask))


-lda
-lda-genre
-k
-klda
-klda-genre


# YAMNet

In [9]:
embedding = 'yamnet'
deem.embedding = embedding

irmas_train, irmas_test = deem.data_loader(dataset='irmas', data_root='')
openmic_train, openmic_test, openmic_mask = deem.data_loader(dataset='openmic', data_root='openmic-2018/')

Loading irmas data:


  0%|          | 0/6705 [00:00<?, ?it/s]

Loading openmic data:


  0%|          | 0/20000 [00:00<?, ?it/s]

In [10]:
for method in debias_methods:
    deem.debias_method = method 
    result_all = result_all.append(debias(irmas_train, irmas_test, openmic_train, openmic_test, openmic_mask))


-lda
-lda-genre
-k
-klda
-klda-genre


In [11]:
len(result_all)

720

In [12]:
result_all

Unnamed: 0,instrument,train_set,test_set,precision,recall,f1-score,support,accuracy,roc_auc,ap,embedding
0,cello,irmas,irmas,0.29454545454545455,0.8617021276595744,0.43902439024390244,94.0,0.8757503001200481,0.9384305126955769,0.6305397835914646,vggish
1,clarinet,irmas,irmas,0.27906976744186046,0.8503937007874016,0.4202334630350194,127.0,0.8211284513805522,0.9022015522913438,0.5982413610664756,vggish
2,flute,irmas,irmas,0.32413793103448274,0.7768595041322314,0.45742092457420924,121.0,0.8661464585834334,0.9055818556259861,0.6286379602499996,vggish
3,guitar,irmas,irmas,0.6519721577726219,0.8121387283236994,0.7232947232947232,346.0,0.8709483793517407,0.9351550183920125,0.8289173972627856,vggish
4,organ,irmas,irmas,0.45738636363636365,0.9252873563218391,0.612167300380228,174.0,0.8775510204081632,0.9679054574589382,0.7637821887905603,vggish
...,...,...,...,...,...,...,...,...,...,...,...
35,piano,openmic,irmas,0.22592592592592592,0.7176470588235294,0.3436619718309859,170.0,0.7202881152460985,0.8256016042780749,0.5704009098702709,yamnet-klda-genre
36,saxophone,openmic,irmas,0.1900369003690037,0.6866666666666666,0.2976878612716763,150.0,0.7082833133253301,0.7909014951627088,0.3382633894288481,yamnet-klda-genre
37,trumpet,openmic,irmas,0.2454212454212454,0.8933333333333333,0.3850574712643678,150.0,0.7430972388955582,0.8996437994722956,0.4910551759999126,yamnet-klda-genre
38,violin,openmic,irmas,0.1837748344370861,0.7762237762237763,0.2971887550200803,143.0,0.6848739495798319,0.8161615141260579,0.3465298066694271,yamnet-klda-genre


In [13]:
result_all.to_csv('results/result_all.csv', index=False)  