In [1]:
from db import execute_query
import numpy as np
from collections import defaultdict
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
import functools
from scipy.optimize import brentq

In [4]:
B = 1
alpha = 0.4
threshold = 0.4

def run_trial(use_transductive=False):
    query = f"""
    select audio_id, background_modifier_id, audio_info_id
    from audio_data 
    """

    class Audio:
        def __init__(self, row):
            self.audio_id = row[0]
            self.background_modifier = row[1]
            self.info_id = row[2]

    data = execute_query(query)

    audio_ids = np.array([Audio(d) for d in data])
    background_modifier_ids = np.array([d[1] for d in data])
    audio_info_ids = np.array([d[2] for d in data])

    labels = np.array(list(zip(background_modifier_ids, audio_info_ids)))
    audio_train, audio_test, _, _ = train_test_split(audio_ids, labels, test_size=0.2, stratify=labels)

    calibration_ids = [audio.audio_id for audio in audio_train]
    calibration_features = [audio.background_modifier for audio in audio_train]

    validation_ids = [audio.audio_id for audio in audio_test]
    validation_features = [audio.background_modifier for audio in audio_test]

    def create_mapping(data):
        mapping = {}
        idx = 0
        for item in data:
            audio_id = item[0]
            if audio_id not in mapping:
                mapping[audio_id] = idx
                idx += 1
        return mapping

    n = len(calibration_ids)
    calibration_predictions = [[] for _ in range(n)]
    calibration_confidence_scores = [[] for _ in range(n)]
    calibration_word_error_rates = [[] for _ in range(n)]
    calibration_audio_info_ids = [[] for _ in range(n)]
    c_background_modifier_ids = [[] for _ in range(n)]

    query = f"""
    select Audio_ID, Prediction, Confidence_Score, Word_Error_Rate, Audio_Info_ID, Background_Modifier_ID
    from audio_data 
    natural join audio_predictions
    where audio_id in {tuple(calibration_ids)}
    """
    rows = execute_query(query)
    mapping = create_mapping(rows)

    for row in rows:
        idx = mapping[row[0]]
        calibration_predictions[idx].append(row[1])
        calibration_confidence_scores[idx].append(row[2])
        calibration_word_error_rates[idx].append(row[3])
        calibration_audio_info_ids[idx].append(row[4])
        c_background_modifier_ids[idx].append(row[5])

    n = len(validation_ids)
    validation_predictions = [[] for _ in range(n)]
    validation_confidence_scores = [[] for _ in range(n)]
    validation_word_error_rates = [[] for _ in range(n)]
    validation_audio_info_ids = [[] for _ in range(n)]
    v_background_modifier_ids = [[] for _ in range(n)]

    query = f"""
    select Audio_ID, Prediction, Confidence_Score, Word_Error_Rate, Audio_Info_ID, Background_Modifier_ID
    from audio_data 
    natural join audio_predictions
    where audio_id in {tuple(validation_ids)}
    """
    rows = execute_query(query)
    mapping = create_mapping(rows)

    for row in rows:
        idx = mapping[row[0]]
        validation_predictions[idx].append(row[1])
        validation_confidence_scores[idx].append(row[2])
        validation_word_error_rates[idx].append(row[3])
        validation_audio_info_ids[idx].append(row[4])
        v_background_modifier_ids[idx].append(row[5])

    def flat(input_list):
        flattened_list = []
        for sublist in input_list:
            if isinstance(sublist, list):
                if sublist:  # Check if the sublist is not empty
                    flattened_list.append(sublist[0])
            else:
                flattened_list.append(sublist)
        return flattened_list

    calibration_features = flat(c_background_modifier_ids)
    validation_features = flat(v_background_modifier_ids)

    def compute_weight_schedule(calibration_features, validation_features):
        X, y = calibration_features + validation_features, [0] * len(calibration_features) + [1] * len(validation_features)
        X, y = np.array(X).reshape(-1, 1), np.array(y)

        binary_classifier = RandomForestClassifier()
        binary_classifier.fit(X, y)

        weight_fn = {}
        feature_set = set(calibration_features) | set(validation_features)
        for feature in feature_set:
            probabilities = binary_classifier.predict_proba([[feature]])[0]
            weight_fn[feature] = probabilities[1] / (1 - probabilities[1])

        weight_schedule = [weight_fn[feature] for feature in calibration_features]
        return weight_schedule, weight_fn, binary_classifier

    weights, weight_fn, classifier = compute_weight_schedule(calibration_features, validation_features)

    def cumsum_2d_list(lst):
        result = []
        for row in lst:
            cum_sum_row = []
            cum_sum = 0
            for num in row:
                cum_sum += num
                cum_sum_row.append(cum_sum)
            result.append(cum_sum_row)
        return result

    def find_first_ge_index(cumsum_list, lam):
        for row_idx, value in enumerate(cumsum_list):
            if value >= lam:
                return row_idx
        return len(cumsum_list) - 1

    def all_greater_or_equal(w, wer_target):
        for element in w:
            if element < wer_target:
                return 0
        return 1

    def c_lam(lam, smx):
        """Compute prediction set indexes using lambda"""
    #     prefix_sums = np.cumsum(smx, axis=1)
        prefix_sums = cumsum_2d_list(smx)
        threshold_indexes = [None for _ in range(len(prefix_sums))]
        for idx, row in enumerate(prefix_sums):
            threshold_idx = find_first_ge_index(row, lam)
            threshold_indexes[idx] = threshold_idx if row[threshold_idx] >= lam else len(row) - 1
        return threshold_indexes

    def loss(wers, wer_target):
        """Compute array of losses"""
        return np.array([all_greater_or_equal(w, wer_target) for w in wers])

    def losses(lam, smx, wers, wer_target, debug=False):
        """Compute array of losses given Lambda, also compute weight schedule"""
        idxs = c_lam(lam, smx)
        prediction_wers = []
        for idx, threshold_idx in enumerate(idxs):
            prediction_wers.append(wers[idx][:threshold_idx+1])

        if debug:
            total_length = 0
            for prediction_wer in prediction_wers:
                total_length += len(prediction_wer)
            print("Mean set size", total_length / len(prediction_wers))

        return loss(prediction_wers, wer_target)

    def conformal_risk_control(lam, smx, wers, wer_target, weight_schedule, test_feature_weight=None):
        """This is where conformal risk control happens"""
        n = len(smx)
        loss_values = losses(lam, smx, wers, wer_target)

        weighted_sum = 0
        for idx, loss in enumerate(loss_values):
            weighted_sum += weight_schedule[idx] * loss

        if test_feature_weight is None:
            return weighted_sum / sum(weight_schedule) - ((n+1)/n*alpha - 1/(n+1))

        return (weighted_sum + test_feature_weight * B) / (sum(weight_schedule) + test_feature_weight) - ((n+1)/n*alpha - 1/(n+1))

    def compute_lamhat(
            confidence_scores, 
            word_error_rates, 
            wer_target, 
            weight_schedule,
            test_feature_weight=None):
        """Search for value of lambda that controls the WER"""
        crc_partial = functools.partial(
            conformal_risk_control, smx=confidence_scores, wers=word_error_rates, 
            wer_target=wer_target, weight_schedule=weight_schedule,
            test_feature_weight=test_feature_weight)
        try:
            return brentq(crc_partial, 0, 1)
        except ValueError as e:
            if crc_partial(0) > 0:
                return 1
            else:
                return 0
            
    if use_transductive == False:
        lamhat = compute_lamhat(calibration_confidence_scores, calibration_word_error_rates, threshold, weights)
        ls = losses(lamhat, validation_confidence_scores, validation_word_error_rates, threshold, debug=True)
        return 1 - ls.mean()
    
    """Use the transductive approach"""
    n = len(validation_confidence_scores)
    lamhats = np.empty(n)
    ls = np.empty(n)
    set_sizes = np.empty(n)

    for idx in range(n):
        feature = validation_features[idx]
        weight = weight_fn[feature]

        vcs = np.array([validation_confidence_scores[idx]])
        vwer = np.array([validation_word_error_rates[idx]])

        lamhat = compute_lamhat(
            np.array(calibration_confidence_scores), 
            np.array(calibration_word_error_rates), 
            threshold, 
            weights,
            test_feature_weight=weight
        )

        lamhats[idx] = lamhat
        ls[idx] = losses(lamhat, vcs, vwer, threshold)
        set_sizes[idx] = c_lam(lamhat, vcs)[0] + 1
        
    print("Mean set size: ", set_sizes.mean())
    return 1 - ls.mean()

In [5]:
R = 5
results = [run_trial() for _ in range(R)]
C_hat = sum(results)/len(results)
C_hat

Mean set size 4.155315085932527
Mean set size 4.181413112667091
Mean set size 4.064926798217695
Mean set size 4.085932527052832
Mean set size 4.0299172501591345


0.6011457670273711

In [None]:
R = 5
results = [run_trial(use_transductive=True) for _ in range(R)]
C_hat = sum(results)/len(results)
C_hat