In [None]:
### Define Chain(s)

from niagara import Chain, Model, ModelIntrinsicLogProb, NullTransformation, LogisticRegressionCalibrator
from niagara import OpenAIClient, FireworksClient

import os

os.environ["FIREWORKS_API_KEY"] = "leave-this-line-but-there-is-no-need-to-add-an-API-key"

llama_chain = Chain(
    models = [
        Model(
            model_name=name, 
            thresholds={"reject": -10000, "accept": 0.0},
            conf_signal=ModelIntrinsicLogProb(),
            conf_signal_transform=NullTransformation(),
            conf_signal_calibrator=LogisticRegressionCalibrator()
        )
        for name in ["llama3.2-1b", "llama3.2-3b", "llama3.1-8b", "llama3.1-70b", "llama3.1-405b"]
    ]
)

qwen_oai_chain = Chain(
    models = [
        Model(
            model_name=name, 
            thresholds={"reject": -10000, "accept": 0.0},
            conf_signal=ModelIntrinsicLogProb(),
            conf_signal_transform=NullTransformation(),
            conf_signal_calibrator=LogisticRegressionCalibrator(),
            client=client
        )
        for name, client in [("gpt-4o-mini", None), ("qwen2.5-32b-coder-instruct", None), ("qwen2.5-72b-instruct", None), ("gpt-4o", None)]
    ]
)

### Select chain, benchmark, transformation, and grab data

import numpy as np
import pickle
from niagara import OneSidedAsymptoticLog, TwoSidedAsymptoticLog
from niagara.utils import compute_ece

PRETTY_NAMES = {
    "xsum": "XSum",
    "mmlu": "MMLU",
    "medmcqa": "MedMCQA",
    "triviaqa": "TriviaQA",
    "truthfulqa": "TruthfulQA",
    "gsm8k": "GSM8K"
}

records = []

for NAME, TRANSFORM in zip(["xsum", "mmlu", "medmcqa", "triviaqa", "truthfulqa", "gsm8k"], [
    TwoSidedAsymptoticLog(), 
    OneSidedAsymptoticLog(), 
    OneSidedAsymptoticLog(), 
    TwoSidedAsymptoticLog(), 
    TwoSidedAsymptoticLog(), 
    TwoSidedAsymptoticLog()
]):
    for CHAIN_NAME, CHAIN in zip(["qwen_oai_chain", "llama_chain"], [qwen_oai_chain, llama_chain]):
        # Update the transformation for the chain
        for model in CHAIN.models:
            model.conf_signal_transform = TRANSFORM

        with open(f'../benchmarks/data/{NAME}/chain_results/{NAME}_full_{CHAIN_NAME}_results_train.pkl', 'rb') as f:
            results_train = pickle.load(f)
        with open(f'../benchmarks/data/{NAME}/chain_results/{NAME}_full_{CHAIN_NAME}_results_test.pkl', 'rb') as f:
            results_test = pickle.load(f)

        ### Compute calibrated confidence values

        process_scores = lambda scores: sum(scores.values()) >= 20

        if NAME=="xsum":
            raw_corr_train = { k: [process_scores(x) for x in v] for k,v in results_train['model_correctness'].items() }
        else:
            raw_corr_train= results_train['model_correctness']

        raw_conf_train = results_train['raw_confidences']

        corr_train = [
            raw_corr_train[model_name] for model_name in CHAIN.model_names
        ]

        transformed_conf_train = [ 
            list(TRANSFORM.transform_confidence_signal(raw_conf_train[model_name]))
                for model_name in CHAIN.model_names
        ]

        calibration_data = [
            {"correctness": corr, "transformed_confidence": conf} 
                for (corr, conf, model_name) 
                    in zip(corr_train, transformed_conf_train, CHAIN.model_names)
        ]

        CHAIN.calibrate(calibration_data)

        calibrated_conf_train = [
            list(
                CHAIN.models[model_idx].conf_signal_calibrator.calibrate_confidence_signal(
                    transformed_conf_train[model_idx]
                )
            )
            for model_idx in range(len(CHAIN.model_names))
        ]

        ### Compute test data

        if NAME=="xsum":
            raw_corr_test = { k: [process_scores(x) for x in v] for k,v in results_test['model_correctness'].items() }
        else:
            raw_corr_test= results_test['model_correctness']

        raw_conf_test = results_test['raw_confidences']

        corr_test = [
            raw_corr_test[model_name] for model_name in CHAIN.model_names
        ]

        transformed_conf_test = [ 
            list(TRANSFORM.transform_confidence_signal(raw_conf_test[model_name]))
                for model_name in CHAIN.model_names
        ]

        calibrated_conf_test = [
            list(
                CHAIN.models[model_idx].conf_signal_calibrator.calibrate_confidence_signal(
                    transformed_conf_test[model_idx]
                )
            )
            for model_idx in range(len(CHAIN.model_names))
        ]

        for model_idx, model_name in enumerate(CHAIN.model_names):
            records.append(
                {
                    "model_name": model_name,
                    "model_idx": model_idx,
                    "chain": CHAIN_NAME,
                    "benchmark": NAME,
                    "test_acc": np.mean(corr_test[model_idx]),
                    "test_ece": compute_ece(calibrated_conf_test[model_idx], corr_test[model_idx], n_bins=10)['ece'],
                    "test_frac_certain": np.mean(np.isinf(raw_conf_test[model_name]) | (np.array(raw_conf_test[model_name]) == 0.0)),
                    "test_frac_neginf": np.mean(np.isinf(raw_conf_test[model_name])),
                    "test_frac_zero": np.mean(np.array(raw_conf_test[model_name]) == 0.0),
                }
            )