In [1]:
from sklearn.linear_model import LogisticRegression
import sklearn.metrics as mc
import warnings
import numpy as np
from tqdm import tqdm
import spacy

In [2]:
def Add_subset(k, scores, test_idx, pred, X, y, thresh):
    #print("test_idx", test_idx)
    #print("old")
    #print(pred[test_idx])
    
    if pred[test_idx] > thresh:
        top_k_index = scores[test_idx].argsort()[-k:]
    else:
        top_k_index = scores[test_idx].argsort()[:k]
        
    y_k = y["train_add"][top_k_index]
    X_k = X["train_add"][top_k_index]
    y_added = np.concatenate((y["train"],y_k))
    X_added = np.concatenate((X["train"],X_k), axis=0)

    prediction = -np.sum(scores[test_idx][top_k_index])
    #print("prediction", prediction)

    return X_k, y_k, prediction, top_k_index


def loss_gradient(X, y, model):
    F_train = np.concatenate([X, np.ones((X.shape[0], 1))], axis=1)
    error_train = model.predict_proba(X)[:, 1] - y
    gradient_train = F_train * error_train[:, None]
    return gradient_train

In [3]:
def approximate_k(test_idx, pred, delta_pred, y, thresh):
    old = pred[test_idx].item()
    
    if pred[test_idx] > thresh:
        top_k_index = np.flip(delta_pred[test_idx].argsort())
    else:
        top_k_index = delta_pred[test_idx].argsort()
    
    for k in range(1, y["train"].shape[0]):
        change = -np.sum(delta_pred[test_idx][top_k_index[:k]])
        
        if old > thresh and old + change < thresh:
            return k
        elif old < thresh and old + change > thresh:
            return k
        
    return None

In [4]:
def new_train(k, dev_index, scores, l2, X, model, pred, y, thresh):
    X_k, y_k, prediction, top_k_index = Add_subset(k, scores, dev_index, pred, X, y, thresh)
    
    if y_k.shape[0] == np.sum(y_k) or np.sum(y_k) == 0: # data contains only one class: 1
        return None, None, None, None, None, None
    
    model.fit(X_k, y_k)

    # predictthe probaility with test point
    test_point = X["dev"][dev_index]
    test_point=np.reshape(test_point, (1,-1))
    
    new = model.predict_proba(test_point)[0][1]
    change = -(model.predict_proba(test_point)[0][1] - new)
    #change = model_k.predict_proba(test_point)[0][1]-model.predict_proba(test_point)[0][1]
    flip = (model.predict(test_point) == model.predict(test_point))
    
    """
    print("change    ", change)
    print("old       ", model.predict_proba(test_point)[0][1])
    print()
    """
    error = np.abs((change - prediction)/prediction)
    return change, flip, prediction, new, error, top_k_index

In [5]:
def IP(X, y, l2, dataname, thresh, model, pred, modi=None):
    gradient_train_add = loss_gradient(X["train_add"], y["train_add"], model)

    w = np.concatenate((model.coef_, model.intercept_[None, :]), axis=1)
    F_train = np.concatenate([X["train"], np.ones((X["train"].shape[0], 1))],
                             axis=1)  # Concatenating one to calculate the gradient with respect to intercept
    F_dev = np.concatenate([X["dev"], np.ones((X["dev"].shape[0], 1))], axis=1)

    # error_train = model.predict_proba(X["train"])[:, 1] - y["train"]
    # error_dev = model.predict_proba(X["dev"])[:, 1] - y["dev"]

    # gradient_train = F_train * error_train[:, None]
    # gradient_dev = F_dev * error_dev[:, None]

    probs = model.predict_proba(X["train"])[:, 1]
    hessian = F_train.T @ np.diag(probs * (1 - probs)) @ F_train / X["train"].shape[0] + l2 * np.eye(F_train.shape[1]) / \
              X["train"].shape[0]
    inverse_hessian = np.linalg.inv(hessian)
    eps = 1 / X["train"].shape[0]
    delta_k = -eps * inverse_hessian @ (- gradient_train_add).T
    grad_f = F_dev * (pred * (1 - pred))
    delta_pred = grad_f @ delta_k
    return delta_pred
    


    


