In [1]:
from sklearn.linear_model import LogisticRegression
import sklearn.metrics as mc
import warnings
import numpy as np
from tqdm import tqdm
import spacy

In [51]:
def Add_subset(k, scores, test_idx, pred, X, y, thresh):
    #print("test_idx", test_idx)
    #print("old")
    #print(pred[test_idx])
    
    if pred[test_idx] > thresh:
        top_k_index = scores[test_idx].argsort()[-k:]
    else:
        top_k_index = scores[test_idx].argsort()[:k]
        
    y_k = y["train_add"][top_k_index]
    X_k = X["train_add"][top_k_index]
    y_added = np.concatenate((y["train"],y_k))
    X_added = np.concatenate((X["train"],X_k), axis=0)

    prediction = -np.sum(scores[test_idx][top_k_index])
    #print("prediction", prediction)

    return X_added, y_added, prediction, top_k_index


def loss_gradient(X, y, model):
    F_train = np.concatenate([X, np.ones((X.shape[0], 1))], axis=1)
    error_train = model.predict_proba(X)[:, 1] - y
    gradient_train = F_train * error_train[:, None]
    return gradient_train

In [52]:
def approximate_k(test_idx, pred, delta_pred, y, thresh):
    old = pred[test_idx].item()
    
    if pred[test_idx] > thresh:
        top_k_index = np.flip(delta_pred[test_idx].argsort())
    else:
        top_k_index = delta_pred[test_idx].argsort()
    
    for k in range(1, y["train"].shape[0]):
        change = -np.sum(delta_pred[test_idx][top_k_index[:k]])
        
        if old > thresh and old + change < thresh:
            return k
        elif old < thresh and old + change > thresh:
            return k
        
    return None

In [53]:
def new_train(k, dev_index, scores, l2, X, model, pred, y, thresh):
    X_k, y_k, prediction, top_k_index = Add_subset(k, scores, dev_index, pred, X, y, thresh)
    
    if y_k.shape[0] == np.sum(y_k) or np.sum(y_k) == 0: # data contains only one class: 1
        return None, None, None

    # Fit the model again
    model_k = LogisticRegression(penalty='l2', C=1/l2)
    model_k.fit(X_k, y_k)

    # predictthe probaility with test point
    test_point = X["dev"][dev_index]
    test_point=np.reshape(test_point, (1,-1))
    
    new = model_k.predict_proba(test_point)[0][1]
    change = -(model.predict_proba(test_point)[0][1] - new)
    #change = model_k.predict_proba(test_point)[0][1]-model.predict_proba(test_point)[0][1]
    flip = (model.predict(test_point) == model_k.predict(test_point))
    
    """
    print("change    ", change)
    print("old       ", model.predict_proba(test_point)[0][1])
    print()
    """
    error = np.abs((change - prediction)/prediction)
    return change, flip, prediction,new, error, top_k_index

In [48]:
def IP(X, y, l2, dataname, thresh, modi=None):
    model = LogisticRegression(penalty='l2', C=1/l2)
    model.fit(X["train"], y["train"])
    pred = np.reshape(model.predict_proba(X["dev"])[:, 1], (model.predict_proba(X["dev"])[:, 1].shape[0], 1))

    gradient_train_add = loss_gradient(X["train_add"], y["train_add"], model)

    w = np.concatenate((model.coef_, model.intercept_[None, :]), axis=1)
    F_train = np.concatenate([X["train"], np.ones((X["train"].shape[0], 1))],
                             axis=1)  # Concatenating one to calculate the gradient with respect to intercept
    F_dev = np.concatenate([X["dev"], np.ones((X["dev"].shape[0], 1))], axis=1)

    # error_train = model.predict_proba(X["train"])[:, 1] - y["train"]
    # error_dev = model.predict_proba(X["dev"])[:, 1] - y["dev"]

    # gradient_train = F_train * error_train[:, None]
    # gradient_dev = F_dev * error_dev[:, None]

    probs = model.predict_proba(X["train"])[:, 1]
    hessian = F_train.T @ np.diag(probs * (1 - probs)) @ F_train / X["train"].shape[0] + l2 * np.eye(F_train.shape[1]) / \
              X["train"].shape[0]
    inverse_hessian = np.linalg.inv(hessian)
    eps = 1 / X["train"].shape[0]
    delta_k = -eps * inverse_hessian @ (- gradient_train_add).T
    grad_f = F_dev * (pred * (1 - pred))
    delta_pred = grad_f @ delta_k
    return delta_pred
    


    


In [33]:
import spacy
sst_dataset = {}
for split in ["train", "dev", "test"]:
    URL = f"https://raw.githubusercontent.com/successar/instance_attributions_NLP/master/Datasets/SST/data/{split}.jsonl"
    import urllib.request, json
    with urllib.request.urlopen(URL) as url:
        data = url.read().decode()
        data = [json.loads(line) for line in data.strip().split("\n")]
        sst_dataset[split] = data


nlp = spacy.load('en_core_web_sm')

