In [None]:
import json
import pickle
import numpy as np
from typing import Dict, List, Tuple

EPS = 1e-6

In [None]:
def load_json(file_path: str) -> Dict:
    with open(file_path, 'r') as f:
        return json.load(f)
    
def load_pkl(file_path: str) -> Dict:
    with open(file_path, 'rb') as f:
        return pickle.load(f)



file1 = 'llm_autointerp/manual_labels_can_final.json'
file2 = 'llm_autointerp/llm_results_can_final_sonnet.pkl'

manual_file = load_json(file1)
llm_file = load_pkl(file2)

In [None]:
sample_idx = 1

print(f'##### Example Prompts\n {manual_file[str(sample_idx)]['example_prompts']}\n\n')
print(f'##### Manual chain of thought\n{manual_file[str(sample_idx)]['chain_of_thought']}\n\n')
print(f'##### LLM chain of thought\n{llm_file[sample_idx][0]}\n\n')
print(f'manual labels {manual_file[str(sample_idx)]['per_class_scores']}')
print(f'LLM labels    {llm_file[sample_idx][1]}')

In [None]:
def extract_scores_manual(data: Dict, is_valid: List[bool]) -> Dict[str, List[int]]:
    manual_labels = {}
    for i, item in enumerate(data.values()):
        if is_valid[i]:
            for category, score in item['per_class_scores'].items():
                if category not in manual_labels:
                    manual_labels[category] = []
                manual_labels[category].append(score)
    return manual_labels

def extract_scores_llm(data: List[Tuple[str, Dict[str, int], bool, str]]) -> Dict[str, List[int]]:
    is_valid = []
    result = {}
    for item in data:
        scores = item[1]  # The scores dictionary is the second element of each tuple
        if scores is None:
            is_valid.append(False)
        else:
            is_valid.append(True)
            for category, score in scores.items():
                if category not in result:
                    result[category] = []
                result[category].append(score)
    return result, is_valid


llm_labels, is_valid = extract_scores_llm(llm_file)
manual_labels = extract_scores_manual(manual_file, is_valid)

print(llm_labels['gender'][:10])
print(manual_labels['gender'][:10])

In [None]:
def extract_scores_manual(data: Dict, is_valid: List[bool]) -> Dict[str, List[int]]:
    manual_labels = {}
    for i, item in enumerate(data.values()):
        if is_valid[i]:
            for category, score in item['per_class_scores'].items():
                if category not in manual_labels:
                    manual_labels[category] = []
                manual_labels[category].append(score)
    return manual_labels

def extract_scores_llm(data: List[Tuple[str, Dict[str, int], bool, str]]) -> Dict[str, List[int]]:
    is_valid = []
    result = {}
    for item in data:
        scores = item[1]  # The scores dictionary is the second element of each tuple
        if scores is None:
            is_valid.append(False)
        else:
            is_valid.append(True)
            for category, score in scores.items():
                if category not in result:
                    result[category] = []
                result[category].append(score)
    return result, is_valid

def cohens_kappa(scores1: Dict[str, List[int]], scores2: Dict[str, List[int]]) -> Dict[str, float]:
    def kappa(a: np.ndarray, b: np.ndarray) -> float:
        n = len(a)
        categories = np.unique(np.concatenate([a, b]))
        n_categories = len(categories)
        
        # Observed agreement
        observed = np.sum(a == b) / n
        
        # Expected agreement
        expected = sum((np.sum(a == i) / n) * (np.sum(b == i) / n) for i in categories)
        
        # Compute kappa
        kappa = (observed - expected) / (1 - expected + EPS)
        return kappa

    results = {}
    for category in scores1.keys():
        a = np.array(scores1[category])
        b = np.array(scores2[category])
        results[category] = kappa(a, b)
    
    return results

def compute_kappa_for_files(file1: str, file2: str) -> Dict[str, float]:
    manual_labels = load_json(file1)
    llm_labels = load_pkl(file2)

    print(f'Length of manual labels: {len(manual_labels)}')
    print(f'Length of LLM labels: {len(llm_labels)}')

    # Find overlapping keys
    # overlap = set(data1.keys()) & set(data2.keys())
    # print(f'Number of shared keys: {len(overlap)}')
    # data1_overlap, data2_overlap = {}, {}
    # for key in overlap:
    #     data1_overlap[key] = data1[key]
    #     data2_overlap[key] = data2[key]



    scores_llm, is_valid_llm_output = extract_scores_llm(llm_labels)
    print(f'Number of invalid valid scores: {len(is_valid_llm_output) - sum(is_valid_llm_output)}')
    scores_manual = extract_scores_manual(manual_labels, is_valid_llm_output)
    
    return cohens_kappa(scores_llm, scores_manual)


kappa_scores = compute_kappa_for_files(file1, file2)

print("Cohen's Kappa scores for each category:")
for category, score in kappa_scores.items():
    print(f"{category}: {score:.4f}")