# snr-consistency / decision-margin consistency (dmc)

References

- Spearman, C. (1910). Correlation calculated from faulty data. British Journal of Psychology, 3(3), 271-295.
- Brown, W. (1910). Some experimental results in the correlation of mental abilities. British Journal of Psychology, 3(3), 296-322.
- Nunnally, J. C., & Bernstein, I. H. (1994). Psychometric Theory (3rd ed.). McGraw-Hill. (See discussions on reliability estimation and the Spearman-Brown formula.)


In [None]:
import os
import pandas as pd
import numpy as np
from glob import glob
from itertools import combinations
from scipy.stats import pearsonr
from collections import defaultdict
from tqdm import tqdm

def load_human_data(data_dir):
    files = sorted(glob(os.path.join(data_dir, "*subject-*")))
    df = None
    for file in files:
        df_ = pd.read_csv(file)
        df_ = df_.sort_values(by='imagename')
        df_['filename'] = df_.imagename.apply(lambda x: "_".join(x.split("_")[-2:]))
        df_['is_correct'] = (df_.object_response==df_.category).astype(float)
        df = pd.concat([df, df_])
        
    return df

def load_model_data(data_dir, model_name):
    files = sorted(glob(os.path.join(data_dir, f"*{model_name}*")))
    assert len(files)==1, f"Expected one file, got {files}"
    file = files[0]
    df_ = pd.read_csv(file)
    df_ = df_.sort_values(by='imagename')
    df_['filename'] = df_.imagename.apply(lambda x: "_".join(x.split("_")[-2:]))
    df_['is_correct'] = (df_.object_response==df_.category).astype(float)
        
    return df_