X, y = {}, {}
for split in ["train", "dev"]:
    X[split] = np.array([nlp(example["document"]).vector for example in tqdm(sst_dataset[split])])
    y[split] = np.array([example["label"] for example in sst_dataset[split]])

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 6920/6920 [00:40<00:00, 172.23it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 872/872 [00:05<00:00, 173.12it/s]


In [34]:
from sklearn.model_selection import train_test_split
X_train, X_train_add, y_train, y_train_add = train_test_split(
    X["train"], y["train"], test_size=0.4, random_state=42)

In [35]:
X["train"] = X_train
X["train_add"] = X_train_add
y["train"] = y_train
y["train_add"]= y_train_add
X["train"].shape, X["train_add"].shape, X["dev"].shape

((4152, 96), (2768, 96), (872, 96))

In [36]:
thresh = 0.5
l2 = 1000
model = LogisticRegression(penalty='l2', C=1/l2)
model.fit(X["train"], y["train"])
pred = np.reshape(model.predict_proba(X["dev"])[:, 1], (model.predict_proba(X["dev"])[:, 1].shape[0], 1))

In [37]:
delta_pred = IP(X, y, l2, "SST", thresh)
delta_pred.shape

(872, 2768)

In [54]:
# Loop over all dev points:
appro_ks = []
new_predictions = []
flip_list = []
for test_idx in tqdm(range(X["dev"].shape[0])):
    appro_k = approximate_k(test_idx, pred, delta_pred, y, thresh)
    if appro_k != None:
        change, _, prediction,new_prediction, error, top_k_index = new_train(appro_k, test_idx, delta_pred, l2, X, model, pred, y, thresh)
        print(test_idx, appro_k, pred[test_idx], new_prediction)
        appro_ks.append(appro_k)
        new_predictions.append(new_prediction)
        flip_list.append(top_k_index)
    else:
        appro_ks.append(None)
        new_predictions.append(None)
        flip_list.append(None)

appro_ks= np.array(appro_ks)
new_predictions=np.array(new_predictions)
flip_list = np.array(flip_list)
np.save("./results/" + "appro_ks_IP" + "_alg1_" + dataname  + str(l2) + ".npy", appro_ks)
np.save("./results/" + "new_predictions" + "_alg1_" + dataname + str(l2) + ".npy", new_predictions)
np.save("./results/" + "old_predictions" + "_alg1_" + dataname + str(l2) + ".npy", pred)
np.save("./results/" + "flip_list" + "_alg1_" + dataname + str(l2) + ".npy", flip_list)

  3%|███▊                                                                                                         | 30/872 [00:00<00:05, 149.17it/s]

0 134 [0.52077396] 0.49728713269858066
1 55 [0.48867849] 0.5025829044023986
2 151 [0.55569746] 0.504179411130753
3 97 [0.52066404] 0.49810821989098775
4 149 [0.54020816] 0.4999450458918281
5 186 [0.52770517] 0.4937434027218943
6 55 [0.51122005] 0.49961439710328226
7 3 [0.50068478] 0.49988251172861625
8 178 [0.53307483] 0.4978678339083847
9 100 [0.51799194] 0.49905691537294394
10 61 [0.51233241] 0.49912682859497276
11 11 [0.50222835] 0.4997309157782503
12 199 [0.54440747] 0.4992949226403268
13 20 [0.4938701] 0.5005344833735793
14 140 [0.52529472] 0.4983150198912873
15 53 [0.50777162] 0.49837402965294064
16 128 [0.52250313] 0.49723790928461015
17 19 [0.5064933] 0.49946195880727146
18 205 [0.54220615] 0.5001668728541306
19 81 [0.51278771] 0.497648166585883
20 22 [0.50394654] 0.4989553878449617
21 186 [0.55922473] 0.5009571418038667
22 117 [0.52156702] 0.49912520663704135
23 124 [0.53618588] 0.5000414442686401
24 18 [0.50573817] 0.4995815657887367
25 57 [0.51673559] 0.5000572911637763
26 1

  5%|█████▉                                                                                                       | 47/872 [00:00<00:05, 154.55it/s]

30 192 [0.53598534] 0.4988216535728453
31 39 [0.50757266] 0.49861882529365403
32 114 [0.52371276] 0.4993334452805755
33 20 [0.49619858] 0.5012053441259131
34 114 [0.51967975] 0.4984141754077538
35 14 [0.50313168] 0.49973362853377523
36 8 [0.49790666] 0.5003264305558184
37 86 [0.47104257] 0.5042566343217711
38 40 [0.50814262] 0.4984039870142087
39 13 [0.49678635] 0.5009759129592611
40 46 [0.52583782] 0.5000971645381274
41 49 [0.50868465] 0.498312332936441
42 107 [0.51980081] 0.4978765981332745
43 47 [0.51574345] 0.49999322746658276
44 30 [0.50522769] 0.4990541057302473
45 230 [0.55640552] 0.5027563859059556
46 8 [0.50171165] 0.4996183756359377
47 109 [0.52424709] 0.5000809132684233
48 127 [0.5249485] 0.49840715768261223
49 40 [0.50855993] 0.4993475711315648
50 77 [0.52841175] 0.4983113733426321
51 152 [0.53598209] 0.5001608488304656
52 155 [0.53189331] 0.4994926174690251
53 71 [0.51469578] 0.4989724111560665
54 134 [0.52986481] 0.49825247957515373
55 63 [0.51111193] 0.49804983595492525


  9%|█████████▉                                                                                                   | 79/872 [00:00<00:05, 152.64it/s]

62 147 [0.53399947] 0.5001222425411264
63 168 [0.53261813] 0.49902211908766153
64 3 [0.5014225] 0.49999974394864577
65 26 [0.49468781] 0.5010067880976886
66 177 [0.5372944] 0.5000307980294272
67 261 [0.55350788] 0.4991022206936955
68 138 [0.52652038] 0.49889278822756633
69 18 [0.4953797] 0.5011374975847274
70 8 [0.49827304] 0.5001716298716004
71 36 [0.50938928] 0.4994574748350145
72 169 [0.53853741] 0.49881486934683233
73 126 [0.52252998] 0.4986234686892074
74 48 [0.50834738] 0.4987576154033688
75 15 [0.50351159] 0.49975357978593776
76 178 [0.53666891] 0.49974999533252784
77 35 [0.50554677] 0.49895131735806303
78 62 [0.51141074] 0.4986700637549086
79 118 [0.52339166] 0.497367612796993
80 36 [0.52049832] 0.4998559158653809
81 37 [0.5083207] 0.4987479108190324
82 20 [0.5089035] 0.4995432637503656
83 105 [0.52057617] 0.49833730634627743
84 136 [0.52564605] 0.49854070001295697
85 180 [0.53557207] 0.4998187441216083
86 201 [0.54579584] 0.4994984003583767
87 13 [0.49557106] 0.500321568669338

 13%|█████████████▋                                                                                              | 111/872 [00:00<00:04, 154.38it/s]

94 96 [0.52024771] 0.49955975620297605
95 46 [0.51048794] 0.4989611454571143
96 31 [0.48390755] 0.5009122301888038
97 28 [0.50602459] 0.4992868618675019
98 2 [0.49956615] 0.5002094250831056
99 137 [0.54192675] 0.500550149730165
100 44 [0.50729464] 0.49766028093266945
101 5 [0.50197205] 0.49967158393809846
102 97 [0.52025003] 0.49858115324505875
103 95 [0.52377635] 0.49754517343403626
104 24 [0.50737894] 0.4999927714261373
105 11 [0.50275697] 0.49936751649591804
106 10 [0.5028016] 0.49963624196012646
107 63 [0.51609867] 0.49798269892718783
108 78 [0.51774957] 0.4978427986932492
109 44 [0.49121768] 0.5020296142784155
110 101 [0.52152214] 0.49981876685002036
111 47 [0.48841391] 0.5022103491751868
112 46 [0.51190577] 0.4991306662373445
113 61 [0.51341407] 0.4994906869096703
114 2 [0.50038396] 0.49974548354131015
115 51 [0.5095889] 0.49897247840709147
116 80 [0.48006723] 0.5030083704691911
117 46 [0.4886689] 0.5022392895315995
118 12 [0.5022328] 0.49930329800208934
119 9 [0.50565035] 0.4997

 17%|█████████████████▊                                                                                          | 144/872 [00:00<00:04, 152.73it/s]

127 1 [0.50028753] 0.49972826529690567
128 135 [0.53355826] 0.49881301664912714
129 100 [0.51778258] 0.49822242866967775
130 50 [0.51368192] 0.4992230365180815
131 15 [0.49466112] 0.5006924000819158
132 191 [0.54125534] 0.500524430609847
133 242 [0.54899468] 0.5003959993983819
134 204 [0.57500589] 0.505774904144882
135 35 [0.50844571] 0.4988166427898566
136 74 [0.51515465] 0.49895679771748985
137 44 [0.51148881] 0.499355803003518
138 32 [0.50637398] 0.49898516888328237
139 57 [0.51669728] 0.4990591860814177
140 84 [0.48310082] 0.5029706988191377
141 40 [0.50839799] 0.4988941096137116
142 17 [0.50324056] 0.49920143690722835
143 141 [0.53030133] 0.4972937321815357
144 109 [0.5218307] 0.49875412671582153
145 169 [0.53190608] 0.49718132020355904
146 12 [0.50361004] 0.4993739023255611
147 32 [0.49090738] 0.5014901184785012
148 125 [0.52186663] 0.4986474683025372
149 60 [0.51054793] 0.49797983151248265
150 15 [0.49646833] 0.5008050962134842
151 35 [0.51310282] 0.49899002174425805
152 122 [0.

 20%|█████████████████████▋                                                                                      | 175/872 [00:01<00:04, 148.01it/s]

156 227 [0.5423628] 0.498487798280383
157 118 [0.52171738] 0.4994561380077992
158 26 [0.5056806] 0.49954269278945723
159 140 [0.52403737] 0.4984477321788091
160 68 [0.48687964] 0.5026119409555699
161 44 [0.49134492] 0.5016079495451176
162 95 [0.51749929] 0.49861052053211924
163 137 [0.52296814] 0.4983727841869377
164 6 [0.50192433] 0.49984784649610176
165 6 [0.49820632] 0.5004117090333223
166 91 [0.5232939] 0.4988936229706362
167 46 [0.51253997] 0.49943001122116126
168 56 [0.51265908] 0.4993358121148242
169 35 [0.48876568] 0.5010520219209105
170 146 [0.53320562] 0.4990154803534339
171 11 [0.50626087] 0.49968402719569316
172 219 [0.54655021] 0.4972245943955032
173 29 [0.5057049] 0.49942647650320815
174 66 [0.51519059] 0.4986570381928504
175 1 [0.49973049] 0.5002821772152688
176 63 [0.51181586] 0.49940899919832077
177 117 [0.53157707] 0.49969923335469907
178 198 [0.54610745] 0.500050939246654
179 73 [0.53401831] 0.5005161546058505
180 38 [0.50831769] 0.4986921985902372
181 185 [0.5310062

 24%|█████████████████████████▌                                                                                  | 206/872 [00:01<00:04, 149.69it/s]

185 21 [0.50655796] 0.49939509219823397
186 24 [0.50683901] 0.49900679084000055
187 43 [0.50860609] 0.4986392780795411
188 88 [0.51667355] 0.4989812234131866
189 104 [0.52497098] 0.49813781209369473
190 73 [0.51055484] 0.4983839664886085
191 137 [0.52564587] 0.49809215604160356
192 29 [0.49266203] 0.5020296819140594
193 64 [0.48586572] 0.5025195471644865
194 63 [0.51203732] 0.4990017323742187
195 7 [0.4988236] 0.5003831167747864
196 207 [0.53438395] 0.497282527063421
197 117 [0.52521681] 0.49891900226874125
198 139 [0.52673493] 0.498598536259023
199 105 [0.54769751] 0.501449040939442
200 93 [0.51781418] 0.49856659408439624
201 58 [0.48421289] 0.503676223202213
202 6 [0.50107477] 0.49971596353440034
203 78 [0.5139688] 0.4984003521259874
204 15 [0.50444397] 0.4995759157008317
205 62 [0.48712755] 0.5015968124688238
206 69 [0.52292315] 0.5000417440763945
207 19 [0.50456329] 0.4993884181307206
208 1 [0.49978216] 0.5001324369220372
209 26 [0.50609453] 0.49899632685346
210 77 [0.47610463] 0.5

 27%|█████████████████████████████▍                                                                              | 238/872 [00:01<00:04, 150.91it/s]

217 190 [0.53599688] 0.49815049579108645
218 137 [0.5399313] 0.49969189711785666
219 103 [0.52608567] 0.49950480993969193
220 130 [0.52320464] 0.4987154618458274
221 142 [0.53464183] 0.5003662542135642
222 86 [0.4771763] 0.5056972054322628
223 154 [0.52985329] 0.49690459659608327
224 47 [0.51036199] 0.49901124061343655
225 47 [0.51042174] 0.49909365600725075
226 68 [0.51037192] 0.49844724357445624
227 128 [0.52118665] 0.4987220479515322
228 1 [0.49993731] 0.5002315186500453
229 29 [0.48970515] 0.5006400650239864
230 39 [0.49288385] 0.5012919712946268
231 39 [0.50855571] 0.49905884043613896
232 36 [0.49412822] 0.5022249788652274
233 35 [0.53458236] 0.5006474063475649
234 68 [0.4800333] 0.5028199584772697
235 36 [0.50595629] 0.49882364801348705
236 69 [0.5149965] 0.49884982628154284
237 70 [0.51610837] 0.5001998035563714
238 50 [0.51095034] 0.4988664845652645
239 7 [0.5016425] 0.49976543315229965
240 78 [0.51104798] 0.4974687580369445
241 191 [0.53945881] 0.4992072202294924
242 31 [0.505

 31%|█████████████████████████████████▍                                                                          | 270/872 [00:01<00:03, 153.45it/s]

249 26 [0.50550333] 0.4991798637096603
250 21 [0.50419889] 0.499309765234365
251 92 [0.51657367] 0.49659505177772945
252 147 [0.52855897] 0.4969403741446923
253 47 [0.51000285] 0.4993320607189796
254 157 [0.53063957] 0.4987095379173842
255 148 [0.52826106] 0.49542976524995713
256 73 [0.51486766] 0.4994666660076288
257 37 [0.49089155] 0.501149332939584
258 1 [0.50005087] 0.4994164361317641
259 23 [0.50506837] 0.49917225389839864
260 108 [0.53169031] 0.4986435197373773
261 87 [0.47371925] 0.5031778197379537
262 184 [0.53348208] 0.4978697213229088
263 63 [0.5127129] 0.49857187232022676
264 47 [0.5106571] 0.4995074094548685
265 25 [0.49067609] 0.5004782369647652
266 68 [0.51682673] 0.4991712249916731
267 20 [0.50521617] 0.4993615013436301
268 131 [0.52199355] 0.49791014458952815
269 151 [0.5271425] 0.49757169716817395
270 103 [0.52102509] 0.4979371759244518
271 42 [0.48805712] 0.5032234216141386
272 14 [0.50243854] 0.4995157057634935
273 5 [0.49838383] 0.5000348650364405
274 30 [0.52867581

 35%|█████████████████████████████████████▎                                                                      | 301/872 [00:01<00:03, 149.02it/s]

281 73 [0.51753508] 0.4990107223361591
282 108 [0.51675665] 0.4979187619992515
283 103 [0.51746689] 0.4987706328943905
284 101 [0.52102156] 0.49935897654907263
285 9 [0.50391463] 0.49967069310397755
286 127 [0.54547204] 0.5016892258351586
287 16 [0.50373706] 0.49921230586904286
288 70 [0.51377027] 0.4990928286634915
289 79 [0.53164267] 0.5009905816388938
290 28 [0.50678081] 0.49903062714729957
291 127 [0.52765111] 0.49945586532607344
292 35 [0.48954341] 0.501915678066283
293 133 [0.53396937] 0.5000659631473484
294 76 [0.48346126] 0.5032398583678563
295 141 [0.53238514] 0.49967917049189753
296 122 [0.52680208] 0.49951122096515965
297 23 [0.50546869] 0.49896503925955554
298 68 [0.5136694] 0.4992695738798327
299 146 [0.52849108] 0.49896395763741774
300 106 [0.5188232] 0.4981508009671168
301 174 [0.52919282] 0.4987545116588454
302 29 [0.52646724] 0.5002399326658633
303 77 [0.51407836] 0.4985317324288102
304 141 [0.53094299] 0.4985821571614885
305 24 [0.50620082] 0.499567286373967
306 108 [

 38%|████████████████████████████████████████▉                                                                   | 331/872 [00:02<00:03, 146.56it/s]

310 100 [0.48165325] 0.5046235445401704
311 20 [0.49384269] 0.5008199316097217
312 44 [0.50861803] 0.49897056899865266
313 5 [0.50307618] 0.49946459001581917
314 58 [0.51177933] 0.49907813036425736
315 21 [0.50377717] 0.49931525317815173
316 94 [0.52877883] 0.4995245053832146
317 68 [0.51141501] 0.49862299892697304
318 156 [0.53564958] 0.5009559702291702
319 167 [0.53749302] 0.49986003032586224
320 7 [0.50131859] 0.4996913992542868
321 71 [0.50956976] 0.49782321900623694
322 85 [0.48162435] 0.5025218112138045
323 87 [0.51492197] 0.49757476721336386
324 7 [0.49613647] 0.5002256244988164
325 34 [0.51285918] 0.49963419511393253
326 56 [0.48413698] 0.5016329917166817
327 35 [0.5073467] 0.49921799791382626
328 119 [0.52510827] 0.49805006304316934
329 121 [0.52492837] 0.4976796953009021
330 137 [0.52792715] 0.4994508284921555
331 117 [0.53402969] 0.5011034490599011
332 261 [0.55101761] 0.49854930172308254
333 51 [0.52284517] 0.500279757000269
334 172 [0.52964878] 0.49775642296435024
335 5 [0

 42%|████████████████████████████████████████████▊                                                               | 362/872 [00:02<00:03, 148.71it/s]

340 124 [0.52684239] 0.4976161525011014
341 65 [0.48479418] 0.502723952830067
342 277 [0.56434482] 0.499092834376619
343 57 [0.51011555] 0.4983173539803421
344 21 [0.4947089] 0.5008725065179099
345 12 [0.50262915] 0.4995243232967931
346 37 [0.51532424] 0.49978951789477094
347 100 [0.52165316] 0.49899367092992253
348 187 [0.55976002] 0.5012299523703838
349 42 [0.51049222] 0.49955086475086896
350 31 [0.47413704] 0.5000008729827647
351 158 [0.55832756] 0.5032446772704106
352 201 [0.53516531] 0.498442403170702
353 56 [0.50842393] 0.4981994965852699
354 110 [0.52528068] 0.5000426493083567
355 201 [0.53547099] 0.49642233153531384
356 67 [0.51674346] 0.4992519913110698
357 84 [0.51554212] 0.4981686814284115
358 65 [0.51635518] 0.49976744742943285
359 96 [0.54330813] 0.5014486295594692
360 101 [0.52052878] 0.4978790384950953
361 123 [0.5255263] 0.5005681764728198
362 7 [0.49861103] 0.5004246899495545
363 137 [0.52829186] 0.4991359972632646
364 10 [0.48627692] 0.4997847333013951
365 5 [0.498582

 45%|████████████████████████████████████████████████▊                                                           | 394/872 [00:02<00:03, 151.21it/s]

372 3 [0.49928276] 0.5001817227125277
373 103 [0.52405657] 0.5002482835864912
374 53 [0.51071435] 0.49878653670415124
375 81 [0.51712497] 0.4994186474685724
376 38 [0.50929145] 0.4997007092269378
377 173 [0.53468317] 0.49806053320927257
378 168 [0.5256067] 0.49663671637245
379 134 [0.52454063] 0.49864289811364404
380 4 [0.49875162] 0.5003089282768673
381 39 [0.50714799] 0.4993963736317168
382 23 [0.5048463] 0.4993433135335248
383 66 [0.51299988] 0.49775411009202775
384 94 [0.51795119] 0.49819723046322306
385 38 [0.50829666] 0.49930652263450964
386 113 [0.5191599] 0.49809577869506094
387 144 [0.53104762] 0.499837021705761
388 202 [0.54078035] 0.4990761595153272
389 109 [0.53033576] 0.5013727207954547
390 54 [0.51237119] 0.4977835128100338
391 133 [0.54706927] 0.5011873519615374
392 75 [0.51670308] 0.4992120784590644
393 53 [0.51187559] 0.4994332770812201
394 17 [0.49400151] 0.5004467735653002
395 28 [0.51027801] 0.4994001032999895
396 133 [0.5298174] 0.49939558756165436
397 49 [0.485800

 49%|████████████████████████████████████████████████████▊                                                       | 426/872 [00:02<00:02, 150.84it/s]

404 36 [0.48946679] 0.5016115898995583
405 32 [0.50736803] 0.4997196169107085
406 27 [0.50789548] 0.49961531958376504
407 51 [0.50920462] 0.49893960042499136
408 36 [0.51025098] 0.4979990996327929
409 33 [0.50787455] 0.49941597083408856
410 189 [0.54019243] 0.4993793418680645
411 121 [0.52515829] 0.49959045429121923
412 36 [0.50771446] 0.49917038413240256
413 105 [0.52300686] 0.49917504373053206
414 68 [0.52747095] 0.5002915530609104
415 203 [0.54360243] 0.49871741415502113
416 95 [0.51930526] 0.49946203562823227
417 30 [0.50724524] 0.49935329107900983
418 44 [0.50889911] 0.4987647710019022
419 86 [0.5185595] 0.49837769977407437
420 72 [0.53123342] 0.5003606884137065
421 145 [0.53145813] 0.49759873367046414
422 39 [0.49128797] 0.501789424507667
423 121 [0.52303895] 0.49751128115317467
424 22 [0.49587054] 0.5009822563158556
425 28 [0.50603181] 0.49940213435347347
426 105 [0.52050453] 0.49899727362328355
427 69 [0.51916814] 0.4992071244443629
428 130 [0.53193199] 0.49933913041222744
429 

 53%|████████████████████████████████████████████████████████▋                                                   | 458/872 [00:03<00:02, 152.09it/s]

434 187 [0.54882555] 0.5020995859768014
435 80 [0.51316443] 0.49746035036776337
436 79 [0.511042] 0.49750155467743695
437 36 [0.49219397] 0.501370726408422
438 29 [0.50617466] 0.49907833615971886
439 83 [0.51400211] 0.49877510060453156
440 119 [0.53774785] 0.5011907841449811
441 119 [0.52204353] 0.4981833636217885
442 6 [0.49827551] 0.5004675220358789
443 74 [0.48681946] 0.5038466333608642
444 72 [0.51967379] 0.4982794159890194
445 14 [0.49687311] 0.5004560346123089
446 7 [0.49808465] 0.5004819797937379
447 2 [0.49953312] 0.5002262255129409
448 39 [0.50710608] 0.49858732389296995
449 111 [0.52186059] 0.4986870253482389
450 41 [0.51363252] 0.49986236626194713
451 127 [0.51877688] 0.49784791952076585
452 6 [0.49819489] 0.5001898858208239
453 85 [0.5149264] 0.49840516292119275
454 145 [0.45212325] 0.5068930486610415
455 70 [0.51290233] 0.4993867442132707
456 3 [0.49894873] 0.5005495804319788
457 59 [0.48846094] 0.5023028153755996
458 51 [0.48363809] 0.5010760359868496
459 75 [0.52981696] 

 56%|████████████████████████████████████████████████████████████▊                                               | 491/872 [00:03<00:02, 153.17it/s]

467 21 [0.50715998] 0.4996483395010959
468 88 [0.51873855] 0.49940044223995256
469 11 [0.5026779] 0.4997273190368705
470 64 [0.48109665] 0.5017376818520723
471 52 [0.47714568] 0.5025205458082203
472 28 [0.49190233] 0.5004353697380681
473 122 [0.53416007] 0.49994385356740345
474 79 [0.51726004] 0.4994113696944812
475 105 [0.51651757] 0.49735792868995016
476 118 [0.52575054] 0.49785475663781836
477 19 [0.50441905] 0.49903854905538403
478 171 [0.53532678] 0.49925367958379685
479 63 [0.48672652] 0.503423821787667
480 32 [0.5060445] 0.49933633394811094
481 11 [0.49361123] 0.5002733112024771
482 279 [0.42812651] 0.5082639381227901
483 10 [0.49775913] 0.5005816405432717
484 123 [0.52601038] 0.49902948137931546
485 20 [0.4927985] 0.501724873408743
486 18 [0.48881768] 0.5004261626728531
487 87 [0.52080489] 0.4990675523237806
488 162 [0.4593432] 0.5082997624224032
489 114 [0.52338601] 0.49967866573737807
490 8 [0.50374039] 0.49965893976478987
491 133 [0.5292392] 0.5003166081014735
492 48 [0.5091

 60%|████████████████████████████████████████████████████████████████▉                                           | 524/872 [00:03<00:02, 156.57it/s]

499 72 [0.51589281] 0.4980507863687806
500 52 [0.48704198] 0.5015103910577305
501 74 [0.51201387] 0.49808549289859083
502 169 [0.54338334] 0.5007111893294955
503 91 [0.51878692] 0.4987841831460023
504 81 [0.48446896] 0.5028587973955659
505 155 [0.5405622] 0.4992915129865693
506 19 [0.50347686] 0.49925064091429894
507 169 [0.52930102] 0.49752067535578287
508 21 [0.50387543] 0.4993129256136307
509 1 [0.49975074] 0.5001599188123491
510 101 [0.52482588] 0.4988394805069622
511 75 [0.48334531] 0.5030259941770389
512 143 [0.5413031] 0.5013595904545897
513 20 [0.49255896] 0.5003749438223292
514 175 [0.53003287] 0.49760273627429064
515 2 [0.49959681] 0.5001992801440301
516 59 [0.48731071] 0.5026921905646545
517 3 [0.50088586] 0.49987615540446134
518 32 [0.48583237] 0.501611737291609
519 100 [0.52591761] 0.5000571254617235
520 17 [0.5042595] 0.4994853873676466
521 104 [0.47348559] 0.5015553595895333
522 51 [0.51844301] 0.4996758769662985
523 15 [0.50415208] 0.4997852419477861
524 10 [0.49712396]

 64%|████████████████████████████████████████████████████████████████████▊                                       | 556/872 [00:03<00:02, 156.18it/s]

531 139 [0.54113004] 0.500461910727532
532 10 [0.49759773] 0.5005795890142308
533 37 [0.49262371] 0.5015586052770651
534 4 [0.50129961] 0.4996558410905933
535 136 [0.47177089] 0.5049920414555518
536 66 [0.51130304] 0.49860915185333554
537 187 [0.54200625] 0.499231376555366
538 16 [0.49430924] 0.5002785812064499
539 32 [0.49082858] 0.5016471318721636
540 19 [0.49200184] 0.5009539344330468
541 16 [0.5047177] 0.49961382292640516
542 32 [0.5063403] 0.49938224203394677
543 22 [0.51779971] 0.499439138756012
544 76 [0.51233241] 0.49773478994059955
545 44 [0.52854976] 0.4999193981230818
546 72 [0.51412637] 0.49897119115716737
547 118 [0.52555368] 0.499015933073655
548 2 [0.50063338] 0.499668702760613
549 26 [0.4925483] 0.5008976110685474
550 183 [0.54018504] 0.5005264113327069
551 30 [0.49292497] 0.5006239377272744
552 43 [0.4921076] 0.5019834993433275
553 71 [0.51667602] 0.4987256115602517
554 82 [0.47145571] 0.5032870634681181
555 98 [0.5193165] 0.4973335069032646
556 156 [0.45861449] 0.5087

 68%|█████████████████████████████████████████████████████████████████████████                                   | 590/872 [00:03<00:01, 159.10it/s]

564 31 [0.50684727] 0.4986452368413812
565 104 [0.52782069] 0.500289978118869
566 23 [0.48270498] 0.5005264820522077
567 59 [0.51190271] 0.499203321364123
568 111 [0.46605173] 0.504346554053495
569 126 [0.52478336] 0.4985691112756455
570 131 [0.52896496] 0.4991861449975535
571 172 [0.53873217] 0.49960744225753234
572 40 [0.51251049] 0.49891808733100207
573 61 [0.51713584] 0.4991634164681645
574 4 [0.49867081] 0.5001668327287265
575 152 [0.52974159] 0.49858480258547583
576 151 [0.52948718] 0.4981050347068182
577 133 [0.52729103] 0.4979411296655655
578 68 [0.51390139] 0.49876673854065856
579 69 [0.48674325] 0.5026095990234275
580 26 [0.50581672] 0.49974625332042527
581 4 [0.49902305] 0.5004586895764289
582 123 [0.47447839] 0.5024362278879004
583 8 [0.49830277] 0.5003967848791439
584 144 [0.44972536] 0.5030115877377469
585 46 [0.48782779] 0.5011776897460454
586 95 [0.52099998] 0.49884471068109115
587 19 [0.49539831] 0.501142907822067
588 42 [0.48960789] 0.5019763285906401
589 43 [0.491715

 71%|█████████████████████████████████████████████████████████████████████████████▏                              | 623/872 [00:04<00:01, 160.90it/s]

597 188 [0.54359591] 0.49948033652240664
598 14 [0.49727723] 0.5006035294105393
599 60 [0.51523617] 0.49953576974683517
600 72 [0.51537833] 0.4986475964297591
601 10 [0.49743129] 0.5002278023219074
602 15 [0.49702675] 0.5008960964076932
603 98 [0.47720608] 0.503777028891651
604 6 [0.50156942] 0.4997439998981174
605 1 [0.50029158] 0.4997927076261388
606 30 [0.49059345] 0.5006666312904348
607 152 [0.45918554] 0.5046772794797121
608 106 [0.52003467] 0.4985171447894577
609 53 [0.52478255] 0.5003508252129679
610 116 [0.52755408] 0.5001647036672224
611 59 [0.51231141] 0.49890096497564895
612 4 [0.49835465] 0.5003365638264776
613 71 [0.51368659] 0.49892829274574196
614 43 [0.48275049] 0.5002518282639918
615 26 [0.50543041] 0.49931173781865723
616 9 [0.50213578] 0.49978637912495283
617 7 [0.49621157] 0.5001023568776662
618 117 [0.52104549] 0.49795346805841967
619 15 [0.49603139] 0.5007394628068875
620 121 [0.5266172] 0.4994138537345116
621 53 [0.51072869] 0.49856518116413145
622 222 [0.5439915

 75%|█████████████████████████████████████████████████████████████████████████████████▍                          | 658/872 [00:04<00:01, 163.46it/s]

633 45 [0.48689307] 0.5020393687853205
634 57 [0.5130664] 0.49864866426620413
635 167 [0.52953206] 0.498577995689626
636 12 [0.49664239] 0.5007857472620344
637 70 [0.48391655] 0.5020414984471525
638 98 [0.52428126] 0.49759380466476627
639 21 [0.50561372] 0.4997169480544724
640 67 [0.47837634] 0.5028067627585874
641 19 [0.51798789] 0.49945617242238377
642 11 [0.50251313] 0.49949907049164427
643 152 [0.54144118] 0.5009653393278465
644 36 [0.51002047] 0.4996188040681253
645 9 [0.49732435] 0.5006237705194696
646 84 [0.51848424] 0.49891616378840553
647 52 [0.48904897] 0.5027648114667554
648 145 [0.52806364] 0.49772469664948016
649 154 [0.52904038] 0.4988260216743615
650 82 [0.51427486] 0.4990873994373284
651 19 [0.4956684] 0.5005353766554524
652 43 [0.48821852] 0.5012761641077731
653 56 [0.50927367] 0.49891491032297053
654 58 [0.513065] 0.4985122958763772
655 37 [0.50730105] 0.49918371437245096
656 36 [0.50969085] 0.49935498903346964
657 114 [0.52064289] 0.4988993823077753
658 34 [0.5061077

 79%|█████████████████████████████████████████████████████████████████████████████████████▋                      | 692/872 [00:04<00:01, 160.88it/s]

665 9 [0.50234599] 0.4996062498105988
666 14 [0.50425158] 0.4994954054576972
667 115 [0.51794073] 0.49782595417085723
668 84 [0.48336978] 0.5043600021155218
669 170 [0.46023446] 0.5068793002188801
670 75 [0.51373297] 0.49817447026811856
671 8 [0.49817932] 0.500458722894663
672 40 [0.49274439] 0.5017759559533136
673 2 [0.49957416] 0.5001016997036764
674 81 [0.51621087] 0.49972974679721266
675 12 [0.50426226] 0.4998558499648978
676 83 [0.51832957] 0.49862747604346436
677 140 [0.52237779] 0.49671735189913513
678 8 [0.50261099] 0.4996870016072006
679 10 [0.5020628] 0.49967696930288374
680 87 [0.51638719] 0.49844003654067126
681 38 [0.49214302] 0.5020001771728432
682 84 [0.48194999] 0.5024688545664296
683 1 [0.49995154] 0.5009219239335104
684 16 [0.50423058] 0.4996280622427633
685 122 [0.52602478] 0.49801867738991623
686 4 [0.49846931] 0.5002174748460516
687 133 [0.52865279] 0.49945889996007203
688 109 [0.52856271] 0.49996195031309354
689 168 [0.53936609] 0.5006815651087875
690 55 [0.508646

 83%|█████████████████████████████████████████████████████████████████████████████████████████▊                  | 725/872 [00:04<00:00, 157.55it/s]

698 99 [0.52520568] 0.49937493540580646
699 79 [0.51554702] 0.49938279581914635
700 113 [0.51983017] 0.4984033094011386
701 9 [0.5041083] 0.49992287230534715
702 91 [0.52409505] 0.4985654757027474
703 52 [0.52096711] 0.49994038963226106
704 14 [0.50354695] 0.4995599723298382
705 178 [0.53563915] 0.5003187023691478
706 2 [0.49880438] 0.50065833127847
707 24 [0.50688078] 0.4991462268122315
708 32 [0.49265162] 0.5013448896671289
709 32 [0.48838134] 0.5009890590674282
710 81 [0.51527654] 0.49783137933762117
711 154 [0.46905665] 0.5098704207625219
712 28 [0.49177484] 0.5007473657180755
713 30 [0.5063836] 0.49952097974349796
714 63 [0.53374602] 0.49950979567625586
715 40 [0.50791512] 0.49900108387006864
716 253 [0.54752114] 0.49869121573558783
717 47 [0.5134031] 0.49953109879551183
718 36 [0.50839314] 0.49942088829206144
719 49 [0.48955573] 0.502537552425506
720 52 [0.48320242] 0.5012912822778026
721 10 [0.50260817] 0.4997759325970361
722 12 [0.50280805] 0.4997634008047893
723 152 [0.5281094

 87%|█████████████████████████████████████████████████████████████████████████████████████████████▉              | 758/872 [00:04<00:00, 158.07it/s]

729 81 [0.53822072] 0.5013874981558923
730 82 [0.47497425] 0.5018699910624423
731 86 [0.48241199] 0.502151847890621
732 86 [0.51649548] 0.4991190145188079
733 123 [0.52266776] 0.49870720214294695
734 12 [0.50266666] 0.49954795366190147
735 46 [0.46760659] 0.49998983906423583
736 93 [0.5203812] 0.4993710280756974
737 47 [0.48662686] 0.5023236249140801
738 3 [0.49948592] 0.5003463710441539
739 79 [0.51664614] 0.4994893109844054
740 41 [0.50586961] 0.498913293845184
741 69 [0.52014777] 0.49933004029248657
742 73 [0.517141] 0.49910258992315065
743 11 [0.49640612] 0.5008368264317217
744 125 [0.52464745] 0.49855977217919617
745 36 [0.50686767] 0.4994462677042794
746 55 [0.50892832] 0.4981946296202878
747 123 [0.52104163] 0.4988116913458842
748 18 [0.49477362] 0.500300641957255
749 19 [0.50862474] 0.4995416863075264
750 16 [0.49453661] 0.5006972138616631
751 60 [0.50927801] 0.4989381600029102
752 185 [0.54405715] 0.5009310304883435
753 45 [0.48811633] 0.5021839320510139
754 8 [0.50486538] 0.4

 91%|██████████████████████████████████████████████████████████████████████████████████████████████████          | 792/872 [00:05<00:00, 162.10it/s]

763 7 [0.49771689] 0.5004670674017477
764 120 [0.53696948] 0.5001159374698441
765 102 [0.51535593] 0.49775096474658076
766 152 [0.55096779] 0.5009884145858783
767 9 [0.50215681] 0.4996016435020264
768 110 [0.52129346] 0.4985622519589018
769 27 [0.49446228] 0.5008945203507423
770 65 [0.51207279] 0.49872418193769447
771 23 [0.5045082] 0.49909015582635335
772 39 [0.49140623] 0.5015524462917952
773 158 [0.53468421] 0.4988553117423259
774 2 [0.50100915] 0.49981076985825285
775 9 [0.49754821] 0.5003113055742279
776 183 [0.53944807] 0.5000361540428557
777 127 [0.52576154] 0.49841965903150903
778 30 [0.50646527] 0.4995626209463503
779 9 [0.49831156] 0.5004072127669476
780 114 [0.52202664] 0.4992897730339088
781 90 [0.51587242] 0.4983584440286074
782 116 [0.52007806] 0.49811910196071413
783 36 [0.50688674] 0.49834497973971403
784 48 [0.49104236] 0.5025560103339362
785 81 [0.4829406] 0.5038530490848978
786 64 [0.51305448] 0.49919232974876865
787 13 [0.50366888] 0.49946798248660135
788 14 [0.5057

 95%|██████████████████████████████████████████████████████████████████████████████████████████████████████▎     | 826/872 [00:05<00:00, 159.32it/s]

796 23 [0.49579235] 0.5007965203377723
797 18 [0.5034109] 0.4994232266732571
798 34 [0.49217166] 0.5013110564675807
799 70 [0.51486559] 0.4977392729591154
800 84 [0.52957407] 0.4998431069861134
801 59 [0.5142417] 0.49840367968654714
802 6 [0.49864909] 0.5004476404739685
803 27 [0.50548391] 0.499216896634198
804 124 [0.52341965] 0.4981743464699546
805 86 [0.51624386] 0.4984218834704816
806 72 [0.51837088] 0.4992362218951683
807 62 [0.51666399] 0.49928434186306336
808 32 [0.49219849] 0.5012921410040784
809 60 [0.5125113] 0.49923172175166974
810 33 [0.50652613] 0.49886179258637714
811 13 [0.52371474] 0.499066282596696
812 19 [0.47167874] 0.49938916587499965
813 43 [0.46288734] 0.4987150970077151
814 104 [0.5193775] 0.498312007614997
815 12 [0.50308983] 0.49975867086641745
816 89 [0.47509798] 0.5029874485717072
817 55 [0.5129302] 0.4994684999776915
818 28 [0.49331485] 0.5007168096392022
819 105 [0.51803851] 0.49774025855480614
820 51 [0.48788404] 0.5013586979152915
821 3 [0.50086644] 0.499

 99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 860/872 [00:05<00:00, 162.10it/s]

828 210 [0.54441463] 0.49723590687249153
829 84 [0.51672596] 0.4984766246746404
830 29 [0.51946619] 0.5004252450319798
831 44 [0.48745573] 0.5013577033461906
832 2 [0.50029319] 0.4997337486407822
833 208 [0.53678411] 0.49608175743709715
834 67 [0.51516344] 0.4994227024098077
835 24 [0.49340464] 0.5009196739304517
836 197 [0.44964339] 0.5058720769222067
837 16 [0.50303855] 0.4996356315385461
838 8 [0.49844626] 0.5003017698262878
839 29 [0.49100805] 0.5005887705889881
840 14 [0.49269979] 0.5006048105909631
841 28 [0.49224209] 0.5009162402168005
842 40 [0.49119474] 0.5010105528812432
843 75 [0.46490492] 0.4985638703026554
844 13 [0.49650564] 0.5003531796516806
845 126 [0.52291681] 0.49860086272908366
846 61 [0.51440413] 0.4994644450826649
847 14 [0.50297998] 0.499576671637271
848 45 [0.51241278] 0.5001918362787638
849 116 [0.47251305] 0.5026801708551717
850 97 [0.51852405] 0.4969270076343028
851 73 [0.47651716] 0.5009058398934219
852 2 [0.49919893] 0.5002201330771036
853 23 [0.50485461] 0

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 872/872 [00:05<00:00, 154.91it/s]

862 31 [0.50727646] 0.4993479737515096
863 154 [0.45015635] 0.5089639351643336
864 28 [0.50770905] 0.4993195831403888
865 77 [0.48459131] 0.5033954519981139
866 26 [0.50620115] 0.4990647806082188
867 123 [0.52284113] 0.49614285378999384
868 78 [0.4830456] 0.5033466379776133
869 35 [0.49325439] 0.5012591641927812
870 137 [0.52907722] 0.49864290021926855
871 10 [0.50165835] 0.499411554389577





ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (872,) + inhomogeneous part.

In [65]:
new_predictions = new_predictions.reshape((872, 1))
np.sum(((new_predictions > 0.5) != (pred >0.5))) / 872

np.float64(0.9105504587155964)