In [45]:
import os
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import random
import argparse
import logging

from sklearn.metrics import roc_auc_score, f1_score,average_precision_score
from sklearn.metrics import precision_recall_fscore_support 
from sklearn.metrics import roc_curve,precision_recall_curve
from sklearn.metrics import auc as auc_score

from datasets import load_dataset, load_metric, concatenate_datasets,DatasetDict,Dataset
from datasets import load_from_disk

import transformers
print("Transformers version is {}".format(transformers.__version__))

import torch

from transformers import (
    AdamW,
    AutoConfig,
    AutoModel,
    AutoModelWithLMHead,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    default_data_collator,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    get_linear_schedule_with_warmup,
    get_scheduler
)

import utils

import seaborn as sns
from pylab import rcParams
from matplotlib import pyplot as plt
from matplotlib import rc

sns.set(style="whitegrid",palette='muted',font_scale=1.2)
# rcParams['figure.figsize']=16,10

%config InlineBackend.figure_format="retina"
%matplotlib inline

Transformers version is 4.6.1


In [46]:
def seed_everything(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

In [47]:
if __name__=="__main__":
    parser = argparse.ArgumentParser(description='Model Inference')
    parser.add_argument('--gpus', type=int, default=[1], nargs='+', help='used gpu')
    parser.add_argument("--seed",  type=int,default=101,
            help="random seed for np.random.seed, torch.manual_seed and torch.cuda.manual_seed.")
    parser.add_argument("--truncation_strategy", type=str, default="tail",help="how to truncate the long length email")
    parser.add_argument("--batch_size", type=int, default=60)

    parser.add_argument("--max_length", type=int, default=2000,help="maximal input length")
    parser.add_argument("--feature_name", default="Client_TextBody", type=str)
    parser.add_argument("--output_dir", default=os.path.join(os.getcwd(),"longformer_repo"), type=str, help="output folder name")
    
    args,_ = parser.parse_known_args()
    
    args.output_dir=f'{args.output_dir}_{args.feature_name}'
    
    print(args)
    
    seed_everything(args.seed)
    
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(x) for x in args.gpus)
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

Namespace(batch_size=8, feature_name='Full_TextBody', gpus=[1], model_checkpoint='allenai/longformer-base-4096', model_output_name='longformer_Full_TextBody_output', seed=101, shuffle_train=True, test_negative_positive_ratio=10, train_negative_positive_ratio=4, validation_split=0.2)


In [60]:
def metric_table(table_name="training_output.txt"):
    Model_Type=[]
    EPOCH=[]
    LOSS=[]
    True_Prediction=[]
    False_Prediction=[]
    Accuracy=[]
    Precision=[]
    Recall=[]
    F1_Score=[]
    AUC=[]
    PR_AUC=[]

    with open(os.path.join(os.getcwd(),table_name),'r') as f:
        for line in f:
            Model_Type.append(str(line.split(",")[0]))
            EPOCH.append(int(line.split(",")[1]))
            LOSS.append(float(line.split(",")[2]))
            True_Prediction.append(int(line.split(",")[3]))
            False_Prediction.append(int(line.split(",")[4]))
            Accuracy.append(float(line.split(",")[5]))
            Precision.append(float(line.split(",")[6]))
            Recall.append(float(line.split(",")[7]))
            F1_Score.append(float(line.split(",")[8]))
            AUC.append(float(line.split(",")[12]))
            PR_AUC.append(float(line.split(",")[13]))

    metrics=pd.DataFrame({"model_type":Model_Type,"epoch":EPOCH,"loss":LOSS,"true_prediction":True_Prediction,"false_prediction":False_Prediction,"accuracy":Accuracy,\
                         "precision":Precision,"recall":Recall,"f1_score":F1_Score,"auc":AUC,"pr_auc":PR_AUC})
    metrics.drop_duplicates(subset=["model_type","epoch"],inplace=True)
    metrics.sort_values(by=['model_type','epoch'],inplace=True)       
    
    return metrics

def style_format(metrics, model, type="training set"):
    metrics=metrics[metrics["model_type"].apply(lambda x : x.split("_")[0]==model)].reset_index(drop=True)
    return metrics.style.format({"loss":"{:.4f}","accuracy":"{:.2%}","true_prediction":"{:,}","false_prediction":"{:,}", "precision":"{:.2%}", "recall":"{:.2%}", \
                                "f1_score":"{:.2%}", "auc":"{:.2%}","pr_auc":"{:.2%}"}) \
    .set_caption(f"Performance Summary For {type} -- {model}") \
    .set_table_styles([{
        'selector': 'caption',
        'props': [
            ('color', 'red'),
            ('font-size', '20px')
        ]
    }])

