### Set up an evaluation script with a random classifier #8

Create a script that:

- Loads the dataset from `Initialize a dataset to evaluate the detection pipeline` #6
- Loads the predictions from a random classifier to classify claims using the taxonomy from `Define our contrarian claims taxonomy` #7
- Generates a text classification report
- Generates a confusion matrix plot

In [1]:
import pandas as pd
from pydantic import BaseModel, ValidationError, conlist
from typing import List
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
class EvaluateClassifier:

    def __init__(self, classifier):
        """
        Initialize the class with a fitted classifier passed as an argument.
        :param classifier: A classification model taking a serie of claims as input 
        """
        self.classifier = classifier
        self._validate_classifier_input()

    def _validate_classifier_input(self):
        """
        Validate that the classifier's `predict` method can accept a list of strings as input.
        """
        try:
            # Sample input to test if the classifier accepts a list of strings
            sample_input = ["sample claim 1", "sample claim 2"]
            self.classifier.predict(sample_input)
        except Exception as e:
            raise ValueError("The classifier's `predict` method must accept a list of strings as input. "
                             "Ensure the classifier is compatible with text data.") from e
    
    def load_data(self, file_path):
        """
        Load data from a csv file.

        :param file_path: Path to the Excel file containing the data
        :return: DataFrame with the loaded data
        """
        benchmark = pd.read_csv(file_path, sheet_name="benchmark")
        if 'claim' not in benchmark.columns or 'label' not in benchmark.columns:
            raise ValueError("Columns 'claim' and 'label' must be present in the benchmark")
        return benchmark
    
    def predict(self, benchmark):
        """
        Predict classes on the benchmark.

        :param X_test: Test features
        :return: Classifier predictions
        """
        claims = benchmark['claim']
        return self.classifier.predict(claims)
    
    def generate_classification_report(self, benchmark, y_pred):
        """
        Generate a classification report.

        :param y_test: True labels
        :param y_pred: Predictions
        :return: DataFrame of the classification report
        """
        report = classification_report(benchmark['label'], y_pred, output_dict=True)
        report_df = pd.DataFrame(report).transpose()
        print("Classification Report:\n", report_df)
        return report_df

    def plot_confusion_matrix(self, y, y_pred, labels):
        """
        Generate and display a confusion matrix.

        :param y_test: True labels
        :param y_pred: Predictions
        :param categories: List of classification categories
        """
        cm = confusion_matrix(y, y_pred, labels=labels)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
        plt.xlabel("Predicted Categories")
        plt.ylabel("True Categories")
        plt.title("Confusion Matrix")
        plt.show()

    def evaluate(self, file_path):
        """
        Run the complete evaluation: loading benchmark, predicting  with classifier, and evaluating.

        :param file_path: Path to the csv file containing the benchmark
        """
        # Load the data
        benchmark = self.load_data(file_path)
        y = benchmark['label']

        # Predict
        y_pred = self.predict(benchmark)
        labels = self.unique(self.y)

        # Generate the classification report
        self.generate_classification_report(y, y_pred)

        # Generate and display the confusion matrix
        self.plot_confusion_matrix(y, y_pred, labels)


### Test with a random classifier

In [3]:
benchmark_path = "../data/benchmark/cards_sample_1000.csv"