# OBQA

In [1]:
from typing import List
import math
import numpy as np
from sklearn.metrics import accuracy_score

## Correct Answers

In [2]:
import pandas as pd
df = pd.read_parquet('obqa.parquet', engine='pyarrow')
answers = df['answerKey'].map(
    {'A': 0,
    'B': 1,
    'C': 2,
    'D': 3}
).to_numpy()

In [3]:
df['question_stem'].iloc[0]

'A person wants to start saving money so that they can afford a nice vacation at the end of the year. After looking over their budget and expenses, they decide the best way to save money is to'

In [4]:
file_path = '../results/fulldepth_obqa_scores_finetuned.jsonl'

import json
with open(file_path, 'r') as f:
    f.readline()
    f.readline()
    f.readline()
    line = f.readline().strip()
    d = json.loads(line)

d.keys()

dict_keys(['HYP', 't5_truth', 'llama_truth', 't5_faith', 'llama_faith', 'premises'])

In [5]:
d['HYP']

'A person wants to start saving money so that they can afford a nice vacation at the end of the year. After looking over their budget and expenses, they decide the best way to save money is to have lunch with friends'

In [6]:
d['premises'][0]['premises'][0]

{'HYP': 'A person wants to start saving money for a vacation at the end of the year.',
 't5_truth': '0.9157349103157619',
 'llama_truth': '0.4754670262336731',
 't5_faith': '0.8787753636560861',
 'llama_faith': '0.3538497759241288',
 'premises': [{'HYP': 'A person wants to start saving money for a vacation.',
   't5_truth': '0.9726121919390229',
   'llama_truth': '0.94086844',
   't5_faith': '0.0',
   'llama_faith': '0.0',
   'premises': []},
  {'HYP': 'The end of the year is when most people take vacations.',
   't5_truth': '0.2421582384858808',
   'llama_truth': '0.2950717806816101',
   't5_faith': '0.0',
   'llama_faith': '0.0',
   'premises': []}]}

In [7]:
def cummul(a):
    """cumulative multiplication"""
    res = 1
    for i in a:
        res *= i
    return res


def reasoning_score(t: List, e: float, mode):
    if mode == 'm':
        return cummul(t) * e
    elif mode == 'gmt':
        return math.pow(cummul(t), 1 / len(t)) * e
    else:
        return math.pow(cummul(t) * e, 1 / (1 + len(t)))

In [8]:
modes = {
'ttm': "t5-truthfulness, t5-faithfullness, all direct multiplication",
'ttgmt': "t5-truthfulness, t5-faithfullness, faithfullness * gm of truthfullness",
'ttgm': "t5-truthfulness, t5-faithfullness, gm of both truthfulness and faithfulness",
'ltm': "llama-truthfulness, t5-faithfullness, all direct multiplication",
'ltgmt': "llama-truthfulness, t5-faithfullness, faithfullness * gm of truthfullness",
'ltgm': "llama-truthfulness, t5-faithfullness, gm of both truthfulness and faithfulness",
}

In [9]:
ttvals, ltvals, tfvals, lfvals = [], [], [], [] 

def extract_values(d):
    tt = float(d['t5_truth'])
    lt = float(d['llama_truth'])
    tf = float(d['t5_faith'])
    lf = float(d['llama_faith'])
    ttvals.append(tt)
    ltvals.append(lt)
    if d['premises']:
        tfvals.append(tf)
        lfvals.append(lf)
        for p in d['premises']:
            extract_values(p)

with open(file_path, 'r') as f:
    for line in f:
        d = json.loads(line.strip())
        extract_values(d)

In [10]:
ttvals, ltvals, tfvals, lfvals = np.array(ttvals), np.array(ltvals), np.array(tfvals), np.array(lfvals)

ttvals.shape, ltvals.shape, tfvals.shape, lfvals.shape

((30441,), (30441,), (14232,), (14232,))

In [11]:
def sig_inv(x):
    return np.log(x / (1 - x))

def sigmoid(x):
    return 1 / ( 1 + np.exp(-x))

logits_tt = sig_inv(ttvals)
logits_tf = sig_inv(tfvals)
logits_lt = sig_inv(ltvals)
logits_lf = sig_inv(lfvals)

In [12]:
from scipy.optimize import minimize


class lin_transform:
    def fit(self, x, y):
        def objective(params, arr1, arr2):
            a, b = params
            transformed_arr1 = a * arr1 + b
            return np.sum((transformed_arr1 - arr2)**2)

        # Initial guess for a and b
        initial_guess = [1, 0]

        # Optimization
        result = minimize(objective, initial_guess, args=(x, y))

        # Extract optimized constants
        self.a, self.b = result.x

    def transform(self, x):
        # return x
        return self.a * x + self.b