In [6]:
import spacy
sst_dataset = {}
for split in ["train", "dev", "test"]:
    URL = f"https://raw.githubusercontent.com/successar/instance_attributions_NLP/master/Datasets/SST/data/{split}.jsonl"
    import urllib.request, json
    with urllib.request.urlopen(URL) as url:
        data = url.read().decode()
        data = [json.loads(line) for line in data.strip().split("\n")]
        sst_dataset[split] = data


nlp = spacy.load('en_core_web_sm')

X, y = {}, {}
for split in ["train", "dev"]:
    X[split] = np.array([nlp(example["document"]).vector for example in tqdm(sst_dataset[split])])
    y[split] = np.array([example["label"] for example in sst_dataset[split]])

100%|██████████| 6920/6920 [00:41<00:00, 165.08it/s]
100%|██████████| 872/872 [00:05<00:00, 167.83it/s]


In [7]:
from sklearn.model_selection import train_test_split
X_train, X_train_add, y_train, y_train_add = train_test_split(
    X["train"], y["train"], test_size=0.4, random_state=42)

In [8]:
X["train"] = X_train
X["train_add"] = X_train_add
y["train"] = y_train
y["train_add"]= y_train_add
X["train"].shape, X["train_add"].shape, X["dev"].shape

((4152, 96), (2768, 96), (872, 96))

In [9]:
thresh = 0.5
l2 = 10
model = LogisticRegression(penalty='l2', C=1/l2, solver='saga',warm_start=True)
model.fit(X["train"], y["train"])
pred = np.reshape(model.predict_proba(X["dev"])[:, 1], (model.predict_proba(X["dev"])[:, 1].shape[0], 1))

In [10]:
delta_pred = IP(X, y, l2, "SST", thresh, model, pred)
delta_pred.shape

(872, 2768)

In [11]:
# Loop over all dev points:
appro_ks = []
new_predictions = []
flip_list = []
for test_idx in tqdm(range(X["dev"].shape[0])):
    appro_k = approximate_k(test_idx, pred, delta_pred, y, thresh)
    if appro_k != None:
        change, _, prediction,new_prediction, error, top_k_index = new_train(appro_k, test_idx, delta_pred, l2, X, model, pred, y, thresh)
        print(test_idx, appro_k, pred[test_idx], new_prediction)
        appro_ks.append(appro_k)
        new_predictions.append(new_prediction)
        flip_list.append(top_k_index)
    else:
        appro_ks.append(None)
        new_predictions.append(None)
        flip_list.append(None)

appro_ks= np.array(appro_ks)
new_predictions=np.array(new_predictions)
# flip_list = np.array(flip_list)
# np.save("./results/" + "appro_ks_IP" + "_alg1_" + dataname  + str(l2) + ".npy", appro_ks)
# np.save("./results/" + "new_predictions" + "_alg1_" + dataname + str(l2) + ".npy", new_predictions)
# np.save("./results/" + "old_predictions" + "_alg1_" + dataname + str(l2) + ".npy", pred)
# np.save("./results/" + "flip_list" + "_alg1_" + dataname + str(l2) + ".npy", flip_list)



0 20 [0.5509179] 0.23607384
1 34 [0.40708712] 0.8511833
2 72 [0.6836342] 0.026963012
3 4 [0.49117252] 0.52820253
4 42 [0.6547108] 0.074020796
5 92 [0.63173795] 0.11043312
6 6 [0.484108] 0.5958128
7 9 [0.48168048] 0.8066516
8 161 [0.6890086] 0.05747488
9 29 [0.5548766] 0.39529526
10 25 [0.4435639] 0.77873194
11 8 [0.5202813] 0.35272473
12 73 [0.6395619] 0.061469458
13 45 [0.37409356] 0.94439125
14 127 [0.6812974] 0.052479442
15 19 [0.53945047] 0.2141494
16 43 [0.5639916] 0.1365217
17 15 [0.45675865] None
18 173 [0.69866514] 0.03239461
19 30 [0.56355983] 0.21295026
20 44 [0.5752861] 0.2331422
21 40 [0.6803442] 0.05015972
22 112 [0.65148365] 0.07550993
23 112 [0.69692945] 0.02965681
24 3 [0.48383203] 0.33698374
25 11 [0.5441342] 0.19649531
26 94 [0.62871855] 0.08128653
27 110 [0.32732755] 0.9107533
28 22 [0.5574646] 0.16651124
29 14 [0.53091234] 0.2708724
30 112 [0.67683524] 0.060638648
31 58 [0.38949615] 0.8606737
32 22 [0.5504116] 0.1831223
33 45 [0.58985686] 0.11287512
34 5 [0.5132863]



