In [102]:
import os
import functools
import numpy as np
from scipy.optimize import brentq

In [77]:
root_data_path = os.path.join("..", "2023-12-04-HMI-dataset-predictions")

def load_data_paths(root_folder=root_data_path):
    dataset = []
    for folder in os.listdir(root_folder):
        folder_path = os.path.join(root_folder, folder)
        files = [os.path.join(folder_path, file) for file in os.listdir(folder_path)]
        dataset.extend(files)
    return dataset

In [78]:
def load_data(paths):
    data = []
    for path in paths:
        with open(path, "r") as f:
            lines = [line.strip() for line in f.readlines()]
            data.append(lines)
    return data

def load_numeric_data(paths):
    data = load_data(paths)
    return np.array(data, dtype=float)

In [234]:
def load_paths(speakers=[], modifiers=[], root_folder=root_data_path):
    datapaths, res = load_data_paths(root_folder), []
    for path in datapaths:
        speaker_filter = any(speaker in path for speaker in speakers) if speakers else True
        modifier_filter = any(modifier in path for modifier in modifiers) if modifiers else True
        if speaker_filter and modifier_filter:
            res.append(path)
    return res

def load_datasets(speakers=[], modifiers=[], root_folder=root_data_path):
    """Load prediction paths for a specific speaker"""
    data_paths = sorted(load_paths(speakers, modifiers, root_folder))
    smx_paths = [path for path in data_paths if "scores" in path]
    wers_paths = [path for path in data_paths if "wer" in path]
    pred_paths = [path for path in data_paths if "pred" in path]
    
    smx = load_numeric_data(smx_paths)
    wers = load_numeric_data(wers_paths)
    preds = load_data(pred_paths)
    
    return smx, wers, preds

In [238]:
smx, wers, preds = load_datasets(speakers=["agtbv"])

In [260]:
threshold = 0.05
n = smx.shape[0]
alpha = 0.95
B = 1

In [225]:
def c_lam(lam, smx, wers):
    """Compute prediction set indexes using lambda"""
    prefix_sums = np.cumsum(smx, axis=1)
    threshold_indexes = np.zeros(prefix_sums.shape[0], dtype=int)
    for idx, row in enumerate(prefix_sums):
        threshold_idx = np.argmax(row >= lam)
        threshold_indexes[idx] = threshold_idx if row[threshold_idx] >= lam else row.shape[0] - 1
    return threshold_indexes

In [226]:
def wer_empirical_risk(word_error_rates, wer_target):
    """Compute the empirical risk using WER as loss function on the calibration data"""
    loss = [int(np.all(w >= wer_target)) for w in word_error_rates]
    return sum(loss)/len(loss)

In [244]:
def crc_util(lam, smx, wers, wer_target):
    """Construct prediction set and compute empirical risk"""
    idxs = c_lam(lam, smx, wers)
    prediction_wers = []
    for idx, threshold_idx in enumerate(idxs):
        prediction_wers.append(wers[idx][:threshold_idx])
    return wer_empirical_risk(prediction_wers, threshold)

def conformal_risk_control(lam, smx, wers, wer_target):
    """Conformal Risk Control happens here"""
    return crc_util(lam, smx, wers, wer_target) - ((n+1)/n*alpha - 1/(n+1))

def compute_lamhat(smx, wers, wer_target):
    """Search for value of lambda that controls the WER"""
    crc_partial = functools.partial(conformal_risk_control, smx=smx, wers=wers, wer_target=wer_target)
    return brentq(crc_partial, 0, 1)

In [261]:
lam_hat = compute_lamhat(smx, wers, threshold)
lam_hat

0.027206686999896867

In [262]:
smxh, wersh, predsh = load_datasets(speakers=["lgong"], modifiers=["park"])
crc_util(lam_hat, smxh, wersh, threshold)

0.9