# Error analysis for Context-Aware ALD Models 

This notebook contrains the error and qualitative analysis for different models trained on CAD for Abusive Language Detection.

## Set up 

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
import json
import seaborn as sns
import os 
import csv

## Read output test file 

In [23]:
file_path = "../../modernbert-class_cad_eval_test_outputs.jsonl"


In [24]:
# Open the jsonl file and read it line by line
def get_error_stats(file_path):
    true_predictions, tp, tn, false_predictions, fp, fn = 0, 0, 0, 0, 0, 0
    model_pred_1, model_pred_0 = 0, 0
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            # Parse the JSON object from each line
            json_obj = json.loads(line.strip())
            if int(json_obj['pred_label']) == int(json_obj['y']):
                true_predictions += 1
                if int(json_obj['pred_label']) == 1:
                    tp += 1
                else:
                    tn += 1
            else:
                false_predictions += 1
                if int(json_obj['pred_label']) == 1:
                    fp += 1
                else:
                    fn += 1
            
            if int(json_obj['pred_label']) == 1:
                model_pred_1 += 1
            else:
                model_pred_0 += 1

    return {
        "true_predictions": true_predictions,
        "true_positives": tp,
        "true_negatives": tn,
        "false_predictions": false_predictions,
        "false_positives": fp,
        "false_negatives": fn,
        "model_pred_1": model_pred_1,
        "model_pred_0": model_pred_0
    }



In [25]:
models = ["bert-class", "modernbert-class", "bert-concat", "bertwithneighconcat", "gat-test"]
stats_list = []

for model in models:
    file_path = os.path.join("../..", model + "_cad_eval_test_outputs.jsonl")
    print("Model: ", model)
    stats = get_error_stats(file_path)
    stats["model"] = model  # Add model name to the stats
    stats_list.append(stats)

# Write stats to a CSV file
output_file = "model_error_stats.csv"
with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=[
        "model", "true_predictions", "true_positives", "true_negatives",
        "false_predictions", "false_positives", "false_negatives",
        "model_pred_1", "model_pred_0"
    ])
    writer.writeheader()
    writer.writerows(stats_list)

print(f"Stats written to {output_file}")

Model:  bert-class
Model:  modernbert-class
Model:  bert-concat
Model:  bertwithneighconcat
Model:  gat-test
Stats written to model_error_stats.csv


Get the correspondance between comment/post ids and graph files.

In [6]:
# output content of 100 first graph objetcs into the sample-reanno folder in data folder
graph_input_dir = "../../data/processed_graphs/processed/"
index_file = "../../data/cad-test-idx-many.txt"
output_file = "eval_set_output.csv"
output_dic = {}

def extract_info_from_graph(index_file, output_file):
    with open(output_file, mode='w', newline='', encoding='utf-8') as csv_file:
        fieldnames = ['filename', 'id', 'reddit_url', 'label', 'anno_ctx', 'anno_tgt', 'anno_tgt_cat', 'body', 'index_in_conv', 'conv_len']
        writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
        writer.writeheader()
        with open(index_file, "r") as file:
            indices = file.readlines()
            indices = [int(index.strip()) for index in indices]
            for index in indices:
                input_file = f"{graph_input_dir}graph-{index}.pt"
                try:
                    # Load the.pt file
                    graph = torch.load(input_file)
                    true_index = [i for i in range(len(graph.y_mask)) if graph.y_mask[i] == True]
                    assert len(true_index) == 1
                    true_index = true_index[0]

                    comment = graph.x_text[true_index]
                    x, a2, a3, label = comment
                    my_id = x.get('id', '')
                    permalink = x.get('permalink', '')
                    reddit_url = 'https://www.reddit.com' + permalink
                    
                    label = x.get('label', '')
                    anno_ctx = x.get('anno_ctx', '')
                    anno_tgt = x.get('anno_tgt', '')
                    anno_tgt_cat = x.get('anno_tgt_cat', '')
                    body = x.get('body', '')
                    
                    # Write to CSV
                    writer.writerow({
                        'filename': input_file,
                        'id': my_id,
                        'reddit_url': reddit_url,
                        'label': label,
                        'anno_ctx': anno_ctx,
                        'anno_tgt': anno_tgt,
                        'anno_tgt_cat': anno_tgt_cat,
                        'body': body,
                        'index_in_conv': true_index,
                        'conv_len': len(graph.y_mask)
                    })

                    #target_comment = graph.x_text[true_index]
                    #print(target_comment)


                except FileNotFoundError:
                    print(f"File {input_file} not found.")


extract_info_from_graph(index_file, output_file)


In [4]:
import json
import csv
import os

input_dir = "../../data/sample-reanno/"
output_file = "../../data/sample-reanno/eval_set_output.csv"

# Prepare the CSV file
with open(output_file, mode='w', newline='', encoding='utf-8') as csv_file:
    fieldnames = ['filename', 'reddit_url', 'index', 'label', 'anno_ctx', 'anno_tgt', 'anno_tgt_cat', 'body']
    writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
    writer.writeheader()
    
    # Loop over the files
    for _, i in enumerate(indices):
        filename = f"graph-{i}.jsonl"
        filepath = os.path.join(input_dir, filename)
        if not os.path.exists(filepath):
            print(f"File {filepath} does not exist.")
            continue
        
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                lines = f.readlines()
                if not lines:
                    print(f"No content in file {filepath}")
                    continue
                
                # First line is the index label
                index_label_line = lines[0].strip()
                index_label = int(index_label_line)
                
                # Remaining lines are JSON entries
                json_entries = []
                for line in lines[1:]:
                    json_entry = json.loads(line.strip())
                    json_entries.append(json_entry)
                
                if index_label >= len(json_entries):
                    print(f"Index label {index_label} out of range in file {filepath}")
                    continue
                
                selected_entry = json_entries[index_label]
                
                # Extract required fields
                node_data = selected_entry[0]  # The first element of the list is the node data dict
                
                permalink = x.get('permalink', '')
                reddit_url = 'https://www.reddit.com' + permalink
                
                label = x.get('label', '')
                anno_ctx = x.get('anno_ctx', '')
                anno_tgt = x.get('anno_tgt', '')
                anno_tgt_cat = x.get('anno_tgt_cat', '')
                body = x.get('body', '')
                
                # Write to CSV
                writer.writerow({
                    'filename': filename,
                    'reddit_url': reddit_url,
                    'index': index_label,
                    'label': label,
                    'anno_ctx': anno_ctx,
                    'anno_tgt': anno_tgt,
                    'anno_tgt_cat': anno_tgt_cat,
                    'body': body
                })
                
        except Exception as e:
            print(f"Error processing file {filepath}: {e}")

NameError: name 'indices' is not defined