In [1]:
import numpy as np

from scipy.sparse import load_npz

In [2]:
from sklearn.model_selection import KFold, GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn import metrics

In [3]:
from imblearn.over_sampling import RandomOverSampler

### Reading data

In [4]:
Xtrain = load_npz('features_silico_duplicated.npz')
Xtrain

<151627x1676 sparse matrix of type '<class 'numpy.float64'>'
	with 1901453 stored elements in Compressed Sparse Row format>

In [5]:
Ytrain = np.load('classes_silico_duplicated.npy')
Ytrain.shape

(151627, 71)

In [8]:
def predict(clf, params, X, Y, n_splits=5):
    accuracytrain = []
    accuracytest = []
    hammingtrain = []
    hammingtest = []
    f1train = []
    f1test = []
    precisiontrain = []
    precisiontest = []
    recalltrain = []
    recalltest = []

    kf = KFold(n_splits=n_splits)
    for train_idx, test_idx in kf.split(X):
        Xtrain = X[train_idx]
        Xtest = X[test_idx]
        Ytrain = Y[train_idx]
        Ytest = Y[test_idx]

        Ytrainpred_arr = []
        Ytestpred_arr = []

        for i in range(71):
            
            ytrain = Ytrain[:,i]
            ytest = Ytest[:,i]
            
            if np.unique(ytrain).size == 1:
                ytrainpred = np.full(ytrain.shape, ytrain[0])
                ytestpred = np.full(ytest.shape, ytrain[0])
            else:
                ros = RandomOverSampler()
                Xtrain_resampled, ytrain_resampled = ros.fit_sample(Xtrain, ytrain)
                
                gscv = GridSearchCV(clf, params, scoring='accuracy', cv=5, n_jobs=-1)
                gscv.fit(Xtrain_resampled, ytrain_resampled)
                print(gscv.best_score_)
                print(gscv.best_params_)
            
            ytrainpred = gscv.predict(Xtrain)
            ytestpred = gscv.predict(Xtest)
            Ytrainpred_arr.append(ytrainpred)
            Ytestpred_arr.append(ytestpred)

        Ytrainpred = np.array(Ytrainpred_arr).T
        Ytestpred = np.array(Ytestpred_arr).T

        Ytrain = (Ytrain==1).astype('int')
        Ytest = (Ytest==1).astype('int')
        Ytrainpred = (Ytrainpred==1).astype('int')
        Ytestpred = (Ytestpred==1).astype('int')

        accuracytrain.append(metrics.accuracy_score(Ytrain, Ytrainpred))
        accuracytest.append(metrics.accuracy_score(Ytest, Ytestpred))
        hammingtrain.append(1 - metrics.hamming_loss(Ytrain, Ytrainpred))
        hammingtest.append(1 - metrics.hamming_loss(Ytest, Ytestpred))
        f1train.append(metrics.f1_score(Ytrain, Ytrainpred, average='micro'))
        f1test.append(metrics.f1_score(Ytest, Ytestpred, average='micro'))
        precisiontrain.append(metrics.precision_score(Ytrain, Ytrainpred, average='micro'))
        precisiontest.append(metrics.precision_score(Ytest, Ytestpred, average='micro'))
        recalltrain.append(metrics.recall_score(Ytrain, Ytrainpred, average='micro'))
        recalltest.append(metrics.recall_score(Ytest, Ytestpred, average='micro'))

    print('Accuracy: \t \t {} \t {}'.format(np.array(accuracytrain).mean(), np.array(accuracytest).mean()))
    print('Hamming: \t \t {} \t {}'.format(np.array(hammingtrain).mean(), np.array(hammingtest).mean()))
    print('Precision: \t \t {} \t {}'.format(np.array(precisiontrain).mean(), np.array(precisiontest).mean()))
    print('Recall: \t \t {} \t {}'.format(np.array(recalltrain).mean(), np.array(recalltest).mean()))
    print('F1: \t \t \t {} \t {}'.format(np.array(f1train).mean(), np.array(f1test).mean()))

In [9]:
params = {'C': [1.0, 10.0, 100.0]}

In [10]:
predict(LogisticRegression(penalty='l2', solver='sag', max_iter=1000, n_jobs=-1), params, Xtrain, Ytrain)



0.8725704336815447
{'C': 1.0}




0.9023355319992276
{'C': 1.0}




0.891727722729954
{'C': 1.0}




0.9079996351363678
{'C': 1.0}




0.9211894880110229
{'C': 1.0}




0.9311174197618485
{'C': 1.0}




