In [1]:
import os
import sys
sys.path.append('../transformers/src/')
sys.path.append('../embeddings/')


from matplotlib import pyplot as plt
from library import GraphLib, Graph
from utils import print_util as pu

import json

In [2]:
#Model_dicts

models = {'fnet_mini': {'l':4,'o':['l','l','l','l'],'h':[256,256,256,256],'n':[4,4,4,4],'f':[[1024],[1024],[1024],[1024]],'p':['dft','dft','dft','dft']},
'fnet_tiny' : { 'l': 2, 'o': ['l']*2, 'h': [128]*2, 'n': [2]*2, 'f': [[4*128]]*2, 'p': ['dft']*2},
'fnet_2_256' : { 'l': 2, 'o': ['l']*2, 'h': [256]*2, 'n': [4]*2, 'f': [[4*256]]*2, 'p': ['dft']*2},
'fnet_4_128' : { 'l': 4, 'o': ['l']*4, 'h': [128]*4, 'n': [2]*4, 'f': [[4*128]]*4, 'p': ['dft']*4},
'bert_mini' : {'l': 4, 'o': ['sa']*4, 'h': [256]*4, 'n': [4]*4, 'f': [[1024]]*4, 'p': ['sdp']*4},
'bert_tiny' : { 'l': 2, 'o': ['sa']*2, 'h': [128]*2, 'n': [2]*2, 'f': [[4*128]]*2, 'p': ['sdp']*2},
'bert_2_256' : { 'l': 2, 'o': ['sa']*2, 'h': [256]*2, 'n': [4]*2, 'f': [[4*256]]*2, 'p': ['sdp']*2},
'bert_4_128' : { 'l': 4, 'o': ['sa']*4, 'h': [128]*4, 'n': [2]*4, 'f': [[4*128]]*4, 'p': ['sdp']*4},
'convbert_mini' : {'l':4,'o':['c','c','c','c'],'h':[256,256,256,256],'n':[4,4,4,4],'f':[[1024],[1024],[1024],[1024]],'p':[9,9,9,9]},
'convbert_tiny' : { 'l': 2, 'o': ['c']*2, 'h': [128]*2, 'n': [2]*2, 'f': [[4*128]]*2, 'p': [9]*2},
'convbert_2_256' : { 'l': 2, 'o': ['c']*2, 'h': [256]*2, 'n': [4]*2, 'f': [[4*256]]*2, 'p': [9]*2},
'convbert_4_128' : { 'l': 4, 'o': ['c']*4, 'h': [128]*4, 'n': [2]*4, 'f': [[4*128]]*4, 'p': [9]*4}
         }




In [3]:
dataset_file = '../dataset/dataset_test.json' 
graphLib = GraphLib.load_from_dataset(dataset_file)

def return_hash(graphlib,model_dict):
    
    graph = graphlib.get_graph(model_dict=model_dict)[0]
    
    return graph.hash

In [4]:
GLUE_TASKS = ['cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli']

def return_glue_score(model_name, disp_task = False):
    
    model_hash =  return_hash(graphLib,models[model_name])
    glue_scores = {}
    
    score = 0
    
    for task in GLUE_TASKS:    
        
        model_dir = f'../models/{task}/{model_hash}/eval_results.json'    
        f = open(model_dir,)
        metrics = json.load(f)

        if task == 'cola':

            glue_scores[task] = metrics['eval_matthews_correlation']
            task_score = glue_scores[task]

        elif task == 'stsb':

            glue_scores[task+'_spearman'] = metrics['eval_spearmanr']
            glue_scores[task+'_pearson'] = metrics['eval_pearson']
            task_score = (metrics['eval_spearmanr']+metrics['eval_pearson'])/2

        elif task == 'mrpc' or task=='qqp':

            glue_scores[task+'_accuracy'] = metrics['eval_accuracy']
            glue_scores[task+'_f1'] = metrics['eval_f1']
            task_score = (metrics['eval_accuracy']+metrics['eval_f1'])/2

        elif task in ["sst2", "mnli",  "qnli", "rte", "wnli"]:

            glue_scores[task] = metrics['eval_accuracy']
            task_score = metrics['eval_accuracy']
            
        if disp_task:
                print(task,':',task_score)
                
        score+=task_score
                        
    
    print(f"{model_name}:", score*1.0/9)

    output_dir = f"../models/glue_score/{model_hash}/"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    with open(output_dir+'glue_score_optuna.json', 'w') as fp:
        json.dump(glue_scores, fp)

