## Using BERT and XLNet to do document classification
This notebook uses the transfomer library and, on top of that,  simpletransformers.

The first part loads the arxiv science docs data and formats them into the pandas tables needed to do the training and testing.

After that we do the classisification with BERT.  While it loads the small language model automatically, You will need to train the network first with the arxiv data.
after it has been trained you can reload the model if you want to play with it later.

The next part does the classification with XLNET.   the process is identical to the training and testing step with BERT.

In [None]:
#pip install simpletransformers

In [None]:
from simpletransformers.classification import ClassificationModel
import pandas as pd

In [None]:
import numpy as np
import pickle
import json

### Load the data.   
you can find the data in the same github repo where you found this notebook.   it is not very big.   This data loader is at all general.  it is specific to this collection, so don't bother to understand this.

In [1]:
def load_docs(path, name):
    filename = path+name+".p"
    fileobj = open(filename, "rb")
    z = fileobj.read()
    lst = pickle.loads(z)
    titles = []
    sitenames = []
    abstracts = []
    for i in range(0, len(lst)):
        titles.extend([lst[i][0]])
        sitenames.extend([lst[i][1]])
        abstracts.extend([lst[i][2]])
        
    print("done loading "+filename)
    return abstracts, sitenames, titles

In [None]:
abstracts, sitenames, disp_title = load_docs("./sci_doc/", 
                                     "sciml_data_arxiv")

In [None]:
len(abstracts)

In [None]:
def read_config(subj, basepath):
    docpath =basepath+ "/config_"+subj+".json"
    with open(docpath, 'rb') as f:
        doc = f.read() 
    z =json.loads(doc)
    subject = z['subject']
    loadset = z['loadset']
    subtopics = []
    for w in z['supertopics']:
        subtopics.extend([(w[0], w[1])])
    return subject, loadset, subtopics


In [None]:
def make_dict(subtopics):
    dic = {}
    for main in subtopics:
        sl = main[1]
        for x in sl:
            dic[x] = main[0]
    return dic

def split_titles(disp_title):
    subject,loadset, subtopics = read_config("all4","./sci_doc")
    dic = make_dict(subtopics)
    lis = []
    for ti in disp_title:
        l = ti.find('[')
        if(l >= 0):
            #lis.append(ti[:l])
            e = ti[l+1:]
            l2 = e.find(']')
            e = e[:l2]
            try:
                if dic[e]== 'compsci':
                    lis.append([ti[:l], 0, e])
                if dic[e]== 'math':
                    lis.append([ti[:l], 1, e])
                if dic[e]== 'Physics':
                    lis.append([ti[:l], 2, e])
                if dic[e]== 'bio':
                    lis.append([ti[:l], 3, e])
                if dic[e]== 'finance':
                    lis.append([ti[:l], 4, e])
            except:
                print(e)
            
    return lis

In [None]:
subject,loadset, subtopics = read_config("all4","./sci_doc")
dic = make_dict(subtopics)
len(dic)

In [None]:
def split_text(disp_title, abstracts):
    subject,loadset, subtopics = read_config("all4","./sci_doc")
    dic = make_dict(subtopics)
    lis = []
    ind = 0
    for ind in range(len(disp_title)): #disp_title,titles:
        ti = disp_title[ind]
        te = abstracts[ind]
        l = ti.find('[')
        if(l >= 0):
            #lis.append(ti[:l])
            e = ti[l+1:]
            l2 = e.find(']')
            e = e[:l2]
            try:
                if dic[e]== 'compsci':
                    lis.append([te, 0])
                if dic[e]== 'math':
                    lis.append([te, 1])
                if dic[e]== 'Physics':
                    lis.append([te, 2])
                if dic[e]== 'bio':
                    lis.append([te, 3])
                if dic[e]== 'finance':
                    lis.append([te, 4])
            except:
                print(e)
            
    return lis

In [None]:
train_titles = split_titles(disp_title[0:4500]) #contains titles only
eval_titles  = split_titles(disp_title[4500:])
train_text = split_text(disp_title[0:4500], abstracts[0:4500])
eval_text = split_text(disp_title[4500:],abstracts[4500:])  #contains text + class

In [None]:
print(eval_titles[0])
print(eval_text[0])

In [None]:
# Train and Evaluation data needs to be in a Pandas Dataframe containing at least two columns. #If the Dataframe has a header, it should contain a 'text' and a 'labels' column. 
#If no header is present, the Dataframe should contain at least two columns, 
#with the first column is the text with type str, 
#and the second column in the label with type int.
#train_data = [['Example sentence belonging to class 1', 1], 
#              ['Example sentence belonging to class 0', 0], 
#              ['Example eval senntence belonging to class 2', 2]]
train_df = pd.DataFrame(train_data)

#eval_data = [['Example eval sentence belonging to class 1', 1], 
#             ['Example eval sentence belonging to class 0', 0], 
#             ['Example eval senntence belonging to class 2', 2]]
eval_df = pd.DataFrame(eval_titles)
text_df = pd.DataFrame(train_text)
text_eval_df = pd.DataFrame(eval_text)



### Do BERT classification.  
the first time you do this you need to run the trainer.   after that is done move the output to a seperate directory "outputs-bert-originalall4".   after you do that you can later reload the data as shown.