In [61]:
metric_training=metric_table(table_name="training_output.txt")
style_format(metric_training, model="longformer",type="training set")

Unnamed: 0,model_type,epoch,loss,true_prediction,false_prediction,accuracy,precision,recall,f1_score,auc,pr_auc
0,longformer_Client_TextBody_output,0,0.1233,38604,1752,95.66%,89.62%,93.46%,91.50%,99.17%,97.64%
1,longformer_Client_TextBody_output,1,0.0781,38583,1773,95.61%,85.59%,99.11%,91.86%,99.70%,99.12%
2,longformer_Client_TextBody_output,2,0.0785,38584,1772,95.61%,85.60%,99.11%,91.86%,99.70%,99.12%
3,longformer_Client_TextBody_output,3,0.0785,38583,1773,95.61%,85.59%,99.11%,91.86%,99.70%,99.12%
4,longformer_Client_TextBody_output,4,0.0781,38583,1773,95.61%,85.59%,99.11%,91.86%,99.70%,99.12%
5,longformer_Client_TextBody_output,5,0.0779,38584,1772,95.61%,85.60%,99.11%,91.86%,99.70%,99.12%
6,longformer_Client_TextBody_output,6,0.0787,38583,1773,95.61%,85.59%,99.11%,91.86%,99.70%,99.12%
7,longformer_Client_TextBody_output,7,0.0779,38583,1773,95.61%,85.59%,99.11%,91.86%,99.70%,99.12%
8,longformer_Client_TextBody_output,8,0.078,38583,1773,95.61%,85.59%,99.11%,91.86%,99.70%,99.12%
9,longformer_Client_TextBody_output,9,0.0785,38583,1773,95.61%,85.59%,99.11%,91.86%,99.70%,99.12%


In [36]:
# metric_training=metric_table(table_name="training_output.txt")
# metric_training=metric_training[metric_training["model_type"].apply(lambda x : x.split("_")[0]=="longformer")].reset_index(drop=True)
# metric_training.sort_values(by=['model_type','epoch'],inplace=True)
# metric_training.to_csv("metrics_training.txt",sep=',',header=None, mode='a',index=False)

In [70]:
2*0.25*1/(1.25)

0.4

In [63]:
metric_test=metric_table(table_name="test_output.txt")
style_format(metric_test, model="longformer", type="test set")

Unnamed: 0,model_type,epoch,loss,true_prediction,false_prediction,accuracy,precision,recall,f1_score,auc,pr_auc
0,longformer_Client_TextBody_output,0,1.3491,10686,5194,67.29%,21.72%,11.84%,15.32%,46.99%,23.29%
1,longformer_Client_TextBody_output,1,1.6398,11004,4876,69.29%,16.34%,5.54%,8.28%,47.31%,23.00%
2,longformer_Client_TextBody_output,2,1.6398,11004,4876,69.29%,16.34%,5.54%,8.28%,47.31%,23.00%
3,longformer_Client_TextBody_output,3,1.6398,11004,4876,69.29%,16.34%,5.54%,8.28%,47.31%,23.00%
4,longformer_Client_TextBody_output,4,1.6398,11004,4876,69.29%,16.34%,5.54%,8.28%,47.31%,23.00%
5,longformer_Client_TextBody_output,5,1.6398,11004,4876,69.29%,16.34%,5.54%,8.28%,47.31%,23.00%
6,longformer_Client_TextBody_output,6,1.6398,11004,4876,69.29%,16.34%,5.54%,8.28%,47.31%,23.00%
7,longformer_Client_TextBody_output,7,1.6398,11004,4876,69.29%,16.34%,5.54%,8.28%,47.31%,23.00%
8,longformer_Client_TextBody_output,8,1.6398,11004,4876,69.29%,16.34%,5.54%,8.28%,47.31%,23.00%
9,longformer_Client_TextBody_output,9,1.6398,11004,4876,69.29%,16.34%,5.54%,8.28%,47.31%,23.00%


In [44]:
# metric_test=metric_table(table_name="validation_output.txt")
# metric_test=metric_test[metric_test["model_type"].apply(lambda x : x.split("_")[0]=="longformer")].reset_index(drop=True)
# metric_test.sort_values(by=['model_type','epoch'],inplace=True)
# metric_test.to_csv("metrics_test.txt",sep=',',header=None, mode='a',index=False)

