In [1]:
import torch
from fastai.text import *
from pathlib import Path
import news_utils.plot
from sklearn.metrics import f1_score, cohen_kappa_score

import pymongo
from collections import defaultdict

# print all available databases
client = pymongo.MongoClient('localhost', 27017)
cursor = client.list_databases()
for db in cursor:
    print(db)

{'name': '10000', 'sizeOnDisk': 270336.0, 'empty': False}
{'name': '10000_cl', 'sizeOnDisk': 262144.0, 'empty': False}
{'name': '10000_cl_clagreement', 'sizeOnDisk': 253952.0, 'empty': False}
{'name': '10000_cl_claudience', 'sizeOnDisk': 368640.0, 'empty': False}
{'name': '10000_cl_clcontroversial', 'sizeOnDisk': 262144.0, 'empty': False}
{'name': '10000_cl_cldisagreement', 'sizeOnDisk': 262144.0, 'empty': False}
{'name': '10000_cl_clinformative', 'sizeOnDisk': 258048.0, 'empty': False}
{'name': '10000_cl_clmean', 'sizeOnDisk': 262144.0, 'empty': False}
{'name': '10000_cl_clpersuasive', 'sizeOnDisk': 360448.0, 'empty': False}
{'name': '10000_cl_clsentiment', 'sizeOnDisk': 262144.0, 'empty': False}
{'name': '10000_cl_cltopic', 'sizeOnDisk': 262144.0, 'empty': False}
{'name': '10000_ner', 'sizeOnDisk': 274432.0, 'empty': False}
{'name': '10000_ner_cl', 'sizeOnDisk': 262144.0, 'empty': False}
{'name': '10000_ner_cl_clagreement', 'sizeOnDisk': 253952.0, 'empty': False}
{'name': '10000_ner_

In [2]:
def get_mets(db, conf=None):
    mydb = client[db]
    res = mydb["metrics"].aggregate([{
        "$match": {"name": 'kappa_score'}  # only consider metric
    },
        {"$unwind": "$values"},
        {"$group":
         {'_id': '$_id',
          'val': {'$max': "$values"}, 'run_id' : { '$first': '$run_id' }}
         },  # find min values
        {"$sort": {"val": -1}}  # sort
    ])
    
    if not conf is None:
        runs = mydb['runs'].find(conf)
        runs  = [r['_id'] for r in list(runs)]
        res = [r for r in res if r['run_id'] in runs]

    best = list(res)[0]

    epoch = None
    max_epochs = 0

    for x in mydb['metrics'].find({'run_id': best['run_id'], 'name': 'kappa_score'}):
        max_epochs = len(x['values'])
        for i, v in enumerate(x['values']):
            if v == best['val'] and epoch is None:
                epoch = i + 1

    for x in mydb['metrics'].find({'run_id': best['run_id'], 'name': 'F1_macro'}):
        f1_macro = x['values'][epoch - 1]

    for x in mydb['metrics'].find({'run_id': best['run_id'], 'name': 'accuracy'}):
        f1_micro = x['values'][epoch - 1]
    
    run = list(mydb['runs'].find({'_id': best['run_id']}))[0]
        
    mod = ''
    if 'mod' in  run['config']:
        mod= run['config']['mod']
        
    return best['val'], f1_micro, f1_macro, epoch, max_epochs, run['config']['exp_id'], run['config']['drop_mult'], mod

In [3]:
all_cols =  ['claudience', 'clpersuasive', 'clsentiment', 'clagreement', 'cldisagreement', 'clinformative', 'clmean', 'clcontroversial', 'cltopic']

In [4]:
def test_model(model_id, clas, datafile, baseline):
    p = list(Path('/home/group7/data/ynacc_proc').glob(f'**/{model_id}.pth'))[0]
    p_fixed = '/'.join(p.parts[:-2])[1:]
    exp = p.parts[-3]
    
    if baseline:
        data_lm = TextLMDataBunch.load(Path('/mnt/data/group07/johannes/ynacc_proc/proper_baseline/exp/' + exp))
    else:
        data_lm = TextLMDataBunch.load(Path('/mnt/data/group07/johannes/ynacc_proc/proper_threads/exp/lm/' + exp))
    
    if baseline:
        UT = Path('~/data/ynacc_proc/proper_baseline/cls')/datafile
    else:
        UT = Path('~/data/ynacc_proc/proper_threads/data/cls')/datafile
    data_clas_train = pd.read_csv(UT/'train.csv')
    data_clas_val = pd.read_csv(UT/'test.csv')
    
    print(data_clas_val.shape)

    data_clas_train = data_clas_train[[clas, 'text_proc']]
    data_clas_val = data_clas_val[[clas, 'text_proc']]

    data_clas_train = data_clas_train.dropna()
    data_clas_val = data_clas_val.dropna()

    data_clas_train[clas] = data_clas_train[clas].astype(int)
    data_clas_val[clas] = data_clas_val[clas].astype(int)

    data_clas = TextClasDataBunch.from_df(p_fixed, data_clas_train, data_clas_val,
                                          vocab=data_lm.train_ds.vocab, bs=64, text_cols=['text_proc'], label_cols=[clas],tokenizer=Tokenizer(cut_n_from_behind=1398))
    del data_lm
    learn = text_classifier_learner(data_clas).load(p.stem)
    res = learn.get_preds(ordered=True)
    preds = np.argmax(res[0], axis=1)
    f1ma = f1_score(res[1], preds, average='macro')
    f1mi = f1_score(res[1], preds, average='micro')
    kappa = cohen_kappa_score(res[1], preds)
    return f1ma, f1mi, kappa

In [5]:
def test_all(db, df, baseline=False, **kwargs):
    dic = {}
    for col in all_cols:
        mid = get_mets(db + col, **kwargs)[5]
        res = test_model(mid, col, df, baseline)
        dic[col] = res
    return dic

In [25]:
d = test_all('lm_threads_cut_cl_', conf={'config.mod': 'simple_fit'}, df='only_threads_unlimited_30000_cut')

KeyboardInterrupt: 

In [54]:
d

{'claudience': (0.8421709668455142, 0.8784029038112523, 0.6849359494081402),
 'clpersuasive': (0.6277251307544922, 0.8119349005424954, 0.2582159624413145),
 'clsentiment': (0.44014628350696794, 0.6454545454545455, 0.33399157941801105),
 'clagreement': (0.677451451314074, 0.8770343580470162, 0.36844580296261464),
 'cldisagreement': (0.7172282679952799,
  0.7179023508137432,
  0.44105222236620445),
 'clinformative': (0.5790177561055901,
  0.7992766726943942,
  0.15817984832069343),
 'clmean': (0.6757378478747876, 0.7884267631103075, 0.3674562749909568),
 'clcontroversial': (0.580391373801917,
  0.5877034358047016,
  0.24910667492496064),
 'cltopic': (0.5927595145516673, 0.6383363471971067, 0.19354838709677424)}

In [11]:
d2 = test_all('lm_threads_cl_', conf={'config.mod': 'simle_fit'}, df="threads_headline_unlimited_30000_cut")

In [12]:
d2

{'claudience': (0.8356205250596659, 0.8802177858439202, 0.6740930599369086),
 'clpersuasive': (0.6456240845365139, 0.8227848101265823, 0.2943673341840055),
 'clsentiment': (0.4406875486894789, 0.6454545454545455, 0.3146877276386918),
 'clagreement': (0.6314807410369185, 0.8553345388788427, 0.2757277102910841),
 'cldisagreement': (0.7164922704805974,
  0.7179023508137432,
  0.4424826801778513),
 'clinformative': (0.5886866314347231,
  0.8282097649186256,
  0.17923039667536367),
 'clmean': (0.6993963322175493, 0.8282097649186256, 0.402977441900108),
 'clcontroversial': (0.5491620955160788,
  0.5605786618444847,
  0.20784380765988553),
 'cltopic': (0.5872282965435757, 0.6401446654611211, 0.17913965822038902)}

In [6]:
d_headline = test_all('threads_headline_cl_', df='threads_headline_unlimited_30000_cut')

In [8]:
d_headline

{'claudience': (0.8540197332358854, 0.8947368421052632, 0.7110801721332225),
 'clpersuasive': (0.6455128205128206, 0.8155515370705244, 0.292466320463611),
 'clsentiment': (0.4279077011508827, 0.6236363636363637, 0.31490775174206587),
 'clagreement': (0.6340299547196099, 0.8625678119349005, 0.28489757027155793),
 'cldisagreement': (0.6995485572601279, 0.701627486437613, 0.4111972226344963),
 'clinformative': (0.5641551071878941,
  0.8191681735985533,
  0.13053048646268994),
 'clmean': (0.6902753996575506, 0.810126582278481, 0.3896952943525924),
 'clcontroversial': (0.5668643115923915,
  0.5750452079566004,
  0.22763986045157103),
 'cltopic': (0.5594910861540878, 0.6148282097649186, 0.12448620082207862)}

In [6]:
d_headline_root = test_all('headline_root_threads_cl_', df='threads_root_headline_unlimited_30000_cut')

In [7]:
d_headline_root

{'claudience': (0.7819326726776149, 0.8166969147005445, 0.5647675282524538),
 'clpersuasive': (0.6571858216970998, 0.8173598553345389, 0.31508209989331315),
 'clsentiment': (0.43345907296087877, 0.62, 0.30324887865195793),
 'clagreement': (0.6624114242440634, 0.8571428571428571, 0.3311542171257099),
 'cldisagreement': (0.7090698653198653, 0.7106690777576854, 0.428431157220191),
 'clinformative': (0.5791139240506329, 0.7938517179023509, 0.1587894638520455),
 'clmean': (0.7063598462743613, 0.8173598553345389, 0.422582679444634),
 'clcontroversial': (0.580518983667977,
  0.5840867992766727,
  0.2309320240413104),
 'cltopic': (0.5796766144251768, 0.6292947558770343, 0.16618245206275417)}

In [19]:
d_headline_article = test_all('threads_headline_article_cl_', df='threads_headline_unlimited_30000_cut')

In [20]:
d_headline_article

{'claudience': (0.848465994905435, 0.8892921960072595, 0.6994375240326576),
 'clpersuasive': (0.656187827816965, 0.8010849909584087, 0.3125155398838182),
 'clsentiment': (0.4478810198506771, 0.5945454545454546, 0.32271246341598103),
 'clagreement': (0.6166630105734217, 0.8571428571428571, 0.25176837309675093),
 'cldisagreement': (0.710668131638152, 0.7106690777576854, 0.4232674558064349),
 'clinformative': (0.5941882874171005, 0.806509945750452, 0.18851570964247022),
 'clmean': (0.6718925985518905, 0.786618444846293, 0.3595642359407204),
 'clcontroversial': (0.5061957357951423,
  0.5298372513562387,
  0.1709719082983533),
 'cltopic': (0.5548103999308974, 0.5786618444846293, 0.1447299423177767)}

In [9]:
def get_table(dic, db, **kwargs):
    rows = []
    for col in all_cols:
        r = []
        r.append(col[2:].title())
        kappa, micro, macro, *rest = get_mets(db + col, **kwargs)
        r += [micro, macro, kappa]
        r.append(dic[col][1])
        r.append(dic[col][0])
        r.append(dic[col][2])
        rows.append(r)

    for r in rows:
        r = [str(x) for x in r]
        print(' & '.join(r) + ' \\\\')

In [10]:
get_table(d, db='lm_threads_cut_cl_', conf={'config.mod': 'simple_fit'})

Audience & 0.8364887833595276 & 0.7963261585921615 & 0.6009355783462524 & 0.8784029038112523 & 0.8421709668455142 & 0.6849359494081402 \\
Persuasive & 0.8473413586616516 & 0.703592626233198 & 0.40735119581222534 & 0.8083182640144665 & 0.6280206112295665 & 0.25792485314968605 \\
Sentiment & 0.6470588445663452 & 0.40863006396588486 & 0.3614158630371094 & 0.6454545454545455 & 0.44014628350696794 & 0.33399157941801105 \\
Agreement & 0.9210977554321289 & 0.7613207547169811 & 0.5264347791671753 & 0.8770343580470162 & 0.677451451314074 & 0.36844580296261464 \\
Disagreement & 0.7941681146621704 & 0.7861335289801907 & 0.5722670555114746 & 0.7160940325497287 & 0.7152560272081179 & 0.43782982277792526 \\
Informative & 0.8473413586616516 & 0.6950424637808927 & 0.39166170358657837 & 0.7974683544303798 & 0.5775401069518716 & 0.15528763536182866 \\
Mean & 0.8147512674331665 & 0.7266333229134104 & 0.45350390672683716 & 0.7884267631103075 & 0.6757378478747876 & 0.3674562749909568 \\
Controversial & 0.7

In [13]:
get_table(d2, db='lm_threads_cl_', conf={'config.mod': 'simle_fit'})

Audience & 0.8278829455375671 & 0.7772512575144154 & 0.5690997838973999 & 0.8802177858439202 & 0.8356205250596659 & 0.6740930599369086 \\
Persuasive & 0.8456260561943054 & 0.6740183896620278 & 0.34844154119491577 & 0.8227848101265823 & 0.6456240845365139 & 0.2943673341840055 \\
Sentiment & 0.6159169673919678 & 0.42018169515097575 & 0.32171106338500977 & 0.6454545454545455 & 0.4406875486894789 & 0.3146877276386918 \\
Agreement & 0.9125214219093323 & 0.7460996541565261 & 0.4948951005935669 & 0.8553345388788427 & 0.6314807410369185 & 0.2757277102910841 \\
Disagreement & 0.7821612358093262 & 0.7738135606164749 & 0.5476285219192505 & 0.7179023508137432 & 0.7164922704805974 & 0.4424826801778513 \\
Informative & 0.843910813331604 & 0.67211767250703 & 0.3481070399284363 & 0.8282097649186256 & 0.5886866314347231 & 0.17923039667536367 \\
Mean & 0.8404802680015564 & 0.7247836349331235 & 0.45336586236953735 & 0.8282097649186256 & 0.6993963322175493 & 0.402977441900108 \\
Controversial & 0.72555744

In [6]:
d_baseline = test_all('dat_false_par_true_hea_false30000_cl_', df='dat_false_par_true_hea_false', baseline=True)

In [9]:
get_table(d_baseline, db='dat_false_par_true_hea_false30000_cl_')

Audience & 0.8347676396369934 & 0.7905299843768778 & 0.5918749570846558 & 0.8820326678765881 & 0.8393549978694297 & 0.6811143856899913 \\
Persuasive & 0.8301886916160583 & 0.6847323198942499 & 0.3704584240913391 & 0.7956600361663653 & 0.6234082430860648 & 0.24718397243605972 \\
Sentiment & 0.6280276775360107 & 0.44814216067070645 & 0.322735071182251 & 0.6527272727272727 & 0.4368316980488497 & 0.3029382100010617 \\
Agreement & 0.9125214219093323 & 0.6960592895476616 & 0.40444666147232056 & 0.8679927667269439 & 0.6092402404437174 & 0.24918630386668417 \\
Disagreement & 0.7787306904792786 & 0.7728248951074299 & 0.5461447238922119 & 0.6907775768535263 & 0.6903725823404026 & 0.38638885464184447 \\
Informative & 0.8473413586616516 & 0.6793238775068756 & 0.36243438720703125 & 0.8083182640144665 & 0.5252834467120182 & 0.054273821432028524 \\
Mean & 0.8404802680015564 & 0.7521813652672715 & 0.5044736862182617 & 0.8010849909584087 & 0.6900350576821166 & 0.3935315347650097 \\
Controversial & 0.70

In [11]:
get_table(d_headline, db='threads_headline_cl_', conf={'config.mod': 'simple_fit'})

Audience & 0.8399311304092407 & 0.7941482370421167 & 0.6008878946304321 & 0.8947368421052632 & 0.8540197332358854 & 0.7110801721332225 \\
Persuasive & 0.8473413586616516 & 0.688988389587192 & 0.3780103325843811 & 0.8155515370705244 & 0.6455128205128206 & 0.292466320463611 \\
Sentiment & 0.6038062572479248 & 0.40958570678427597 & 0.318831205368042 & 0.6236363636363637 & 0.4279077011508827 & 0.31490775174206587 \\
Agreement & 0.9210977554321289 & 0.7613207547169811 & 0.5264347791671753 & 0.8625678119349005 & 0.6340299547196099 & 0.28489757027155793 \\
Disagreement & 0.7924528121948242 & 0.7838974613473515 & 0.5678068399429321 & 0.701627486437613 & 0.6995485572601279 & 0.4111972226344963 \\
Informative & 0.8593481779098511 & 0.6932523997741389 & 0.39213693141937256 & 0.8191681735985533 & 0.5641551071878941 & 0.13053048646268994 \\
Mean & 0.8456260561943054 & 0.7496278057718735 & 0.5002095699310303 & 0.810126582278481 & 0.6902753996575506 & 0.3896952943525924 \\
Controversial & 0.727272748

In [10]:
get_table(d_headline_root, db='headline_root_threads_cl_', conf={'config.mod': 'simple_fit'})

Audience & 0.826161801815033 & 0.7970646110644243 & 0.5960416793823242 & 0.8166969147005445 & 0.7819326726776149 & 0.5647675282524538 \\
Persuasive & 0.8713550567626953 & 0.7209725279984686 & 0.4427451491355896 & 0.8173598553345389 & 0.6571858216970998 & 0.31508209989331315 \\
Sentiment & 0.6107266545295715 & 0.41123962765207106 & 0.3285176753997803 & 0.62 & 0.43345907296087877 & 0.30324887865195793 \\
Agreement & 0.9159519672393799 & 0.7636453894841353 & 0.5290092825889587 & 0.8571428571428571 & 0.6624114242440634 & 0.3311542171257099 \\
Disagreement & 0.7855917811393738 & 0.7770677255248659 & 0.5541368722915649 & 0.7106690777576854 & 0.7090698653198653 & 0.428431157220191 \\
Informative & 0.8456260561943054 & 0.6901133947554925 & 0.3819933533668518 & 0.7938517179023509 & 0.5791139240506329 & 0.1587894638520455 \\
Mean & 0.838765025138855 & 0.738500152695068 & 0.4779966473579407 & 0.8173598553345389 & 0.7063598462743613 & 0.422582679444634 \\
Controversial & 0.728987991809845 & 0.6863

In [21]:
get_table(d_headline_article, db='threads_headline_article_cl_', conf={'config.mod': 'simple_fit'})

Audience & 0.8330464959144592 & 0.7852943977751109 & 0.5837217569351196 & 0.8892921960072595 & 0.848465994905435 & 0.6994375240326576 \\
Persuasive & 0.8113207817077637 & 0.6512095896967324 & 0.3036462664604187 & 0.8010849909584087 & 0.656187827816965 & 0.3125155398838182 \\
Sentiment & 0.6089965105056763 & 0.42122653355484513 & 0.32407450675964355 & 0.5945454545454546 & 0.4478810198506771 & 0.32271246341598103 \\
Agreement & 0.9090909361839294 & 0.7318097784104224 & 0.4669921398162842 & 0.8571428571428571 & 0.6166630105734217 & 0.25176837309675093 \\
Disagreement & 0.7632933259010315 & 0.7565541031227306 & 0.5134850740432739 & 0.7106690777576854 & 0.710668131638152 & 0.4232674558064349 \\
Informative & 0.8267581462860107 & 0.6667119480622393 & 0.3340609073638916 & 0.806509945750452 & 0.5941882874171005 & 0.18851570964247022 \\
Mean & 0.8267581462860107 & 0.7273745861981156 & 0.45505446195602417 & 0.786618444846293 & 0.6718925985518905 & 0.3595642359407204 \\
Controversial & 0.73927956