Code adapted based on the [official metrics code of MAFALDA's paper](https://github.com/ChadiHelwe/MAFALDA/blob/main/src/metrics.py).

**Span** - the span of a fallacy in a text is the smallest
contiguous sequence of sentences that comprises the conclusion and the premises of the fallacy.

Evaluation input: fallacy span and label;
Evaluation output: F1/Precision/Recall scores

In [15]:
from enum import Enum
from itertools import chain, combinations
from typing import Any, List, Set, Tuple, Union, Dict
import json

Transform labels to numerical

In [6]:
LABEL_MAP = {'abusive ad hominem':1,
             'ad populum':2,
             'appeal to false authority':3,
             'appeal to nature':4,
             'appeal to tradition':5,
             'guilt by association':6,
             'tu quoque':7,
             'causal oversimplification':8,
             'circular reasoning':9,
             'equivocation':10,
             'false analogy':11,
             'false causality':12,
             'false dilemma':13,
             'hasty generalization':14,
             'slippery slope':15,
             'straw man':16,
             'fallacy of division':17,
             'appeal to positive emotion':18,
             'appeal to anger':19,
             'appeal to fear':20,
             'appeal to pity':21,
             'appeal to ridicule':22,
             'appeal to worse problem':23}

Labels information

In [7]:
NOTHING_LABEL = 0

In [8]:
# number of classes on level 2 of the fallacy classification (the deepest one)
class NbClasses(Enum):
    LVL_2 = 23

Classes for spans

In [9]:
class Span:
    def __init__(self, span: str):
        self.span = span

    def __str__(self) -> str:
        return f"{self.span}"

    def __repr__(self) -> str:
        return self.__str__()


class PredictionSpan(Span):
    def __init__(self, span: str, label: Union[int, None], interval: List[int]):
        super().__init__(span)
        self.label = label
        self.interval = interval

    def __eq__(self, other):
        if not isinstance(other, PredictionSpan):
            return False
        return self.span == other.span

    def __str__(self) -> str:
        return super().__str__() + f" - {self.interval} - {self.label}"

    def __repr__(self) -> str:
        return self.__str__()


class GroundTruthSpan(Span):
    def __init__(self, span: str, labels: Set[Union[int, None]], interval: List[int]):
        super().__init__(span)
        self.labels = labels
        self.interval = interval

    def __eq__(self, other):
        if not isinstance(other, GroundTruthSpan):
            return False
        return (
            self.span == other.span
            and self.labels == other.labels
            and self.interval == other.interval
        )

    def __str__(self) -> str:
        return super().__str__() + f" - {self.interval} - {self.labels}"

    def __repr__(self) -> str:
        return self.__str__()

    def __hash__(self):
        return hash(
            (self.span, tuple(sorted(self.labels)), self.interval[0], self.interval[1])
        )

In [10]:
class ScoreType(Enum):
    PRECISION = 1
    RECALL = 2


class AnnotatedText:
    def __init__(self, spans: List[Union[PredictionSpan, GroundTruthSpan]]):
        self.spans = spans

    def __len__(self):
        return len(self.spans)

    def __str__(self) -> str:
        return f"{self.spans}"

    def __repr__(self) -> str:
        return self.__str__()

    def generate_label_conjunctions(
        self, score_type: ScoreType, to_numeric_labels=False
    ) -> List[List[Any]]:
        """Return a list of all possible conjunctions of spans in the text.
        If to_numeric_labels is True, then we keep only the numeric labels of the spans (removing none and nothing values).
        Otherwise, returns GroundTruthSpan with one label only.
        score_type: PRECISION or RECALL because we don't handle the disjunctions the same way (OR for precision, XOR for recall)
        """
        # Validate and process spans
        label_sets = [
            s
            for s in self.spans
            if isinstance(s, GroundTruthSpan) and s.labels and s.labels != {None}
        ]
        if not all(isinstance(span, GroundTruthSpan) for span in self.spans):
            raise Exception("This method is only useful for gold spans.")

        def generate_ground_truths(labels, span):
            return [
                GroundTruthSpan(span.span, {label}, span.interval) for label in labels
            ]

        def combine_label_sets(label_sets):
            if not label_sets:
                return [[]]

            first_set, rest_sets = label_sets[0], label_sets[1:]
            combinations_of_rest = combine_label_sets(rest_sets)

            if score_type == ScoreType.PRECISION:
                return [
                    list(ground_truths) + list(combination)
                    for combination in combinations_of_rest
                    for ground_truths in chain.from_iterable(
                        combinations(
                            generate_ground_truths(list(first_set.labels), first_set), r
                        )
                        for r in range(1, len(first_set.labels) + 1)
                    )
                ]
            else:
                return [
                    generate_ground_truths([label], first_set) + combination
                    for label in first_set.labels
                    for combination in combinations_of_rest
                ]

        all_combinations = combine_label_sets(label_sets)

        # Convert combinations to the desired format
        def process_combination(combination):
            if to_numeric_labels:
                return list(
                    {
                        label
                        for span in combination
                        for label in span.labels
                        if isinstance(label, int) and label > NOTHING_LABEL
                    }
                )
            return [
                span
                for span in combination
                if len(span.labels) == 1 and NOTHING_LABEL not in span.labels
            ]

        result = [process_combination(combination) for combination in all_combinations]

        def are_lists_equal(list1, list2):
            # Step 1: Check lengths
            if len(list1) != len(list2):
                return False

            # Step 2: Iterate through elements
            for item1, item2 in zip(list1, list2):
                # Step 3: Check for equality (customize this part as needed)
                if item1 != item2:
                    return False

            # Step 4: All elements are equal
            return True

        result = [
            lst
            for i, lst in enumerate(result)
            if not any(
                are_lists_equal(lst, lst2) for j, lst2 in enumerate(result) if j < i
            )
        ]
        return result

Score calculation utilities

In [11]:
def label_score(pred: PredictionSpan, gold: GroundTruthSpan) -> float:
    """Return 1 if pred values are in gold, 0 else.
    Correspond to delta in the C function in the paper"""
    # equivalent to δ(a, b) or F ⊆ F′
    # can be change if necessary, to take into account close labels for instance
    if pred.label in gold.labels:
        return 1
    return 0


class PartialScoreType(Enum):
    JACCARD_INDEX = 1
    PRED_SIZE = 2
    GOLD_SIZE = 3


def partial_overlap_score(
    pred: PredictionSpan,
    gold: GroundTruthSpan,
    partial_score_type: PartialScoreType = PartialScoreType.JACCARD_INDEX,
) -> float:
    """Return the jaccard index between two spans, or use FINE Grained analysis score.
    Corresponds to the fraction in the C function in the paper"""
    # equivalent to Jaccard(a, b) or F ∩ F′ / F ∪ F′
    a, b = pred.interval
    c, d = gold.interval
    # Compute the intersection
    intersection_start = max(a, c)
    intersection_end = min(b, d)
    intersection_length = max(0, intersection_end - intersection_start)

    if partial_score_type == PartialScoreType.JACCARD_INDEX:
        # Compute the union
        union_length = (b - a) + (d - c) - intersection_length

        # Check for no overlap
        if union_length == 0:
            return 0

        # Compute the Jaccard index
        return intersection_length / union_length
    elif partial_score_type == PartialScoreType.PRED_SIZE:
        return intersection_length / (b - a) if (b - a) else 0
    elif partial_score_type == PartialScoreType.GOLD_SIZE:
        return intersection_length / (d - c) if (d - c) else 0


def text_full_task_precision(
    pred_corpus: AnnotatedText, gold_corpus: AnnotatedText
) -> float:
    """Return the precision score of a prediction (spans + labels)."""
    if not pred_corpus.spans and not gold_corpus.spans:
        # there was nothing to predict and nothing was found
        return 1
    disjunct = gold_corpus.generate_label_conjunctions(ScoreType.PRECISION)
    if disjunct == [[]]:
        # there was nothing to predict
        if any(s.label != NOTHING_LABEL for s in pred_corpus.spans):
            # there was something predicted
            return 0
        # there was nothing predicted
        return 1
    precision_scores = []
    for y_trues in disjunct:
        p_sum = []
        for pred_span in pred_corpus.spans:
            p_sum += [0]
            for gold_span in y_trues:
                ls = label_score(pred_span, gold_span)
                if ls == 0:
                    continue
                p_pos = partial_overlap_score(
                    pred_span, gold_span, PartialScoreType.PRED_SIZE
                )
                if p_pos * ls > p_sum[-1]:
                    p_sum[-1] = p_pos * ls
        # print(p_sum)
        p_sum = sum(p_sum)

        p_denominator = len([s for s in pred_corpus.spans if s.label != NOTHING_LABEL])
        if p_denominator == 0:
            if len(y_trues) == 0:
                precision_scores.append(1)
            elif len(y_trues) > 0:
                precision_scores.append(0)
            else:
                raise Exception(
                    "This should not happen: denominator = 0 but len(y_true) > 0"
                )
        else:
            precision_score = p_sum / p_denominator
            if precision_score > 1:
                raise Exception("This should not happen: precision > 1")
            precision_scores.append(precision_score)
    precision = max(precision_scores)
    return precision


def text_full_task_recall(
    pred_corpus: AnnotatedText, gold_corpus: AnnotatedText
) -> float:
    """Return the recall score of a prediction (spans + labels)."""
    if not pred_corpus.spans and not gold_corpus.spans:
        # there was nothing to predict and nothing was found
        return 1
    disjunct = gold_corpus.generate_label_conjunctions(ScoreType.RECALL)
    if disjunct == [[]]:
        # there was nothing to predict
        if any(s.label != NOTHING_LABEL for s in pred_corpus.spans):
            # there was something predicted
            return 0
        # there was nothing predicted
        return 1
    recall_scores = []
    for y_trues in disjunct:
        r_sum = []
        for gold_span in y_trues:
            r_sum += [0]
            for pred_span in pred_corpus.spans:
                ls = label_score(pred_span, gold_span)
                if ls == 0:
                    continue
                r_pos = partial_overlap_score(
                    pred_span, gold_span, PartialScoreType.GOLD_SIZE
                )
                if r_pos * ls > r_sum[-1]:
                    r_sum[-1] = r_pos * ls
        r_sum = sum(r_sum)

        # we don't count spans wo fallacies:
        r_denominator = len(
            [s for s in y_trues if None not in s.labels and 0 not in s.labels]
        )
        if r_denominator == 0:
            if len(pred_corpus) == 0:
                recall_scores.append(1)
            elif len(pred_corpus) > 0:
                recall_scores.append(0)
            else:
                raise Exception(
                    "This should not happen: divide by 0 on recall denominator"
                )
        else:
            recall_score = r_sum / r_denominator
            recall_scores.append(recall_score)
    recall = max(recall_scores)
    return recall


def text_full_task_p_r_f1(
    pred_corpus: AnnotatedText, gold_corpus: AnnotatedText
) -> Tuple[float, float, float]:
    """Return the precision, recall, and F1 scores of a prediction (spans + labels)."""
    if not pred_corpus.spans and not gold_corpus.spans:
        # there was nothing to predict and nothing was found
        return 1, 1, 1

    precision = text_full_task_precision(pred_corpus, gold_corpus)
    recall = text_full_task_recall(pred_corpus, gold_corpus)
    f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0
    return precision, recall, f1


def text_label_only_precision(
    pred_text: AnnotatedText, gold_text: AnnotatedText, nb_classes: NbClasses
) -> float:
    """Compute the precision score for labels only."""
    if not pred_text.spans and not gold_text.spans:
        # there was nothing to predict and nothing was found
        return 1

    y_pred = {
        s.label
        for s in pred_text.spans
        if isinstance(s.label, int)
        and s.label != NOTHING_LABEL
        and s.label - 1 < nb_classes.value
    }
    y_trues = [
        {
            list(s.labels)[0]
            for s in y_true
            if isinstance(list(s.labels)[0], int)
            and list(s.labels)[0] != NOTHING_LABEL
            and list(s.labels)[0] - 1 < nb_classes.value
        }
        for y_true in gold_text.generate_label_conjunctions(ScoreType.PRECISION)
    ]
    precision_results = []
    for y_true in y_trues:
        tp = len(set(y_pred).intersection(y_true))
        if tp == 0:
            if len(y_pred) == 0 and len(y_true) == 0:
                precision_results.append(1)
            else:
                precision_results.append(0)
            continue
        fp = len(set(y_pred).difference(y_true))
        precision_score = tp / (tp + fp)
        precision_results.append(precision_score)
    precision = max(precision_results)
    return precision


def text_label_only_recall(
    pred_text: AnnotatedText, gold_text: AnnotatedText, nb_classes: NbClasses
) -> float:
    """Compute the recall score for labels only."""
    if not pred_text.spans and not gold_text.spans:
        # there was nothing to predict and nothing was found
        return 1

    y_pred = {
        s.label
        for s in pred_text.spans
        if isinstance(s.label, int)
        and s.label != NOTHING_LABEL
        and s.label - 1 < nb_classes.value
    }
    y_trues = [
        {
            list(s.labels)[0]
            for s in y_true
            if isinstance(list(s.labels)[0], int)
            and list(s.labels)[0] != NOTHING_LABEL
            and list(s.labels)[0] - 1 < nb_classes.value
        }
        for y_true in gold_text.generate_label_conjunctions(ScoreType.RECALL)
    ]
    recall_results = []
    for y_true in y_trues:
        tp = len(set(y_pred).intersection(y_true))
        if tp == 0:
            if len(y_pred) == 0 and len(y_true) == 0:
                recall_results.append(1)
            else:
                recall_results.append(0)
            continue
        fn = len(set(y_true).difference(y_pred))
        recall_score = tp / (tp + fn)
        recall_results.append(recall_score)
    recall = max(recall_results)
    return recall


def text_label_only_p_r_f1(
    pred_text: AnnotatedText, gold_text: AnnotatedText, nb_classes: NbClasses
) -> Tuple[float, float, float]:
    """Compute the precision, recall, and F1 scores for labels only."""
    if not pred_text.spans and not gold_text.spans:
        # there was nothing to predict and nothing was found
        return 1, 1, 1

    precision = text_label_only_precision(pred_text, gold_text, nb_classes)
    recall = text_label_only_recall(pred_text, gold_text, nb_classes)
    f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0
    return precision, recall, f1


## Collect everything together for evaluation

The ouput .json file should contain the dictionaries of the following format:

```
[

  {'text 1':[{'fallacy type': [0,1]},{'fallacy type':[2,3]}],
  
  {'text 2': [{...}, {...}, {...}]},

  ...

  {'text n': [{...}, {...}]},

]
```

Where [0,1] is a span of a fallacy


Gold labels are in the following format (initially), the texts are stored separately:



```
[[107, 261, 'guilt by association'], [107, 338, 'causal oversimplification'], [158, 338, 'ad populum'], [158, 338, 'nothing'], [391, 542, 'circular reasoning']]
```

In [46]:
def collect_predictions(json_data_preds: List[Dict[str, List[Dict[str, List[int]]]]]
                        ) -> AnnotatedText:

    all_spans = []

    for item in json_data_preds:
        for text, fallacies in item.items():
            for fallacy in fallacies:
                for fallacy_type, interval in fallacy.items():
                    if fallacy_type.lower() in LABEL_MAP:
                        span = PredictionSpan(text, LABEL_MAP[fallacy_type.lower()], interval)
                        all_spans.append(span)
                    else:
                      return f"No fallacy {fallacy_type} exists in the label set."

    annotated_text_preds = AnnotatedText(all_spans)
    return annotated_text_preds

In [47]:
def collect_golds(json_data_golds: List[Dict[str, List[Dict[str, List[int]]]]]
                  ) -> AnnotatedText:
    all_spans = []

    for item in json_data_golds:
        for text, fallacies in item.items():
            for fallacy in fallacies:
                for fallacy_type, interval in fallacy.items():
                    if fallacy_type.lower() in LABEL_MAP:
                        span = GroundTruthSpan(text, {LABEL_MAP[fallacy_type.lower()]}, interval)
                        all_spans.append(span)
                    else:
                      return f"No fallacy {fallacy_type} exists in the label set."

    annotated_text_golds = AnnotatedText(all_spans)
    return annotated_text_golds

In [51]:
def run_evaluation(json_predictions_path: str,
                   json_golds_path: str,
                   labels_only:str=False):

    with open(json_predictions_path, 'r') as j:
        predictions = json.loads(j.read())

    pred_text = collect_predictions(predictions)

    with open(json_golds_path, 'r') as j:
        golds = json.loads(j.read())

    gold_text = collect_golds(golds)

    if labels_only:
      # Testing text_label_only_p_r_f1: not considering spans
      p_l, r_l, f1_l = text_label_only_p_r_f1(
          pred_text, gold_text, NbClasses.LVL_2)

      print("\n---LABELS ONLY:---\n")
      print(f"Precision = {p_l}")
      print(f"Recall = {r_l}")
      print(f"F1 score = {f1_l}")

    else:
      # Testing text_full_task_p_r_f1: considering spans
      p, r, f1 = text_full_task_p_r_f1(pred_text, gold_text)

      print("\n---SPANS AND LABELS:---\n")
      print(f"Precision = {p}")
      print(f"Recall = {r}")
      print(f"F1 score = {f1}")

## Run the evaluation

Example structure of example_preds.json file:


```
[
    {"TITLE: There is a difference between a'smurf' and an'alt'. Please learn it and stop using them interchangeably. POST: Someone once told me they have an 'alt' cause their main account was too high of rank to play with their friends. It's exactly the same as smurfing.":[{"false analogy": [0,12]},{"Appeal TO Fear":[12,29]}]},
    {"America is the best place to live, because it's better than any other country.": [{"circular reasoning": [0,78]}]}
]
```

Example structure of example_golds.json file:


```
[
    {"TITLE: There is a difference between a'smurf' and an'alt'. Please learn it and stop using them interchangeably. POST: Someone once told me they have an 'alt' cause their main account was too high of rank to play with their friends. It's exactly the same as smurfing.":[{"appeal to fear":[12,29]}]},
    {"America is the best place to live, because it's better than any other country.": [{"Circular Reasoning": [0,78]}]}
]
```



In [52]:
run_evaluation(json_predictions_path="/content/example_preds.json",
               json_golds_path="/content/example_golds.json")


---SPANS AND LABELS:---

Precision = 0.6666666666666666
Recall = 1.0
F1 score = 0.8