transform_t = lin_transform()
transform_t.fit(logits_lt, logits_tt)
transform_f = lin_transform()
transform_f.fit(logits_lf, logits_tf)

In [13]:
print(transform_t.a, transform_t.b)

0.8533741243931974 1.741459361583275


In [14]:
transform_t.transform(logits_lt).mean(), logits_tt.mean()

(3.633535919425435, 3.633535879180158)

# Entailer + Direct

def get_scores(d, mode, depth):
    def direct_score(d, mode):
        if mode[0] == 't': # t5
            score = float(d['t5_truth'])
        else:
            score = float(d['llama_truth'])
            logit = sig_inv(score)
            transformed_logit = transform_t.transform(logit)
            score = sigmoid(transformed_logit)
        return score

    def entailment_score(d, mode):
        if mode[1] == 't':
            entail_score = float(d['t5_faith'])
        else:
            entail_score = float(d['llama_faith'])
            logit = sig_inv(entail_score)
            transformed_logit = transform_f.transform(logit)
            entail_score = sigmoid(transformed_logit)
        return float(entail_score)


    sd = direct_score(d, mode)
    if depth == 0:
        return sd
    cd = max(sd, 1-sd)
    se = entailment_score(d, mode)

    if se > cd:
        p_scores = [get_scores(p, mode, depth-1) for p in d['premises']]
        sr = reasoning_score(p_scores, se, mode[2:])
    else:
        sr = 0
    return max(sr, sd)


In [15]:
# Entailer

def get_scores(d, mode, depth):
    def direct_score(d, mode):
        if mode[0] == 't': # t5
            score = float(d['t5_truth'])
        else:
            score = float(d['llama_truth'])
            logit = sig_inv(score)
            transformed_logit = transform_t.transform(logit)
            score = sigmoid(transformed_logit)
        return score

    def entailment_score(d, mode):
        if mode[1] == 't':
            entail_score = float(d['t5_faith'])
        else:
            entail_score = float(d['llama_faith'])
            logit = sig_inv(entail_score)
            transformed_logit = transform_f.transform(logit)
            entail_score = sigmoid(transformed_logit)
        return entail_score


    sd = direct_score(d, mode)
    if depth == 0:
        return sd
    se = entailment_score(d, mode)

    p_scores = [get_scores(p, mode, depth-1) for p in d['premises']]
    sr = reasoning_score(p_scores, se, mode[2:])
    return sr


# Entailer + direct: Root with T5, rest with llama for truth

def get_scores(d, mode, depth, use_t5=True):
    def direct_score(d, mode):
        if use_t5:
            score = float(d['t5_truth'])
        else:
            score = float(d['llama_truth'])
            logit = sig_inv(score)
            transformed_logit = transform_t.transform(logit)
            score = sigmoid(transformed_logit)
        return score

    def entailment_score(d, mode):
        entail_score = float(d['t5_faith'])
        return entail_score


    sd = direct_score(d, mode)
    if depth == 0:
        return sd
    cd = max(sd, 1-sd)
    se = entailment_score(d, mode)

    if se > cd:
        p_scores = [get_scores(p, mode, depth-1, False) for p in d['premises']]
        sr = reasoning_score(p_scores, se, mode[2:])
    else:
        sr = 0
    return max(sr, sd)


modes = ['_tm', '_tgmt', '_tgm']

In [16]:

for depth in range(4):
    print('$'*10)
    print(depth)
    for mode in modes:
        pred_ans = []
        print(mode, end=': ')
        with open(file_path, 'r') as f:
            dicts = []
            for line in f:
                d = json.loads(line.strip())
                dicts.append(d)
                
                if len(dicts) >= 4:
                    scores = [get_scores(d, mode, depth) for d in dicts]
                    pred_ans.append(np.argmax(scores))     
                    dicts = []
                    
        print(accuracy_score(answers, pred_ans))    


$$$$$$$$$$
0
ttm: 0.74
ttgmt: 0.74
ttgm: 0.74
ltm: 0.618
ltgmt: 0.618
ltgm: 0.618
$$$$$$$$$$
1
ttm: 0.63
ttgmt: 0.65
ttgm: 0.628
ltm: 0.612
ltgmt: 0.658
ltgm: 0.62
$$$$$$$$$$
2
ttm: 0.538
ttgmt: 0.608
ttgm: 0.596
ltm: 0.552
ltgmt: 0.622
ltgm: 0.604
$$$$$$$$$$
3
ttm: 0.492
ttgmt: 0.61
ttgm: 0.614
ltm: 0.57
ltgmt: 0.658
ltgm: 0.666
