# Training the bert model

## installing libraries

In [None]:
!pip install transformers pandas numpy matplotlib tqdm keras
!pip install lightgbm --install-option=--gpu


### Imports

In [None]:
# general libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

import random

random.seed(42)
np.random.seed(42)

In [None]:
import sklearn
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score, cross_validate

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import MinMaxScaler
from sklearn import metrics

from lightgbm import LGBMClassifier

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
for l,v in {'pandas':pd, 'numpy':np, 'sklearn':sklearn,'transformers':transformers}.items():
    print(f'{l}  version {v.__version__}')

pandas  version 1.1.3
numpy  version 1.18.5
keras  version 2.4.3
sklearn  version 0.22.2.post1
transformers  version 3.4.0


## Reading data from google drive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
baseDir = '/content/gdrive/My Drive/Colab Notebooks/AA'

In [None]:
datasets = pd.read_json(baseDir+'/data/AllDS_BERT.json.zip', orient='records', compression='gzip')
datasets['row_index'] = np.arange(len(datasets));
datasets = datasets[['row_index','dataset','problem','language','set','filename','text','label','BERT_vector']]


*   **row_index**: auxiliaray field to help merging the result of the models.
*   **Dataset**: group of documents belonging to the same scenario (social, literature, lyrics, etc)
*   **Problem**:  a especific test case like 20 authors with short text or 5 authors with long texts.
*   **Set**: Known is the development set,  unkown is the validation. Note this is not the train-test split, what should be done only with the development set.
*   **Filename**: the original filename in the corpus
*   **label**: the target
*   **BERT_vector**: the text corresponding to bert vector



In [None]:
datasets.head()

Unnamed: 0,row_index,dataset,problem,language,set,filename,text,label,BERT_vector
0,0,pan18_train,problem00001,en,known,known00001.txt,"graceful ones.\n\n""One more,"" Marvelous said, ...",candidate00001,"[0.4001088142, -0.9896771312, -0.2849564552, -..."
1,1,pan18_train,problem00001,en,known,known00002.txt,"before. If he can, he’ll remember a classmate ...",candidate00001,"[-0.1036661863, -0.2112793922, -0.1762309372, ..."
2,2,pan18_train,problem00001,en,known,known00003.txt,she thought - he was in Team Baron only becaus...,candidate00001,"[0.1540681422, -0.3122391403, -0.0233840719, -..."
3,3,pan18_train,problem00001,en,known,known00007.txt,"As far as she remembers, she's always hated pr...",candidate00001,"[-0.013924662, -0.4036496282, -0.3019542694000..."
4,4,pan18_train,problem00001,en,known,known00006.txt,"“Wait for me, please!”\n\nShe glanced towards ...",candidate00001,"[-0.1647676378, -0.2584642172, 0.1031820625, 0..."


In [None]:
datasets.groupby(['dataset']).agg({'problem':'nunique'}).T

dataset,lyrics,pan18_eval,pan18_train,socialaa
problem,10,20,10,32


## Trainining the models with kfold  cross validation

In [None]:
clfs = {
    'LR': make_pipeline(MinMaxScaler(),LogisticRegression(max_iter=1000, C=0.1, random_state=42)),
    'LGBM': make_pipeline(MinMaxScaler(), LGBMClassifier(max_depth=7,n_estimators=25, num_leaves=32)),
    'MLP': make_pipeline(MinMaxScaler(),MLPClassifier(hidden_layer_sizes=[15,15],max_iter=700, random_state=42,early_stopping=True)),
}

In [None]:
best_estimators = [];
predictions = [];
classification_reports = [];

