# Run PWDA Experiment on XLNet Model for Text Classification

In [None]:
!pip install -U simpletransformers 

In [None]:
from simpletransformers.classification import ClassificationModel
import pandas as pd
import logging
import sklearn

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

args = {
    'num_train_epochs': 4,
    'train_batch_size': 16,
    'max_seq_length': 128,
    'overwrite_output_dir': True}

In [None]:
eval_file_normal = 'test.txt'
eval_file_ocr = 'textflint-Ocr-test.txt'
eval_file_insertAdv = 'textflint-InsertAdv-test.txt'
eval_df_normal = []
eval_df_ocr = []
eval_df_insertAdv = []

# 测试集处理：常规测试集
with open(eval_file_normal, 'r', encoding='UTF-8') as test_f_normal:
    test_lines = test_f_normal.readlines()
    sentences = []
    labels = []
    for _ in test_lines:
        parts = _.strip().split('\t')
        sentence = parts[0]
        label = int(parts[1])
        sentences.append(sentence)
        labels.append(label)
    eval_df_normal = pd.DataFrame({'text': sentences, 'labels': labels})

# 测试集处理：ocr测试集
with open(eval_file_ocr, 'r', encoding='UTF-8') as test_f_ocr:
    test_lines = test_f_ocr.readlines()
    sentences = []
    labels = []
    for _ in test_lines:
        parts = _.strip().split('\t')
        sentence = parts[0]
        label = int(parts[1])
        sentences.append(sentence)
        labels.append(label)
    eval_df_ocr = pd.DataFrame({'text': sentences, 'labels': labels})

# 测试集处理：insertAdv测试集
with open(eval_file_insertAdv, 'r', encoding='UTF-8') as test_f_insertAdv:
    test_lines = test_f_insertAdv.readlines()
    sentences = []
    labels = []
    for _ in test_lines:
        parts = _.strip().split('\t')
        sentence = parts[0]
        label = int(parts[1])
        sentences.append(sentence)
        labels.append(label)
    eval_df_insertAdv = pd.DataFrame({'text': sentences, 'labels': labels})

In [None]:
dataset = 'SST-2'  # 或CR
aug_num = [1, 2, 4, 8, 16]
alphas = [0.1, 0.2, 0.3, 0.4, 0.5]
operations = ['RP', 'RI', 'RS', 'RD']
all_num = [100, 500, 2000, 5484] if dataset == 'SST-2' else [100, 500, 1472]

f1_list_normal = []
f1_list_ocr = []
f1_list_insertAdv = []

for alpha in alphas:
    for operation in operations:
        for all in all_num:
            for _ in aug_num:
                train_dir = 'pwda-' + str(alpha)
                train_file = train_dir + '/' + 'pwda-train-' + str(alpha) + '_' + operation + '_' + str(
                    all) + '_' + str(_) + '.txt'
                train_df = []

                # 训练集处理
                with open(train_file, 'r', encoding='UTF-8') as train_f:
                    train_lines = train_f.readlines()
                    sentences = []
                    labels = []
                    for _ in train_lines:
                        parts = _.strip().split('\t')
                        sentence = parts[0]
                        label = int(parts[1])
                        sentences.append(sentence)
                        labels.append(label)
                    train_df = pd.DataFrame({'clean_text': sentences, 'target': labels})

                # 创建模型
                model = ClassificationModel('xlnet', 'xlnet-base-cased', args=args)
                # 训练模型
                model.train_model(train_df)
                # 验证模型
                result_normal, model_outputs, wrong_predictions = model.eval_model(eval_df_normal, acc=sklearn.metrics.accuracy_score)
                result_ocr, model_outputs, wrong_predictions = model.eval_model(eval_df_ocr, acc=sklearn.metrics.accuracy_score)
                result_insertAdv, model_outputs, wrong_predictions = model.eval_model(eval_df_insertAdv, acc=sklearn.metrics.accuracy_score)
                # 计算F1值
                recall_normal = result_normal['tp'] / (result_normal['tp'] + result_normal['fn'])
                precision_normal = result_normal['tp'] / (result_normal['tp'] + result_normal['fp'])
                f1_normal = 2 * precision_normal * recall_normal / (precision_normal + recall_normal)
                f1_list_normal.append(f1_normal)
                recall_ocr = result_ocr['tp'] / (result_ocr['tp'] + result_ocr['fn'])
                precision_ocr = result_ocr['tp'] / (result_ocr['tp'] + result_ocr['fp'])
                f1_ocr = 2 * precision_ocr * recall_ocr / (precision_ocr + recall_ocr)
                f1_list_ocr.append(f1_ocr)
                recall_insertAdv = result_insertAdv['tp'] / (result_insertAdv['tp'] + result_insertAdv['fn'])
                precision_insertAdv = result_insertAdv['tp'] / (result_insertAdv['tp'] + result_insertAdv['fp'])
                f1_insertAdv = 2 * precision_insertAdv * recall_insertAdv / (precision_insertAdv + recall_insertAdv)
                f1_list_insertAdv.append(f1_insertAdv)
                # 打印结果
                print(train_file)
                print(f1_normal, f1_ocr, f1_insertAdv)

In [None]:
f1_list_normal

In [None]:
f1_list_ocr

In [None]:
f1_list_insertadv