In [1]:
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import pandas as pd

from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.linear_model import LogisticRegression
from transformers import AutoModel, AutoTokenizer
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from lightning.classification import LinearSVC
import torch

In [2]:
train_df = pd.read_csv('../data/reviews_with_embeddings_sample_train.csv')
test_df = pd.read_csv('../data/reviews_with_embeddings_sample_test.csv')

In [3]:
train_df.head()

Unnamed: 0,title,abstract,doi,review,references_count,references_contradicted_total,references_contradicted_avg,references_mentioned_total,references_mentioned_total_avg,references_supported_total,...,citations_to_references_total,citations_to_references_total_avg,contradiction_percentage_avg,mentioning_percentage_avg,supporting_percentage_avg,contradiction_to_supporting_ratio_avg,contradiction_to_supporting_contradiction_ratio_avg,supporting_to_supporting_contradiction_ratio_avg,title_abstract,embeddings
0,On “Bettering Humanity” in Science and Enginee...,Authors such as Krishnasamy Selvan argue that ...,10.1007/s11948-007-9014-9,0,1,0,0.0,12,12.0,0,...,6,6.0,0.076923,1.0,0.076923,1.0,1.0,1.0,On “Bettering Humanity” in Science and Enginee...,[-1.61832273e+00 2.50268847e-01 5.84022641e-...
1,X-ray Scattering Studies of Methylophilus meth...,Small angle x-ray solution scattering has bee...,10.1074/jbc.m001564200,0,34,6,0.176471,7233,212.735294,218,...,8075,237.5,0.030342,0.933643,0.084957,0.411699,0.394052,0.964706,X-ray Scattering Studies of Methylophilus meth...,[-1.15774763e+00 -1.06575049e-01 3.78147513e-...
2,Administration of dexmedetomidine alone during...,We report the clinical management of 2 adults ...,10.1007/s00540-011-1174-8,0,20,22,1.1,3198,159.9,172,...,4228,211.4,0.048151,0.935298,0.063496,0.695398,0.536021,0.884608,Administration of dexmedetomidine alone during...,[-6.61146998e-01 3.08037162e-01 7.32850969e-...
3,Evaluation of Tau Imaging in Staging Alzheimer...,,10.1001/jamaneurol.2016.2078,0,50,244,4.88,69683,1393.66,3961,...,84626,1692.52,0.011087,0.917848,0.072126,0.157045,0.147288,0.942722,Evaluation of Tau Imaging in Staging Alzheimer...,[-4.46374565e-02 1.95754838e+00 1.34340003e-...
4,Reward and Affective Regulation in Depression-...,Background There is a disproportionately hig...,10.1016/j.biopsych.2014.04.018,0,126,280,2.222222,51947,412.277778,1804,...,61032,484.380952,0.028922,0.906765,0.077215,0.421393,0.31818,0.856201,Reward and Affective Regulation in Depression-...,[-2.79321611e-01 1.04500818e+00 2.65802503e-...


In [4]:
INPUT_COL_NAMES = 'embeddings'
TARGET_COL_NAME = 'review'

In [5]:
estimator = LinearSVC(loss="squared_hinge", random_state=42, verbose=1)
Cs = np.logspace(-4, 2, 7)
model = GridSearchCV(estimator=estimator, cv=3, param_grid={'C': Cs}, verbose=1, n_jobs=4)

In [6]:
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=17)

## Training and Train Set Predictions

In [7]:
test_df['embeddings'] = test_df['embeddings'].apply(
    lambda x: np.fromstring(x[1:-1], dtype=np.float64, sep=' ')
)

In [8]:
train_df['embeddings'] = train_df['embeddings'].apply(
    lambda x: np.fromstring(x[1:-1], dtype=np.float64, sep=' ')
)

In [9]:
acc_scores, prec_scores, cv_f1_scores, recall_scores = [], [], [], []
skf_split_generator = skf.split(X=train_df[INPUT_COL_NAMES], y=train_df[TARGET_COL_NAME])

for fold_id, (train_idx, val_idx) in enumerate(skf_split_generator):
    curr_train_df = train_df.iloc[train_idx]
    curr_val_df = train_df.iloc[val_idx]
    model.fit(np.stack(curr_train_df[INPUT_COL_NAMES]), curr_train_df[TARGET_COL_NAME])
    
    # making predictions for the current validation set

    curr_preds = model.predict(np.stack(curr_val_df[INPUT_COL_NAMES]))
    curr_f1 = f1_score(y_true=curr_val_df[TARGET_COL_NAME], y_pred=curr_preds)
    curr_acc = accuracy_score(y_true=curr_val_df[TARGET_COL_NAME], y_pred=curr_preds)
    curr_prec = precision_score(y_true=curr_val_df[TARGET_COL_NAME], y_pred=curr_preds)
    curr_recall = recall_score(y_true=curr_val_df[TARGET_COL_NAME], y_pred=curr_preds)
    cv_f1_scores.append(curr_f1)
    acc_scores.append(curr_acc)
    prec_scores.append(curr_prec)
    recall_scores.append(curr_recall)
    print(f"F1-score for fold {fold_id} is {curr_f1:.3}. Accuracy is {curr_acc:.3}. Precision is {curr_prec:.3}. Recall is {curr_recall:.3}.")
    
print(f'Average cross-validation F1-score is {np.mean(cv_f1_scores):.3} +/- {np.std(cv_f1_scores):.3}.')
print(f'Average cross-validation ACC is {np.mean(acc_scores):.3} +/- {np.std(acc_scores):.3}.')
print(f'Average cross-validation Prec is {np.mean(prec_scores):.3} +/- {np.std(prec_scores):.3}.')
print(f'Average cross-validation Recall is {np.mean(recall_scores):.3} +/- {np.std(recall_scores):.3}.')


Fitting 3 folds for each of 7 candidates, totalling 21 fits




iter 1 6.534953666163638 (80000)
iter 2 9.059299326682595 (80000)
iter 3 8.043899663086933 (79942)
iter 4 7.919227610216584 (79901)
iter 5 7.555550690106932 (79825)
iter 6 7.733767635124929 (79760)
iter 7 7.993840334535776 (79750)
iter 8 7.341623977336122 (79745)
iter 9 6.921685180030187 (79705)
iter 10 6.709224438085299 (79483)
iter 11 6.056093409407867 (79405)
iter 12 6.416456595192772 (79249)
iter 13 6.319989770084816 (79192)
iter 14 6.456551654874687 (79100)
iter 15 7.246448842733688 (79066)
iter 16 6.640354653572729 (78988)
iter 17 5.8879489884626945 (78916)
iter 18 6.545738424567232 (78797)
iter 19 5.982153206332864 (78772)
iter 20 6.6370053735231025 (78707)
iter 21 5.912570031851029 (78632)
iter 22 5.566767010780101 (78548)
iter 23 5.815981179774488 (78467)
iter 24 6.035028028387485 (78440)
iter 25 5.427694186505443 (78386)
iter 26 5.61443521371225 (78349)
iter 27 5.381984800905192 (78138)
iter 28 5.25429805081588 (78059)
iter 29 6.0504635509483675 (78017)
iter 30 5.477228727252

iter 239 0.7267455925322776 (36102)
iter 240 0.7008539143084132 (36051)
iter 241 0.6715267509347925 (35775)
iter 242 0.7822488614345962 (35641)
iter 243 0.6293681267787473 (35596)
iter 244 0.7170616520659713 (35151)
iter 245 0.7057545361081864 (35073)
iter 246 0.6942618087837273 (35040)
iter 247 0.7218957148396492 (35009)
iter 248 0.67234262559081 (34953)
iter 249 0.661828548982476 (34785)
iter 250 0.6911663571224528 (34762)
iter 251 0.6141839129681002 (34753)
iter 252 0.6742516514828762 (34595)
iter 253 0.6038584772676521 (34533)
iter 254 0.6113893532484573 (34453)
iter 255 0.6147954482038988 (34426)
iter 256 0.6106139868074179 (34391)
iter 257 0.5925983150755912 (34121)
iter 258 0.6115765780605482 (34102)
iter 259 0.5758074529931615 (34067)
iter 260 0.6415255082403395 (34022)
iter 261 0.5869480855778959 (34012)
iter 262 0.5526615863830306 (33986)
iter 263 0.5944281154510447 (33896)
iter 264 0.524112718906912 (33866)
iter 265 0.6733912495074825 (33664)
iter 266 0.6050660357707376 (336

iter 470 0.07042090860208639 (26859)
iter 471 0.0689129171867763 (26856)
iter 472 0.06771576430988967 (26851)
iter 473 0.07715407679386543 (26843)
iter 474 0.07240901488003268 (26841)
iter 475 0.06478501877645348 (26839)
iter 476 0.0721466457585395 (26836)
iter 477 0.07330319569233539 (26827)
iter 478 0.07195612498364179 (26810)
iter 479 0.06289894958555606 (26809)
iter 480 0.07811047356101436 (26806)
iter 481 0.06719983335212809 (26801)
iter 482 0.0727485498729401 (26798)
iter 483 0.06295908166484004 (26797)
iter 484 0.0692028865770455 (26790)
iter 485 0.06480367053166419 (26790)
iter 486 0.06867098411047334 (26785)
iter 487 0.062119976720870246 (26784)
iter 488 0.06659877530131356 (26777)
iter 489 0.06332319638120955 (26761)
iter 490 0.06461985854892802 (26758)
iter 491 0.06256854702446242 (26755)
iter 492 0.05866896927560519 (26751)
iter 493 0.0572261056978895 (26750)
iter 494 0.05879338616351781 (26748)
iter 495 0.0713606722637575 (26743)
iter 496 0.058056050314323066 (26736)
iter 

iter 691 0.009220285949062351 (26304)
iter 692 0.008932391450075239 (26303)
iter 693 0.008888646458943494 (26299)
iter 694 0.009157002656634147 (26299)
iter 695 0.009625378428100129 (26299)
iter 696 0.009236345754704556 (26299)
iter 697 0.00883115525919731 (26299)
iter 698 0.009278885457894494 (26297)
iter 699 0.008483032698554105 (26296)
iter 700 0.009733252810686932 (26295)
iter 701 0.009102580608032665 (26295)
iter 702 0.008550428276289771 (26295)
iter 703 0.008128155606246468 (26295)
iter 704 0.009406855225988053 (26292)
iter 705 0.008415846338974225 (26292)
iter 706 0.007883427417908394 (26292)
iter 707 0.007468684964182654 (26290)
iter 708 0.008462888923094741 (26290)
iter 709 0.0074645048787573 (26290)
iter 710 0.00846249451951539 (26289)
iter 711 0.0069098016981412914 (26289)
iter 712 0.008185911992727779 (26289)
iter 713 0.007178790990856276 (26288)
iter 714 0.0073296851304335 (26287)
iter 715 0.007453849645245758 (26287)
iter 716 0.007208281379356489 (26287)
iter 717 0.007825

iter 909 0.0012849064804535404 (26238)
iter 910 0.0010862732584548673 (26238)
iter 911 0.0014522827883123313 (26238)
iter 912 0.001103657116552581 (26238)
iter 913 0.0012428304791643752 (26238)
iter 914 0.0011080679087708545 (26238)
iter 915 0.001124397271978994 (26238)
iter 916 0.0012669899964506404 (26237)
iter 917 0.0010914335606522357 (26237)
iter 918 0.0009813152821672722 (26237)
iter 919 0.001059718194751813 (80000)
iter 920 0.0011301194728152442 (26245)
iter 921 0.0009602670687740844 (26243)
iter 922 0.0010120451247264306 (80000)
iter 923 0.0010084401375998198 (26247)
iter 924 0.0010001409774345088 (26241)
iter 925 0.0009949321611214604 (26241)
iter 926 0.0010727247103780035 (80000)
iter 927 0.0010431510572345426 (26245)
iter 928 0.0009378512499555736 (26243)
iter 929 0.0010285239676950109 (80000)
iter 930 0.0009660978410805932 (26245)
iter 931 0.001033997365587791 (80000)
iter 932 0.0009043928127037693 (26249)
iter 933 0.0009016619999541176 (80000)

Converged at iteration 932
F



iter 1 5.961813345005586 (80000)
iter 2 9.528920453469777 (80000)
iter 3 8.575780595155642 (79993)
iter 4 8.34902737108124 (79981)
iter 5 7.91054805782661 (79921)
iter 6 7.9156386324160195 (79914)
iter 7 6.651219606791143 (79823)
iter 8 6.85035587878607 (79709)
iter 9 6.660004729676052 (79557)
iter 10 6.592869100281007 (79426)
iter 11 7.316423529700998 (79387)
iter 12 6.99055279825689 (79373)
iter 13 6.5919634008166526 (79358)
iter 14 6.4795859357722225 (79304)
iter 15 6.311451217417117 (79276)
iter 16 6.002881865374894 (79258)
iter 17 6.3654399067085645 (79207)
iter 18 6.025257916765929 (79147)
iter 19 5.9337236610727135 (79027)
iter 20 5.965263952505577 (78956)
iter 21 6.10792010404915 (78924)
iter 22 5.822416833917892 (78896)
iter 23 5.728981503829138 (78775)
iter 24 5.674515189104321 (78755)
iter 25 5.930423347564965 (78691)
iter 26 6.4067862662440085 (78533)
iter 27 5.449711625885488 (78448)
iter 28 5.263577377244157 (78340)
iter 29 5.252247771235146 (78145)
iter 30 6.197116296155

iter 236 0.7253824203112671 (36369)
iter 237 0.8672565360949809 (36246)
iter 238 0.8452848474530925 (36245)
iter 239 0.7538503665216969 (36236)
iter 240 0.7339670307805419 (36215)
iter 241 0.8885283613536515 (35984)
iter 242 0.6923822644769075 (35961)
iter 243 0.7002502897614239 (35626)
iter 244 0.6431405932043937 (35462)
iter 245 0.726234425559303 (35222)
iter 246 0.7269230545500949 (35107)
iter 247 0.684558444821458 (35081)
iter 248 0.801228847511913 (34856)
iter 249 0.7009664047930475 (34740)
iter 250 0.661838979981821 (34652)
iter 251 0.6612657182907533 (34634)
iter 252 0.6512607568435852 (34616)
iter 253 0.6345105288821007 (34509)
iter 254 0.5729746599416248 (34355)
iter 255 0.6112317704616954 (34193)
iter 256 0.6940146666099285 (34057)
iter 257 0.5778105506649965 (34032)
iter 258 0.6514074988639904 (33861)
iter 259 0.5742937834870633 (33711)
iter 260 0.5620274374906271 (33674)
iter 261 0.6444817360404916 (33568)
iter 262 0.5570406550108756 (33524)
iter 263 0.5930185792798983 (334

iter 461 0.07808546574531569 (27036)
iter 462 0.08476829924988204 (27028)
iter 463 0.08519947243347403 (27025)
iter 464 0.08023243919957576 (27023)
iter 465 0.09434485135856682 (27021)
iter 466 0.07565994231061747 (27019)
iter 467 0.0779833626905048 (27015)
iter 468 0.08752948612807171 (27012)
iter 469 0.08294894624464427 (27009)
iter 470 0.08530475823964603 (27007)
iter 471 0.07176181083321023 (26999)
iter 472 0.07188937838035803 (26985)
iter 473 0.07334311284971123 (26980)
iter 474 0.07170577321962032 (26978)
iter 475 0.07806738572739048 (26958)
iter 476 0.08014519086971289 (26950)
iter 477 0.0818504646889165 (26939)
iter 478 0.06755497935358945 (26939)
iter 479 0.06393596630431284 (26930)
iter 480 0.06814982645495692 (26916)
iter 481 0.06773154727909284 (26912)
iter 482 0.06685408602232576 (26907)
iter 483 0.06199587366927378 (26904)
iter 484 0.06657371572732212 (26888)
iter 485 0.06709345194080964 (26882)
iter 486 0.06174919030464276 (26871)
iter 487 0.06699874356747698 (26867)
ite

iter 682 0.01128522166670794 (26393)
iter 683 0.009998414873560138 (26393)
iter 684 0.009725565167796879 (26393)
iter 685 0.010303414965523533 (26392)
iter 686 0.010512310178428033 (26391)
iter 687 0.009525311954250316 (26391)
iter 688 0.010001763106904298 (26390)
iter 689 0.008846801435960505 (26389)
iter 690 0.009875310870757992 (26386)
iter 691 0.00932016397719862 (26385)
iter 692 0.010132103449906352 (26383)
iter 693 0.008623820903503582 (26381)
iter 694 0.008762224899760807 (26381)
iter 695 0.008832988262356819 (26381)
iter 696 0.009784152177806968 (26381)
iter 697 0.00841495456705671 (26380)
iter 698 0.00878992281667057 (26378)
iter 699 0.00863488502767415 (26378)
iter 700 0.008106997817289852 (26378)
iter 701 0.008333812872623705 (26377)
iter 702 0.008247896660092131 (26377)
iter 703 0.00844523994537319 (26376)
iter 704 0.009359297972801398 (26376)
iter 705 0.009284517082568305 (26376)
iter 706 0.009088100813560429 (26375)
iter 707 0.008211635414529198 (26375)
iter 708 0.0091238

iter 898 0.0014236036093931564 (26315)
iter 899 0.0013991313405865569 (26315)
iter 900 0.0013945112024233985 (26315)
iter 901 0.0011923699081670255 (26315)
iter 902 0.0012841430129270906 (26315)
iter 903 0.0014404466896602575 (26315)
iter 904 0.001338775307793881 (26315)
iter 905 0.0013413167967374143 (26315)
iter 906 0.0012585800249085288 (26314)
iter 907 0.0012908930886125214 (26314)
iter 908 0.0012059326474883614 (26313)
iter 909 0.0012042514429740914 (26313)
iter 910 0.0013255583566158524 (26313)
iter 911 0.001393612379646584 (26313)
iter 912 0.0012618061000188169 (26313)
iter 913 0.0012110215582218364 (26313)
iter 914 0.0011575608327081266 (26313)
iter 915 0.001145170505886859 (26313)
iter 916 0.0010220124253504728 (26313)
iter 917 0.0010240258822282744 (26313)
iter 918 0.0010873393192596392 (26313)
iter 919 0.0010931529265990614 (26313)
iter 920 0.0010654232255528293 (26313)
iter 921 0.0010454209108165108 (26313)
iter 922 0.0011543604919433265 (26313)
iter 923 0.00116816124897348



iter 1 6.3568074872434375 (80000)
iter 2 9.838456073710997 (80000)
iter 3 7.762199710947659 (79993)
iter 4 7.332881282691187 (79849)
iter 5 7.6958724990989795 (79776)
iter 6 7.448653885808632 (79690)
iter 7 7.465502424015254 (79606)
iter 8 7.230175409312828 (79579)
iter 9 7.440343629621323 (79547)
iter 10 6.500846036495252 (79520)
iter 11 6.224396890126046 (79414)
iter 12 6.298523419022917 (79326)
iter 13 6.460349724066118 (79228)
iter 14 6.297613880356363 (79157)
iter 15 6.236260099608437 (78941)
iter 16 5.693076585528027 (78881)
iter 17 6.172255036360983 (78703)
iter 18 6.471984743998002 (78665)
iter 19 6.286937121910238 (78638)
iter 20 5.839371876654624 (78632)
iter 21 5.877312110153325 (78593)
iter 22 5.430432894016315 (78514)
iter 23 5.845844356862733 (78388)
iter 24 5.923322245422646 (78312)
iter 25 5.206146676386352 (78298)
iter 26 6.048200852512094 (78185)
iter 27 5.434463684613066 (78179)
iter 28 5.88119259761157 (78116)
iter 29 5.578964315205786 (78096)
iter 30 6.160764948579

iter 238 0.7388972556386028 (36120)
iter 239 0.6765570623023871 (36034)
iter 240 0.7016525873064119 (35924)
iter 241 0.6725560082484552 (35860)
iter 242 0.736140495334817 (35680)
iter 243 0.7105636007618659 (35575)
iter 244 0.7308768792133664 (35550)
iter 245 0.7327210622955522 (35531)
iter 246 0.6491650947844305 (35519)
iter 247 0.74686194964194 (35322)
iter 248 0.7619601752580192 (35305)
iter 249 0.7278260684154005 (35275)
iter 250 0.6755406277690114 (35227)
iter 251 0.6563186042490892 (34893)
iter 252 0.6645414469269939 (34844)
iter 253 0.6232987452153156 (34788)
iter 254 0.6318564437278875 (34568)
iter 255 0.5898801876535087 (34549)
iter 256 0.6713013605602292 (34376)
iter 257 0.6183792579046947 (34366)
iter 258 0.6172389976491643 (34239)
iter 259 0.6603847510805385 (34200)
iter 260 0.5715533335933831 (34169)
iter 261 0.6255122264595454 (33968)
iter 262 0.5860643736690112 (33901)
iter 263 0.5912927923868749 (33855)
iter 264 0.5985037940782322 (33794)
iter 265 0.5792688259578096 (33

iter 467 0.07548641512772272 (26964)
iter 468 0.06925933853826674 (26954)
iter 469 0.07473755212154568 (26943)
iter 470 0.07747910441138706 (26933)
iter 471 0.07542221819845599 (26930)
iter 472 0.07332534879308913 (26928)
iter 473 0.06977166806489651 (26915)
iter 474 0.07723163433041745 (26906)
iter 475 0.08046410281886052 (26902)
iter 476 0.08446333762209293 (26896)
iter 477 0.06561384376204289 (26896)
iter 478 0.07301532853666914 (26885)
iter 479 0.07503910910019672 (26881)
iter 480 0.07240225611199272 (26881)
iter 481 0.06429975416530975 (26878)
iter 482 0.0807422578521571 (26871)
iter 483 0.06865253358587994 (26865)
iter 484 0.06018524362048426 (26865)
iter 485 0.06919585755584609 (26858)
iter 486 0.07182499562471789 (26852)
iter 487 0.06730757076500726 (26851)
iter 488 0.061789348703783345 (26848)
iter 489 0.05559726616906832 (26844)
iter 490 0.06664788396291518 (26838)
iter 491 0.06019888282076863 (26834)
iter 492 0.056270769964586756 (26825)
iter 493 0.06265439849581765 (26816)


iter 686 0.009641810157489453 (26366)
iter 687 0.010104399365394645 (26366)
iter 688 0.009171863637143077 (26365)
iter 689 0.009395148005444092 (26365)
iter 690 0.009365824226222175 (26365)
iter 691 0.00883907501639103 (26365)
iter 692 0.00890778582382036 (26363)
iter 693 0.00852176174262076 (26362)
iter 694 0.008532945246571298 (26362)
iter 695 0.008501609838995394 (26362)
iter 696 0.008138579454573747 (26362)
iter 697 0.008459891865528008 (26360)
iter 698 0.008095244764483263 (26359)
iter 699 0.008715787204359955 (26359)
iter 700 0.008657405984428213 (26359)
iter 701 0.008725142669434305 (26359)
iter 702 0.008554177172497568 (26359)
iter 703 0.007837541791135022 (26359)
iter 704 0.008903097328266704 (26356)
iter 705 0.008735437681714642 (26356)
iter 706 0.00805628332176575 (26354)
iter 707 0.007806977220280807 (26351)
iter 708 0.007974042828831684 (26351)
iter 709 0.008774944267092855 (26349)
iter 710 0.008562313302426627 (26349)
iter 711 0.007286590700411488 (26349)
iter 712 0.00723

iter 904 0.0011591354956272115 (26302)
iter 905 0.0011292599772317335 (26301)
iter 906 0.0012458832881940274 (26301)
iter 907 0.0011199407907121484 (26301)
iter 908 0.0011427479373804644 (26301)
iter 909 0.0011087168265846853 (26301)
iter 910 0.0011242584190180038 (26301)
iter 911 0.0012498303926488097 (26301)
iter 912 0.001281148680460354 (26301)
iter 913 0.0010815824937624452 (26301)
iter 914 0.001105690995643463 (26301)
iter 915 0.001095939135630608 (26301)
iter 916 0.0011149306880807418 (26301)
iter 917 0.001147284853593833 (26301)
iter 918 0.001220491708178617 (26301)
iter 919 0.0010880502759172336 (26301)
iter 920 0.001003330522100332 (26301)
iter 921 0.0011155684814181943 (26301)
iter 922 0.001068272034807305 (26301)
iter 923 0.0011881867336793056 (26301)
iter 924 0.001116984916559649 (26301)
iter 925 0.001150234187660061 (26301)
iter 926 0.001090625307968579 (26301)
iter 927 0.0009995480076260173 (26300)
iter 928 0.0009124444821136859 (80000)

Converged at iteration 927
F1-scor



iter 1 6.261324460305259 (80000)
iter 2 10.091826436146317 (80000)
iter 3 8.042059735224173 (79999)
iter 4 8.169780769789039 (79848)
iter 5 7.349167672818486 (79774)
iter 6 7.074445051243137 (79735)
iter 7 7.03948791207243 (79663)
iter 8 8.530716692176615 (79595)
iter 9 7.560968089086499 (79595)
iter 10 6.573076511787363 (79582)
iter 11 6.9397339444165835 (79527)
iter 12 6.194763526996074 (79494)
iter 13 6.208248569601558 (79425)
iter 14 7.2561573055054325 (79368)
iter 15 6.201016908784083 (79343)
iter 16 6.465382001893175 (79268)
iter 17 6.117170241148816 (79148)
iter 18 6.057558436509938 (78938)
iter 19 6.03578671451253 (78886)
iter 20 5.869369937533393 (78726)
iter 21 5.86649863804222 (78688)
iter 22 6.049226372831223 (78603)
iter 23 5.402405370680832 (78577)
iter 24 5.779849273743959 (78529)
iter 25 5.7253734882956975 (78518)
iter 26 5.377493800553257 (78497)
iter 27 5.564105668461071 (78378)
iter 28 5.575303438598381 (78301)
iter 29 5.460865274202595 (78247)
iter 30 5.533053804515

iter 236 0.7040778943861434 (36257)
iter 237 0.7268507599175067 (36030)
iter 238 0.7205897977182529 (35829)
iter 239 0.75094040708235 (35668)
iter 240 0.6881730966318963 (35604)
iter 241 0.667071564607481 (35472)
iter 242 0.6692105153668256 (35399)
iter 243 0.7645166259946774 (35243)
iter 244 0.6386234202900405 (35087)
iter 245 0.8135420110123034 (34910)
iter 246 0.6544134158299736 (34879)
iter 247 0.6695783434544549 (34795)
iter 248 0.5970385764030146 (34738)
iter 249 0.6253364122573322 (34369)
iter 250 0.6457286795492944 (34271)
iter 251 0.7081582626665832 (34173)
iter 252 0.6038100448859256 (34118)
iter 253 0.562923564055279 (34079)
iter 254 0.6401704166708424 (33959)
iter 255 0.6208324302996704 (33931)
iter 256 0.61253721300272 (33852)
iter 257 0.6086326239749387 (33752)
iter 258 0.595776092935152 (33687)
iter 259 0.5699463607005286 (33626)
iter 260 0.5899102245296006 (33529)
iter 261 0.6499921061809728 (33464)
iter 262 0.5470182937655341 (33459)
iter 263 0.5409450983077297 (33361)

iter 464 0.08037690233349291 (26866)
iter 465 0.09264845700244591 (26852)
iter 466 0.07322955492769953 (26851)
iter 467 0.07299184178252366 (26845)
iter 468 0.07646013976091998 (26832)
iter 469 0.08215087852080499 (26832)
iter 470 0.07926475619809656 (26829)
iter 471 0.07262024413127421 (26826)
iter 472 0.08347348406304039 (26823)
iter 473 0.07234622095243087 (26819)
iter 474 0.07784062860295629 (26808)
iter 475 0.07037744418047816 (26806)
iter 476 0.07043928916507705 (26797)
iter 477 0.07575207719612498 (26791)
iter 478 0.06398435266853189 (26788)
iter 479 0.06891292897134552 (26773)
iter 480 0.06562026693428802 (26770)
iter 481 0.06673924089186942 (26765)
iter 482 0.06865729899498373 (26750)
iter 483 0.07047357971545673 (26746)
iter 484 0.05780524466763105 (26742)
iter 485 0.06431195664242957 (26724)
iter 486 0.05914765397695284 (26720)
iter 487 0.06048822629269211 (26706)
iter 488 0.060268489109814916 (26703)
iter 489 0.05831400262229969 (26696)
iter 490 0.05993036614695957 (26695)


iter 683 0.009231403902549298 (26242)
iter 684 0.009556549637125489 (26239)
iter 685 0.010484149988739352 (26239)
iter 686 0.008703326775225673 (26238)
iter 687 0.008795902022336217 (26238)
iter 688 0.008774870703663673 (26238)
iter 689 0.00927903137447747 (26238)
iter 690 0.009522493331636461 (26238)
iter 691 0.008143969968918821 (26238)
iter 692 0.008783694218354482 (26237)
iter 693 0.008401908598145547 (26236)
iter 694 0.008147634139089088 (26235)
iter 695 0.00863002336405299 (26234)
iter 696 0.007852896720074692 (26234)
iter 697 0.008332551838561075 (26233)
iter 698 0.00932881582360387 (26232)
iter 699 0.008219999545558632 (26231)
iter 700 0.0083634846328193 (26231)
iter 701 0.009755297135180946 (26229)
iter 702 0.00829515237427958 (26229)
iter 703 0.008524256983561018 (26229)
iter 704 0.0072867536966098045 (26229)
iter 705 0.007416287570800734 (26228)
iter 706 0.007936507563721232 (26228)
iter 707 0.007814767489453422 (26228)
iter 708 0.00856209014736209 (26228)
iter 709 0.0074669

iter 901 0.001292138102906848 (26164)
iter 902 0.001214141270806171 (26164)
iter 903 0.0011719922591594406 (26164)
iter 904 0.0011923234585253968 (26162)
iter 905 0.0014508547407907323 (26162)
iter 906 0.0011172178068366845 (26162)
iter 907 0.0011728296862627552 (26162)
iter 908 0.0012565297953624523 (26162)
iter 909 0.0011194312562807646 (26162)
iter 910 0.001105304806838181 (26162)
iter 911 0.0010841194802761484 (26162)
iter 912 0.0013170604629420546 (26162)
iter 913 0.001157669315585877 (26162)
iter 914 0.001143715640328058 (26162)
iter 915 0.0010258710139229368 (26162)
iter 916 0.00100419938758467 (26162)
iter 917 0.0011417552035357936 (26162)
iter 918 0.001046400318071572 (26162)
iter 919 0.0010820861408992744 (26162)
iter 920 0.0011803616970127162 (26162)
iter 921 0.0010628146839805064 (26162)
iter 922 0.0009394703246154712 (26162)
iter 923 0.0009385362457467739 (80000)

Converged at iteration 922
F1-score for fold 3 is 0.659. Accuracy is 0.943. Precision is 0.76. Recall is 0.582



iter 1 6.822843484697255 (80000)
iter 2 8.329973319371334 (80000)
iter 3 6.6676972570770285 (79999)
iter 4 5.4633275245360835 (79904)
iter 5 4.801964798695121 (79810)
iter 6 4.243916012132296 (79681)
iter 7 4.090117141527441 (79117)
iter 8 3.568492048998563 (78663)
iter 9 3.5372089726080134 (77854)
iter 10 2.9633486461062857 (77394)
iter 11 2.7557166115615477 (75731)
iter 12 2.347127262963143 (75346)
iter 13 2.3763596355688437 (72939)
iter 14 2.3585745338073796 (71548)
iter 15 1.9455195656991044 (70632)
iter 16 1.8410633421684168 (69044)
iter 17 1.6328090699375157 (65185)
iter 18 1.5394778819634043 (64342)
iter 19 1.311308094762234 (61429)
iter 20 1.2317642082734241 (58308)
iter 21 1.0980016372235024 (56103)
iter 22 0.9833211927787051 (53291)
iter 23 1.0111845242451016 (50484)
iter 24 0.8511506399195488 (49502)
iter 25 0.7688550667010827 (47879)
iter 26 0.6997898525635732 (45241)
iter 27 0.6780833052974954 (43317)
iter 28 0.6475880986266598 (42923)
iter 29 0.5548410795752781 (40703)
it

## Test Set Predictions

In [10]:
curr_preds = model.predict(X=np.stack(test_df[INPUT_COL_NAMES]))
curr_f1 = f1_score(y_true=test_df[TARGET_COL_NAME], y_pred=curr_preds)
curr_acc = accuracy_score(y_true=test_df[TARGET_COL_NAME], y_pred=curr_preds)
curr_prec = precision_score(y_true=test_df[TARGET_COL_NAME], y_pred=curr_preds)
curr_recall = recall_score(y_true=test_df[TARGET_COL_NAME], y_pred=curr_preds)
print(f"[Final Test] F1-score for fold {curr_f1:.3}. Accuracy is {curr_acc:.3}. Precision is {curr_prec:.3}. Recall is {curr_recall:.3}.")

[Final Test] F1-score for fold 0.651. Accuracy is 0.943. Precision is 0.8. Recall is 0.549.
