QUARTZ

In [1]:
from typing import List
import math
import numpy as np
from sklearn.metrics import accuracy_score

## Correct Answers

In [2]:
import pandas as pd
df = pd.read_parquet('quartz.parquet', engine='pyarrow')
answers = df['answerKey'].map(
    {'A': 0,
    'B': 1}
).to_numpy()

In [3]:
file_path = '../results/fulldepth_quartz_scores_arc_wt_finetuned.jsonl'

import json
with open(file_path, 'r') as f:
    line = f.readline().strip()
    d = json.loads(line)

d.keys()

dict_keys(['HYP', 't5_truth', 'llama_truth', 't5_faith', 'llama_faith', 'premises'])

In [4]:
d['premises'][0]['premises'][0]

{'HYP': 'Kinetic energy is a measure of the speed at which an object is moving.',
 't5_truth': '0.15759111276743043',
 'llama_truth': '0.9687583',
 't5_faith': '0.99681995379518',
 'llama_faith': '0.9462813350065744',
 'premises': [{'HYP': 'Kinetic energy is a measure of the speed with which an object is moving.',
   't5_truth': '0.16663328878398564',
   'llama_truth': '0.9953239',
   't5_faith': '0.0',
   'llama_faith': '0.0',
   'premises': []},
  {'HYP': 'Speed is a kind of measure of kinetic energy.',
   't5_truth': '0.9049609796656751',
   'llama_truth': '0.6892292',
   't5_faith': '0.0',
   'llama_faith': '0.0',
   'premises': []}]}

In [5]:
def cummul(a):
    """cumulative multiplication"""
    res = 1
    for i in a:
        res *= i
    return res


def reasoning_score(t: List, e: float, mode):
    if mode == 'm':
        return cummul(t) * e
    elif mode == 'gmt':
        return math.pow(cummul(t), 1 / len(t)) * e
    else:
        return math.pow(cummul(t) * e, 1 / (1 + len(t)))

In [6]:
modes = {
'ttm': "t5-truthfulness, t5-faithfullness, all direct multiplication",
'ttgmt': "t5-truthfulness, t5-faithfullness, faithfullness * gm of truthfullness",
'ttgm': "t5-truthfulness, t5-faithfullness, gm of both truthfulness and faithfulness",
'ltm': "llama-truthfulness, t5-faithfullness, all direct multiplication",
'ltgmt': "llama-truthfulness, t5-faithfullness, faithfullness * gm of truthfullness",
'ltgm': "llama-truthfulness, t5-faithfullness, gm of both truthfulness and faithfulness",
}

In [7]:
ttvals, ltvals, tfvals, lfvals = [], [], [], [] 

def extract_values(d):
    tt = float(d['t5_truth'])
    lt = float(d['llama_truth'])
    tf = float(d['t5_faith'])
    lf = float(d['llama_faith'])
    ttvals.append(tt)
    ltvals.append(lt)
    if d['premises']:
        tfvals.append(tf)
        lfvals.append(lf)
        for p in d['premises']:
            extract_values(p)

with open(file_path, 'r') as f:
    for line in f:
        d = json.loads(line.strip())
        extract_values(d)

In [8]:
ttvals, ltvals, tfvals, lfvals = np.array(ttvals), np.array(ltvals), np.array(tfvals), np.array(lfvals)

ttvals.shape, ltvals.shape, tfvals.shape, lfvals.shape

((24994,), (24994,), (11500,), (11500,))

In [9]:
def sig_inv(x):
    return np.log(x / (1 - x))

def sigmoid(x):
    return 1 / ( 1 + np.exp(-x))

logits_tt = sig_inv(ttvals)
logits_tf = sig_inv(tfvals)
logits_lt = sig_inv(ltvals)
logits_lf = sig_inv(lfvals)

In [10]:
from scipy.optimize import minimize


class lin_transform:
    def fit(self, x, y):
        def objective(params, arr1, arr2):
            a, b = params
            transformed_arr1 = a * arr1 + b
            return np.sum((transformed_arr1 - arr2)**2)

        # Initial guess for a and b
        initial_guess = [1, 0]

        # Optimization
        result = minimize(objective, initial_guess, args=(x, y))

        # Extract optimized constants
        self.a, self.b = result.x

    def transform(self, x):
        return x
        return self.a * x + self.b


transform_t = lin_transform()
transform_t.fit(logits_lt, logits_tt)
transform_f = lin_transform()
transform_f.fit(logits_lf, logits_tf)

In [11]:
transform_t.a, transform_t.b

(0.9349988171757933, -0.3152830293672452)

In [12]:
transform_t.transform(logits_lt).mean(), logits_tt.mean()

(2.7075044865203868, 2.216230456932593)

# Entailer + Direct