#for each dataset, for each problem, for each classifier, run cross validation
for dataset in datasets['dataset'].unique():
    for problem in datasets.query(f"dataset == '{dataset}'")['problem'].unique():
        print('\n',dataset,problem,' ---- ',end=' ');

        X_train, y_train, index_train, X_test, y_test, index_test = filter_dataset(datasets,dataset,problem)

        embedding_size = X_train.shape[1];
        nclass = len(np.unique(y_train));

        for c in clfs:
            print(c, end=' ');
            classes_ = np.sort(np.unique(y_train))
            grid_search = GridSearchCV(estimator = clfs[c],param_grid={}, scoring = 'f1_macro', cv = 5)
            grid_search.fit(X_train, y_train);
            clf = grid_search.best_estimator_;

            pred_train = clf.predict(X_train);
            pred_proba_train = clf.predict_proba(X_train);

            pred_test = clf.predict(X_test);
            pred_proba_test = clf.predict_proba(X_test);

            #saving predictions for future analysis
            def appendPrediction(predictions,classes_,index, pred,proba):
                for i,p, pr in  zip(index, pred, proba):
                    predictions.append(dict(**{
                            'dataset':dataset,
                            'problem':problem,
                            'model':c,
                            'row_index':i,
                            'pred':p,
                        }, **{
                            cc:pr_ for cc,pr_ in zip(classes_, pr)
                        })
                    )

            appendPrediction(predictions,clf.classes_, index_train, pred_train, pred_proba_train)
            appendPrediction(predictions,clf.classes_, index_test, pred_test, pred_proba_test)


            #saving estimators
            best_estimators.append({
                'dataset':dataset,
                'problem':problem,
                'model':c,
                'estimator':clf
            })

            class_report = metrics.classification_report(y_test,pred_test,output_dict=True)

            classification_reports.append(dict(**{
                'dataset':dataset,
                'problem':problem,
                'model':c,
                'classification_report':class_report
            }, **class_report['macro avg']))
        print();


 pan18_train problem00001  ----  LR LGBM MLP 

 pan18_train problem00002  ----  LR LGBM MLP 

 pan18_train problem00003  ----  LR LGBM MLP 

 pan18_train problem00004  ----  LR LGBM MLP 

 pan18_train problem00005  ----  LR LGBM MLP 

 pan18_train problem00006  ----  LR LGBM MLP 

 pan18_train problem00007  ----  LR LGBM MLP 

 pan18_train problem00008  ----  LR LGBM MLP 

 pan18_train problem00009  ----  LR LGBM MLP 

 pan18_train problem00010  ----  LR LGBM MLP 

 pan18_eval problem00001  ----  LR LGBM MLP 

 pan18_eval problem00002  ----  LR LGBM MLP 

 pan18_eval problem00003  ----  LR LGBM MLP 

 pan18_eval problem00004  ----  LR LGBM MLP 

 pan18_eval problem00005  ----  LR LGBM MLP 

 pan18_eval problem00006  ----  LR LGBM MLP 

 pan18_eval problem00007  ----  LR LGBM MLP 

 pan18_eval problem00008  ----  LR LGBM MLP 

 pan18_eval problem00009  ----  LR LGBM MLP 

 pan18_eval problem00010  ----  LR LGBM MLP 

 pan18_eval problem00011  ----  LR LGBM MLP 

 pan18_eval problem0001

## Saving the classification results for future analysis

In [None]:
#saving estimators
import pickle
import bz2
with bz2.BZ2File(baseDir+'/BERT/output_bert/bert_best_estimators.pkl.bz2','wb') as f:
    pickle.dump(best_estimators,f);

In [None]:
pd.DataFrame(predictions)\
    .sort_values('row_index').round(5)\
    .to_csv(baseDir+'/BERT/output_bert/bert_predictions.csv.zip', index=False, compression='zip', encoding='utf-8')

### predictions 

* predictions file contains the prediction for a dataset|problem|instance vs LR|MLP|LGBM
* Model is the model pipeline.
* Prediction is the label with the highest probability

In [None]:
pd.DataFrame(predictions).round(5).sort_values('row_index')

