# Error analysis for Context-Aware ALD Models 

This notebook contrains the error and qualitative analysis for different models trained on CAD for Abusive Language Detection.

## Set up 

In [13]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
import json
import seaborn as sns
import os 
import csv

## Read output test file 

In [2]:
file_path = "../../modernbert-class_cad_eval_test_outputs.jsonl"


In [3]:
# Open the jsonl file and read it line by line
def get_error_stats(file_path):
    true_predictions, tp, tn, false_predictions, fp, fn = 0, 0, 0, 0, 0, 0
    model_pred_1, model_pred_0 = 0, 0
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            # Parse the JSON object from each line
            json_obj = json.loads(line.strip())
            if int(json_obj['pred_label']) == int(json_obj['y']):
                true_predictions += 1
                if int(json_obj['pred_label']) == 1:
                    tp += 1
                else:
                    tn += 1
            else:
                false_predictions += 1
                if int(json_obj['pred_label']) == 1:
                    fp += 1
                else:
                    fn += 1
            
            if int(json_obj['pred_label']) == 1:
                model_pred_1 += 1
            else:
                model_pred_0 += 1

    return {
        "true_predictions": true_predictions,
        "true_positives": tp,
        "true_negatives": tn,
        "false_predictions": false_predictions,
        "false_positives": fp,
        "false_negatives": fn,
        "model_pred_1": model_pred_1,
        "model_pred_0": model_pred_0
    }



In [4]:
models = ["bert-class", "modernbert-class", "bert-concat", "bertwithneighconcat", "gat-test"]
stats_list = []

for model in models:
    file_path = os.path.join("../..", model + "_cad_eval_test_outputs.jsonl")
    print("Model: ", model)
    stats = get_error_stats(file_path)
    stats["model"] = model  # Add model name to the stats
    stats_list.append(stats)

# Write stats to a CSV file
output_file = "model_error_stats.csv"
with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=[
        "model", "true_predictions", "true_positives", "true_negatives",
        "false_predictions", "false_positives", "false_negatives",
        "model_pred_1", "model_pred_0"
    ])
    writer.writeheader()
    writer.writerows(stats_list)

print(f"Stats written to {output_file}")

Model:  bert-class
Model:  modernbert-class
Model:  bert-concat
Model:  bertwithneighconcat
Model:  gat-test
Stats written to model_error_stats.csv


Get the correspondance between comment/post ids and graph files.

In [5]:
# output content of 100 first graph objetcs into the sample-reanno folder in data folder
graph_input_dir = "../../data/processed_graphs/processed/"
index_file = "../../data/cad-test-idx-many.txt"
output_file = "eval_set_output.csv"
output_dic = {}

def extract_info_from_graph(index_file, output_file):
    with open(output_file, mode='w', newline='', encoding='utf-8') as csv_file:
        fieldnames = ['filename', 'id', 'reddit_url', 'label', 'anno_ctx', 'anno_tgt', 'anno_tgt_cat', 'body', 'index_in_conv', 'conv_len']
        writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
        writer.writeheader()
        with open(index_file, "r") as file:
            indices = file.readlines()
            indices = [int(index.strip()) for index in indices]
            for index in indices:
                input_file = f"{graph_input_dir}graph-{index}.pt"
                try:
                    # Load the.pt file
                    graph = torch.load(input_file)
                    true_index = [i for i in range(len(graph.y_mask)) if graph.y_mask[i] == True]
                    assert len(true_index) == 1
                    true_index = true_index[0]

                    comment = graph.x_text[true_index]
                    x, a2, a3, label = comment
                    my_id = x.get('id', '')
                    permalink = x.get('permalink', '')
                    reddit_url = 'https://www.reddit.com' + permalink
                    
                    label = x.get('label', '')
                    anno_ctx = x.get('anno_ctx', '')
                    anno_tgt = x.get('anno_tgt', '')
                    anno_tgt_cat = x.get('anno_tgt_cat', '')
                    body = x.get('body', '')
                    
                    # Write to CSV
                    writer.writerow({
                        'filename': input_file,
                        'id': my_id,
                        'reddit_url': reddit_url,
                        'label': label,
                        'anno_ctx': anno_ctx,
                        'anno_tgt': anno_tgt,
                        'anno_tgt_cat': anno_tgt_cat,
                        'body': body,
                        'index_in_conv': true_index,
                        'conv_len': len(graph.y_mask)
                    })

                    #target_comment = graph.x_text[true_index]
                    #print(target_comment)


                except FileNotFoundError:
                    print(f"File {input_file} not found.")