def get_scores(d, mode, depth):
    def direct_score(d, mode):
        if mode[0] == 't': # t5
            score = float(d['t5_truth'])
        else:
            score = float(d['llama_truth'])
            logit = sig_inv(score)
            transformed_logit = transform_t.transform(logit)
            score = sigmoid(transformed_logit)
        return score

    def entailment_score(d, mode):
        if mode[1] == 't':
            entail_score = float(d['t5_faith'])
        else:
            entail_score = float(d['llama_faith'])
            logit = sig_inv(entail_score)
            transformed_logit = transform_f.transform(logit)
            entail_score = sigmoid(transformed_logit)
        return float(entail_score)


    sd = direct_score(d, mode)
    if depth == 0:
        return sd
    cd = max(sd, 1-sd)
    se = entailment_score(d, mode)

    if se > cd:
        p_scores = [get_scores(p, mode, depth-1) for p in d['premises']]
        sr = reasoning_score(p_scores, se, mode[2:])
    else:
        sr = 0
    return max(sr, sd)


In [13]:
# Entailer

def get_scores(d, mode, depth):
    def direct_score(d, mode):
        if mode[0] == 't': # t5
            score = float(d['t5_truth'])
        else:
            score = float(d['llama_truth'])
            logit = sig_inv(score)
            transformed_logit = transform_t.transform(logit)
            score = sigmoid(transformed_logit)
        return score

    def entailment_score(d, mode):
        if mode[1] == 't':
            entail_score = float(d['t5_faith'])
        else:
            entail_score = float(d['llama_faith'])
            logit = sig_inv(entail_score)
            transformed_logit = transform_f.transform(logit)
            entail_score = sigmoid(transformed_logit)
        return entail_score


    sd = direct_score(d, mode)
    if depth == 0:
        return sd
    se = entailment_score(d, mode)

    p_scores = [get_scores(p, mode, depth-1) for p in d['premises']]
    sr = reasoning_score(p_scores, se, mode[2:])
    return sr


# Entailer + direct: Root with T5, rest with llama for truth

def get_scores(d, mode, depth, use_t5=True):
    def direct_score(d, mode):
        if use_t5:
            score = float(d['t5_truth'])
        else:
            score = float(d['llama_truth'])
            logit = sig_inv(score)
            transformed_logit = transform_t.transform(logit)
            score = sigmoid(transformed_logit)
        return score

    def entailment_score(d, mode):
        entail_score = float(d['t5_faith'])
        return entail_score


    sd = direct_score(d, mode)
    if depth == 0:
        return sd
    cd = max(sd, 1-sd)
    se = entailment_score(d, mode)

    if se > cd:
        p_scores = [get_scores(p, mode, depth-1, False) for p in d['premises']]
        sr = reasoning_score(p_scores, se, mode[2:])
    else:
        sr = 0
    return max(sr, sd)

modes = ['_tm', '_tgmt', '_tgm']

In [16]:
df = pd.DataFrame(np.zeros((784, 3)), columns=['ttm', 'ltgm', 'correct'])

for depth in range(4):
    print('$'*10)
    print(depth)
    for mode in modes:
        pred_ans = []
        print(mode, end=': ')
        with open(file_path, 'r') as f:
            dicts = []
            for line in f:
                d = json.loads(line.strip())
                dicts.append(d)
                
                if len(dicts) >= 2:
                    scores = [get_scores(d, mode, depth) for d in dicts]
                    pred_ans.append(np.argmax(scores))     
                    dicts = []

        print(f'{accuracy_score(answers, pred_ans):.4f}')    
        df[mode] = pred_ans == answers
df['correct'] = answers

$$$$$$$$$$
0
ttm: 0.7895
ttgmt: 0.7895
ttgm: 0.7895
ltm: 0.7130
ltgmt: 0.7130
ltgm: 0.7130
$$$$$$$$$$
1
ttm: 0.7602
ttgmt: 0.7883
ttgm: 0.7602
ltm: 0.8061
ltgmt: 0.8227
ltgm: 0.7997
$$$$$$$$$$
2
ttm: 0.7449
ttgmt: 0.7959
ttgm: 0.7883
ltm: 0.7768
ltgmt: 0.8151
ltgm: 0.8074
$$$$$$$$$$
3
ttm: 0.7143
ttgmt: 0.7908
ttgm: 0.7972
ltm: 0.7334
ltgmt: 0.8048
ltgm: 0.8125


In [15]:
df[df['ttm'] == False][df['ltgm'] == True].head(50)

  df[df['ttm'] == False][df['ltgm'] == True].head(50)


Unnamed: 0,ttm,ltgm,correct,ttgmt,ttgm,ltm,ltgmt
24,False,True,0,False,False,True,True
39,False,True,1,False,False,True,True
41,False,True,1,False,False,True,True
71,False,True,1,True,False,True,True
108,False,True,1,False,False,True,False
114,False,True,1,False,False,True,False
130,False,True,0,True,True,False,True
138,False,True,1,True,False,True,True
148,False,True,0,True,False,True,True
157,False,True,1,True,False,True,True