In [None]:
#either load the previously trained model 
#model = ClassificationModel('bert',  'outputs-bert-origianlall4/', num_labels=5, use_cuda=False)

# or create and train it
model = ClassificationModel('bert', 'bert-base-cased', num_labels=5, 
                            args={'reprocess_input_data': True, 
                                  'overwrite_output_dir': True}, use_cuda=False)

model.train_model(text_df)
#this takes a while

In [None]:
results = ''
model_outputs = []
wrong_predictions = []
print(len(text_eval_df))

In [None]:
# Evaluate the model
result, model_outputs, wrong_predictions = model.eval_model(text_eval_df[0:2600])


In [None]:
print("fraction wrong = ", len(wrong_predictions)/2600.0)

In [None]:
category = {0: 'cs  ', 1: 'math', 2: 'phys', 3: 'bio', 4: 'fina'}

The following cell just selects the first few incorrect predictions and prints them out.

In [None]:
c = 0
cs_sample = True
math_sample= True
physics_sampe = True
bio_sample = True
finance_sampe = True
need_sample = [True, True, True, True, True]
for s in range(len(wrong_predictions)):
    if need_sample[wrong_predictions[s].label]:
            print('item',wrong_predictions[s].guid, 
              'is ',category[wrong_predictions[s].label], ' but predicted to be ',
              category[model_outputs[wrong_predictions[s].guid].argmax()])
            print(eval_titles[wrong_predictions[s].guid])
            print(eval_text[wrong_predictions[s].guid])
            print('------------------------------------------------------------')
            need_sample[wrong_predictions[s].label] = False
 

The following function prints out the confusion matrix.   For row X it prints the percent of predictions it makes in each category.  In other words if an X abstract is predicted to be a Y, then a 1 is added to row X, column Y.

In [None]:
def show_confusion(data, model_outputs):
    pr = []
    tr = []
    for ou in model_outputs:
        pr.append(ou.argmax())
    for x in data:
        #print(x)
        tr.append(x[1])
    mat = np.zeros([5,5])
    i = 0
    for p in pr:
        mat[tr[i], p]+= 1
        i+=1
        
    truevals = mat[0,0]+mat[1,1]+mat[2,2]+mat[3,3]+mat[4,4]
    for i in range(5):
        s = np.sum(mat[i,:])
        mat[i,:] = np.round(100*mat[i,:]/s)
    pds = {' ':['compsci','math', 'physics','bio','finance']}
    pds['compsci'] = mat[:,0]
    pds['math']    = mat[:,1]
    pds['physics'] = mat[:,2]
    pds['bio']    = mat[:,3]
    pds['finance'] = mat[:,4]
    print("accuracy =", truevals/len(data))
    #print(pds)
    pdf = pd.DataFrame(pds)
    return pdf

In [None]:
print(show_confusion(eval_text[0:2600], model_outputs))

#### XLNET
now we will try  to  do the classification using xlnet


In [None]:
#as with bert, either read the trained model
#modelxl = ClassificationModel('xlnet',  'outputs-xlnet-origianall4/', num_labels=5, use_cuda=False)
#or trainit
modelxl = ClassificationModel('xlnet', 'xlnet-base-cased', num_labels=5, 
                            args={'reprocess_input_data': True, 
                                 'overwrite_output_dir': True}, use_cuda=False)
modelxl.train_model(text_df)

In [None]:
resultxl, model_outputslx, wrong_predictionslx = modelxl.eval_model(text_eval_df[0:2600])

In [None]:
print(show_confusion(eval_text[0:2600], model_outputslx))

#### compute the "best of 2"
this function takes the two most higly rated signals and if one of them is correct then the it is considered correct.  of course, this is a rather ah hoc measure, but still interesting.  see the analysis in the paper: https://esciencegroup.com/2020/02/20/modeling-natural-language-with-transformers-bert-roberta-and-xlnet/


In [None]:
def bestof2(model_outputs, text):
    mat = np.zeros([5,5])
    for i in range(len(model_outputs)):
        out = np.zeros(5)
        for j in range(5):
            out[j] = model_outputs[i][j]
        v = out.argmax()
        out[v] = -3.14       
        u = out.argmax()
        if u > 0.25*v:
            second = u
        else:
            second = -1
        tru = text[i][1]
        
        if (v != tru) and (second == tru):
            v = second
        mat[tru, v]+=1
        #print(i, tru, v, out)
    truevals = mat[0,0]+mat[1,1]+mat[2,2]+mat[3,3]+mat[4,4]
    for i in range(5):
        s = np.sum(mat[i,:])
        mat[i,:] = np.round(100*mat[i,:]/s)
    pds = {' ':['compsci', 'math', 'physics','bio','finance']}
    pds['compsci'] = mat[:,0]
    pds['math']    = mat[:,1]
    pds['physics'] = mat[:,2]
    pds['bio']    = mat[:,3]
    pds['finance'] = mat[:,4]
    print("accuracy =", truevals/len(text))
    pdf = pd.DataFrame(pds) 
    return pdf

In [None]:
print(bestof2(model_outputs,  eval_text[0:2600]))