In [1]:
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets
import random
import string

from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, DataCollatorWithPadding, AdamW, get_scheduler, MarianMTModel, MarianTokenizer
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output
from torch.nn.utils.rnn import pad_sequence
import gc

In [8]:
data_files = {
    "test": "../data/test-for-llm - test-for-llm.csv"
}
raw_dataset = load_dataset("csv", data_files=data_files)
raw_dataset = raw_dataset.remove_columns(['en_text', 'messages'])
raw_dataset['test'] = raw_dataset["test"].filter(lambda x: x["llm_labels"] is not None)

Filter:   0%|          | 0/724 [00:00<?, ? examples/s]

In [9]:
raw_dataset

DatasetDict({
    test: Dataset({
        features: ['labels', 'llm_labels'],
        num_rows: 214
    })
})

In [11]:
for name in raw_dataset:
        data_dict = {"labels": [], "llm_labels": []}
        for item in raw_dataset[name]:
            aspect2label = {'all': 0, 'amn': 0, 'ch': 0, 'mgt': 0, 'nat': 0, 'ppl': 0}
            labels = item['labels'].split()
            for label in labels:
                try:
                    key, value = label.split('-')
                except:
                    print("Unknown label with text:" + item['labels'])
                if(key not in aspect2label or value not in ['0', '1', '2', '3']):
                    raise Exception("Unknown label:", label)
                aspect2label[key] = int(value)
            data_dict["labels"].append('all-' + str(aspect2label['all']))
            data_dict["labels"].append('amn-' + str(aspect2label['amn']))
            data_dict["labels"].append('ch-' + str(aspect2label['ch']))
            data_dict["labels"].append('ppl-' + str(aspect2label['ppl']))
            data_dict["labels"].append('mgt-' + str(aspect2label['mgt']))
            data_dict["labels"].append('nat-' + str(aspect2label['nat']))
            ################################################################
            aspect2label = {'all': 0, 'amn': 0, 'ch': 0, 'mgt': 0, 'nat': 0, 'ppl': 0}
            labels = item['llm_labels'].split()
            for label in labels:
                try:
                    key, value = label.split('-')
                except:
                    print("Unknown label with text:" + item['llm_labels'])
                if(key not in aspect2label or value not in ['0', '1', '2', '3']):
                    raise Exception("Unknown label:", label)
                aspect2label[key] = int(value)
            data_dict["llm_labels"].append('all-' + str(aspect2label['all']))
            data_dict["llm_labels"].append('amn-' + str(aspect2label['amn']))
            data_dict["llm_labels"].append('ch-' + str(aspect2label['ch']))
            data_dict["llm_labels"].append('ppl-' + str(aspect2label['ppl']))
            data_dict["llm_labels"].append('mgt-' + str(aspect2label['mgt']))
            data_dict["llm_labels"].append('nat-' + str(aspect2label['nat']))

        raw_dataset[name] = Dataset.from_dict(DatasetDict(data_dict))

In [15]:
raw_dataset['test'][0]

{'labels': 'all-1', 'llm_labels': 'all-1'}

In [31]:
count_labels = {
    "all-0": 0, "amn-0": 0, "ch-0": 0, "mgt-0": 0, "nat-0": 0, "ppl-0": 0,
    "all-1": 0, "amn-1": 0, "ch-1": 0, "mgt-1": 0, "nat-1": 0, "ppl-1": 0,
    "all-2": 0, "amn-2": 0, "ch-2": 0, "mgt-2": 0, "nat-2": 0, "ppl-2": 0,
    "all-3": 0, "amn-3": 0, "ch-3": 0, "mgt-3": 0, "nat-3": 0, "ppl-3": 0,
}

