In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, cohen_kappa_score
from sklearn.metrics import accuracy_score
import os
import glob
import warnings
warnings.filterwarnings('ignore')

def load_data(file_path):
    """Carrega logits e labels de um arquivo .npz"""
    try:
        data = np.load(file_path)
        logits = data['logits']
        labels = data['labels']
        return logits, labels
    except Exception as e:
        print(f"Erro ao carregar {file_path}: {e}")
        return None, None

def logits_to_predictions(logits):
    """Converte logits em predições"""
    return np.argmax(logits, axis=1)

def load_dataset(base_path, dataset_name):
    """Carrega test set para um dataset"""

    patterns = {
        'BERT': f"bert-base-uncased_{dataset_name}_test.npz",
        'ELECTRA': f"electra-base-discriminator_{dataset_name}_test.npz",
        'RoBERTa': f"logits_roberta-base_{dataset_name}_test.npz"
    }

    folders = {
        'BERT': 'logits_google-bert',
        'ELECTRA': 'logits_electra',
        'RoBERTa': 'logits_roberta'
    }

    print(f"\n🔍 Carregando {dataset_name.upper()}...")

    all_predictions = {}
    labels = None

    for model_name, filename in patterns.items():
        folder = folders[model_name]
        file_path = os.path.join(base_path, folder, filename)

        if os.path.exists(file_path):
            logits, lbls = load_data(file_path)
            if logits is not None:
                all_predictions[model_name] = logits_to_predictions(logits)
                if labels is None:
                    labels = lbls
                print(f"✅ {model_name}: {len(logits)} amostras")
        else:
            print(f"❌ {model_name}: arquivo não encontrado")

    return all_predictions, labels

def create_agreement_matrix(pred1, pred2, labels, n_classes):
    """
    Cria matriz de concordância entre dois modelos

    Matriz 2x2:
    - [0,0]: Ambos errados
    - [0,1]: Modelo 1 errado, Modelo 2 correto
    - [1,0]: Modelo 1 correto, Modelo 2 errado
    - [1,1]: Ambos corretos
    """

    correct1 = (pred1 == labels).astype(int)
    correct2 = (pred2 == labels).astype(int)

    agreement_matrix = np.zeros((2, 2), dtype=int)

    for i in range(len(labels)):
        agreement_matrix[correct1[i], correct2[i]] += 1

    return agreement_matrix

def create_detailed_agreement_matrix(pred1, pred2, n_classes):
    """
    Cria matriz de concordância detalhada (classe por classe)
    Mostra quantas vezes cada par de predições ocorre
    """

    agreement_matrix = np.zeros((n_classes, n_classes), dtype=int)

    for i in range(len(pred1)):
        agreement_matrix[pred1[i], pred2[i]] += 1

    return agreement_matrix