In [14]:
for key in models:
    #print(key)
    return_glue_score(key)

fnet_mini: 0.3133889307877333
fnet_tiny: 0.2764506836141395
fnet_2_256: 0.4037778175597316
fnet_4_128: 0.4012487138288356
bert_mini: 0.5854608585042039
bert_tiny: 0.5439949222422282
bert_2_256: 0.6713462680277513
bert_4_128: 0.558872917432773
convbert_mini: 0.3686046134950291
convbert_tiny: 0.34452967959489184
convbert_2_256: 0.40267048942812883
convbert_4_128: 0.37637474090762774


In [10]:
#W/o hyperparameter tuning
return_glue_score('bert_mini',True)

cola : 0.0437601222642778
mnli : 0.7481692432872253
mrpc : 0.812801065308013
qnli : 0.8325096101043383
qqp : 0.849596123423532
rte : 0.5631768953068592
sst2 : 0.8268348623853211
stsb : 0.7857150097449594
wnli : 0.4225352112676056
bert_mini: 0.6538997936769035


In [11]:
#W/o hyperparameter tuning
return_glue_score('bert_tiny',True)

cola : 0.0
mnli : 0.6882628152969894
mrpc : 0.7599857450398557
qnli : 0.7924217462932455
qqp : 0.8131614635290008
rte : 0.5776173285198556
sst2 : 0.819954128440367
stsb : 0.282906695531303
wnli : 0.49295774647887325
bert_tiny: 0.5808075187921656


In [12]:
#W/o hyperparameter tuning
return_glue_score('bert_2_256',True)

cola : 0.0
mnli : 0.7209113100081367
mrpc : 0.7633269720101782
qnli : 0.8103606077246934
qqp : 0.8324127636020967
rte : 0.6209386281588448
sst2 : 0.8314220183486238
stsb : 0.7630462851405746
wnli : 0.39436619718309857
bert_2_256: 0.6374205313529164


In [13]:
#W/o hyperparameter tuning
return_glue_score('bert_4_128',True)

cola : 0.0
mnli : 0.7104353132628153
mrpc : 0.7597163865546218
qnli : 0.800109829763866
qqp : 0.8294449225786891
rte : 0.5992779783393501
sst2 : 0.8268348623853211
stsb : -0.055277708660460716
wnli : 0.5352112676056338
bert_4_128: 0.5561947613144264


In [14]:
#W/o hyperparameter tuning
return_glue_score('fnet_mini',True)

cola : 0.0
mnli : 0.5948942229454841
mrpc : 0.7303247822160068
qnli : 0.5118066996155958
qqp : 0.45421567284928477
rte : 0.48014440433212996
sst2 : 0.5103211009174312
stsb : 0.13649436765700407
wnli : 0.43661971830985913
fnet_mini: 0.4283134409825328


In [15]:
#W/o hyperparameter tuning
return_glue_score('fnet_tiny',True)

cola : -0.0463559874942472
mnli : 0.544446704637917
mrpc : 0.15808823529411764
qnli : 0.4946000366099213
qqp : 0.31613024326967304
rte : 0.5306859205776173
sst2 : 0.49311926605504586
stsb : 0.020093730118187596
wnli : 0.5633802816901409
fnet_tiny: 0.3415764923064859


In [16]:
#W/o hyperparameter tuning
return_glue_score('fnet_2_256',True)