Unnamed: 0,dataset,problem,model,row_index,pred,candidate00001,candidate00002,candidate00003,candidate00004,candidate00005,candidate00006,candidate00007,candidate00008,candidate00009,candidate00010,candidate00011,candidate00012,candidate00013,candidate00014,candidate00015,candidate00016,candidate00017,candidate00018,candidate00019,candidate00020,candidate00021,candidate00022,candidate00023,candidate00024,candidate00025,candidate00026,candidate00027,candidate00028,candidate00029,candidate00030,candidate00031,candidate00032,candidate00033,candidate00034,candidate00035,candidate00036,candidate00037,candidate00038,candidate00039,candidate00040,candidate00041,candidate00042,candidate00043,candidate00044,candidate00045,candidate00046,candidate00047,candidate00048,candidate00049,candidate00050
0,pan18_train,problem00001,LR,0,candidate00001,0.31580,0.02728,0.01090,0.03243,0.04595,0.03042,0.04867,0.01975,0.04859,0.05737,0.03326,0.01372,0.02993,0.07244,0.02539,0.03529,0.03982,0.03997,0.03499,0.03804,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
245,pan18_train,problem00001,LGBM,0,candidate00001,0.48454,0.02085,0.01822,0.02566,0.01875,0.03488,0.02635,0.02891,0.01893,0.06259,0.01025,0.01134,0.01335,0.07360,0.02828,0.01434,0.04600,0.01904,0.02012,0.02400,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
490,pan18_train,problem00001,MLP,0,candidate00005,0.06086,0.03304,0.06282,0.03827,0.06884,0.04990,0.05381,0.03692,0.05603,0.04218,0.05754,0.05448,0.04543,0.05300,0.04507,0.06163,0.02969,0.03630,0.05645,0.05773,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
491,pan18_train,problem00001,MLP,1,candidate00005,0.06837,0.04095,0.05835,0.03500,0.08937,0.06533,0.04328,0.03291,0.06356,0.04717,0.05221,0.04062,0.04202,0.05834,0.04477,0.06412,0.02353,0.03138,0.05032,0.04840,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
1,pan18_train,problem00001,LR,1,candidate00001,0.34838,0.06337,0.01064,0.05981,0.04476,0.01041,0.02305,0.01504,0.04484,0.03527,0.04471,0.01362,0.05181,0.05797,0.03475,0.05394,0.02444,0.02972,0.01716,0.01631,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
92896,socialaa,problem00032,MLP,30964,candidate00031,0.00399,0.00785,0.00001,0.00001,0.00009,0.03388,0.01107,0.00142,0.00078,0.09011,0.00130,0.06559,0.00043,0.00239,0.00000,0.00012,0.11235,0.00003,0.06147,0.00449,0.00001,0.00000,0.00428,0.08226,0.00030,0.00348,0.01666,0.02823,0.01296,0.00038,0.11860,0.00073,0.04028,0.01851,0.00254,0.00298,0.00000,0.00000,0.07581,0.00239,0.00126,0.00618,0.00484,0.00231,0.00011,0.00035,0.10595,0.00006,0.07113,0.00002
87796,socialaa,problem00032,LR,30964,candidate00049,0.00265,0.00362,0.00089,0.01404,0.00422,0.00811,0.02147,0.00913,0.00230,0.02881,0.01582,0.06632,0.00078,0.01912,0.00449,0.00497,0.02752,0.00133,0.02165,0.00071,0.00031,0.00719,0.00034,0.02887,0.04554,0.00386,0.00462,0.14114,0.00809,0.00187,0.03274,0.00180,0.00398,0.00969,0.01540,0.01052,0.00131,0.00031,0.06494,0.02468,0.00353,0.00213,0.00823,0.00153,0.00704,0.00294,0.01739,0.00554,0.28495,0.00155
90347,socialaa,problem00032,LGBM,30965,candidate00044,0.00939,0.00938,0.00939,0.00939,0.01968,0.02922,0.01054,0.00942,0.00939,0.00939,0.01343,0.01143,0.00941,0.00937,0.00940,0.01715,0.01273,0.00940,0.00940,0.01146,0.00940,0.00938,0.01644,0.00944,0.01955,0.01133,0.01212,0.00939,0.02304,0.00938,0.01508,0.00936,0.01389,0.00956,0.00939,0.00940,0.00940,0.00940,0.00940,0.00938,0.00942,0.01426,0.01199,0.22971,0.05096,0.01314,0.00941,0.02584,0.00940,0.15378
87797,socialaa,problem00032,LR,30965,candidate00050,0.00098,0.00794,0.01147,0.00879,0.09830,0.03226,0.00502,0.00466,0.00199,0.00403,0.01611,0.02212,0.00128,0.00855,0.00331,0.06340,0.00762,0.00097,0.01096,0.02645,0.00087,0.01762,0.03306,0.00055,0.00410,0.01244,0.00910,0.00037,0.09638,0.00367,0.01272,0.04420,0.00978,0.00623,0.01206,0.00036,0.00038,0.00056,0.01441,0.01127,0.00153,0.06136,0.02056,0.04878,0.06154,0.00473,0.00635,0.04460,0.00220,0.12199