51 65 [0.6296311] 0.07350321
52 163 [0.6934624] 0.04695206
53 30 [0.5782436] 0.20339026
54 78 [0.61627805] 0.05982234
55 22 [0.5495632] 0.3518882
56 7 [0.5244039] 0.3992227
57 61 [0.38089994] 0.8948031
58 69 [0.3218399] 0.923836
59 34 [0.41967943] 0.93134534
60 21 [0.541312] 0.09392139
61 112 [0.65289956] 0.10600053
62 65 [0.6444478] 0.08612042
63 57 [0.60406274] 0.11171479
64 12 [0.546608] 0.2404296
65 11 [0.51781946] 0.31414303
66 116 [0.62637115] 0.057322454
67 176 [0.7404305] 0.026491715
68 68 [0.6054877] 0.12769097
69 24 [0.42987174] 0.8912959
70 36 [0.4423776] 0.88641465
71 52 [0.37785205] 0.91667026
72 73 [0.6557302] 0.046603248
73 52 [0.57818687] 0.14066175
74 3 [0.5082837] None
75 65 [0.40364125] 0.868941
76 101 [0.665291] 0.055766836
77 1 [0.49911857] None
78 15 [0.47279027] 0.78608465
79 46 [0.5985158] 0.163559
80 8 [0.5477181] 0.12542789
81 16 [0.5502247] 0.085509114
82 38 [0.3529809] 0.9317658
83 8 [0.5252721] 0.55291855
84 47 [0.59397006] 0.18307318
85 111 [0.6649714] 0.0



103 7 [0.45756036] 0.74057096
104 17 [0.44251966] 0.8297815
105 88 [0.6308782] 0.09587739
106 23 [0.5558214] 0.1308753
107 26 [0.59216225] 0.1980491
108 19 [0.54578084] 0.19224997
109 69 [0.35872588] 0.9321818
110 33 [0.565053] 0.13871813
111 17 [0.54202217] 0.3642817
112 45 [0.62261707] 0.12195668
113 39 [0.5799945] 0.15657632
114 1 [0.49721602] None
115 5 [0.48867428] 0.7862403
116 10 [0.4639336] 0.8323743
117 18 [0.43417713] 0.8093683
118 24 [0.44858813] 0.8704485
119 30 [0.6073501] 0.17819414
120 65 [0.36184856] 0.9125127
121 8 [0.47906154] 0.8311671
122 44 [0.5680357] 0.16558512
123 90 [0.6263384] 0.13746364
124 11 [0.4504417] 0.8461737
125 68 [0.6133001] 0.107870266
126 143 [0.73097515] 0.03178184
127 5 [0.46706486] None
128 171 [0.76324886] 0.0161233
129 31 [0.5574268] 0.12355064
130 2 [0.48851776] None
131 5 [0.46995676] None
132 158 [0.7081248] 0.022087896
133 223 [0.7428901] 0.01934965
134 172 [0.79958093] 0.002545127
135 7 [0.51595706] 0.13441929
136 62 [0.61452633] 0.154886

 21%|██▏       | 186/872 [00:00<00:02, 247.20it/s]

153 96 [0.6198899] 0.10280194
154 13 [0.53401446] 0.5078076
155 39 [0.6147506] 0.11211915
156 150 [0.6726271] 0.037866324
157 14 [0.5392382] 0.17001222
158 2 [0.49575025] None
159 39 [0.58836126] 0.2430805
160 100 [0.34028384] 0.94260854
161 34 [0.42205048] 0.8416871
162 5 [0.48549208] 0.72146165
163 61 [0.6021] 0.11812888
164 1 [0.4966043] None
165 10 [0.5492155] 0.18894818
166 8 [0.5300609] 0.16493544
167 14 [0.5471866] 0.34163296
168 3 [0.49269667] 0.45025998
169 40 [0.39585838] 0.91817427
170 7 [0.53602624] 0.1632816
171 23 [0.6118059] 0.11349537
172 114 [0.70532435] 0.03505508
173 57 [0.3991051] 0.8673289
174 47 [0.6295987] 0.11736485
175 1 [0.5052556] None
176 88 [0.6331542] 0.066918
177 119 [0.6867983] 0.040022574
178 79 [0.6426105] 0.05207998
179 4 [0.4705278] 0.8012918
180 8 [0.47377643] 0.8945316
181 96 [0.6307869] 0.091742575
182 154 [0.68766636] 0.05553233
183 301 [0.73177266] 0.024291547
184 11 [0.53136873] 0.1947278
185 12 [0.54709375] 0.19995995
186 1 [0.49685743] None
1

 27%|██▋       | 236/872 [00:00<00:02, 235.93it/s]