cola : -0.006434621036303265
mnli : 0.5667209113100081
mrpc : 0.73670814479638
qnli : 0.5043016657514187
qqp : 0.4485801998879112
rte : 0.49097472924187724
sst2 : 0.48394495412844035
stsb : 0.032333552656536146
wnli : 0.5633802816901409
fnet_2_256: 0.42450109093626776


In [17]:
#W/o hyperparameter tuning
return_glue_score('fnet_4_128',True)

cola : 0.0
mnli : 0.5693653376729048
mrpc : 0.7480253018237863
qnli : 0.5053999633900788
qqp : 0.4531960153447163
rte : 0.4729241877256318
sst2 : 0.5091743119266054
stsb : 0.015041115092908158
wnli : 0.43661971830985913
fnet_4_128: 0.41219399458738787


In [18]:
#W/o hyperparameter tuning
return_glue_score('convbert_mini',True)

cola : -0.008978697532733301
mnli : 0.6870423108218063
mrpc : 0.3139309616061994
qnli : 0.4935017389712612
qqp : 0.38892708291421213
rte : 0.48375451263537905
sst2 : 0.5
stsb : 0.6851238603703937
wnli : 0.5492957746478874
convbert_mini: 0.4547330604927117


In [12]:
#W/o hyperparameter tuning
return_glue_score('convbert_tiny',True)

cola : 0.056663872782814644
mnli : 0.5702807160292921
mrpc_accuracy : 0.3088235294117647
mrpc_f1 : 0.11320754716981131
qnli : 0.4870950027457441
qqp_accuracy : 0.6025228790502103
qqp_f1 : 0.1266304347826087
rte : 0.5054151624548736
sst2 : 0.4873853211009174
stsb_spearman : -0.1850780853119452
stsb_pearson : -0.14830445080657897
wnli : 0.5492957746478874
convbert_tiny: 0.28949480867145


In [19]:
#W/o hyperparameter tuning
return_glue_score('convbert_2_256',True)

cola : -0.025881504697334985
mnli : 0.6651749389747762
mrpc : 0.5907928388746804
qnli : 0.5052169137836353
qqp : 0.4380322860627459
rte : 0.48736462093862815
sst2 : 0.5160550458715596
stsb : 0.13629454636087635
wnli : 0.4507042253521127
convbert_2_256: 0.4181948790579644


In [20]:
#W/o hyperparameter tuning
return_glue_score('convbert_4_128',True)

cola : -0.0220169981008474
mnli : 0.6581570382424735
mrpc : 0.36038011695906436
qnli : 0.4905729452681677
qqp : 0.3364265034347715
rte : 0.5018050541516246
sst2 : 0.4782110091743119
stsb : -0.13267689840577668
wnli : 0.5492957746478874
convbert_4_128: 0.3577949494857419


In [13]:
#W hyperparameter tuning
return_glue_score('convbert_mini',True)

cola : 0.05512328662084196
mnli : 0.5419039869812856
mrpc : 0.3104586735544076
qnli : 0.4865458539264141
qqp : 0.3642803512469725
rte : 0.51985559566787
sst2 : 0.47591743119266056
stsb : -0.02840711839242723
wnli : 0.5352112676056338
convbert_mini: 0.3623210364892955


In [6]:
return_glue_score('convbert_tiny',True)

cola : -0.02907606795597719
mnli : 0.3496745321399512
mrpc : 0.6614357152256876
qnli : 0.5074135090609555
qqp : 0.46152933495368514
rte : 0.48736462093862815
sst2 : 0.5022935779816514
stsb : 0.04047439836643846
wnli : 0.5492957746478874
convbert_tiny: 0.3922672661509898


In [7]:
return_glue_score('convbert_2_256',True)

cola : -0.027803716697716103
mnli : 0.5003051261187957
mrpc : 0.5117478025693036
qnli : 0.504484715357862
qqp : 0.4647691851157587
rte : 0.4693140794223827
sst2 : 0.5229357798165137
stsb : -0.0695342099356992
wnli : 0.6197183098591549
convbert_2_256: 0.38843745240292843