extract_info_from_graph(index_file, output_file)


  from .autonotebook import tqdm as notebook_tqdm


## Check model errors

#### bot-gat-dir-3l-cad-512-7_3625338_eval_outputs

In [9]:
import json
import csv

# Input JSONL file
input_file = "../bertclass-cad-512-123_3625347_eval_outputs.jsonl"
# Output CSV file
output_file = "../bertclass-cad-512-123_3625347_eval_outputs.csv"

# Read the JSONL file and extract the data
data = []
with open(input_file, 'r', encoding='utf-8') as f:
    for line in f:
        obj = json.loads(line.strip())
        obj['tp'], obj['tn'], obj['fp'], obj['fn'] = 0, 0, 0, 0
        if int(obj['pred_label']) == int(obj['y']):
            if int(obj['pred_label']) == 1:
                obj['tp'] = 1
            else:
                obj['tn'] = 1
        else:
            if int(obj['pred_label']) == 1:
                obj['fp'] = 1
            else:
                obj['fn'] = 1
        data.append(obj)

# Flatten and normalize arrays or complex fields into strings
def flatten_value(value):
    if isinstance(value, list):
        return ";".join(map(str, value))  # Join list elements with semicolons
    return value

# Get all unique keys in the JSON objects (assuming consistent fields)
if data:
    headers = data[0].keys()
else:
    raise ValueError("The JSONL file is empty or malformed.")

# Write to CSV
with open(output_file, 'w', newline='', encoding='utf-8') as f:
    writer = csv.DictWriter(f, fieldnames=headers)
    writer.writeheader()
    for item in data:
        flattened_item = {key: flatten_value(value) for key, value in item.items()}
        writer.writerow(flattened_item)

print(f"Conversion complete! Data saved to {output_file}")


Conversion complete! Data saved to ../bertclass-cad-512-123_3625347_eval_outputs.csv


In [10]:
import pandas as pd

# Create a DataFrame
bert_results = pd.read_csv("../bertclass-cad-512-123_3625347_eval_outputs.csv")
gat3l_results = pd.read_csv("../bot-gat-dir-3l-cad-512-7_3625338_eval_outputs.csv")
gat2l_results = pd.read_csv("../bot-gat-dir-2l-cad-512-7_3624560_eval_outputs.csv")



In [12]:
bert_results.head()

Unnamed: 0,id,reddit_url,index,text,anno_ctx,anno_tgt,anno_tgt_cat,label,y,pred_label,y_pred,x,masked_index,tp,tn,fp,fn
0,el5eh5l,https://www.reddit.com/r/TumblrInAction/commen...,4,"Seriously though, what the hell am I supposed ...",,,,Neutral,0.0,1,"[-0.4662129580974579, 0.13268911838531494]","{'id': 'el5eh5l', 'name': 't1_el5eh5l', 'autho...",4,0,0,1,0
1,etotgus,https://www.reddit.com/r/TumblrInAction/commen...,36,"I agree. Like, I just deal with it because I d...",,,,Neutral,0.0,1,"[-0.24242620170116425, 0.11224387586116791]","{'id': 'etotgus', 'name': 't1_etotgus', 'autho...",36,0,0,1,0
2,ep0yt9p,https://www.reddit.com/r/TumblrInAction/commen...,39,Udon 4 life,,,,Neutral,0.0,0,"[1.141236662864685, -0.6677318215370178]","{'id': 'ep0yt9p', 'name': 't1_ep0yt9p', 'autho...",39,0,1,0,0
3,etnmw0w,https://www.reddit.com/r/TumblrInAction/commen...,46,"Feminist, Feminism, *Funny, LOL*",PreviousContent,feminists,political affiliation,AffiliationDirectedAbuse,1.0,1,"[-1.1837022304534912, 0.9363416433334351]","{'id': 'etnmw0w', 'name': 't1_etnmw0w', 'autho...",46,1,0,0,0
4,emzvb9e,https://www.reddit.com/r/CCJ2/comments/bmuhp3/...,2,Also can't take a shit in a public bathroom wi...,,,,Neutral,0.0,1,"[-1.0287977457046509, 0.5232383608818054]","{'id': 'emzvb9e', 'name': 't1_emzvb9e', 'autho...",2,0,0,1,0
