# Run PWDA Experiment on RoBERTa Model for Text Multi-Classification

In [None]:
!pip install -U simpletransformers

In [None]:
from simpletransformers.classification import ClassificationModel
import pandas as pd
import logging
import sklearn

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

args = {
   'num_train_epochs': 4,
   'train_batch_size': 16,
   'max_seq_length': 128,
   'overwrite_output_dir': True
}

In [None]:
eval_file_normal = 'test.txt'
eval_file_ocr = 'textflint-Ocr-test.txt'
eval_file_insertAdv = 'textflint-InsertAdv-test.txt'
eval_df_normal = []
eval_df_ocr = []
eval_df_insertAdv = []

with open(eval_file_normal, 'r', encoding='UTF-8') as test_f_normal:
    test_lines = test_f_normal.readlines()
    sentences = []
    labels = []
    for _ in test_lines:
        parts = _.strip().split('\t')
        sentence = parts[0]
        label = int(parts[1])
        sentences.append(sentence)
        labels.append(label)
    eval_df_normal = pd.DataFrame({'text': sentences, 'labels': labels})

# 测试集处理：ocr测试集
with open(eval_file_ocr, 'r', encoding='UTF-8') as test_f_ocr:
    test_lines = test_f_ocr.readlines()
    sentences = []
    labels = []
    for _ in test_lines:
        parts = _.strip().split('\t')
        sentence = parts[0]
        label = int(parts[1])
        sentences.append(sentence)
        labels.append(label)
    eval_df_ocr = pd.DataFrame({'text': sentences, 'labels': labels})

# 测试集处理：insertAdv测试集
with open(eval_file_insertAdv, 'r', encoding='UTF-8') as test_f_insertAdv:
    test_lines = test_f_insertAdv.readlines()
    sentences = []
    labels = []
    for _ in test_lines:
        parts = _.strip().split('\t')
        sentence = parts[0]
        label = int(parts[1])
        sentences.append(sentence)
        labels.append(label)
    eval_df_insertAdv = pd.DataFrame({'text': sentences, 'labels': labels})

In [None]:
dataset = 'TREC-6'
aug_num = [1, 2, 4, 8, 16]
all_num = [100, 500, 2000, 3569]
alphas = [0.1, 0.2, 0.3, 0.4, 0.5]
operations = ['RP', 'RI', 'RS', 'RD']

mcc_list_normal = []
mcc_list_ocr = []
mcc_list_insertAdv = []

for alpha in alphas:
    for operation in operations:
        for all in all_num:
            for _ in aug_num:
                train_dir = 'pwda-' + str(alpha)
                train_file = train_dir + '/' + 'pwda-train-' + str(alpha) + '_' + operation + '_' + str(
                    all) + '_' + str(_) + '.txt'
                train_df = []

                # 训练集处理
                with open(train_file, 'r', encoding='UTF-8') as train_f:
                    train_lines = train_f.readlines()
                    sentences = []
                    labels = []
                    for _ in train_lines:
                        parts = _.strip().split('\t')
                        sentence = parts[0]
                        label = int(parts[1])
                        sentences.append(sentence)
                        labels.append(label)
                    train_df = pd.DataFrame({'text': sentences, 'labels': labels})
                # 创建模型
                model = ClassificationModel('roberta', 'roberta-base', num_labels=6, args=args)
                # 训练模型
                model.train_model(train_df)
                # 验证模型
                result_normal, model_outputs, wrong_predictions = model.eval_model(eval_df_normal, acc=sklearn.metrics.accuracy_score)
                result_ocr, model_outputs, wrong_predictions = model.eval_model(eval_df_ocr, acc=sklearn.metrics.accuracy_score)
                result_insertAdv, model_outputs, wrong_predictions = model.eval_model(eval_df_insertAdv, acc=sklearn.metrics.accuracy_score)
                # 计算MCC
                print(train_file)
                print(result_normal['mcc'], result_ocr['mcc'], result_insertAdv['mcc'])
                mcc_list_normal.append(result_normal['mcc'])
                mcc_list_ocr.append(result_ocr['mcc'])
                mcc_list_insertAdv.append(result_insertAdv['mcc'])

In [None]:
mcc_list_normal

In [None]:
mcc_list_ocr

In [None]:
mcc_list_insertAdv