203 37 [0.56440246] 0.19673812
204 31 [0.59323597] 0.14705825
205 141 [0.2765632] 0.9583916
206 31 [0.60411876] 0.11431539
207 22 [0.56394] 0.13265826
208 38 [0.4144702] 0.8928772
209 71 [0.6232479] 0.0896495
210 29 [0.39145812] 0.9129805
211 6 [0.4799585] 0.6968464
212 37 [0.5748664] 0.19454488
213 185 [0.2954189] 0.94634306
214 25 [0.5558865] 0.08752767
215 6 [0.4780136] 0.8573426
216 17 [0.54320246] 0.12191772
217 103 [0.6626889] 0.0637372
218 58 [0.6484889] 0.0552864
219 1 [0.4952907] None
220 119 [0.6878717] 0.073944926
221 98 [0.6674137] 0.060780305
222 2 [0.4924529] None
223 124 [0.71297896] 0.036200136
224 11 [0.4705696] 0.8415623
225 53 [0.6187755] 0.11618287
226 31 [0.54614305] 0.3235895
227 98 [0.602203] 0.13526438
228 36 [0.42462862] 0.79856795
229 54 [0.31034988] 0.9294969
230 37 [0.42665407] 0.8729948
231 7 [0.52404267] 0.31970677
232 35 [0.42543912] 0.8765785
233 49 [0.7066281] 0.021632181
234 46 [0.3333174] 0.9458513
235 18 [0.46256495] 0.75554115
236 5 [0.515815] 0.180

 33%|███▎      | 284/872 [00:01<00:02, 230.33it/s]