def get_split_halves(N):
    subjects = list(range(0,N))
    splits = []
    for subsetA in combinations(subjects, N//2):
        subsetA = list(subsetA)
        subsetB = list(np.setdiff1d(subjects, subsetA))
        assert len(np.setdiff1d(subsetA,subsetB)) == len(subsetA), "oops"
        assert len(np.setdiff1d(subsetB,subsetA)) == len(subsetB), "oops"
        assert (len(subsetA) + len(subsetB)) == N, f"oops, total should be {N}"
        splits.append((subsetA,subsetB))
    
    return splits[0:len(splits)//2] if N%2==0 else splits

def error_consistency(expected_consistency, observed_consistency):
        """Return error consistency as measured by Cohen's kappa."""

        assert expected_consistency >= 0.0
        assert expected_consistency <= 1.0
        assert observed_consistency >= 0.0
        assert observed_consistency <= 1.0

        if observed_consistency == 1.0:
            return 1.0
        else:
            return (observed_consistency - expected_consistency) / (1.0 - expected_consistency)
    
def expected_consistency(df1, df2):
    p1 = df1.is_correct.mean()
    p2 = df2.is_correct.mean()
    expected_consistency = p1 * p2 + (1 - p1) * (1 - p2)
    
    return expected_consistency, p1, p2

def observed_consistency(df1, df2):
    return (df1.is_correct == df2.is_correct).sum() / len(df1)

def compute_error_consistency(df1, df2):
    expected_con, p1, p2 = expected_consistency(df1, df2)
    observed_con = observed_consistency(df1, df2)
    error_con = error_consistency(expected_con, observed_con)
    return expected_con, observed_con, error_con

def compute_human_vs_model_error_consistency(human, model):
    human_subjects = human.subj.unique()
    human_cond = human.condition.unique()
    model_subjects = model.subj.unique()
    model_cond = model.condition.unique()
    assert (human_cond == model_cond).all(), "Human and Model data must contain the same conditions"
    
    results = defaultdict(list)
    for human_subj in human_subjects:        
        for model_subj in model_subjects:
            for condition in conditions:
                df1 = human[(human.subj == human_subj) & (human.condition==condition)].sort_values(by='filename').reset_index(drop=True)
                df2 = model[(model.subj == model_subj) & (model.condition==condition)].sort_values(by='filename').reset_index(drop=True)
                expected_con, p1, p2 = expected_consistency(df1, df2)
                observed_con = observed_consistency(df1, df2)
                error_con = error_consistency(expected_con, observed_con)
                
                results['condition'].append(condition)
                results['human_subj'].append(human_subj)
                results['model_subj'].append(model_subj)
                
                results['human_pct_correct'].append(p1)
                results['model_pct_correct'].append(p2)
                
                results['expected_consistency'].append(expected_con)
                results['observed_consistency'].append(observed_con)
                results['error_consistency'].append(error_con)
    
    return pd.DataFrame(results)

## Colour vs. greyscale

Number of human subjects is only 4 :(

In [None]:
data_dir = os.path.join(os.environ['MODELVSHUMANDIR'], 'raw-data', 'colour')
data_dir

In [None]:
df = load_human_data(data_dir)
df

In [None]:
subjects = df.subj.unique()
subjects

In [None]:
sub1 = df[df.subj == subjects[0]]
sub1

In [None]:
sub1[sub1.imagename=="0001_cl_s01_cr_oven_40_n04111531_14126.png"]

In [None]:
sub2 = df[df.subj == subjects[1]]
sub2

In [None]:
sub2[sub2.imagename=="0001_cl_s01_cr_oven_40_n04111531_14126.png"]

In [None]:
df[df.imagename=="0001_cl_s01_cr_oven_40_n04111531_14126.png"]

In [None]:
df['filename'].value_counts().unique()

In [None]:
splits = get_split_halves(len(subjects))
splits

In [None]:
results = defaultdict(list)
conditions = df.condition.unique()
groupby = ['condition', 'filename']
correlations = defaultdict(list)
for split_num, (splitA,splitB) in enumerate(splits):
    subA = subjects[splitA]
    subB = subjects[splitB]
    dfA = df[df.subj.isin(subA)]
    dfB = df[df.subj.isin(subB)]
    
    grouped_A = dfA.groupby(groupby)['is_correct'].mean().reset_index()
    grouped_A.rename(columns={'is_correct': 'mean_correct_A'}, inplace=True)
    
    grouped_B = dfB.groupby(groupby)['is_correct'].mean().reset_index()
    grouped_B.rename(columns={'is_correct': 'mean_correct_B'}, inplace=True)
    
    merged_df = pd.merge(grouped_A, grouped_B, on=groupby, how='outer')
    merged_df_sorted = merged_df.sort_values(by='filename').reset_index(drop=True)
    
    for condition in conditions:
        cond_df = merged_df_sorted[merged_df_sorted.condition == condition]
        r = pearsonr(cond_df.mean_correct_A, cond_df.mean_correct_B)[0]
        correlations[condition].append(r)
        results['split_num'].append(split_num)
        results['splitA'].append(subA)
        results['splitB'].append(subB)
        results['condition'].append(condition)
        results['pearsonr'].append(r)
        
for condition in conditions:
    adjusted_correlations = [(2 * r) / (1 + r) for r in correlations[condition]]
    avg_split_half_corr = np.mean(correlations[condition])
    noise_ceiling = np.mean(adjusted_correlations)
    print(condition, avg_split_half_corr, noise_ceiling)

In [None]:
res_df = pd.DataFrame(results)
res_df

In [None]:
res_summary = res_df.groupby(by='condition')['pearsonr'].mean().reset_index()
res_summary.rename(columns={'pearsonr': 'avg_split_half_corr'}, inplace=True)
r = res_summary['avg_split_half_corr']
res_summary['noise_ceiling'] = (2 * r) / (1 + r)
res_summary

In [None]:
results = defaultdict(list)
conditions = df.condition.unique()

# compute error consistency
num_subjects = len(subjects)
for condition in conditions:
    for idx1 in range(0,num_subjects-1):
        sub1 = subjects[idx1]
        df1 = df[(df.subj==sub1) & (df.condition==condition)]
        df1 = df1.sort_values(by='filename').reset_index(drop=True)
        for idx2 in range(idx1+1, num_subjects):
            sub2 = subjects[idx2]
            df2 = df[(df.subj==sub2)  & (df.condition==condition)]
            df2 = df2.sort_values(by='filename').reset_index(drop=True)
            (df1.filename == df2.filename).all(), "Dataframe filenames not aligned"

            expected_con, p1, p2 = expected_consistency(df1, df2)
            observed_con = observed_consistency(df1, df2)
            error_con = error_consistency(expected_con, observed_con)
            
            results['condition'].append(condition)
            results['subj1'].append(sub1)
            results['subj2'].append(sub1)
            results['pct_correct_subj1'].append(p1)
            results['pct_correct_subj2'].append(p2)
            results['expected_consistency'].append(expected_con)
            results['observed_consistency'].append(observed_con)
            results['error_consistency'].append(error_con)
            
error_con_df = pd.DataFrame(results)
error_con_df

In [None]:
error_con_summary = error_con_df.groupby(by='condition')['error_consistency'].mean().reset_index()
error_con_summary.rename(columns={'error_consistency': 'avg_error_consistency'}, inplace=True)
error_con_summary

In [None]:
df2

# Next we need to reproduce the "raw-data" for models

colour_vit-b-16_session-1.csv

In [None]:
import math
import pandas as pd
from modelvshuman.helper import wordnet_functions as wnf
from modelvshuman.helper import human_categories as hc
from modelvshuman.datasets.decision_mappings import DecisionMapping
from pdb import set_trace

class ResultAgg:
    def __init__(self, model_name, dataset):
        self.model_name = model_name
        self.dataset = dataset
        self.decision_mapping = self.dataset.decision_mapping
        self.info_mapping = self.dataset.info_mapping
        self.session_list = []
        self.results = []
        self.index = 0  # Initialize trial index

    def print_batch(self, object_response, batch_targets, paths,
                    target_act, max_non_target_act, decision_margin_act,
                    target_prob, max_non_target_prob, decision_margin_prob):
        """
        Aggregates batch results into the internal results list.

        Parameters:
        - object_response: List of model responses.
        - batch_targets: List of target values (unused in this function but kept for compatibility).
        - paths: List of file paths corresponding to each response.
        - target_act: output activation (logit) for the target category
        - max_non_target_act: max activation (logit) among non-target categories
        - decision_margin_act: shorted distance from the (target_act, max_non_target_act) point to 
                           the decision boundary (unit line)
        """
        for idx,(response, target, path) in enumerate(zip(object_response, batch_targets, paths)):
            session_name, img_name, condition, category = self.info_mapping(path)
            session_num = int(session_name.split("-")[-1])

            if session_num not in self.session_list:
                self.session_list.append(session_num)
                self.index = 0  # Reset index for new session

            self.index += 1  # Increment trial index
            
            # Collect data into the results list
            row = {
                "subj": self.model_name,
                "session": str(session_num),
                "trial": str(self.index),
                "rt": "NaN",  # Reaction time is not applicable here
                "object_response": response[0],
                "category": category,
                "condition": condition,
                "imagename": img_name,
                "filename": "_".join(img_name.split("_")[-2:]),
                "is_correct": float(response[0] == category),
                "target_act": target_act[idx],
                "max_non_target_act": max_non_target_act[idx], 
                "decision_margin_act": decision_margin_act[idx],
                "target_prob": target_prob[idx],
                "max_non_target_prob": max_non_target_prob[idx], 
                "decision_margin_prob": decision_margin_prob[idx]
            }
            self.results.append(row)

    def as_dataframe(self):
        """
        Converts the aggregated results into a pandas DataFrame.

        Returns:
        - pd.DataFrame: DataFrame containing the aggregated results.
        """
        return pd.DataFrame(self.results)
    
class ImageNetProbabilitiesTo16ClassesMappingWithSortedProbs(DecisionMapping):
    """Return the 16 class categories sorted by probabilities, and the sorted probabilities"""

    def __init__(self, aggregation_function=None):
        if aggregation_function is None:
            aggregation_function = np.mean
        self.aggregation_function = aggregation_function
        self.categories = hc.get_human_object_recognition_categories()

    def __call__(self, logits, probabilities):
        # Ensure that logits and probabilities are valid and have matching shapes
        self.check_input(probabilities)
        assert logits.shape == probabilities.shape, "Logits and probabilities must have the same shape."
        
        aggregated_class_probabilities = []
        aggregated_class_logits = []
        c = hc.HumanCategories()
    
        for category in self.categories:
            indices = c.get_imagenet_indices_for_category(category)
            # Aggregate probabilities
            prob_values = np.take(probabilities, indices, axis=-1)
            aggregated_prob = self.aggregation_function(prob_values, axis=-1)
            aggregated_class_probabilities.append(aggregated_prob)

        # Convert lists to arrays and transpose to shape (batch_size, 16)
        aggregated_class_probabilities = np.array(aggregated_class_probabilities).T  # Shape: (batch_size, 16)
    
        # Sort the aggregated probabilities to get sorted indices
        sorted_indices = np.flip(np.argsort(aggregated_class_probabilities, axis=-1), axis=-1)  # Shape: (batch_size, 16)
    
        # Use sorted indices to sort categories, logits, and probabilities
        sorted_categories = np.take(self.categories, sorted_indices, axis=-1)  # Shape: (batch_size, 16)
        sorted_probs = np.take_along_axis(aggregated_class_probabilities, sorted_indices, axis=-1)  # Shape: (batch_size, 16)
        
        return sorted_categories, sorted_probs
    
class ImageNetActivationsTo16ClassesMappingWithSortedProbsAndLogits(DecisionMapping):
    """Return the 16 class categories sorted by probabilities,
        and include the sorted probs and logits. To ensure logits/probs agree on
        correct class, must aggregate logits first, then compute probs on aggregated logits.        
    """

    def __init__(self, aggregation_function=None):
        if aggregation_function is None:
            aggregation_function = np.mean
        self.aggregation_function = aggregation_function
        self.categories = hc.get_human_object_recognition_categories()

    def __call__(self, logits, softmax):
        aggregated_class_logits = []
        c = hc.HumanCategories()
    
        for category in self.categories:
            indices = c.get_imagenet_indices_for_category(category)
            # Aggregate logits
            logits_values = np.take(logits, indices, axis=-1)
            aggregated_logit = self.aggregation_function(logits_values, axis=-1)
            aggregated_class_logits.append(aggregated_logit)
    
        # Convert lists to arrays and transpose to shape (batch_size, 16)
        aggregated_class_logits = np.array(aggregated_class_logits).T  # Shape: (batch_size, 16)
        aggregated_class_probabilities = softmax(aggregated_class_logits)

        # Sort the aggregated probabilities to get sorted indices
        sorted_indices = np.flip(np.argsort(aggregated_class_probabilities, axis=-1), axis=-1)  # Shape: (batch_size, 16)
    
        # Use sorted indices to sort categories, logits, and probabilities
        sorted_categories = np.take(self.categories, sorted_indices, axis=-1)  # Shape: (batch_size, 16)
        sorted_probs = np.take_along_axis(aggregated_class_probabilities, sorted_indices, axis=-1)  # Shape: (batch_size, 16)
        sorted_logits = np.take_along_axis(aggregated_class_logits, sorted_indices, axis=-1)  # Shape: (batch_size, 16)
        
        return sorted_categories, sorted_probs, sorted_logits
    
def compute_decision_margin_distance(act1, act2):
    return signed_distance_to_unit_line(act1, act2)

def signed_distance_to_unit_line(xi, yi):
    '''Calculate the distance between the point (xi, yi) and the line x=y
        distance point (x0,y0) to line (ax + by + c = 0):
        abs(a * x0 + b * y0 + c) / sqrt(a^2 + b^2)
        https://www.mathportal.org/calculators/analytic-geometry/line-point-distance.php
    '''
    distance = (xi-yi) / math.sqrt(2)
    return distance    

In [None]:
import torch
from modelvshuman.utils import load_dataset, load_model

def device():
    return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def run_analysis(model_name, dataset_name, num_workers=len(os.sched_getaffinity(0)), batch_size=32, use_logits=True):
    data_dir = os.path.join(os.environ['MODELVSHUMANDIR'], 'raw-data', 'colour')
    model, framework = load_model(model_name)
    dataset = load_dataset(dataset_name, num_workers=num_workers, batch_size=batch_size)
    results_agg = ResultAgg(model_name, dataset)
    
    if use_logits:
        decision_mapping = ImageNetActivationsTo16ClassesMappingWithIndices(aggregation_function=np.mean)
    else:
        decision_mapping = ImageNetProbabilitiesTo16ClassesMappingWithIndices(aggregation_function=np.mean)
    
    for metric in dataset.metrics:
        metric.reset()

    for images, target, paths in tqdm(dataset.loader):
        bs = len(images)
        images = images.to(device())
        logits = model.forward_batch(images)
        
        if isinstance(target, torch.Tensor):
            batch_targets = model.to_numpy(target)
        else:
            batch_targets = target
            
        if use_logits:
            predictions, sorted_logits, sorted_probs = decision_mapping(logits, model.softmax)
        else:
            softmax_output = model.softmax(logits)
            predictions = dataset.decision_mapping(softmax_output)
            preds, sorted_logits, sorted_probs = decision_mapping(logits, softmax_output)
            assert (preds==predictions).all()
        
        target_mask = predictions == np.array(batch_targets)[:, np.newaxis]
        non_targets = ~target_mask
        
        target_act = sorted_logits[target_mask]
        non_target_act = np.where(non_targets, sorted_logits, np.nan)
        max_non_target_act = np.nanmax(non_target_act, axis=1)
        decision_margin_act = compute_decision_margin_distance(target_act, max_non_target_act)

        target_prob = sorted_probs[target_mask]
        non_target_prob = np.where(non_targets, sorted_probs, np.nan)
        max_non_target_prob = np.nanmax(non_target_prob, axis=1)
        decision_margin_prob = compute_decision_margin_distance(target_prob, max_non_target_prob)

        for metric in dataset.metrics:
            metric.update(predictions,
                          batch_targets,
                          paths)

        # Aggregate the batch results
        results_agg.print_batch(predictions, 
                                batch_targets, 
                                paths,
                                target_act,
                                max_non_target_act,
                                decision_margin_act,
                                target_prob,
                                max_non_target_prob,
                                decision_margin_prob)
        
    return dataset, results_agg

In [None]:
dataset_name = "colour"

model_name = "vit_b_16"
# model_name = "alexnet2023_baseline_pgd"
# model_name = "resnet18"
# model_name = "resnet50"
# model_name = "resnet50_l2_eps0_01"
dataset2, results_agg2 = run_analysis(model_name, dataset_name, num_workers=len(os.sched_getaffinity(0)), batch_size=32,
                                      use_logits=False)

for metric in dataset2.metrics:
    print(metric)

In [None]:
results2 = results_agg2.as_dataframe()
results2

In [None]:
model_name = "alexnet"
# model_name = "resnet34"
# model_name = "resnet50_trained_on_SIN"
# model_name = "resnet50_l2_eps0_03"
        # "alexnet2023_baseline_pgd",
        # "resnet50_l2_eps0",
        # "resnet50_l2_eps0_01",
        # "resnet50_l2_eps0_03",
        
dataset_name = "colour"
dataset3, results_agg3 = run_analysis(model_name, dataset_name, num_workers=len(os.sched_getaffinity(0)), batch_size=32,
                                      use_logits=False)

for metric in dataset3.metrics:
    print(metric)

In [None]:
results3 = results_agg3.as_dataframe()
results3

In [None]:
human_avg = df.groupby(by=['condition','filename'])['is_correct'].mean().reset_index()

for condition in df.condition.unique():
    human = human_avg[human_avg.condition==condition].sort_values(by='filename').reset_index(drop=True)
    model1 = results2[results2.condition==condition].sort_values(by='filename').reset_index(drop=True)
    model2 = results3[results3.condition==condition].sort_values(by='filename').reset_index(drop=True)
    
    corr1_act = pearsonr(human.is_correct, model1.decision_margin_act)[0]
    corr1_prob = pearsonr(human.is_correct, model1.decision_margin_prob)[0]    
    print(f"model1vshuman, {condition}, {corr1_act:3.3f}, {corr1_prob:3.3f}")
    
    corr2_act = pearsonr(human.is_correct, model2.decision_margin_act)[0]
    corr2_prob = pearsonr(human.is_correct, model2.decision_margin_prob)[0]
    print(f"model2vshuman, {condition}, {corr2_act:3.3f}, {corr2_prob:3.3f}")
    
    corr3_act = pearsonr(model1.decision_margin_act, model2.decision_margin_act)[0]
    corr3_prob = pearsonr(model1.decision_margin_prob, model2.decision_margin_prob)[0]
    print(f"modelvsmodel, {condition}, {corr3_act:3.3f}, {corr3_prob:3.3f}")


In [None]:
err_con1 = compute_human_vs_model_error_consistency(df, results2)
err_con1.groupby(by=['condition'])['error_consistency'].mean()

In [None]:
err_con1.groupby(by=['condition'])['human_pct_correct'].mean()

In [None]:
err_con2 = compute_human_vs_model_error_consistency(df, results3)
err_con2.groupby(by=['condition'])['error_consistency'].mean()

In [None]:
err_con3 = compute_human_vs_model_error_consistency(results2, results3)
err_con3.groupby(by=['condition'])['error_consistency'].mean()

In [None]:
df.condition.unique(), df.condition.unique()

In [None]:
%config InlineBackend.figure_format='retina'

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

results = [
    dict(Color='color', correlation=0.7908335570106865, label="human-vs-human"),
    dict(Color='grayscale', correlation=0.7445981495245118, label="human-vs-human"),
  
    dict(Color='color', correlation=0.071, label="vit-b-16 vs human"),
    dict(Color='grayscale', correlation=0.124, label="vit-b-16 vs human"),
    
    dict(Color='color', correlation=0.232, label="alexnet vs human"),
    dict(Color='grayscale', correlation=0.253, label="alexnet vs human"),
    
    dict(Color='color', correlation=0.664, label="model-vs-model"),
    dict(Color='grayscale', correlation=0.549, label="model-vs-model"),
]
res = pd.DataFrame(results)

plt.figure(figsize=(10, 6))  # Increased width for space on the right
ax = sns.lineplot(data=res, x="Color", y="correlation", hue="label")
plt.legend(title='Day of Week', bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
ax.set_ylim([0,1])

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

results = [
    dict(Color='color', correlation=0.7908335570106865, label="human-vs-human"),
    dict(Color='grayscale', correlation=0.7445981495245118, label="human-vs-human"),
  
    dict(Color='color', correlation=0.080, label="vit-b-16 vs human"),
    dict(Color='grayscale', correlation=0.156, label="vit-b-16 vs human"),
    
    dict(Color='color', correlation=0.337, label="alexnet vs human"),
    dict(Color='grayscale', correlation=0.366, label="alexnet vs human"),
    
    dict(Color='color', correlation=0.504, label="model-vs-model"),
    dict(Color='grayscale', correlation=0.526, label="model-vs-model"),
]
res = pd.DataFrame(results)

plt.figure(figsize=(10, 6))  # Increased width for space on the right
ax = sns.lineplot(data=res, x="Color", y="correlation", hue="label")
plt.legend(title='Day of Week', bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
ax.set_ylim([0,1])

In [None]:
model1

In [None]:
human_avg.filename

In [None]:
human.filename == model1.filename

In [None]:
model_name = "vit-b-16"
data_dir = os.path.join(os.environ['MODELVSHUMANDIR'], 'raw-data', 'colour')
results1 = load_model_data(data_dir, model_name)
results1

In [None]:
model_name = "vit_b_16"
model, framework = load_model(model_name)
print(framework)
model

In [None]:
dataset_name = "colour"
dataset = load_dataset(dataset_name, num_workers=len(os.sched_getaffinity(0)), batch_size=32)
dataset

In [None]:
results_agg = ResultAgg(model_name, dataset)
results_agg

In [None]:
for metric in dataset.metrics:
    metric.reset()
    
decision_mapping = ImageNetActivationsTo16ClassesMappingWithIndices(aggregation_function=np.mean)
        
for images, target, paths in tqdm(dataset.loader):
    bs = len(images)
    images = images.to(device())
    logits = model.forward_batch(images)
    softmax_output = model.softmax(logits)
    if isinstance(target, torch.Tensor):
        batch_targets = model.to_numpy(target)
    else:
        batch_targets = target
    # predictions = dataset.decision_mapping(softmax_output)
    predictions, sorted_logits, sorted_probs = decision_mapping(logits, model.softmax)
    
    target_mask = predictions == np.array(batch_targets)[:, np.newaxis]
    non_targets = ~target_mask
    target_act = sorted_logits[target_mask]
    non_target_act = np.where(non_targets, sorted_logits, np.nan)
    max_non_target_act = np.nanmax(non_target_act, axis=1)
    decision_margin_act = compute_decision_margin_distance(target_act, max_non_target_act)
    
    target_prob = sorted_probs[target_mask]
    non_target_prob = np.where(non_targets, sorted_probs, np.nan)
    max_non_target_prob = np.nanmax(non_target_prob, axis=1)
    decision_margin_prob = compute_decision_margin_distance(target_prob, max_non_target_prob)
    
    for metric in dataset.metrics:
        metric.update(predictions,
                      batch_targets,
                      paths)
        
    # Aggregate the batch results
    results_agg.print_batch(predictions, 
                            batch_targets, 
                            paths,
                            target_act,
                            max_non_target_act,
                            decision_margin_act,
                            target_prob,
                            max_non_target_prob,
                            decision_margin_prob)
    

In [None]:
for metric in dataset.metrics:
    print(metric)

In [None]:
results2 = results_agg.as_dataframe()
results2

In [None]:
results2.is_correct.sum(), (results2.decision_margin_prob>0).sum()

In [None]:
results1

In [None]:
(results1.filename == results2.filename).all()

In [None]:
(results1.is_correct == results2.is_correct).all()

In [None]:
human_avg = df.groupby(by=['condition','filename'])['is_correct'].mean().reset_index()
human_avg

In [None]:
dm_prob = results2.groupby(by=['condition','filename'])[['decision_margin_act', 
                                                         'decision_margin_prob']].mean().reset_index()
dm_prob

In [None]:
human = human_avg[human_avg.condition=='bw']
human

In [None]:
dm_prob = dm_prob[dm_prob.condition=='bw']
dm_prob

In [None]:
(human.filename==dm_prob.filename).all()

In [None]:
# human.is_correct.mean()

In [None]:
pearsonr(human.is_correct, dm_prob.decision_margin_prob)[0]

In [None]:
pearsonr(human.is_correct, dm_prob.decision_margin_act)[0]

# alexnet

In [None]:
model_name = "alexnet"
model, framework = load_model(model_name)
print(framework)
model

In [None]:
dataset_name = "colour"
dataset = load_dataset(dataset_name, num_workers=len(os.sched_getaffinity(0)), batch_size=32)
dataset

In [None]:
results_agg = ResultAgg(model_name, dataset)
results_agg

In [None]:
ResultAgg??

In [None]:
for metric in dataset.metrics:
    metric.reset()
    
decision_mapping = ImageNetProbabilitiesTo16ClassesMappingWithIndices(aggregation_function=np.mean)
        
for images, target, paths in tqdm(dataset.loader):
    bs = len(images)
    images = images.to(device())
    logits = model.forward_batch(images)
    softmax_output = model.softmax(logits)
    if isinstance(target, torch.Tensor):
        batch_targets = model.to_numpy(target)
    else:
        batch_targets = target
    predictions = dataset.decision_mapping(softmax_output)
    preds, sorted_logits, sorted_probs = decision_mapping(logits, logits)
    # assert (preds == predictions).all(), "oops, predictions are wrongo"
    
    target_mask = preds == np.array(batch_targets)[:, np.newaxis]
    non_targets = ~target_mask
    target_act = sorted_logits[target_mask]
    non_target_act = np.where(non_targets, sorted_logits, np.nan)
    max_non_target_act = np.nanmax(non_target_act, axis=1)
    decision_margin_act = compute_decision_margin_distance(target_act, max_non_target_act)
    
    target_prob = sorted_probs[target_mask]
    non_target_prob = np.where(non_targets, sorted_probs, np.nan)
    max_non_target_prob = np.nanmax(non_target_prob, axis=1)
    decision_margin_prob = compute_decision_margin_distance(target_prob, max_non_target_prob)
    
    for metric in dataset.metrics:
        metric.update(predictions,
                      batch_targets,
                      paths)
        
    # Aggregate the batch results
    results_agg.print_batch(preds, 
                            batch_targets, 
                            paths,
                            target_act,
                            max_non_target_act,
                            decision_margin_act,
                            target_prob,
                            max_non_target_prob,
                            decision_margin_prob)

In [None]:
results3 = results_agg.as_dataframe()
results3

In [None]:
dm_prob3 = results3.groupby(by=['condition','filename'])[['decision_margin_act', 
                                                         'decision_margin_prob']].mean().reset_index()
dm_prob3

In [None]:
dm_prob3 = dm_prob3[dm_prob3.condition=='bw']
dm_prob3

In [None]:
(human.filename==dm_prob3.filename).all()

In [None]:
pearsonr(human.is_correct, dm_prob3.decision_margin_prob)[0]

In [None]:
pearsonr(human.is_correct, dm_prob3.decision_margin_act)[0]

In [None]:
pearsonr(dm_prob.decision_margin_act, dm_prob3.decision_margin_act)[0]

In [None]:
sns.scatterplot(x=dm_prob.decision_margin_act, y=dm_prob3.decision_margin_act)

In [None]:
sns.scatterplot(x=human.is_correct, y=dm_prob3.decision_margin_act)