0.9331387467380622
{'C': 1.0}




0.9951283188860829
{'C': 1.0}




0.9707711177928637
{'C': 1.0}




0.9760049500725442
{'C': 1.0}




0.9492570212124006
{'C': 1.0}




0.94001779510211
{'C': 1.0}




0.9415309542007173
{'C': 1.0}




0.9825335842576773
{'C': 1.0}




0.954129218246566
{'C': 1.0}




0.8967614319073205
{'C': 1.0}




0.8850760408032075
{'C': 1.0}




0.9978755353821477
{'C': 1.0}




0.9436777321894816
{'C': 1.0}




0.915547189184564
{'C': 1.0}




0.9752998129607218
{'C': 1.0}




0.9466134983930484
{'C': 1.0}




0.9469675990030463
{'C': 1.0}




0.9426055358362351
{'C': 1.0}




0.8928458965608218
{'C': 1.0}




0.8879432806188814
{'C': 1.0}




0.884371301216349
{'C': 1.0}




0.9770777739429091
{'C': 1.0}




0.8844754960837717
{'C': 1.0}




0.8998485950636915
{'C': 1.0}




0.8901427738188826
{'C': 1.0}




0.8963112853589646
{'C': 1.0}




0.9134436427606252
{'C': 1.0}




0.906672624869313
{'C': 1.0}




0.8875832134829525
{'C': 10.0}




0.8677472839528791
{'C': 1.0}




0.9001698520941418
{'C': 1.0}




0.9180931809061958
{'C': 1.0}




0.8531185761054386
{'C': 1.0}




0.8760931715256753
{'C': 1.0}




0.9974229842852571
{'C': 1.0}




0.8948895166986177
{'C': 1.0}




0.8953630206816195
{'C': 1.0}




0.9325180040056249
{'C': 1.0}




0.9217291036300426
{'C': 1.0}




0.9561364060617374
{'C': 1.0}




0.9194538902705172
{'C': 1.0}




0.9199736120978382
{'C': 1.0}




0.8611096678097613
{'C': 1.0}




0.9155704799656627
{'C': 1.0}




0.8788637325960796
{'C': 1.0}




0.8626315071529661
{'C': 1.0}




0.8978050626524262
{'C': 1.0}




0.906014001670556
{'C': 1.0}




0.9595804219380831
{'C': 1.0}




0.9631654773238523
{'C': 1.0}




0.901847205617391
{'C': 1.0}




0.8617737271054245
{'C': 1.0}




0.7767488370070866
{'C': 1.0}




0.9109704641350211
{'C': 1.0}




0.9778868218498353
{'C': 1.0}




0.9243338718723578
{'C': 1.0}




0.9320237963522473
{'C': 1.0}




0.9615349025018086
{'C': 1.0}




0.9213444768667668
{'C': 1.0}




0.9643839907880143
{'C': 1.0}




0.9405425920240167
{'C': 1.0}




0.9367333818329471
{'C': 1.0}




0.9200158989163634
{'C': 1.0}




0.8695744882486732
{'C': 1.0}




0.8958115236459361
{'C': 1.0}




0.8871590841103922
{'C': 1.0}




0.9045352922389013
{'C': 1.0}




0.9175082760590835
{'C': 1.0}




0.9290101282680311
{'C': 10.0}




0.9299272422350778
{'C': 1.0}




0.9941086633790079
{'C': 1.0}




0.9700401168584497
{'C': 1.0}




0.9749955196750271
{'C': 1.0}




0.9504201469467087
{'C': 1.0}




0.9372109373676453
{'C': 1.0}




0.9382037567279353
{'C': 1.0}




0.9810493730815254
{'C': 1.0}




0.9490780791477957
{'C': 1.0}




0.895126517644389
{'C': 1.0}




0.8832894948281841
{'C': 1.0}




0.9978413890192398
{'C': 1.0}




0.9400971466354713
{'C': 1.0}




0.9130272952853598
{'C': 1.0}




0.9746657075152336
{'C': 1.0}




0.941090940011054
{'C': 1.0}




0.9463719220504557
{'C': 1.0}




0.9337059894779441
{'C': 1.0}




0.8878591367102396
{'C': 1.0}




0.8878353515040007
{'C': 1.0}




0.8812481117190909
{'C': 1.0}




0.9747497706055173
{'C': 1.0}




0.8810440048377411
{'C': 1.0}




0.8957055733655888
{'C': 1.0}




0.8848179712203453
{'C': 1.0}




0.891650853889943
{'C': 1.0}




0.9126936905541059
{'C': 1.0}




