In [1]:
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

from deem import deem # proposed module, deem is in short for 'de'-biasing 'em'beddings

In [2]:
classf_param = {'C': [10**k for k in range(-14, 4, 1)], 
                'metric': 'roc_auc', 'cv': 3}  

In [3]:
def debias(irmas_train, irmas_test, openmic_train, openmic_test, openmic_mask):
    
    X_train, X_test = deem.projection(irmas_train, irmas_test)
    train, test = (X_train, irmas_train[1]), (X_test, irmas_test[1])
    irmas_irmas = deem.irmas_irmas(train, test)

    X_train, X_test = deem.projection(openmic_train, openmic_test)
    train, test = (X_train, openmic_train[1]), (X_test, openmic_test[1])
    openmic_openmic = deem.openmic_openmic(train, test, openmic_mask)
    
    X_train, X_test = deem.projection(irmas_train, openmic_test)
    train, test = (X_train, irmas_train[1]), (X_test, openmic_test[1])
    irmas_openmic = deem.irmas_openmic(train, test, openmic_mask)
    
    X_train, X_test = deem.projection(openmic_train, irmas_test)
    train, test = (X_train, openmic_train[1]), (X_test, irmas_test[1])
    openmic_irmas = deem.openmic_irmas(train, test)
    
    result = pd.concat([irmas_irmas, openmic_openmic, irmas_openmic, openmic_irmas], ignore_index=True)
    result['embedding'] = [deem.embedding + deem.debias_method] * len(result)

    print(deem.debias_method)
    return result

# VGGish

In [4]:
embedding = 'vggish'
deem = deem(embedding, classf_param)

irmas_train, irmas_test = deem.data_loader(dataset='irmas', data_root='')
openmic_train, openmic_test, openmic_mask = deem.data_loader(dataset='openmic', data_root='openmic-2018/')

Loading irmas data:


  0%|          | 0/6705 [00:00<?, ?it/s]

Loading openmic data:


  0%|          | 0/20000 [00:00<?, ?it/s]

In [5]:
deem.debias_method = ''  # '' = orginal, no debiasing
result_all = debias(irmas_train, irmas_test, openmic_train, openmic_test, openmic_mask)  # original




In [6]:
debias_methods = ['', '-lda', '-lda-genre', '-k', '-klda', '-klda-genre']

for method in debias_methods[1:]:
    deem.debias_method = method 
    result_all = result_all.append(debias(irmas_train, irmas_test, openmic_train, openmic_test, openmic_mask))

-lda
-lda-genre
-k
-klda
-klda-genre


# OpenL3

In [7]:
embedding = 'openl3'
deem.embedding = embedding

irmas_train, irmas_test = deem.data_loader(dataset='irmas', data_root='')
openmic_train, openmic_test, openmic_mask = deem.data_loader(dataset='openmic', data_root='openmic-2018/')

Loading irmas data:


  0%|          | 0/6705 [00:00<?, ?it/s]

Loading openmic data:


  0%|          | 0/20000 [00:00<?, ?it/s]

In [8]:
for method in debias_methods:
    deem.debias_method = method 
    result_all = result_all.append(debias(irmas_train, irmas_test, openmic_train, openmic_test, openmic_mask))


-lda
-lda-genre
-k
-klda
-klda-genre


# YAMNet

In [9]:
embedding = 'yamnet'
deem.embedding = embedding

irmas_train, irmas_test = deem.data_loader(dataset='irmas', data_root='')
openmic_train, openmic_test, openmic_mask = deem.data_loader(dataset='openmic', data_root='openmic-2018/')

Loading irmas data:


  0%|          | 0/6705 [00:00<?, ?it/s]

Loading openmic data:


  0%|          | 0/20000 [00:00<?, ?it/s]

In [10]:
for method in debias_methods:
    deem.debias_method = method 
    result_all = result_all.append(debias(irmas_train, irmas_test, openmic_train, openmic_test, openmic_mask))


-lda
-lda-genre
-k
-klda
-klda-genre


In [11]:
len(result_all)

720

In [12]:
result_all

Unnamed: 0,instrument,train_set,test_set,precision,recall,f1-score,support,accuracy,roc_auc,ap,embedding
0,cello,irmas,irmas,0.29454545454545455,0.8617021276595744,0.43902439024390244,94.0,0.8757503001200481,0.9384305126955769,0.6305397835914646,vggish
1,clarinet,irmas,irmas,0.29971181556195964,0.8188976377952756,0.43881856540084385,127.0,0.8403361344537815,0.895581034826787,0.5767983221841048,vggish
2,flute,irmas,irmas,0.32413793103448274,0.7768595041322314,0.45742092457420924,121.0,0.8661464585834334,0.9055818556259861,0.6286379602499996,vggish
3,guitar,irmas,irmas,0.6490825688073395,0.8179190751445087,0.7237851662404093,346.0,0.8703481392557023,0.9374146085128744,0.8363853171365732,vggish
4,organ,irmas,irmas,0.4923547400611621,0.9252873563218391,0.6427145708582834,174.0,0.8925570228091236,0.966337709161505,0.7493789197196196,vggish
...,...,...,...,...,...,...,...,...,...,...,...
35,piano,openmic,irmas,0.24125874125874125,0.8117647058823529,0.3719676549865229,170.0,0.7202881152460985,0.8591538219565901,0.6470052473201762,yamnet-klda-genre
36,saxophone,openmic,irmas,0.23308270676691728,0.62,0.3387978142076502,150.0,0.7821128451380552,0.8038258575197891,0.3298471810916173,yamnet-klda-genre
37,trumpet,openmic,irmas,0.2945205479452055,0.86,0.4387755102040817,150.0,0.801920768307323,0.911165347405453,0.6514027823276708,yamnet-klda-genre
38,violin,openmic,irmas,0.22156862745098038,0.7902097902097902,0.3460949464012251,143.0,0.7436974789915967,0.8511862398927401,0.44775732257410006,yamnet-klda-genre


In [13]:
result_all.to_csv('results/result_all.csv', index=False)  