In [1]:
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.linear_model import LogisticRegressionCV
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


In [2]:
path="/home/ftamagnan/dataset/"
name="total_metrics_training.npz"
data=dict(np.load(path+name))

data["random"]= np.random.rand(6729,3)


In [3]:
def feature_selection(scaler=True,cv=True,list_list_label=[],penalty='l2',stats=True):

    
    
    
    for list_label in list_list_label:
            list_x=[]
            for key in data.keys():
                if key in list_label:
                    list_x.append(data[key])
            X=np.concatenate(list_x,axis=1)
            y=data['fills'][:,1].reshape(-1)
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)
            if scaler:
                scaler = StandardScaler()
                scaler.fit(X_train)
                X_train=scaler.transform(X_train)
                X_test=scaler.transform(X_test)
            
            if cv:
                clf = LogisticRegressionCV(cv=2, random_state=0,
                                   multi_class='ovr',penalty=penalty,solver='liblinear',max_iter=300,n_jobs=-1).fit(X_train, y_train)
            else:
                clf = LogisticRegression(random_state=0,C=100000000).fit(X_train, y_train)

            y_pred=clf.predict(X_test)
            tn, fp, fn, tp=confusion_matrix(y_test, y_pred).ravel()
            if stats:
                print("__________Features used : "+str(list_label)+"_______")
                print("tn,fp,fn,tp = ",tn,fp,fn,tp)
                print("Accuracy = ",(tp+tn)/(tn+fp+fn+tp))
                print("Recall = ",(tp)/(fn+tp))
                print("Precision = ",(tp)/(fp+tp))

    return clf



# 1.Feature selection

In [4]:
list_list_label=[['vae_embeddings','offbeat_notes','drums_pitches_used','velocity_metrics'],
               ['offbeat_notes','drums_pitches_used','velocity_metrics'],
               ['vae_embeddings'],
               ['offbeat_notes'],
               ['drums_pitches_used'],
               ['velocity_metrics'],
                 ['drums_pitches_used','velocity_metrics'],
                 ['random']
]

clf=feature_selection(scaler=True,cv=True,list_list_label=list_list_label)

__________Features used : ['vae_embeddings', 'offbeat_notes', 'drums_pitches_used', 'velocity_metrics']_______
tn,fp,fn,tp =  1577 38 62 342
Accuracy =  0.9504705299653293
Recall =  0.8465346534653465
Precision =  0.9
__________Features used : ['offbeat_notes', 'drums_pitches_used', 'velocity_metrics']_______
tn,fp,fn,tp =  1567 48 67 337
Accuracy =  0.9430411094601288
Recall =  0.8341584158415841
Precision =  0.8753246753246753
__________Features used : ['vae_embeddings']_______
tn,fp,fn,tp =  1601 14 365 39
Accuracy =  0.8122833085685983
Recall =  0.09653465346534654
Precision =  0.7358490566037735
__________Features used : ['offbeat_notes']_______
tn,fp,fn,tp =  1614 1 404 0
Accuracy =  0.799405646359584
Recall =  0.0
Precision =  0.0
__________Features used : ['drums_pitches_used']_______
tn,fp,fn,tp =  1525 90 171 233
Accuracy =  0.8707280832095097
Recall =  0.5767326732673267
Precision =  0.7213622291021672
__________Features used : ['velocity_metrics']_______
tn,fp,fn,tp =  1551



In [5]:
list_list_label=[['vae_embeddings','offbeat_notes','drums_pitches_used','velocity_metrics']]
             

clf_l1=feature_selection(scaler=True,cv=True,list_list_label=list_list_label,penalty='l1',stats=False)
clf_l2=feature_selection(scaler=True,cv=True,list_list_label=list_list_label,penalty='l2',stats=False)

In [6]:
name_pitches = ['bass drum','snare drum','closed hi-hat','open hi-hat','low tom','mid tom','high tom','crash cymbal','ride cymbal']
name_features=['min_velocity','max_velocity','std_velocity','max_velocity']

In [7]:
def stats_weights(clf):
    coef=clf.coef_
    coef=coef.reshape(-1)
    print("------VAE EMBEDDINGS-------")
    print(coef[0:32])

    print("------OFFBEATS NOTES-------")
    print(coef[32])

    print("------VELOCITY METRICS-------")
    for i,pitch in enumerate(name_pitches):
        for j,metric in enumerate(name_features):
            print(metric+' of '+pitch,coef[33+i+j])
        
    print("------PITCHES USED-------")
    for i,pitch in enumerate(name_pitches):
        print('use of '+pitch,coef[33+36+i])

# 2. Magnitude of weights with L1 reg

In [8]:
stats_weights(clf_l1)

------VAE EMBEDDINGS-------
[ 0.04384562 -0.10273759  0.04127356  0.27254441  0.28161399  0.11470882
  0.28008234  0.109668   -0.17435904  0.04193808  0.05404063  0.36916974
  0.09312229 -0.23717745  0.43223129  0.02199915 -0.29484831  0.00865722
 -0.3605209  -0.07537753 -0.13489147  0.07119396  0.08821702 -0.22635638
  0.00925412  0.0926016   0.24566683  0.08003343 -0.1735406  -0.28832851
  0.1522567   0.07934086]
------OFFBEATS NOTES-------
0.0
------VELOCITY METRICS-------
min_velocity of bass drum 0.0
max_velocity of bass drum 0.0
std_velocity of bass drum 0.0
max_velocity of bass drum 0.0
min_velocity of snare drum 0.0
max_velocity of snare drum 0.0
std_velocity of snare drum 0.0
max_velocity of snare drum 0.0
min_velocity of closed hi-hat 0.0
max_velocity of closed hi-hat 0.0
std_velocity of closed hi-hat 0.0
max_velocity of closed hi-hat 0.0
min_velocity of open hi-hat 0.0
max_velocity of open hi-hat 0.0
std_velocity of open hi-hat 0.0
max_velocity of open hi-hat 0.0
min_velocit

# 3.Magnitude of weights with L2 reg

In [9]:
stats_weights(clf_l2)

------VAE EMBEDDINGS-------
[ 0.04579893 -0.11510138  0.03998235  0.2952678   0.29363009  0.11867883
  0.28924514  0.11314356 -0.18448098  0.04871108  0.04720873  0.3909745
  0.09534698 -0.24245628  0.43758741  0.03825606 -0.29471778  0.0132464
 -0.36721753 -0.07634466 -0.14393415  0.07250993  0.09211486 -0.23854252
  0.01521906  0.10318535  0.25537608  0.08939967 -0.17232741 -0.30583347
  0.15840271  0.08290737]
------OFFBEATS NOTES-------
0.0
------VELOCITY METRICS-------
min_velocity of bass drum 0.0
max_velocity of bass drum 0.0
std_velocity of bass drum 0.0
max_velocity of bass drum 0.0
min_velocity of snare drum 0.0
max_velocity of snare drum 0.0
std_velocity of snare drum 0.0
max_velocity of snare drum 0.0
min_velocity of closed hi-hat 0.0
max_velocity of closed hi-hat 0.0
std_velocity of closed hi-hat 0.0
max_velocity of closed hi-hat 0.0
min_velocity of open hi-hat 0.0
max_velocity of open hi-hat 0.0
std_velocity of open hi-hat 0.0
max_velocity of open hi-hat 0.0
min_velocity 