In [64]:
metric_training=metric_table(table_name="metrics_training.txt")
style_format(metric_training, model="bert",type="training set")

Unnamed: 0,model_type,epoch,loss,true_prediction,false_prediction,accuracy,precision,recall,f1_score,auc,pr_auc
0,bert_Client_TextBody_tail,0,0.4764,33742,6602,83.64%,66.27%,70.38%,68.26%,86.72%,75.80%
1,bert_Client_TextBody_tail,1,0.4766,33742,6602,83.64%,66.27%,70.36%,68.25%,86.73%,75.80%
2,bert_Client_TextBody_tail,2,0.4762,33745,6599,83.64%,66.28%,70.37%,68.26%,86.73%,75.80%
3,bert_Client_TextBody_tail,3,0.4767,33740,6604,83.63%,66.27%,70.36%,68.25%,86.72%,75.80%
4,bert_Client_TextBody_tail,4,0.4764,33741,6603,83.63%,66.26%,70.36%,68.25%,86.72%,75.80%
5,bert_Client_TextBody_tail,5,0.4766,33741,6603,83.63%,66.26%,70.37%,68.25%,86.72%,75.79%
6,bert_Client_TextBody_tail,6,0.4763,33743,6601,83.64%,66.26%,70.38%,68.26%,86.73%,75.80%
7,bert_Client_TextBody_tail,7,0.4762,33739,6605,83.63%,66.26%,70.36%,68.25%,86.72%,75.79%
8,bert_Client_TextBody_tail,8,0.4762,33741,6603,83.63%,66.26%,70.36%,68.25%,86.73%,75.79%
9,bert_Client_TextBody_tail,9,0.4765,33743,6601,83.64%,66.27%,70.37%,68.26%,86.73%,75.80%


In [65]:
metric_training=metric_table(table_name="metrics_test.txt")
style_format(metric_training, model="bert",type="test set")

Unnamed: 0,model_type,epoch,loss,true_prediction,false_prediction,accuracy,precision,recall,f1_score,auc,pr_auc
0,bert_Client_TextBody_tail,0,0.636,10598,5282,66.74%,29.45%,23.68%,26.25%,53.57%,26.38%
1,bert_Client_TextBody_tail,1,0.636,10598,5282,66.74%,29.45%,23.68%,26.25%,53.57%,26.38%
2,bert_Client_TextBody_tail,2,0.636,10598,5282,66.74%,29.45%,23.68%,26.25%,53.57%,26.38%
3,bert_Client_TextBody_tail,3,0.636,10598,5282,66.74%,29.45%,23.68%,26.25%,53.57%,26.38%
4,bert_Client_TextBody_tail,4,0.636,10598,5282,66.74%,29.45%,23.68%,26.25%,53.57%,26.38%
5,bert_Client_TextBody_tail,5,0.636,10598,5282,66.74%,29.45%,23.68%,26.25%,53.57%,26.38%
6,bert_Client_TextBody_tail,6,0.636,10598,5282,66.74%,29.45%,23.68%,26.25%,53.57%,26.38%
7,bert_Client_TextBody_tail,7,0.636,10598,5282,66.74%,29.45%,23.68%,26.25%,53.57%,26.38%
8,bert_Client_TextBody_tail,8,0.636,10598,5282,66.74%,29.45%,23.68%,26.25%,53.57%,26.38%
9,bert_Client_TextBody_tail,9,0.636,10598,5282,66.74%,29.45%,23.68%,26.25%,53.57%,26.38%


In [66]:
metric_training=metric_table(table_name="metrics_training.txt")
style_format(metric_training, model="CNN",type="training set")