In [8]:
return_glue_score('convbert_4_128',True)

cola : 0.019093405342267576
mnli : 0.3715419039869813
mrpc : 0.6078102529025922
qnli : 0.49990847519677833
qqp : 0.42676803231659954
rte : 0.5415162454873647
sst2 : 0.5137614678899083
stsb : 0.7688935669943996
wnli : 0.4788732394366197
convbert_4_128: 0.46979628772816795


In [9]:
return_glue_score('fnet_mini',True)

cola : 0.0
mnli : 0.43185516680227826
mrpc : 0.15808823529411764
qnli : 0.4946000366099213
qqp : 0.3159163987138264
rte : 0.5270758122743683
sst2 : 0.4908256880733945
stsb : 0.2040465316381462
wnli : 0.5633802816901409
fnet_mini: 0.3539764612329104


In [10]:
return_glue_score('fnet_tiny',True)

cola : -0.01514459092315947
mnli : 0.37072823433685925
mrpc : 0.16640592407175636
qnli : 0.48929159802306427
qqp : 0.3561744956335734
rte : 0.5234657039711191
sst2 : 0.4908256880733945
stsb : 0.02360055513750737
wnli : 0.5633802816901409
fnet_tiny: 0.32985865444602847


In [11]:
return_glue_score('fnet_2_256',True)

cola : -0.0008379276266605604
mnli : 0.5350895036615134
mrpc : 0.6247725894481504
qnli : 0.5035694673256452
qqp : 0.4711049467149334
rte : 0.4296028880866426
sst2 : 0.49770642201834864
stsb : 0.14049262187074613
wnli : 0.5211267605633803
fnet_2_256: 0.4136252524514111


In [12]:
return_glue_score('fnet_4_128',True)

cola : 0.0
mnli : 0.5165785191212368
mrpc : 0.15808823529411764
qnli : 0.4946000366099213
qqp : 0.3159163987138264
rte : 0.5270758122743683
sst2 : 0.4908256880733945
stsb : -0.031375285964505784
wnli : 0.5633802816901409
fnet_4_128: 0.33723218731249993


In [14]:
return_glue_score('bert_mini',True)

cola : 0.12490292474785736
mnli : 0.519426362896664
mrpc : 0.7822712418300655
qnli : 0.4946000366099213
qqp : 0.8539259771895291
rte : 0.4729241877256318
sst2 : 0.8279816513761468
stsb : 0.7857150097449594
wnli : 0.4225352112676056
bert_mini: 0.5871425114875979


In [15]:
return_glue_score('bert_tiny',True)

cola : 0.037467036153535895
mnli : 0.7117575264442636
mrpc : 0.7437758551104833
qnli : 0.7843675636097382
qqp : 0.8396133232315167
rte : 0.4729241877256318
sst2 : 0.5091743119266054
stsb : 0.282906695531303
wnli : 0.49295774647887325
bert_tiny: 0.5416604718013279


In [16]:
return_glue_score('bert_4_128',True)

cola : -0.016334423518002312
mnli : 0.7248779495524816
mrpc : 0.6902456337726348
qnli : 0.7682591982427238
qqp : 0.3159163987138264
rte : 0.5848375451263538
sst2 : 0.7362385321100917
stsb : 0.14983080761493864
wnli : 0.352112676056338
bert_4_128: 0.47844270196348737


In [17]:
return_glue_score('bert_2_256',True)

cola : 0.09880168479861816
mnli : 0.5763832384052074
mrpc : 0.7480253018237863
qnli : 0.7980962840929892
qqp : 0.3159163987138264
rte : 0.6209386281588448
sst2 : 0.5091743119266054
stsb : 0.8308536686774791
wnli : 0.4084507042253521
bert_2_256: 0.5451822467580788