def analyze_dataset_2(predictions, labels, dataset_name):
    """Analisa Kappa, Matriz de Confusão e Concordância"""

    if len(predictions) != 3 or labels is None:
        print(f"❌ Dados incompletos para {dataset_name}")
        return

    print(f"\n{'='*60}")
    print(f"ANÁLISE: {dataset_name.upper()}")
    print(f"{'='*60}")

    models = ['BERT', 'ELECTRA', 'RoBERTa']
    n_classes = len(np.unique(labels))
    class_names = [f'C{i}' for i in range(n_classes)]

    # 1. Accuracy individual
    print("ACCURACY INDIVIDUAL:")
    for model in models:
        acc = accuracy_score(labels, predictions[model])
        print(f"  {model:8}: {acc:.4f}")

    # 2. Kappa de Cohen
    print(f"\nKAPPA DE COHEN:")
    pairs = [('BERT', 'ELECTRA'), ('BERT', 'RoBERTa'), ('ELECTRA', 'RoBERTa')]

    for m1, m2 in pairs:
        kappa = cohen_kappa_score(predictions[m1], predictions[m2])
        print(f"  {m1} vs {m2:8}: {kappa:.4f}")

    # 3. Matrizes de Confusão (original)
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    fig.suptitle(f'Confusion Matrices - {dataset_name.upper()}', fontsize=14)

    for i, model in enumerate(models):
        cm = confusion_matrix(labels, predictions[model])
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=class_names, yticklabels=class_names,
                   ax=axes[i])
        axes[i].set_title(model)
        axes[i].set_xlabel('Predicted')
        axes[i].set_ylabel('True label')

    plt.tight_layout()
    plt.show()

    # 4. NOVA: Matrizes de Concordância (Correto/Incorreto)
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    fig.suptitle(f'Agreement Matrices (Correct/Incorrect) - {dataset_name.upper()}', fontsize=14)

    agreement_labels = ['Incorrect', 'Correct']

    for i, (m1, m2) in enumerate(pairs):
        agreement_matrix = create_agreement_matrix(
            predictions[m1], predictions[m2], labels, n_classes
        )

        # Calcular percentuais
        percentages = agreement_matrix / agreement_matrix.sum() * 100

        # Criar anotações com count e percentual
        annot_text = np.empty_like(agreement_matrix, dtype=object)
        for row in range(2):
            for col in range(2):
                annot_text[row, col] = f'{agreement_matrix[row, col]}\n({percentages[row, col]:.1f}%)'

        sns.heatmap(agreement_matrix, annot=annot_text, fmt='', cmap='RdYlGn',
                   xticklabels=[f'{m2}\n{label}' for label in agreement_labels],
                   yticklabels=[f'{m1}\n{label}' for label in agreement_labels],
                   ax=axes[i])
        axes[i].set_title(f'{m1} vs {m2}')

    plt.tight_layout()
    plt.show()

    # 5. NOVA: Matrizes de Concordância Detalhada (Classe por Classe)
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    fig.suptitle(f'Detailed Agreement Matrices (Class-by-Class) - {dataset_name.upper()}', fontsize=14)

    for i, (m1, m2) in enumerate(pairs):
        detailed_matrix = create_detailed_agreement_matrix(
            predictions[m1], predictions[m2], n_classes
        )

        sns.heatmap(detailed_matrix, annot=True, fmt='d', cmap='YlOrRd',
                   xticklabels=[f'{m2}\n{name}' for name in class_names],
                   yticklabels=[f'{m1}\n{name}' for name in class_names],
                   ax=axes[i])
        axes[i].set_title(f'{m1} vs {m2}')

    plt.tight_layout()
    plt.show()

    # 6. NOVA: Estatísticas de Concordância
    print(f"\nESTATÍSTICAS DE CONCORDÂNCIA:")
    for m1, m2 in pairs:
        agreement_matrix = create_agreement_matrix(
            predictions[m1], predictions[m2], labels, n_classes
        )

        total = agreement_matrix.sum()
        both_correct = agreement_matrix[1, 1]
        both_wrong = agreement_matrix[0, 0]
        total_agreement = both_correct + both_wrong

        print(f"\n  {m1} vs {m2}:")
        print(f"    Ambos corretos:   {both_correct:4d} ({both_correct/total*100:.1f}%)")
        print(f"    Ambos errados:    {both_wrong:4d} ({both_wrong/total*100:.1f}%)")
        print(f"    Concordância total: {total_agreement:4d} ({total_agreement/total*100:.1f}%)")
        print(f"    Discordância:     {total-total_agreement:4d} ({(total-total_agreement)/total*100:.1f}%)")

def analyze_dataset(predictions, labels, dataset_name):
    """Analisa Kappa e Matriz de Confusão"""

    if len(predictions) != 3 or labels is None:
        print(f"❌ Dados incompletos para {dataset_name}")
        return

    print(f"\n{'='*60}")
    print(f"ANÁLISE: {dataset_name.upper()}")
    print(f"{'='*60}")

    models = ['BERT', 'ELECTRA', 'RoBERTa']

    # 1. Accuracy individual
    print("ACCURACY INDIVIDUAL:")
    for model in models:
        acc = accuracy_score(labels, predictions[model])
        print(f"  {model:8}: {acc:.4f}")

    # 2. Kappa de Cohen
    print(f"\nKAPPA DE COHEN:")
    pairs = [('BERT', 'ELECTRA'), ('BERT', 'RoBERTa'), ('ELECTRA', 'RoBERTa')]

    for m1, m2 in pairs:
        kappa = cohen_kappa_score(predictions[m1], predictions[m2])
        print(f"  {m1} vs {m2:8}: {kappa:.4f}")

    # 3. Matrizes de Confusão
    if dataset_name in ['emotion', 'amazonpolarity']:
        fig, axes = plt.subplots(1, 3, figsize=(15, 4))
        fig.suptitle(f'Confusion Matrixes - {dataset_name.upper()}', fontsize=14)

        n_classes = len(np.unique(labels))
        class_names = [f'C{i}' for i in range(n_classes)]

        for i, model in enumerate(models):
            cm = confusion_matrix(labels, predictions[model])
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=class_names, yticklabels=class_names,
                    ax=axes[i])
            axes[i].set_title(model)
            axes[i].set_xlabel('Predicted')
            axes[i].set_ylabel('True label')

        plt.tight_layout()
        plt.show()

def main():
    # AJUSTE O CAMINHO AQUI
    base_path = "./"

    datasets = ['amazonpolarity', 'banking77', 'huffpost', 'emotion', 'clincoos']

    for dataset in datasets:
        try:
            predictions, labels = load_dataset(base_path, dataset)
            if predictions and labels is not None:
                analyze_dataset(predictions, labels, dataset)
                analyze_dataset_2(predictions, labels, dataset)
            else:
                print(f"❌ Falhou ao carregar {dataset}")
        except Exception as e:
            print(f"❌ Erro em {dataset}: {e}")

    print(f"\n✅ Análise concluída!")

if __name__ == "__main__":
    main()