Unnamed: 0,model_type,epoch,loss,true_prediction,false_prediction,accuracy,precision,recall,f1_score,auc,pr_auc
0,CNN_Client_TextBody,0,0.6527,34635,5725,85.82%,73.53%,67.59%,70.44%,89.49%,78.46%
1,CNN_Client_TextBody,1,0.6228,36174,4186,89.63%,90.80%,65.11%,75.84%,94.80%,87.16%
2,CNN_Client_TextBody,2,0.5943,37291,3069,92.40%,86.05%,83.05%,84.52%,96.45%,89.97%
3,CNN_Client_TextBody,3,0.5663,37582,2778,93.12%,89.70%,81.86%,85.60%,97.28%,91.68%
4,CNN_Client_TextBody,4,0.5379,37905,2455,93.92%,88.23%,87.31%,87.77%,97.81%,92.93%
5,CNN_Client_TextBody,5,0.5091,38153,2207,94.53%,90.81%,86.92%,88.82%,98.17%,93.85%
6,CNN_Client_TextBody,6,0.4799,38387,1973,95.11%,91.32%,88.90%,90.09%,98.50%,94.80%
7,CNN_Client_TextBody,7,0.4511,38577,1783,95.58%,92.24%,89.89%,91.05%,98.74%,95.54%
8,CNN_Client_TextBody,8,0.4224,38884,1476,96.34%,91.23%,94.45%,92.81%,98.93%,96.08%
9,CNN_Client_TextBody,9,0.395,38936,1424,96.47%,89.99%,96.63%,93.19%,99.09%,96.62%


In [67]:
metric_training=metric_table(table_name="metrics_test.txt")
style_format(metric_training, model="CNN",type="test set")

Unnamed: 0,model_type,epoch,loss,true_prediction,false_prediction,accuracy,precision,recall,f1_score,auc,pr_auc
0,CNN_Client_TextBody,0,0.6952,10945,4935,68.92%,22.66%,10.08%,13.95%,49.87%,25.27%
1,CNN_Client_TextBody,1,0.6979,11637,4243,73.28%,25.32%,3.53%,6.19%,50.32%,25.93%
2,CNN_Client_TextBody,2,0.698,11316,4564,71.26%,25.93%,8.06%,12.30%,50.64%,26.27%
3,CNN_Client_TextBody,3,0.7016,11528,4352,72.59%,27.32%,5.79%,9.56%,50.87%,26.35%
4,CNN_Client_TextBody,4,0.7042,11443,4437,72.06%,26.34%,6.55%,10.49%,50.82%,26.21%
5,CNN_Client_TextBody,5,0.7121,11600,4280,73.05%,30.38%,6.05%,10.08%,50.80%,26.15%
6,CNN_Client_TextBody,6,0.719,11650,4230,73.36%,31.43%,5.54%,9.42%,50.58%,25.83%
7,CNN_Client_TextBody,7,0.7295,11674,4206,73.51%,28.78%,4.03%,7.07%,50.79%,25.73%
8,CNN_Client_TextBody,8,0.7337,11655,4225,73.39%,32.17%,5.79%,9.82%,50.75%,25.79%
9,CNN_Client_TextBody,9,0.7346,11572,4308,72.87%,31.59%,7.30%,11.87%,51.04%,25.99%


In [68]:
metric_training=metric_table(table_name="metrics_training.txt")
style_format(metric_training, model="TF-IDF",type="training set")

Unnamed: 0,model_type,epoch,loss,true_prediction,false_prediction,accuracy,precision,recall,f1_score,auc,pr_auc
0,TF-IDF_Client_TextBody,0,0.6919,31629,8731,78.37%,55.44%,68.58%,61.32%,83.92%,65.00%
1,TF-IDF_Client_TextBody,1,0.6907,34603,5757,85.74%,66.11%,88.11%,75.54%,93.10%,81.09%
2,TF-IDF_Client_TextBody,2,0.6894,35624,4736,88.27%,70.09%,92.57%,79.77%,95.40%,85.64%
3,TF-IDF_Client_TextBody,3,0.6882,35866,4494,88.87%,70.78%,94.45%,80.92%,96.15%,87.37%
4,TF-IDF_Client_TextBody,4,0.687,36149,4211,89.57%,72.03%,95.24%,82.03%,96.49%,88.18%
5,TF-IDF_Client_TextBody,5,0.6857,36125,4235,89.51%,71.55%,96.33%,82.11%,96.66%,88.56%
6,TF-IDF_Client_TextBody,6,0.6845,36215,4145,89.73%,71.96%,96.53%,82.46%,96.78%,88.85%
7,TF-IDF_Client_TextBody,7,0.6833,36303,4057,89.95%,72.40%,96.63%,82.78%,96.85%,89.05%
8,TF-IDF_Client_TextBody,8,0.6821,36369,3991,90.11%,72.69%,96.83%,83.04%,96.92%,89.24%
9,TF-IDF_Client_TextBody,9,0.6808,36392,3968,90.17%,72.85%,96.73%,83.11%,96.97%,89.37%