0.8966666947673684
{'C': 1.0}




0.8826034901548692
{'C': 1.0}




0.8655281009969698
{'C': 1.0}




0.8983390101341407
{'C': 1.0}




0.9158002117052516
{'C': 1.0}




0.8467952080407452
{'C': 1.0}




0.8721263640810268
{'C': 1.0}




0.9971788048238227
{'C': 1.0}




0.8917673569945094
{'C': 1.0}




0.8977255137625538
{'C': 1.0}




0.9300394684124833
{'C': 1.0}




0.9195154719597816
{'C': 1.0}




0.9549234581310886
{'C': 1.0}




0.9190860487936927
{'C': 1.0}




0.9199081474715181
{'C': 1.0}




0.8539450335425233
{'C': 1.0}




0.9100850802413553
{'C': 1.0}




0.8767204114081288
{'C': 1.0}




0.8624435545710147
{'C': 1.0}




0.8941409995458139
{'C': 1.0}




0.9072964034725093
{'C': 1.0}




0.957030262159656
{'C': 1.0}




0.9600446522599941
{'C': 1.0}




0.9013082570043104
{'C': 1.0}




0.8572871382777851
{'C': 1.0}




0.7752670786936832
{'C': 1.0}




0.9056929293951895
{'C': 1.0}




0.9763666534029818
{'C': 1.0}




0.9231612784409583
{'C': 1.0}




0.9289749283241027
{'C': 1.0}




0.9517818383726213
{'C': 1.0}




0.9161563765182186
{'C': 1.0}




0.96216509162348
{'C': 1.0}




0.9345734609283046
{'C': 1.0}




0.9321985851611156
{'C': 1.0}




0.9135569122689637
{'C': 1.0}




0.8634096522895687
{'C': 1.0}




0.8954971435579311
{'C': 1.0}




0.8868525277626776
{'C': 1.0}




0.9035541328080111
{'C': 1.0}




0.9139716102944274
{'C': 1.0}




0.9263600369769591
{'C': 1.0}




0.9310387210524531
{'C': 1.0}




0.9941975538375654
{'C': 1.0}




0.9556031273268801
{'C': 1.0}




0.972321102398328
{'C': 1.0}




0.9423978005213703
{'C': 1.0}




0.9393038833305071
{'C': 1.0}




0.93596595158268
{'C': 1.0}




0.9795262325517431
{'C': 1.0}




0.948553601332914
{'C': 1.0}




0.8921541768523982
{'C': 1.0}




0.8940594391801724
{'C': 1.0}




0.9976021111571611
{'C': 1.0}




0.9396984076056671
{'C': 1.0}




0.9195685054892764
{'C': 1.0}




0.9675780691499651
{'C': 1.0}




0.9466375811174589
{'C': 1.0}




0.9366833960685905
{'C': 1.0}




0.9370514892441154
{'C': 1.0}




0.8884395511327546
{'C': 1.0}




0.8853766881398337
{'C': 1.0}




0.8784967008705687
{'C': 1.0}




0.975555648395722
{'C': 1.0}




0.8961115260980134
{'C': 1.0}




0.8960768473403853
{'C': 1.0}




0.8719640482722636
{'C': 1.0}




0.8847738870610354
{'C': 1.0}




0.9093510643694097
{'C': 1.0}




0.8872426399846065
{'C': 1.0}




0.8725271997329247
{'C': 1.0}




0.8597401184247513
{'C': 1.0}




0.896572934973638
{'C': 1.0}




0.9066506036362403
{'C': 1.0}




0.8443376726370762
{'C': 1.0}




0.8649310717664549
{'C': 1.0}




0.9973948298681499
{'C': 1.0}




0.8966917268082188
{'C': 1.0}




0.8850964091483401
{'C': 1.0}




0.9202832032910059
{'C': 1.0}




0.9134161631288419
{'C': 1.0}




0.9537365496038784
{'C': 1.0}




0.9063119935442655
{'C': 1.0}




0.906540210275072
{'C': 1.0}




0.8322057708218422
{'C': 1.0}




0.9044301042869817
{'C': 1.0}




0.8662796092310547
{'C': 1.0}




0.8435849815564476
{'C': 1.0}




0.8835320385895438
{'C': 1.0}




0.8965405963598813
{'C': 1.0}




0.9569080389974024
{'C': 1.0}




0.9616246734642285
{'C': 1.0}




0.8974985383780172
{'C': 1.0}




0.8580497491455131
{'C': 1.0}




0.755154099733138
{'C': 1.0}




