In [1]:
import pandas as pd
import numpy as np
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, RocCurveDisplay

sns.set()

In [2]:
def pfbeta(labels, predictions, beta=1):
    y_true_count = 0
    ctp = 0
    cfp = 0

    for idx in range(len(labels)):
        prediction = min(max(predictions[idx], 0), 1)
        if (labels[idx]):
            y_true_count += 1
            ctp += prediction
        else:
            cfp += prediction

    beta_squared = beta * beta
    c_precision = ctp / (ctp + cfp)
    c_recall = ctp / y_true_count
    if (c_precision > 0 and c_recall > 0):
        result = (1 + beta_squared) * (c_precision * c_recall) / (beta_squared * c_precision + c_recall)
        return result
    else:
        return 0

In [3]:
def get_part_metrics(df: pl.DataFrame, threshold=0.3) -> dict:
    df = df.with_columns((df["preds"] > threshold).alias("preds_bin"))
    metrics = {}
    # binary metrics using the threshold
    metrics["accuracy"] = accuracy_score(df["labels"].to_numpy(), df["preds_bin"].to_numpy())
    metrics["precision"] = precision_score(df["labels"].to_numpy(), df["preds_bin"].to_numpy())
    metrics["recall"] = recall_score(df["labels"].to_numpy(), df["preds_bin"].to_numpy())
    metrics["f1"] = f1_score(df["labels"].to_numpy(), df["preds_bin"].to_numpy())
    # probabilistic F1 (doesn't depend on the threshold)
    metrics["pf1"] = pfbeta(df["labels"].to_numpy(), df["preds"].to_numpy())
    # ROC AUC
    metrics["roc_auc"] = roc_auc_score(df["labels"].to_numpy(), df["preds"].to_numpy())
    return metrics


def get_all_metrics(df: pl.DataFrame, threshold=0.3) -> pd.DataFrame:
    groups = [list(range(5)), [0, 1], [0, 4], [0, 2], [0, 3]]
    group_names = ["all", "StableDiffusion", "Midjourney", "Dalle2", "Dalle3"]
    all_metrics = []
    for i, g in enumerate(groups):
        subset = df.filter(pl.col("domains").is_in(g))
        metrics = get_part_metrics(subset, threshold=threshold)
        metrics["group"] = group_names[i]
        all_metrics.append(metrics)
    
    return pd.DataFrame(all_metrics)

In [4]:
df1 = pl.read_csv("outputs/preds-image-classifier-1.csv")
metrics_df1 = get_all_metrics(df1, threshold=0.5)

In [5]:
metrics_df1

Unnamed: 0,accuracy,precision,recall,f1,pf1,roc_auc,group
0,0.922883,0.905793,0.885671,0.895619,0.862582,0.978179,all
1,0.942132,0.763441,0.926759,0.837209,0.79686,0.985916,StableDiffusion
2,0.939611,0.751802,0.909746,0.823267,0.77424,0.981999,Midjourney
3,0.931319,0.636029,0.814597,0.714323,0.648632,0.965689,Dalle2
4,0.935942,0.617021,0.848404,0.714446,0.651111,0.971403,Dalle3


In [6]:
df14 = pl.read_csv("outputs/preds-image-classifier-14.csv")
metrics_df14 = get_all_metrics(df14, threshold=0.5)

In [7]:
metrics_df14

Unnamed: 0,accuracy,precision,recall,f1,pf1,roc_auc,group
0,0.948212,0.949066,0.910212,0.929233,0.905132,0.989983,all
1,0.9621,0.857482,0.916244,0.88589,0.847167,0.990723,StableDiffusion
2,0.967343,0.856164,0.948047,0.899766,0.860124,0.993656,Midjourney
3,0.956384,0.771242,0.833431,0.801132,0.749605,0.982562,Dalle2
4,0.966024,0.767055,0.919548,0.836408,0.7782,0.99057,Dalle3


In [8]:
df142 = pl.read_csv("outputs/preds-image-classifier-142.csv")
metrics_df142 = get_all_metrics(df142, threshold=0.5)

In [9]:
metrics_df142

Unnamed: 0,accuracy,precision,recall,f1,pf1,roc_auc,group
0,0.959943,0.931624,0.96348,0.947284,0.925286,0.993246,all
1,0.957618,0.812693,0.95649,0.878748,0.841169,0.991817,StableDiffusion
2,0.961245,0.809524,0.979901,0.886601,0.849104,0.995908,Midjourney
3,0.957067,0.726496,0.950559,0.823559,0.769529,0.991425,Dalle2
4,0.958237,0.704136,0.962101,0.81315,0.757835,0.993257,Dalle3


In [10]:
df1423 = pl.read_csv("outputs/preds-image-classifier-1423.csv")
metrics_df1423 = get_all_metrics(df1423, threshold=0.5)

In [11]:
metrics_df1423

Unnamed: 0,accuracy,precision,recall,f1,pf1,roc_auc,group
0,0.965634,0.961898,0.945452,0.953604,0.935444,0.99484,all
1,0.968388,0.887373,0.919869,0.903329,0.865455,0.992069,StableDiffusion
2,0.974261,0.8867,0.955631,0.919876,0.885934,0.99582,Midjourney
3,0.974997,0.834021,0.952325,0.889255,0.847135,0.99567,Dalle2
4,0.976638,0.818694,0.966755,0.886585,0.842263,0.997264,Dalle3