In [69]:
metric_training=metric_table(table_name="metrics_test.txt")
style_format(metric_training, model="TF-IDF",type="test set")

Unnamed: 0,model_type,epoch,loss,true_prediction,false_prediction,accuracy,precision,recall,f1_score,auc,pr_auc
0,TF-IDF_Client_TextBody,0,0.6932,9562,6318,60.21%,20.41%,20.40%,20.41%,49.69%,23.93%
1,TF-IDF_Client_TextBody,1,0.6932,9765,6115,61.49%,20.25%,18.39%,19.27%,49.82%,23.80%
2,TF-IDF_Client_TextBody,2,0.6932,9962,5918,62.73%,19.83%,16.12%,17.78%,49.70%,23.89%
3,TF-IDF_Client_TextBody,3,0.6932,10075,5805,63.44%,22.32%,18.64%,20.32%,49.71%,24.00%
4,TF-IDF_Client_TextBody,4,0.6933,10169,5711,64.04%,24.10%,20.40%,22.10%,49.83%,24.12%
5,TF-IDF_Client_TextBody,5,0.6933,10134,5746,63.82%,24.00%,20.65%,22.20%,49.86%,24.20%
6,TF-IDF_Client_TextBody,6,0.6933,10145,5735,63.89%,24.08%,20.65%,22.24%,49.93%,24.28%
7,TF-IDF_Client_TextBody,7,0.6933,10232,5648,64.43%,25.01%,21.16%,22.93%,49.95%,24.35%
8,TF-IDF_Client_TextBody,8,0.6934,10194,5686,64.19%,24.13%,20.15%,21.96%,49.91%,24.37%
9,TF-IDF_Client_TextBody,9,0.6934,10204,5676,64.26%,24.51%,20.65%,22.42%,49.95%,24.46%


In [None]:
def eval_func(data_loader,model,device,num_classes=2,loss_weight=None):
    model.eval()
    fin_targets=[]
    fin_outputs=[]
    losses=[]
    
    model=model.to(device)
#     for batch_idx, batch in enumerate(data_loader):
    batch_idx=0
    for batch in tqdm(data_loader, position=0, leave=True):
        batch={k:v.type(torch.LongTensor).to(device) for k,v in batch.items()}
        with torch.no_grad():
            outputs=model(**batch)
        logits=outputs['logits']
        if loss_weight is None:
            loss = F.cross_entropy(logits.view(-1, num_classes).to(device), 
                                   batch["labels"])
        else:
            loss = F.cross_entropy(logits.view(-1, num_classes).to(device), 
                                   batch["labels"], weight=loss_weight.float().to(device))
            
        losses.append(loss.item())
        
        fin_targets.append(batch["labels"].cpu().detach().numpy())
        fin_outputs.append(torch.softmax(logits.view(-1, num_classes),dim=1).cpu().detach().numpy())   

        batch_idx+=1

    return np.concatenate(fin_outputs), np.concatenate(fin_targets), losses
    
def model_inference(args,tokenizer,model,feature_name,device):
#     train_module=utils.Loader_Creation(train_data, tokenizer,feature_name)
    test_module=utils.Loader_Creation(test_data, tokenizer,feature_name)

    
#     train_dataloader=DataLoader(train_module,
#                                 shuffle=True,
#                                 batch_size=args.batch_size,
#                                 collate_fn=train_module.collate_fn,
#                                 drop_last=True   # longformer model bug
#                                )

    test_dataloader=DataLoader(test_module,
                                shuffle=False,
                                batch_size=args.batch_size,
                                collate_fn=test_module.collate_fn
                               )
    model.eval()
    
#     train_pred,train_target,train_losses=eval_func(train_dataloader,model,device)
#     train=dict(pred=train_pred,target=train_target,losses=train_losses)
    
    test_pred,test_target,test_losses=eval_func(test_dataloader,model,device)
    test=dict(pred=test_pred,target=test_target,losses=test_losses)
    
#     return train, test
    return test

In [None]:
data_dir=os.path.join(os.getcwd(),"dataset",args.feature_name+"_truncation_"+args.truncation_strategy)
email_all=load_from_disk(data_dir)
train_data=email_all['train']
test_data=email_all['test']

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [None]:
tokenizer=AutoTokenizer.from_pretrained(args.output_dir)
model=AutoModelForSequenceClassification.from_pretrained(args.output_dir)
test=model_inference(args,tokenizer,model,args.feature_name,device)