count = {
    "true-positive": {
        "all-0": 0, "amn-0": 0, "ch-0": 0, "mgt-0": 0, "nat-0": 0, "ppl-0": 0,
        "all-1": 0, "amn-1": 0, "ch-1": 0, "mgt-1": 0, "nat-1": 0, "ppl-1": 0,
        "all-2": 0, "amn-2": 0, "ch-2": 0, "mgt-2": 0, "nat-2": 0, "ppl-2": 0,
        "all-3": 0, "amn-3": 0, "ch-3": 0, "mgt-3": 0, "nat-3": 0, "ppl-3": 0,
    },
    "false-positive": {
        "all-0": 0, "amn-0": 0, "ch-0": 0, "mgt-0": 0, "nat-0": 0, "ppl-0": 0,
        "all-1": 0, "amn-1": 0, "ch-1": 0, "mgt-1": 0, "nat-1": 0, "ppl-1": 0,
        "all-2": 0, "amn-2": 0, "ch-2": 0, "mgt-2": 0, "nat-2": 0, "ppl-2": 0,
        "all-3": 0, "amn-3": 0, "ch-3": 0, "mgt-3": 0, "nat-3": 0, "ppl-3": 0,
    },
    "false-negative": {
        "all-0": 0, "amn-0": 0, "ch-0": 0, "mgt-0": 0, "nat-0": 0, "ppl-0": 0,
        "all-1": 0, "amn-1": 0, "ch-1": 0, "mgt-1": 0, "nat-1": 0, "ppl-1": 0,
        "all-2": 0, "amn-2": 0, "ch-2": 0, "mgt-2": 0, "nat-2": 0, "ppl-2": 0,
        "all-3": 0, "amn-3": 0, "ch-3": 0, "mgt-3": 0, "nat-3": 0, "ppl-3": 0,
    },
}

total_instances = 0
true_instances = 0

# Assuming raw_dataset['test'] is a list of dictionaries containing 'labels' and 'llm_labels'
for instances in raw_dataset['test']:
    total_instances += 1
    labels = instances['labels']
    llm_labels = instances['llm_labels']
    count_labels[labels] += 1

    if labels == llm_labels:
        true_instances += 1
        count['true-positive'][labels] += 1
    else:
        count['false-negative'][labels] += 1
        count['false-positive'][llm_labels] += 1

macro_precision_sum = 0
macro_recall_sum = 0
macro_f1_sum = 0
num_classes = len(count["false-negative"])

for type in count["false-negative"]:
    true_positive = count['true-positive'][type]
    false_positive = count['false-positive'][type]
    false_negative = count['false-negative'][type]

    if true_positive == 0:
        precision = 0
        recall = 0
        f1 = 0
    else:
        precision = true_positive / (true_positive + false_positive)
        recall = true_positive / (true_positive + false_negative)
        f1 = 2 * precision * recall / (precision + recall)
    
    print(type, 'precision:', precision, 'recall:', recall, 'f1:', f1)
    
    macro_precision_sum += precision
    macro_recall_sum += recall
    macro_f1_sum += f1

macro_precision = macro_precision_sum / num_classes
macro_recall = macro_recall_sum / num_classes
macro_f1 = macro_f1_sum / num_classes

print('Macro Precision:', macro_precision)
print('Macro Recall:', macro_recall)
print('Macro F1:', macro_f1)
print('Accuracy:', true_instances / total_instances)
print(count_labels)

all-0 precision: 0.8 recall: 0.3333333333333333 f1: 0.47058823529411764
amn-0 precision: 0.7955801104972375 recall: 0.9230769230769231 f1: 0.85459940652819
ch-0 precision: 0.8287671232876712 recall: 0.8175675675675675 f1: 0.8231292517006803
mgt-0 precision: 0.9148936170212766 recall: 0.9297297297297298 f1: 0.9222520107238605
nat-0 precision: 0.8602150537634409 recall: 0.9090909090909091 f1: 0.8839779005524862
ppl-0 precision: 0.9704433497536946 recall: 0.9899497487437185 f1: 0.9800995024875622
all-1 precision: 0.8166666666666667 recall: 0.98 f1: 0.890909090909091
amn-1 precision: 0.47368421052631576 recall: 0.2647058823529412 f1: 0.339622641509434
ch-1 precision: 0.3829787234042553 recall: 0.6 f1: 0.4675324675324675
mgt-1 precision: 0.1 recall: 0.14285714285714285 f1: 0.11764705882352941
nat-1 precision: 0.2916666666666667 recall: 0.25 f1: 0.2692307692307692
ppl-1 precision: 0.7142857142857143 recall: 0.8333333333333334 f1: 0.7692307692307692
all-2 precision: 0.7875 recall: 0.913043478