0.9054749094846547
{'C': 1.0}




0.9758673908276688
{'C': 1.0}




0.9195671678084202
{'C': 1.0}




0.9281434318638982
{'C': 1.0}




0.9629247552541014
{'C': 1.0}




0.916262967002308
{'C': 1.0}




0.9625946929798691
{'C': 1.0}




0.9356560724129718
{'C': 1.0}




0.9370895484396478
{'C': 1.0}




0.9097257256247268
{'C': 1.0}




0.8564806287222938
{'C': 1.0}




0.8880895567852972
{'C': 10.0}




0.8835184913856258
{'C': 1.0}




0.9002110491508001
{'C': 1.0}




0.908996910608585
{'C': 1.0}




0.925572734945315
{'C': 1.0}




0.932928318218944
{'C': 1.0}




0.9940793368857312
{'C': 1.0}




0.9802128339189348
{'C': 1.0}




0.9694716688314975
{'C': 1.0}




0.9405700500958208
{'C': 1.0}




0.939662723554074
{'C': 1.0}




0.9311784602667161
{'C': 1.0}




0.9748407424578032
{'C': 1.0}




0.9427217988842019
{'C': 1.0}




0.8890341394740812
{'C': 1.0}




0.8917620596523208
{'C': 1.0}




0.9972511206969467
{'C': 1.0}




0.9496398090072704
{'C': 1.0}




0.9314255683455677
{'C': 1.0}




0.9433459822914586
{'C': 1.0}




0.9487015625264512
{'C': 1.0}




0.9176464726982861
{'C': 1.0}




0.9482464604788128
{'C': 1.0}




0.8907663157894737
{'C': 1.0}




0.8889857993276485
{'C': 1.0}




0.8751599820487675
{'C': 1.0}




0.9717094002891374
{'C': 1.0}




0.8924309951251327
{'C': 1.0}




0.903576326175664
{'C': 1.0}




0.8714260619343182
{'C': 1.0}




0.8821580572477676
{'C': 1.0}




0.9101717360865091
{'C': 1.0}




0.8956907118497086
{'C': 1.0}




0.8779229477896665
{'C': 1.0}




0.8648192355520056
{'C': 1.0}




0.8937045924148503
{'C': 1.0}




0.8881328045899426
{'C': 1.0}




0.8426500773941719
{'C': 1.0}




0.8661170518764171
{'C': 1.0}




0.9975800169067311
{'C': 1.0}




0.8869502589465621
{'C': 1.0}




0.8785166706120145
{'C': 1.0}




0.9083570629512865
{'C': 1.0}




0.9070431499742792
{'C': 1.0}




0.9472685504636564
{'C': 1.0}




0.8950482560600151
{'C': 1.0}




0.9051056133539085
{'C': 1.0}




0.827412252575777
{'C': 1.0}




0.9028900918253824
{'C': 1.0}




0.8604518581081081
{'C': 1.0}




0.8355490287009841
{'C': 1.0}




0.8729269650599767
{'C': 1.0}




0.893408625472016
{'C': 1.0}




0.9540415028031127
{'C': 1.0}




0.9613178838951311
{'C': 1.0}




0.9177593809801148
{'C': 1.0}




0.8637254036883438
{'C': 1.0}




0.740135139854035
{'C': 1.0}




0.9054735234215886
{'C': 1.0}




0.9738249648993782
{'C': 1.0}




0.9179262243233346
{'C': 1.0}




0.9246380261692414
{'C': 1.0}




0.9604561438594736
{'C': 1.0}




0.9128285568933356
{'C': 1.0}




0.9628292518724227
{'C': 1.0}




0.9371845252210441
{'C': 1.0}




0.9380431914332167
{'C': 1.0}




0.9065973013341434
{'C': 1.0}




0.8534219807454008
{'C': 1.0}




0.9010384226097788
{'C': 1.0}




0.8987412794237896
{'C': 1.0}




0.909132331212896
{'C': 1.0}




0.9170863400537407
{'C': 1.0}




0.9250788912579957
{'C': 1.0}




0.931183254661712
{'C': 1.0}




0.993958273936304
{'C': 1.0}




0.9517043065902097
{'C': 1.0}




0.9732548776191514
{'C': 1.0}




0.936980125989285
{'C': 1.0}




0.9347484079364681
{'C': 1.0}




0.9278536774367441
{'C': 1.0}




0.9733362450650429
{'C': 1.0}




0.9446957184601702
{'C': 1.0}




0.8866090483420574
{'C': 1.0}




