In [1]:
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import pandas as pd

from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.linear_model import LogisticRegression
from transformers import AutoModel, AutoTokenizer
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from lightning.classification import LinearSVC
import torch

In [2]:
train_df = pd.read_csv('../data/reviews_with_embeddings_sample_train.csv')
test_df = pd.read_csv('../data/reviews_with_embeddings_sample_test.csv')

In [3]:
train_df.head()

Unnamed: 0,title,abstract,doi,review,references_count,references_contradicted_total,references_contradicted_avg,references_mentioned_total,references_mentioned_total_avg,references_supported_total,...,citations_to_references_total,citations_to_references_total_avg,contradiction_percentage_avg,mentioning_percentage_avg,supporting_percentage_avg,contradiction_to_supporting_ratio_avg,contradiction_to_supporting_contradiction_ratio_avg,supporting_to_supporting_contradiction_ratio_avg,title_abstract,embeddings
0,On “Bettering Humanity” in Science and Enginee...,Authors such as Krishnasamy Selvan argue that ...,10.1007/s11948-007-9014-9,0,1,0,0.0,12,12.0,0,...,6,6.0,0.076923,1.0,0.076923,1.0,1.0,1.0,On “Bettering Humanity” in Science and Enginee...,[-1.61832273e+00 2.50268847e-01 5.84022641e-...
1,X-ray Scattering Studies of Methylophilus meth...,Small angle x-ray solution scattering has bee...,10.1074/jbc.m001564200,0,34,6,0.176471,7233,212.735294,218,...,8075,237.5,0.030342,0.933643,0.084957,0.411699,0.394052,0.964706,X-ray Scattering Studies of Methylophilus meth...,[-1.15774763e+00 -1.06575049e-01 3.78147513e-...
2,Administration of dexmedetomidine alone during...,We report the clinical management of 2 adults ...,10.1007/s00540-011-1174-8,0,20,22,1.1,3198,159.9,172,...,4228,211.4,0.048151,0.935298,0.063496,0.695398,0.536021,0.884608,Administration of dexmedetomidine alone during...,[-6.61146998e-01 3.08037162e-01 7.32850969e-...
3,Evaluation of Tau Imaging in Staging Alzheimer...,,10.1001/jamaneurol.2016.2078,0,50,244,4.88,69683,1393.66,3961,...,84626,1692.52,0.011087,0.917848,0.072126,0.157045,0.147288,0.942722,Evaluation of Tau Imaging in Staging Alzheimer...,[-4.46374565e-02 1.95754838e+00 1.34340003e-...
4,Reward and Affective Regulation in Depression-...,Background There is a disproportionately hig...,10.1016/j.biopsych.2014.04.018,0,126,280,2.222222,51947,412.277778,1804,...,61032,484.380952,0.028922,0.906765,0.077215,0.421393,0.31818,0.856201,Reward and Affective Regulation in Depression-...,[-2.79321611e-01 1.04500818e+00 2.65802503e-...


In [4]:
test_df.columns

Index(['title', 'abstract', 'doi', 'review', 'references_count',
       'references_contradicted_total', 'references_contradicted_avg',
       'references_mentioned_total', 'references_mentioned_total_avg',
       'references_supported_total', 'references_supported_avg',
       'in_text_citations_to_references_total',
       'in_text_citations_to_references_total_avg',
       'citations_to_references_total', 'citations_to_references_total_avg',
       'contradiction_percentage_avg', 'mentioning_percentage_avg',
       'supporting_percentage_avg', 'contradiction_to_supporting_ratio_avg',
       'contradiction_to_supporting_contradiction_ratio_avg',
       'supporting_to_supporting_contradiction_ratio_avg', 'title_abstract',
       'embeddings'],
      dtype='object')

In [5]:
INPUT_COL_NAMES = [
   'references_count',
   'references_contradicted_total', 'references_contradicted_avg',
   'references_mentioned_total', 'references_mentioned_total_avg',
   'references_supported_total', 'references_supported_avg',
   'in_text_citations_to_references_total',
   'in_text_citations_to_references_total_avg',
   'citations_to_references_total', 'citations_to_references_total_avg',
   'contradiction_percentage_avg', 'mentioning_percentage_avg',
   'supporting_percentage_avg', 'contradiction_to_supporting_ratio_avg',
   'contradiction_to_supporting_contradiction_ratio_avg',
   'supporting_to_supporting_contradiction_ratio_avg',
   
]
EMB_COL_NAME = 'embeddings'
TARGET_COL_NAME = 'review'

In [6]:
estimator = LinearSVC(loss="squared_hinge", random_state=42, verbose=1)
Cs = np.logspace(-4, 2, 7)
model = GridSearchCV(estimator=estimator, cv=3, param_grid={'C': Cs}, verbose=1, n_jobs=2)

In [7]:
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=17)

## Training and Train Set Predictions

In [8]:
test_df['embeddings'] = test_df['embeddings'].apply(
    lambda x: np.fromstring(x[1:-1], dtype=np.float64, sep=' ')
)

In [9]:
train_df['embeddings'] = train_df['embeddings'].apply(
    lambda x: np.fromstring(x[1:-1], dtype=np.float64, sep=' ')
)

In [10]:
def normalize(v):
    norm=np.linalg.norm(v)
    if norm==0:
        norm=np.finfo(v.dtype).eps
    return v/norm

acc_scores, prec_scores, cv_f1_scores, recall_scores = [], [], [], []
skf_split_generator = skf.split(X=train_df[INPUT_COL_NAMES], y=train_df[TARGET_COL_NAME])

for fold_id, (train_idx, val_idx) in enumerate(skf_split_generator):
    curr_train_df = train_df.iloc[train_idx]
    curr_val_df = train_df.iloc[val_idx]
    X = np.concatenate((normalize(curr_train_df[INPUT_COL_NAMES].to_numpy()), np.stack(curr_train_df[EMB_COL_NAME])), axis=1)
    model.fit(X, curr_train_df[TARGET_COL_NAME])
    
    # making predictions for the current validation set
    X = np.concatenate((normalize(curr_val_df[INPUT_COL_NAMES].to_numpy()), np.stack(curr_val_df[EMB_COL_NAME])), axis=1)
    
    curr_preds = model.predict(X)
    curr_f1 = f1_score(y_true=curr_val_df[TARGET_COL_NAME], y_pred=curr_preds)
    curr_acc = accuracy_score(y_true=curr_val_df[TARGET_COL_NAME], y_pred=curr_preds)
    curr_prec = precision_score(y_true=curr_val_df[TARGET_COL_NAME], y_pred=curr_preds)
    curr_recall = recall_score(y_true=curr_val_df[TARGET_COL_NAME], y_pred=curr_preds)
    cv_f1_scores.append(curr_f1)
    acc_scores.append(curr_acc)
    prec_scores.append(curr_prec)
    recall_scores.append(curr_recall)
    print(f"F1-score for fold {fold_id} is {curr_f1:.3}. Accuracy is {curr_acc:.3}. Precision is {curr_prec:.3}. Recall is {curr_recall:.3}.")
    
print(f'Average cross-validation F1-score is {np.mean(cv_f1_scores):.3} +/- {np.std(cv_f1_scores):.3}.')
print(f'Average cross-validation ACC is {np.mean(acc_scores):.3} +/- {np.std(acc_scores):.3}.')
print(f'Average cross-validation Prec is {np.mean(prec_scores):.3} +/- {np.std(prec_scores):.3}.')
print(f'Average cross-validation Recall is {np.mean(recall_scores):.3} +/- {np.std(recall_scores):.3}.')


Fitting 3 folds for each of 7 candidates, totalling 21 fits




iter 1 6.534921880062737 (80000)
iter 2 9.059277572226325 (80000)
iter 3 8.043859375900553 (79942)
iter 4 7.919211578693272 (79901)
iter 5 7.555323466388444 (79825)
iter 6 7.733890843384418 (79760)
iter 7 7.993602854268911 (79750)
iter 8 7.341711106962665 (79745)
iter 9 6.922085037829044 (79705)
iter 10 6.709293138449182 (79483)
iter 11 6.056236618811344 (79405)
iter 12 6.417232777069142 (79249)
iter 13 6.319404169152648 (79192)
iter 14 6.4569585311311855 (79100)
iter 15 7.246090093041161 (79066)
iter 16 6.641180373746394 (78988)
iter 17 5.881001267225317 (78916)
iter 18 6.5321398769987375 (78797)
iter 19 6.092010772827544 (78767)
iter 20 6.116154560199299 (78735)
iter 21 5.998818687659311 (78649)
iter 22 5.901182721025668 (78637)
iter 23 5.640732862676959 (78552)
iter 24 5.614473014658605 (78525)
iter 25 6.259766724326926 (78433)
iter 26 5.553244355810513 (78379)
iter 27 5.9372935522553885 (78359)
iter 28 6.25793492295155 (78346)
iter 29 5.483880235690769 (78336)
iter 30 5.44338574988

iter 238 0.7784946898146331 (35816)
iter 239 0.7272160807642656 (35675)
iter 240 0.6691282868261198 (35630)
iter 241 0.6780667570576977 (35479)
iter 242 0.7278336720655216 (35074)
iter 243 0.6718727560327925 (35053)
iter 244 0.7108593363188065 (34954)
iter 245 0.6749233727218324 (34910)
iter 246 0.6279121265515585 (34885)
iter 247 0.6481448421808808 (34753)
iter 248 0.6935319517997885 (34633)
iter 249 0.651955682205478 (34611)
iter 250 0.6530812605800932 (34518)
iter 251 0.6254841728070908 (34483)
iter 252 0.6175398547809086 (34328)
iter 253 0.6121460982257637 (34236)
iter 254 0.6117598590273563 (34095)
iter 255 0.6095906571163319 (33994)
iter 256 0.6255213783283919 (33976)
iter 257 0.6270635156629913 (33962)
iter 258 0.5748914536518521 (33937)
iter 259 0.5907059490838662 (33775)
iter 260 0.6108522788404995 (33668)
iter 261 0.641929971552745 (33570)
iter 262 0.6242279215442925 (33521)
iter 263 0.6335767719130583 (33509)
iter 264 0.6010336705906034 (33388)
iter 265 0.5379733807081424 (3

iter 465 0.08015521824173777 (26913)
iter 466 0.08170048446033401 (26910)
iter 467 0.06919027334671893 (26905)
iter 468 0.07164029691393999 (26884)
iter 469 0.07466277135624011 (26879)
iter 470 0.08011907031309165 (26877)
iter 471 0.06812085175999233 (26877)
iter 472 0.07430239132511607 (26861)
iter 473 0.06951000178592952 (26860)
iter 474 0.07030553614237822 (26858)
iter 475 0.06715634966306924 (26853)
iter 476 0.06787053283262431 (26837)
iter 477 0.06913104996441158 (26831)
iter 478 0.08263536242407787 (26828)
iter 479 0.06499322756360347 (26826)
iter 480 0.0695151537525831 (26821)
iter 481 0.07048062820400655 (26820)
iter 482 0.061244160265116476 (26817)
iter 483 0.06751744041474397 (26800)
iter 484 0.06450141430520011 (26797)
iter 485 0.06363585241754866 (26796)
iter 486 0.06647052061069608 (26793)
iter 487 0.06816378400459991 (26788)
iter 488 0.05758342167694097 (26786)
iter 489 0.057517785906851936 (26781)
iter 490 0.05972636144050837 (26772)
iter 491 0.06577352096719952 (26769)


iter 685 0.009383539580705785 (26322)
iter 686 0.010295413994300347 (26322)
iter 687 0.009714325946211425 (26316)
iter 688 0.009477091286305084 (26316)
iter 689 0.008908372521079788 (26316)
iter 690 0.00925496636073525 (26315)
iter 691 0.009592521458888031 (26315)
iter 692 0.009075933558939542 (26315)
iter 693 0.00916734795432085 (26314)
iter 694 0.009619378816555835 (26313)
iter 695 0.009607857929306074 (26313)
iter 696 0.009660316608923025 (26313)
iter 697 0.008372655262737366 (26312)
iter 698 0.007466036869200332 (26310)
iter 699 0.008801393488083148 (26309)
iter 700 0.008301597442580522 (26309)
iter 701 0.008207955351612228 (26306)
iter 702 0.008543378923834055 (26306)
iter 703 0.009871886183875467 (26306)
iter 704 0.00790984262128501 (26306)
iter 705 0.008565959860801553 (26306)
iter 706 0.0078013230810541345 (26306)
iter 707 0.00784449473193137 (26306)
iter 708 0.008259549909430217 (26306)
iter 709 0.006960711232521163 (26306)
iter 710 0.007221953930973024 (26306)
iter 711 0.0082

iter 902 0.0012331274932074476 (26248)
iter 903 0.0013185892213052192 (26248)
iter 904 0.0012886155845945324 (26247)
iter 905 0.0012880657116704497 (26247)
iter 906 0.0010805636976614075 (26247)
iter 907 0.001339229944624723 (26247)
iter 908 0.0012186678554824137 (26247)
iter 909 0.001383405193406205 (26247)
iter 910 0.001298306099476898 (26247)
iter 911 0.0011506350452686674 (26247)
iter 912 0.0013487628391905399 (26247)
iter 913 0.0010775064753113122 (26247)
iter 914 0.001188532156591926 (26246)
iter 915 0.0009979813880720417 (26246)
iter 916 0.001182315793361477 (80000)
iter 917 0.0011068732257517588 (26254)
iter 918 0.001032241407999429 (26251)
iter 919 0.0011523575931504243 (26251)
iter 920 0.0010429263178993193 (26250)
iter 921 0.0011019541792760545 (26250)
iter 922 0.0009622964197915362 (26250)
iter 923 0.0010056698729943164 (80000)
iter 924 0.001067336298589766 (26251)
iter 925 0.001081304158047136 (26244)
iter 926 0.0010944276949987064 (26244)
iter 927 0.001215575219017273 (26



iter 1 5.961783610058087 (80000)
iter 2 9.528923360345686 (80000)
iter 3 8.575661198098812 (79993)
iter 4 8.348934315617656 (79981)
iter 5 7.910638027245614 (79921)
iter 6 7.915459542338535 (79914)
iter 7 6.650949219666263 (79823)
iter 8 6.8505019953715705 (79709)
iter 9 6.660187733325277 (79557)
iter 10 6.592473157976647 (79426)
iter 11 7.317213811101887 (79387)
iter 12 6.990598814908232 (79373)
iter 13 6.591280562324002 (79358)
iter 14 6.478409964299079 (79304)
iter 15 6.317150321108942 (79275)
iter 16 6.0530104836443375 (79261)
iter 17 6.0910168994320495 (79183)
iter 18 6.128357920513176 (79100)
iter 19 5.72013087641969 (79007)
iter 20 6.555784216137027 (78871)
iter 21 6.767611063249166 (78791)
iter 22 5.88831326329065 (78779)
iter 23 5.6169057430905625 (78742)
iter 24 5.290619548408873 (78589)
iter 25 5.526448599686046 (78207)
iter 26 5.439996005408669 (78121)
iter 27 5.639009851238631 (78080)
iter 28 5.774061874947076 (78053)
iter 29 5.674476488937637 (78041)
iter 30 5.36012708328

iter 235 0.7237540218895677 (36906)
iter 236 0.7471410583071201 (36633)
iter 237 0.7135030525093458 (36600)
iter 238 0.7842203102187688 (36495)
iter 239 0.6671552664699292 (36291)
iter 240 0.6452389120908675 (36089)
iter 241 0.8046825377862536 (35858)
iter 242 0.7124007270533533 (35773)
iter 243 0.7151456840491253 (35652)
iter 244 0.7226445197296906 (35608)
iter 245 0.6545131684716146 (35567)
iter 246 0.6866080296584578 (35425)
iter 247 0.691298284019301 (35351)
iter 248 0.676935467316876 (35222)
iter 249 0.6621993477963878 (35161)
iter 250 0.6359605411907225 (35034)
iter 251 0.6694563779142025 (34931)
iter 252 0.6830402146653309 (34892)
iter 253 0.6868550069666565 (34826)
iter 254 0.5977851261560213 (34763)
iter 255 0.5277132345224955 (34623)
iter 256 0.6424448731922873 (34183)
iter 257 0.5737334359690708 (34127)
iter 258 0.6088979828519832 (33876)
iter 259 0.5443948614586755 (33807)
iter 260 0.587658689297292 (33747)
iter 261 0.5842845573647228 (33698)
iter 262 0.608678313451807 (336

iter 463 0.08411263442906958 (27027)
iter 464 0.08178921784884496 (27025)
iter 465 0.07528224236800174 (27018)
iter 466 0.07970710152862791 (27010)
iter 467 0.08414655589095571 (26989)
iter 468 0.0760084550708705 (26987)
iter 469 0.06572591482851659 (26978)
iter 470 0.07515134611363528 (26955)
iter 471 0.0795199790393195 (26953)
iter 472 0.08316210350407016 (26947)
iter 473 0.07022271439634226 (26947)
iter 474 0.07400867050254192 (26942)
iter 475 0.07875371605769364 (26941)
iter 476 0.07915818537018057 (26937)
iter 477 0.07327471074457742 (26936)
iter 478 0.07782281517627324 (26928)
iter 479 0.07417771492869621 (26926)
iter 480 0.06794134915027567 (26919)
iter 481 0.06604058714384667 (26916)
iter 482 0.07110498319215655 (26900)
iter 483 0.06180848554700247 (26899)
iter 484 0.06652219930198655 (26880)
iter 485 0.06900223078695414 (26874)
iter 486 0.06857959462255916 (26867)
iter 487 0.07185507729121435 (26852)
iter 488 0.06858971667584411 (26848)
iter 489 0.07272058061787709 (26848)
ite

iter 683 0.010440787633983428 (26396)
iter 684 0.012013729547143082 (26395)
iter 685 0.010654544133557431 (26394)
iter 686 0.011271141435355458 (26394)
iter 687 0.008733771799699148 (26394)
iter 688 0.008964019996405115 (26392)
iter 689 0.012534791907341747 (26392)
iter 690 0.010622378914075147 (26392)
iter 691 0.010547504069876923 (26390)
iter 692 0.010166632488058568 (26390)
iter 693 0.009732735943146609 (26389)
iter 694 0.009605139662670625 (26387)
iter 695 0.009889338548426538 (26386)
iter 696 0.010497305939903445 (26386)
iter 697 0.009075471325973083 (26385)
iter 698 0.00999956859168899 (26385)
iter 699 0.008375996036021516 (26384)
iter 700 0.008748194249818764 (26383)
iter 701 0.009555416394557648 (26383)
iter 702 0.0090183766149603 (26383)
iter 703 0.007394115404984414 (26383)
iter 704 0.00863285673324439 (26376)
iter 705 0.008796190644682997 (26374)
iter 706 0.008351108468691748 (26373)
iter 707 0.008262400199331166 (26373)
iter 708 0.007276545663534181 (26372)
iter 709 0.00904

iter 898 0.001444243620199348 (26319)
iter 899 0.0013707141210926171 (26319)
iter 900 0.0013859811944702816 (26319)
iter 901 0.0014319250519910787 (26319)
iter 902 0.0012906936282780362 (26319)
iter 903 0.0013538784143293514 (26319)
iter 904 0.0015790169381074137 (26319)
iter 905 0.0013507689073704032 (26319)
iter 906 0.001321183042423707 (26319)
iter 907 0.001451778891556496 (26319)
iter 908 0.0012257121494782808 (26319)
iter 909 0.001204442891456936 (26319)
iter 910 0.0012849606104912428 (26319)
iter 911 0.0012692661984739506 (26319)
iter 912 0.0013384978072993603 (26319)
iter 913 0.0011130985643393217 (26319)
iter 914 0.0012621610376936602 (26319)
iter 915 0.0012665493393568261 (26319)
iter 916 0.0011592297131864937 (26319)
iter 917 0.0012702370533042018 (26319)
iter 918 0.0012144936763245906 (26318)
iter 919 0.0012344789375630105 (26318)
iter 920 0.0011490716461724226 (26318)
iter 921 0.0011376521967076322 (26318)
iter 922 0.0011379281470910718 (26318)
iter 923 0.001080893753517164

iter 189 1.254609422932119 (44044)
iter 190 1.2557751698473878 (43827)
iter 191 1.2281664762121915 (43810)
iter 192 1.1382561896929513 (43786)
iter 193 1.1573318837993405 (43553)
iter 194 1.1227684636372497 (43477)
iter 195 1.1347343075257774 (43363)
iter 196 1.065204239571004 (43169)
iter 197 1.1708457069273877 (42600)
iter 198 1.0335186368598244 (42436)
iter 199 1.1037257769893227 (42050)
iter 200 1.190602360734657 (41987)
iter 201 1.0791227596934747 (41893)
iter 202 1.1522366346892727 (41715)
iter 203 0.9563373804064477 (41670)
iter 204 1.0248557669825733 (41311)
iter 205 1.0252922891705198 (41243)
iter 206 1.0651624503266617 (41073)
iter 207 0.9896421935588282 (40971)
iter 208 1.0045002805094014 (40876)
iter 209 0.9788852072126292 (40588)
iter 210 0.9419273966282702 (40285)
iter 211 0.966282091388494 (39850)
iter 212 0.9453989427208165 (39521)
iter 213 0.9679661852778274 (39273)
iter 214 0.9265843353375793 (39161)
iter 215 0.9007059830860492 (38903)
iter 216 0.9895989304732018 (387

iter 417 0.11837036774822568 (27412)
iter 418 0.13295235515582574 (27394)
iter 419 0.14569062558308205 (27382)
iter 420 0.11503686335184703 (27378)
iter 421 0.11693291113570005 (27372)
iter 422 0.10706989093999467 (27363)
iter 423 0.11811655934696375 (27333)
iter 424 0.12267343987056081 (27330)
iter 425 0.13233670590978708 (27322)
iter 426 0.11350513739814673 (27307)
iter 427 0.11831703791940407 (27302)
iter 428 0.1118077670845971 (27301)
iter 429 0.10355263075636897 (27295)
iter 430 0.11124980687244233 (27268)
iter 431 0.11481348198717438 (27262)
iter 432 0.11156745592721698 (27260)
iter 433 0.10984652076268797 (27253)
iter 434 0.10334322103047833 (27252)
iter 435 0.12808740924948253 (27249)
iter 436 0.11892841003256455 (27248)
iter 437 0.10138896063698664 (27231)
iter 438 0.09353200640234602 (27214)
iter 439 0.10756734560306347 (27206)
iter 440 0.1145847772320385 (27198)
iter 441 0.10787860586545917 (27198)
iter 442 0.09389399788289865 (27185)
iter 443 0.0876824040964731 (27181)
iter

iter 639 0.015543236695267804 (26414)
iter 640 0.015340484831649267 (26409)
iter 641 0.016404384076139616 (26409)
iter 642 0.013967204574884456 (26407)
iter 643 0.013646735099063223 (26406)
iter 644 0.014066572814566039 (26404)
iter 645 0.013106827102790777 (26404)
iter 646 0.01263225099314208 (26404)
iter 647 0.013247238296157648 (26403)
iter 648 0.015003091430951243 (26402)
iter 649 0.012886832492226788 (26401)
iter 650 0.014213053204994636 (26401)
iter 651 0.013600250509477219 (26401)
iter 652 0.012856988655441826 (26399)
iter 653 0.012783282851192401 (26398)
iter 654 0.013663438536615176 (26398)
iter 655 0.013308330865193025 (26397)
iter 656 0.013237524819990916 (26395)
iter 657 0.011719637717007993 (26395)
iter 658 0.012472116965227542 (26394)
iter 659 0.012412102055668488 (26394)
iter 660 0.014403193385628121 (26394)
iter 661 0.01287163945384201 (26394)
iter 662 0.010912343094513943 (26392)
iter 663 0.013216331076601834 (26391)
iter 664 0.01087122979569943 (26391)
iter 665 0.0118

iter 856 0.0019781584354375825 (26311)
iter 857 0.0019335712164129937 (26311)
iter 858 0.002110478742335331 (26311)
iter 859 0.0018410340545739687 (26310)
iter 860 0.002233456318426169 (26310)
iter 861 0.0018923187889652582 (26310)
iter 862 0.0016737448385659714 (26310)
iter 863 0.002222527424925777 (26310)
iter 864 0.0018840302918138587 (26310)
iter 865 0.0020718932121812847 (26310)
iter 866 0.0016318166281600044 (26310)
iter 867 0.0018169553672228306 (26310)
iter 868 0.0016743631419201543 (26310)
iter 869 0.0016729960626352125 (26310)
iter 870 0.0017423339086931194 (26310)
iter 871 0.0018557849173440472 (26310)
iter 872 0.0017979151836773344 (26310)
iter 873 0.0015211782399209561 (26310)
iter 874 0.0016118189563815166 (26310)
iter 875 0.0018102158855234363 (26310)
iter 876 0.0016926133997273284 (26310)
iter 877 0.001476717600934624 (26310)
iter 878 0.0017072649147653607 (26310)
iter 879 0.0015187471702207758 (26310)
iter 880 0.0013793995061665912 (26310)
iter 881 0.001511170267396581

iter 147 1.9037917255122363 (54526)
iter 148 1.7706861052811744 (54418)
iter 149 1.8875443940667105 (54258)
iter 150 1.720435742667106 (53909)
iter 151 1.613550262315825 (53547)
iter 152 1.7507034797531882 (53179)
iter 153 1.697599341242412 (53127)
iter 154 1.715959627719309 (52962)
iter 155 1.704720965776716 (52721)
iter 156 1.5875086785669468 (52522)
iter 157 1.638790638145083 (52093)
iter 158 1.6098728535747613 (51729)
iter 159 1.6199876848295895 (51404)
iter 160 1.5994777925979138 (50869)
iter 161 1.562585609287598 (50740)
iter 162 1.5136262629966615 (50688)
iter 163 1.633478695858829 (50508)
iter 164 1.5927906305171566 (50487)
iter 165 1.6419454047306936 (50459)
iter 166 1.4950346153348628 (50347)
iter 167 1.571390602252141 (50296)
iter 168 1.441315431517535 (49996)
iter 169 1.4369689037366142 (49739)
iter 170 1.48765932347345 (49513)
iter 171 1.3784908412773667 (49313)
iter 172 1.2996377851139767 (48943)
iter 173 1.42709229418475 (48599)
iter 174 1.4103286551779235 (48353)
iter 1

iter 377 0.16410081479725078 (28017)
iter 378 0.1622682055375826 (27971)
iter 379 0.21945891756919733 (27946)
iter 380 0.16746838871767566 (27940)
iter 381 0.1794512365235505 (27919)
iter 382 0.17636032781857872 (27904)
iter 383 0.17254736747662694 (27888)
iter 384 0.16279118697932582 (27863)
iter 385 0.18775916556127842 (27842)
iter 386 0.17461909445713622 (27839)
iter 387 0.15610699431595168 (27835)
iter 388 0.1794793458406162 (27790)
iter 389 0.17181386977865076 (27783)
iter 390 0.14962486525931556 (27766)
iter 391 0.17428825867203762 (27753)
iter 392 0.1682709988129053 (27718)
iter 393 0.14652754475233604 (27716)
iter 394 0.15728404411002023 (27694)
iter 395 0.1505893827915589 (27679)
iter 396 0.142729950926921 (27660)
iter 397 0.1358887786657227 (27635)
iter 398 0.15200394979986045 (27601)
iter 399 0.14359897352317813 (27599)
iter 400 0.14211184054358394 (27580)
iter 401 0.13678624850707544 (27561)
iter 402 0.15080231519475515 (27548)
iter 403 0.14317082384544366 (27547)
iter 404 

iter 604 0.0221634874284998 (26320)
iter 605 0.02247813641093343 (26317)
iter 606 0.020471168922004068 (26317)
iter 607 0.02086871381514452 (26317)
iter 608 0.019437996821221576 (26317)
iter 609 0.02132220744999589 (26315)
iter 610 0.018187388317524346 (26313)
iter 611 0.02066620632293803 (26311)
iter 612 0.020990672221411502 (26310)
iter 613 0.016767003838166594 (26308)
iter 614 0.017567754486737946 (26307)
iter 615 0.017975621248050827 (26306)
iter 616 0.019236949943483572 (26304)
iter 617 0.01883045278354456 (26304)
iter 618 0.01713330485697223 (26301)
iter 619 0.01776540181965891 (26301)
iter 620 0.01861080879815927 (26300)
iter 621 0.016572785689951724 (26300)
iter 622 0.016334606440273805 (26298)
iter 623 0.01723187932815931 (26298)
iter 624 0.015716221252560184 (26297)
iter 625 0.015387223512454667 (26295)
iter 626 0.017199805793738587 (26294)
iter 627 0.015726725959153198 (26294)
iter 628 0.016950791104106366 (26293)
iter 629 0.01692925774526382 (26293)
iter 630 0.0154936245554

iter 822 0.002528975185146526 (26172)
iter 823 0.002345942469743456 (26172)
iter 824 0.0027591784211986603 (26172)
iter 825 0.0025683649510128415 (26171)
iter 826 0.0028714448318102326 (26170)
iter 827 0.0022953362310730793 (26170)
iter 828 0.0023328848086411613 (26169)
iter 829 0.002770777658156115 (26169)
iter 830 0.0027919295253376157 (26169)
iter 831 0.002303071761679043 (26169)
iter 832 0.0023711545858671745 (26169)
iter 833 0.002434881334911071 (26169)
iter 834 0.002352481075985541 (26169)
iter 835 0.002439642235164463 (26169)
iter 836 0.0021791098236408415 (26169)
iter 837 0.002353554540063474 (26169)
iter 838 0.002196519271405434 (26169)
iter 839 0.0021151696726751518 (26169)
iter 840 0.0022306681894516123 (26169)
iter 841 0.002059984219854219 (26169)
iter 842 0.0020444724307421974 (26169)
iter 843 0.0020909271638750726 (26169)
iter 844 0.001939805448634313 (26169)
iter 845 0.0019873864798039265 (26169)
iter 846 0.0019609001103671336 (26168)
iter 847 0.00213196986635222 (26168)

## Test Set Predictions

In [28]:
X = np.concatenate((normalize(test_df[INPUT_COL_NAMES].to_numpy()), np.stack(test_df[EMB_COL_NAME])), axis=1)
curr_preds = model.best_estimator_.predict(X)
curr_f1 = f1_score(y_true=test_df[TARGET_COL_NAME], y_pred=curr_preds)
curr_acc = accuracy_score(y_true=test_df[TARGET_COL_NAME], y_pred=curr_preds)
curr_prec = precision_score(y_true=test_df[TARGET_COL_NAME], y_pred=curr_preds)
curr_recall = recall_score(y_true=test_df[TARGET_COL_NAME], y_pred=curr_preds)
print(f"[Final Test] F1-score for fold {curr_f1:.3}. Accuracy is {curr_acc:.3}. Precision is {curr_prec:.3}. Recall is {curr_recall:.3}.")

[Final Test] F1-score for fold 0.651. Accuracy is 0.943. Precision is 0.8. Recall is 0.549.