249 49 [0.59659135] 0.08938553
250 11 [0.47368178] 0.6438684
251 34 [0.58425623] 0.237224
252 814 [0.87155] 0.0005154754
253 44 [0.41413093] 0.83850294
254 113 [0.6486649] 0.064642325
255 73 [0.66936874] 0.070420034
256 31 [0.5727742] 0.35153136
257 8 [0.4753445] 0.8787498
258 48 [0.34298423] 0.94878566
259 27 [0.4331006] 0.8806848
260 51 [0.617164] 0.10332889
261 30 [0.368476] 0.95356417
262 113 [0.6445377] 0.067004636
263 1 [0.4992955] None
264 42 [0.60587305] 0.15200679
265 20 [0.4281933] 0.874649
266 19 [0.44188723] 0.8685638
267 15 [0.45801836] 0.8412167
268 129 [0.68058616] 0.0733216
269 117 [0.6265843] 0.080370225
270 31 [0.5856381] 0.13953927
271 21 [0.5482068] 0.21710752
272 9 [0.5158023] 0.16827822
273 10 [0.4551327] 0.9084183
274 44 [0.68541443] 0.038929965
275 35 [0.5729731] 0.1814786
276 77 [0.73483] 0.025084727
277 13 [0.44581184] 0.92073345
278 155 [0.67379767] 0.043734472
279 47 [0.63838893] 0.10552867
280 8 [0.5312563] 0.14645459
281 12 [0.52706337] 0.4831799
282 1 [0.

 39%|███▊      | 337/872 [00:01<00:02, 244.88it/s]

291 9 [0.52078146] 0.64211744
292 4 [0.5186693] 0.21580954
293 23 [0.5552974] 0.17667177
294 43 [0.40113732] 0.91804504
295 56 [0.6223823] 0.105068885
296 33 [0.57432264] 0.13865553
297 12 [0.45695585] 0.7259929
298 47 [0.5863168] 0.22387147
299 17 [0.5372818] 0.29605052
300 24 [0.549339] 0.14739326
301 185 [0.7017942] 0.054386858
302 13 [0.38784522] 0.94811666
303 1 [0.49855852] None
304 47 [0.5914762] 0.09960749
305 8 [0.47636974] 0.75048673
306 11 [0.5253681] 0.2560835
307 1 [0.49628758] None
308 6 [0.45629632] None
309 15 [0.5790727] 0.13183694
310 93 [0.35476947] 0.9248476
311 40 [0.3770731] 0.8665493
312 12 [0.5286046] 0.34102538
313 2 [0.48916087] None
314 18 [0.5417372] 0.33947852
315 6 [0.48203796] None
316 13 [0.45279017] 0.7444226
317 1 [0.50390506] None
318 50 [0.5762193] 0.07842192
319 155 [0.6717191] 0.04193787
320 9 [0.48188037] 0.8925573
321 8 [0.5150872] 0.44084728
322 50 [0.36946845] 0.9353635
323 4 [0.51154137] 0.37632972
324 12 [0.45106572] 0.87875384
325 7 [0.47476



346 48 [0.65177214] 0.046997894
347 18 [0.5603818] 0.26046896
348 97 [0.70652384] 0.018360633
349 85 [0.6467934] 0.0817543
350 51 [0.29605272] 0.9771413
351 49 [0.6602526] 0.04965821
352 102 [0.6223314] 0.111901104
353 56 [0.590416] 0.27061972
354 74 [0.6196062] 0.105437495
355 111 [0.67942333] 0.071314536
356 25 [0.4286145] 0.7802234
357 37 [0.5783862] 0.42912054
358 12 [0.53294206] 0.3990839
359 76 [0.69203633] 0.030486891
360 13 [0.5320437] 0.2711976
361 10 [0.5289591] 0.36912242
362 15 [0.46531403] 0.662614
363 63 [0.6170653] 0.13641764
364 66 [0.20983146] 0.99779946
365 57 [0.382305] 0.88706017
366 39 [0.6278507] 0.09186187
367 10 [0.52510494] 0.45597905
368 18 [0.54002094] 0.30093524
369 106 [0.6894242] 0.03227756
370 2 [0.5042993] None
371 111 [0.6489999] 0.06802833
372 22 [0.55546075] 0.22318785
373 3 [0.514579] None
374 13 [0.46044248] 0.69609874
375 1 [0.494873] None
376 3 [0.50975484] 0.60955733
377 101 [0.6689683] 0.056097
378 74 [0.6015565] 0.13489528
379 48 [0.5756339] 0.

 50%|█████     | 438/872 [00:01<00:01, 234.16it/s]

398 14 [0.43638936] 0.8716754
399 20 [0.40571982] 0.9535598
400 218 [0.7400034] 0.029274564
401 85 [0.64624584] 0.054277785
402 50 [0.6895111] 0.025720757
403 2 [0.49299294] 0.5469188
404 42 [0.37537917] 0.927542
405 4 [0.48726404] 0.63955146
406 9 [0.5269213] 0.3257573
407 59 [0.6073324] 0.12595446
408 100 [0.6980448] 0.034006394
409 3 [0.49024644] 0.6079308
410 154 [0.69743216] 0.034772914
411 35 [0.5686495] 0.21000804
412 11 [0.52901256] 0.53647566
413 25 [0.54273796] 0.2878747
414 21 [0.5746554] 0.091694035
415 191 [0.71013] 0.02718936
416 22 [0.5507534] 0.22579998
417 35 [0.42365336] 0.8936269
418 21 [0.5447057] 0.22359242
419 26 [0.57156086] 0.17454962
420 46 [0.6860314] 0.05808673
421 6 [0.5243246] 0.76309437
422 38 [0.41209215] 0.9260506
423 32 [0.5764876] 0.3097861
424 6 [0.48127213] 0.7030011
425 63 [0.38909376] 0.8942154
426 81 [0.6156398] 0.11109632
427 20 [0.5632332] 0.19632442
428 106 [0.7189021] 0.029836005
429 28 [0.5636652] 0.27722722
430 228 [0.78375286] 0.013109001
4



441 25 [0.55757475] 0.2070992
442 20 [0.4477724] None
443 74 [0.38243043] 0.9203076
444 12 [0.53180504] 0.16996281
445 103 [0.3391647] 0.93237066
446 3 [0.48675334] 0.39307907
447 40 [0.39290065] 0.8589006
448 22 [0.43242887] 0.80765885
449 70 [0.6181183] 0.08917373
450 38 [0.38418645] 0.90815866
451 1 [0.49874705] None
452 153 [0.29990897] 0.9410497
453 80 [0.60637945] 0.118464395
454 67 [0.27973253] 0.98191065
455 55 [0.6040585] 0.23038512
456 12 [0.46142188] 0.7304479
457 19 [0.46026674] 0.7655804
458 44 [0.38594255] 0.921445
459 52 [0.6746475] 0.05452835
460 21 [0.44563422] 0.77808875
461 6 [0.47932371] 0.7414547
462 38 [0.4142269] 0.9086195
463 38 [0.43560874] 0.7532082
464 7 [0.47786987] 0.61642253
465 3 [0.51259863] 0.28425062
466 75 [0.31729096] 0.951002
467 11 [0.5498122] 0.16755071
468 13 [0.45417005] 0.71106684
469 88 [0.37135035] 0.8904168
470 38 [0.38724628] 0.9138655
471 31 [0.3745705] 0.96164495
472 60 [0.33603868] 0.924231
473 38 [0.6048954] 0.093098536
474 5 [0.4830745

 62%|██████▎   | 545/872 [00:02<00:01, 258.20it/s]

494 19 [0.541753] 0.2566848
495 3 [0.48705393] 0.6566712
496 130 [0.29062253] 0.9739358
497 14 [0.55597484] 0.22973862
498 32 [0.43645012] 0.8983231
499 24 [0.44270426] 0.8541026
500 12 [0.46399805] 0.87970626
501 9 [0.5135211] 0.4420453
502 83 [0.67654] 0.054503992
503 97 [0.6323393] 0.07500039
504 60 [0.34854662] 0.9262453
505 32 [0.5784048] 0.029059751
506 15 [0.46551546] 0.8539188
507 132 [0.6944105] 0.064240694
508 17 [0.5336661] 0.6599485
509 3 [0.51081365] 0.62456167
510 35 [0.58052194] 0.13928795
511 61 [0.3455577] 0.9411711
512 40 [0.6046614] 0.11169223
513 30 [0.39084452] 0.9363806
514 90 [0.63273126] 0.12871307
515 9 [0.47283515] None
516 67 [0.38744187] 0.89811003
517 12 [0.4490581] 0.75076735
518 4 [0.53080136] None
519 15 [0.5469548] 0.3785384
520 11 [0.4659098] 0.8180803
521 106 [0.28282884] 0.9699359
522 17 [0.41445372] 0.7748188
523 31 [0.43105572] 0.8524178
524 12 [0.53378797] 0.3714757
525 7 [0.53704613] None
526 276 [0.11612704] 0.99988985
527 23 [0.5539147] 0.26614

 69%|██████▉   | 601/872 [00:02<00:01, 257.74it/s]

550 65 [0.6368123] 0.09132362
551 8 [0.46982178] 0.7944634
552 39 [0.41894034] 0.83515745
553 26 [0.42322898] 0.85209405
554 12 [0.4182109] None
555 6 [0.52047527] 0.29504725
556 109 [0.19840008] 0.9962794
557 56 [0.35959595] 0.93062085
558 7 [0.44143617] None
559 63 [0.41155124] 0.85150117
560 9 [0.47444645] 0.80653214
561 74 [0.6252594] 0.1451327
562 20 [0.44413704] 0.8887155
563 3 [0.50999373] 0.6360403
564 12 [0.54202384] 0.2517938
565 40 [0.6114503] 0.12941577
566 101 [0.26729038] 0.9931518
567 18 [0.46805835] 0.6141969
568 145 [0.19901974] 0.9941525
569 47 [0.5923256] 0.10089844
570 82 [0.6355847] 0.067036815
571 81 [0.643673] 0.07944163
572 1 [0.5057614] None
573 72 [0.64815307] 0.08444406
574 68 [0.35609058] 0.93328583
575 114 [0.6492582] 0.061925936
576 23 [0.5633728] 0.17526987
577 68 [0.63185596] 0.082356215
578 52 [0.60265267] 0.13008106
579 95 [0.33074817] 0.92537594
580 42 [0.40376434] 0.8762537
581 3 [0.49006262] 0.7219363
582 148 [0.29306147] 0.969172
583 17 [0.4591862]



601 34 [0.42826942] 0.953399
602 4 [0.4917664] 0.8894563
603 44 [0.38211116] 0.92031574
604 60 [0.39563] 0.8919752
605 11 [0.4643485] 0.79048055
606 13 [0.55480915] 0.4449078
607 119 [0.26928374] 0.98243976
608 21 [0.54367834] 0.11464317
609 90 [0.27942964] 0.9671971
610 7 [0.5185809] 0.21912344
611 80 [0.38395682] 0.8820731
612 75 [0.32129553] 0.9543368
613 53 [0.5777909] 0.17548344
614 30 [0.38928255] 0.92076224
615 14 [0.4608717] None
616 35 [0.42716095] 0.82058936
617 21 [0.39570823] 0.880131
618 29 [0.5654582] 0.20515943
619 11 [0.52810824] 0.25598615
620 17 [0.45333388] 0.7667456
621 8 [0.4781199] 0.8707047
622 48 [0.6056698] 0.13608989
623 3 [0.4924818] 0.67149043
624 69 [0.41001433] 0.91790175
625 47 [0.39968804] 0.925588
626 85 [0.3385774] 0.9397598
627 8 [0.522704] 0.5100118
628 30 [0.39786205] 0.9263901
629 105 [0.34224844] 0.9261327
630 2 [0.51239103] None
631 10 [0.46663326] 0.7593481
632 1 [0.5009861] None
633 92 [0.31412873] 0.9529634
634 24 [0.5579895] 0.18831275
635 10



658 90 [0.6277299] 0.13737516
659 122 [0.653042] 0.05712198
660 50 [0.5770734] 0.17791986
661 147 [0.21732856] 0.9927812
662 18 [0.568634] 0.13202083
663 1 [0.5079978] None
664 1 [0.5006451] None
665 87 [0.3573201] 0.9207872
666 24 [0.55231184] 0.2042319
667 54 [0.5817618] 0.22205615
668 119 [0.30727917] 0.9536458
669 278 [0.18139572] 0.9966838
670 10 [0.5311969] 0.57769066
671 16 [0.449814] 0.58760643
672 43 [0.42769164] 0.88233984
673 101 [0.35770088] 0.9094923
674 5 [0.5128665] 0.19539106
675 42 [0.36136648] 0.91905695
676 65 [0.6245052] 0.08721102
677 6 [0.5131644] 0.39228448
678 74 [0.33496642] 0.9416035
679 5 [0.4871728] None
680 23 [0.5553633] 0.17056504
681 43 [0.3979225] 0.91161776
682 44 [0.41286445] 0.91729695
683 45 [0.7193673] 0.050652303
684 81 [0.34641287] 0.90959626
685 43 [0.5987916] 0.14742044
686 9 [0.5348956] 0.15472388
687 105 [0.6569025] 0.071421996
688 44 [0.39105675] 0.8830439
689 80 [0.65935624] 0.06271947
690 8 [0.51652884] 0.7773762
691 6 [0.5250969] 0.137529



702 5 [0.47202525] 0.8166205
703 2 [0.48569313] None
704 20 [0.54547286] 0.2133325
705 68 [0.61194086] 0.15534085
706 4 [0.46858323] None
707 17 [0.55276823] 0.21124417
708 9 [0.47740254] 0.9121837
709 73 [0.3357308] 0.9442974
710 18 [0.5533491] 0.20430392
711 134 [0.2974694] 0.97051495
712 3 [0.5131394] 0.5834105
713 22 [0.4563396] 0.87287337
714 120 [0.7694027] 0.009630282
715 3 [0.49168786] None
716 267 [0.7548745] 0.015525359
717 2 [0.49225697] None
718 6 [0.5166108] 0.53780013
719 47 [0.4062633] 0.86388725
720 10 [0.4444551] None
721 10 [0.46725976] 0.822527
722 57 [0.36969122] 0.9043323
723 41 [0.5788247] 0.16466126
724 1 [0.5022383] None
725 8 [0.5185849] 0.32362866
726 208 [0.80653316] 0.0049013873
727 8 [0.462168] 0.8508705
728 61 [0.3799576] 0.934798
729 20 [0.59502715] 0.07122288
730 67 [0.30897075] 0.957427
731 164 [0.26990354] 0.9593505
732 2 [0.50744885] 0.6433986
733 6 [0.5243864] 0.18435399
734 34 [0.4140021] 0.9298428
735 34 [0.30453065] 0.97188944
736 3 [0.50815153] N



762 27 [0.55760896] 0.17364971
763 79 [0.30475608] 0.9540682
764 63 [0.6354386] 0.061017927
765 32 [0.55391175] 0.397696
766 41 [0.6486804] 0.064243846
767 13 [0.4521294] 0.89600927
768 82 [0.63169503] 0.106182545
769 17 [0.46181098] 0.79092735
770 17 [0.4495179] 0.85091865
771 3 [0.50984055] 0.33869198
772 24 [0.44330627] 0.79628146
773 101 [0.6648565] 0.06133768
774 42 [0.38985157] 0.9039337
775 55 [0.34118658] 0.88786906
776 97 [0.6338315] 0.08113212
777 2 [0.50529516] 0.4764626
778 15 [0.5379583] 0.38333815
779 100 [0.35306618] 0.8963701
780 4 [0.50945276] 0.41499817
781 36 [0.578228] 0.18695521
782 51 [0.59971654] 0.15675944
783 6 [0.5177216] 0.7517573
784 18 [0.46370628] 0.7700561
785 37 [0.40695792] 0.9140491
786 13 [0.46729422] 0.85830784
787 13 [0.54338855] 0.18431991
788 5 [0.5261692] 0.1815935
789 97 [0.34905764] 0.95004493
790 46 [0.37981725] 0.88584554
791 30 [0.4181723] 0.73025554
792 114 [0.30469024] 0.9300911
793 4 [0.48876864] 0.5761387
794 30 [0.56898624] 0.35837334
7

 98%|█████████▊| 856/872 [00:03<00:00, 264.60it/s]

819 34 [0.57056254] 0.1515238
820 52 [0.3971405] 0.86501414
821 7 [0.48051193] 0.60036284
822 40 [0.5968163] 0.10866005
823 106 [0.26224294] 0.97674495
824 13 [0.47306937] 0.8445809
825 14 [0.5236346] 0.51391417
826 4 [0.48093235] 0.8053637
827 76 [0.3707889] 0.9049733
828 115 [0.7207189] 0.03525137
829 21 [0.55326897] 0.24762115
830 26 [0.6476711] 0.09092302
831 43 [0.38733575] 0.9324196
832 14 [0.4702725] 0.5614425
833 57 [0.62894976] 0.11847074
834 20 [0.5740479] 0.20983711
835 17 [0.45477152] 0.77891415
836 479 [0.12365636] 0.99964535
837 6 [0.48472655] 0.853882
838 12 [0.47304466] 0.702527
839 27 [0.406788] 0.86744773
840 5 [0.54287887] None
841 27 [0.41354173] 0.8372244
842 117 [0.3071623] 0.930484
843 64 [0.28404498] 0.98509014
844 7 [0.4732421] 0.89528507
845 43 [0.55489606] 0.17281999
846 29 [0.5648009] 0.24940093
847 2 [0.505845] None
848 36 [0.42216545] 0.8111851
849 200 [0.26270947] 0.9757721
850 30 [0.5733596] 0.10964876
851 75 [0.30751362] 0.95844054
852 4 [0.5137675] 0.4

100%|██████████| 872/872 [00:03<00:00, 250.99it/s]

868 8 [0.47568718] 0.6451312
869 23 [0.45380282] 0.8009489
870 34 [0.58116955] 0.11251625
871 5 [0.48611158] 0.6050083





In [12]:
new_predictions = new_predictions.reshape((872, 1))
eval = []
num = 0
for i in range(872):
    if new_predictions[i] != None:
        eval.append((new_predictions[i] > 0.5) != (pred[i] > 0.5))
        num += 1
np.sum(eval) / num

np.float64(0.9623115577889447)