0.8715638324023294
{'C': 1.0}




0.9972202190282925
{'C': 1.0}




0.9307820341675257
{'C': 1.0}




0.9128744022149509
{'C': 1.0}




0.9705797666028692
{'C': 1.0}




0.9371978049555805
{'C': 1.0}




0.9426131066860391
{'C': 1.0}




0.9389049792042029
{'C': 1.0}




0.8643894172399839
{'C': 1.0}




0.8630449340103244
{'C': 1.0}




0.8760850220782055
{'C': 1.0}




0.9746577533803694
{'C': 1.0}




0.8767768896272876
{'C': 1.0}




0.8909888370683414
{'C': 1.0}




0.8823097378751424
{'C': 1.0}




0.8973983329123516
{'C': 1.0}




0.9028533632512266
{'C': 1.0}




0.8875916968453753
{'C': 1.0}




0.8592915066154503
{'C': 1.0}




0.8566352968935756
{'C': 1.0}




0.8914316976685465
{'C': 1.0}




0.9046644214989059
{'C': 1.0}




0.8443396226415094
{'C': 1.0}




0.8605361455347668
{'C': 1.0}




0.9974693031951626
{'C': 1.0}




0.8910162867542226
{'C': 1.0}




0.8741423730065655
{'C': 1.0}




0.9051747368421053
{'C': 1.0}




0.9003370182977544
{'C': 1.0}




0.9496204711507916
{'C': 1.0}




0.9007394309596528
{'C': 1.0}




0.902207561532606
{'C': 10.0}




0.8358267648482334
{'C': 1.0}




0.906056855117129
{'C': 1.0}




0.8680008252527337
{'C': 1.0}




0.8395938781977167
{'C': 1.0}




0.8804350575439484
{'C': 1.0}




0.8909151734017275
{'C': 1.0}




0.9529177329361916
{'C': 1.0}




0.9584667616732333
{'C': 1.0}




0.897010764280781
{'C': 1.0}




0.859461177566077
{'C': 1.0}




0.7410905515958007
{'C': 1.0}




0.9083754965502823
{'C': 1.0}




0.9717235868102289
{'C': 1.0}




0.9221225408665558
{'C': 1.0}




0.9236830444851312
{'C': 1.0}




0.9629353736343282
{'C': 1.0}




0.9192112544197671
{'C': 1.0}




0.9638679544808917
{'C': 1.0}




0.942834690431206
{'C': 1.0}




0.938906523505222
{'C': 1.0}




0.9182808410031242
{'C': 1.0}
Accuracy: 	 	 0.06658945027154081 	 0.0556502411633647
Hamming: 	 	 0.9166980731955698 	 0.9099438218957998
Precision: 	 	 0.3398169813737109 	 0.31747229538179733
Recall: 	 	 0.8955976998786033 	 0.8780245623619727
F1: 	 	 	 0.49221850322207716 	 0.46371046734242005


In [11]:
predict(LinearSVC(penalty='l2'), params, Xtrain, Ytrain)

0.8758388730610953
{'C': 10.0}
0.9138505788013375
{'C': 1.0}
0.875770503295745
{'C': 1.0}
0.9510079357840008
{'C': 100.0}
0.9523633491842118
{'C': 100.0}
0.9670416809049409
{'C': 1.0}
0.9755735927068153
{'C': 10.0}
0.9999402771141902
{'C': 1.0}
0.9958986720726399
{'C': 1.0}
0.9948365622599642
{'C': 100.0}
0.9645901254640998
{'C': 1.0}
0.9688373866621472
{'C': 1.0}
0.9288989862509763
{'C': 1.0}
0.9895583095562281
{'C': 1.0}
0.9774758351704257
{'C': 1.0}
0.8990228751501961
{'C': 100.0}
0.9364902813255968
{'C': 1.0}
0.9996715185229426
{'C': 100.0}
0.9795149367603676
{'C': 10.0}
0.9864165311885
{'C': 1.0}
0.9920487148455022
{'C': 1.0}
0.9833948339483395
{'C': 100.0}
0.9682116093082584
{'C': 1.0}
0.9790235142358505
{'C': 1.0}
0.9365143533136452
{'C': 1.0}
0.9226997330047237
{'C': 100.0}
0.932093480067491
{'C': 10.0}
0.9974156726406438
{'C': 1.0}
0.9466234161690322
{'C': 10.0}
0.9159364268435031
{'C': 100.0}
0.9016052544822865
{'C': 100.0}
0.9638168711268497
{'C': 1.0}
0.9427291810077421
{'C