### Classification Reports

The classification report contains the results for the test (or unknown text, or  validation set) which is the official result.

In [None]:
pd.DataFrame(classification_reports).to_json(
    baseDir+'/BERT/output_bert/bert_classification_reports.json.zip',
    orient='records',
    compression='zip'
)

In [None]:
import re
def statistics(x):
    docs = x.query('set == "known"').groupby('label').agg({'filename':'nunique'}).mean().astype(int).values[0];

    nchar = int(x.query('set == "unknown"')['text'].apply(lambda x:len(x)).mean());
    leastOne = lambda x: x if x>0 else 1;
    nword = x.query('set == "unknown"')['text'].apply(lambda x:len(re.findall(r'\b\w+\b',x))).apply(leastOne).mean();

    nauthors = len(x['label'].unique())
    return pd.Series({
        'ndocs':docs,
        'nauthors':nauthors,
        'nchar': int(nchar/10)*10,
        'nword':int(np.ceil(nword/5)*5),
    })

metadata = datasets.groupby(['dataset','problem','language']).apply(statistics).reset_index()

In [None]:
temp = pd.DataFrame(classification_reports) \
        .pivot_table(
            index=['dataset','problem'],
            columns='model',
            values='f1-score')[['LR','LGBM','MLP']]

temp2 = temp.reset_index().merge(metadata)

with open(baseDir+'/BERT/output_bert/bert_report.txt','w') as f:

    f.write(temp.round(2).to_latex())
    f.write("\n\n");
    f.write(temp2.groupby(['dataset','language']).mean()[['LR','LGBM','MLP']].reset_index().round(2).to_latex(index=False))

    f.write("\n\n");
    f.write(temp2.groupby(['dataset','ndocs']).mean()[['LR','LGBM','MLP']].reset_index().round(2).to_latex(index=False))

    f.write("\n\n");
    f.write(temp2.groupby(['dataset','nauthors']).mean()[['LR','LGBM','MLP']].reset_index().round(2).to_latex(index=False))

    f.write("\n\n");
    f.write(temp2.groupby(['dataset','nchar']).mean()[['LR','LGBM','MLP']].reset_index().round(2).to_latex(index=False))

    f.write("\n\n");
    f.write(temp2.groupby(['dataset','nword']).mean()[['LR','LGBM','MLP']].reset_index().round(2).to_latex(index=False))


In [None]:
with pd.option_context("display.max_rows", 100):
    display(temp.round(4))

Unnamed: 0_level_0,model,LR,LGBM,MLP
dataset,problem,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
lyrics,problem00001,0.4302,0.367,0.1733
lyrics,problem00002,0.1585,0.1477,0.0292
lyrics,problem00003,0.1347,0.139,0.0361
lyrics,problem00004,0.1066,0.086,0.0106
lyrics,problem00005,0.0767,0.0877,0.006
lyrics,problem00006,0.424,0.3284,0.0718
lyrics,problem00007,0.2456,0.2635,0.0333
lyrics,problem00008,0.1905,0.2056,0.0091
lyrics,problem00009,0.1968,0.1771,0.0026
lyrics,problem00010,0.1118,0.1219,0.0157
