## Dataset & prompt setup

In [11]:
import pandas as pd

In [12]:
NUM_DATASETS = 2

DATASET_NAMES = ['sara', 'echr']

sara_df = pd.read_csv('DATASETS/sara_annotated.csv')
echr_df = pd.read_csv('DATASETS/echr_annotated.csv')

DATASETS = [sara_df, echr_df]

DATASET_NAME_TO_DATASET = {'sara': sara_df, 'echr': echr_df}

In [13]:
NUM_PROMPTS = 9

IN_CONTEXT_PROMPT_BEGINNING = 4
THREE_SHOT_PROMPT = 5
TWO_SHOT_PROMPT = 8

PROMPT_NAMES = ['''simpleinstruction''',
                '''thinkambiguityfirst''', 
                '''impersonateexpert''', 
                '''definedambiguity''', 
                '''oneshot''', 
                '''threeshot''', 
                '''oneshotrationale''', 
                '''oneshotrationaleexplain''', 
                '''smarttwoshot''']

SYSTEM_PROMPTS = {
    'v1': ['''You will be provided with a legal statute and a fact pattern. If there is ambiguity, respond “True”; if there is no ambiguity, respond “False.”''',
           '''You will be provided with a legal statute and a fact pattern. First, define what ambiguity means in this context. Second, if there is ambiguity in the following, respond “True”; if there is no ambiguity, respond “False.”''', 
           '''Imagine you are a legal expert. You will be provided with a legal statute and a fact pattern. If there is ambiguity, respond “True”; if there is no ambiguity, respond “False.”''', 
           '''You will be provided with a legal statute and a fact pattern. Legal ambiguity is when it is unclear whether a general statute applies to a specific fact pattern; both interpretations of the statute as applying or not applying to the fact pattern are feasible. If there is legal ambiguity in whether the following statute applies to the following fact pattern, respond “True”; if there is no ambiguity in whether the following statute applies to the following fact pattern, respond “False.”''',
           '''You will be provided with multiple pairs of a legal statute and a fact pattern. If there is ambiguity, respond “True”; if there is no ambiguity, respond “False.”''', 
           '''You will be provided with multiple pairs of a legal statute and a fact pattern. If there is ambiguity, respond “True”; if there is no ambiguity, respond “False.”''', 
           '''You will be provided with multiple pairs of a legal statute and a fact pattern. If there is ambiguity, respond “True”; if there is no ambiguity, respond “False.”''',
           '''You will be provided with multiple pairs of a legal statute and a fact pattern. If there is ambiguity, respond “True”; if there is no ambiguity, respond “False.” Then, explain your reasoning.''',
           '''You will be provided with multiple pairs of a legal statute and a fact pattern. If there is ambiguity, respond “True”; if there is no ambiguity, respond “False.”'''],
    'v2': ['''You will be provided with a legal statute and a fact pattern. Your task is to identify if there is legal ambiguity in the application of the statute to the fact pattern. If there is this legal ambiguity, respond “True”; if there is not this legal ambiguity, respond “False.”''',
           '''You will be provided with a legal statute and a fact pattern. Your task is to identify if there is legal ambiguity in the application of the statute to the fact pattern. First, explain how you would determine whether such legal ambiguity exists. Second, if there is this legal ambiguity in the following, respond “True”; if there is not this legal ambiguity, respond “False.”''', 
           '''Imagine you are a legal expert. You will be provided with a legal statute and a fact pattern. Your task is to identify if there is legal ambiguity in the application of the statute to the fact pattern. If there is this legal ambiguity, respond “True”; if there is not this legal ambiguity, respond “False.”''',
           '''You will be provided with a legal statute and a fact pattern. Your task is to identify if there is legal ambiguity in the application of the statute to the fact pattern. Legal ambiguity is when it is unclear whether a general statute applies to a specific fact pattern; both interpretations of the statute as applying or not applying to the fact pattern are feasible. If there is this legal ambiguity, respond “True”; if there is not this legal ambiguity, respond “False.”''',
           '''You will be provided with multiple pairs of a legal statute and a fact pattern. Your task is to identify if there is legal ambiguity in the application of the statute to the fact pattern. If there is this legal ambiguity, respond “True”; if there is not this legal ambiguity, respond “False.”''',
           '''You will be provided with multiple pairs of a legal statute and a fact pattern. Your task is to identify if there is legal ambiguity in the application of the statute to the fact pattern. If there is this legal ambiguity, respond “True”; if there is not this legal ambiguity, respond “False.”''',
           '''You will be provided with multiple pairs of a legal statute and a fact pattern. Your task is to identify if there is legal ambiguity in the application of the statute to the fact pattern. If there is this legal ambiguity, respond “True”; if there is not this legal ambiguity, respond “False.”''',
           '''You will be provided with multiple pairs of a legal statute and a fact pattern. Your task is to identify if there is legal ambiguity in the application of the statute to the fact pattern. If there is this legal ambiguity, respond “True”; if there is not this legal ambiguity, respond “False.” Then, explain your reasoning.''',
           '''You will be provided with multiple pairs of a legal statute and a fact pattern. Your task is to identify if there is legal ambiguity in the application of the statute to the fact pattern. If there is this legal ambiguity, respond “True”; if there is not this legal ambiguity, respond “False.”''']
}

USER_PROMPTS = {
    'v1': ['''Statute: {s}

Fact pattern: {f}''',
           '''Statute: {s}

Fact pattern: {f}''', 
           '''Statute: {s}

Fact pattern: {f}''', 
           '''Statute: {s}

Fact pattern: {f}''', 
           '''Statute: {s1}
Fact pattern: {f1}
choice: True
choice: False
A: {a1}

Statute: {s}
Fact pattern: {f}
choice: True
choice: False
A: ''', 
           '''Statute: {s1}
Fact pattern: {f1}
choice: True
choice: False
A: {a1}

Statute: {s2}
Fact pattern: {f2}
choice: True
choice: False
A: {a2}

Statute: {s3}
Fact pattern: {f3}
choice: True
choice: False
A: {a3}

Statute: {s}
Fact pattern: {f}
choice: True
choice: False
A: ''', 
           '''Statute: {s1}
Fact pattern: {f1}
choice: True
choice: False
A: {a1}
Explanation: {e1}

Statute: {s}
Fact pattern: {f}
choice: True
choice: False
A: ''', 
           '''Statute: {s1}
Fact pattern: {f1}
choice: True
choice: False
A: {a1}
Explanation: {e1}

Statute: {s}
Fact pattern: {f}
choice: True
choice: False
A: 
Explanation: ''', 
           '''Statute: {s1}
Fact pattern: {f1}
choice: True
choice: False
A: {a1}

Statute: {s2}
Fact pattern: {f2}
choice: True
choice: False
A: {a2}

Statute: {s}
Fact pattern: {f}
choice: True
choice: False
A: '''], 
    'v2': ['''Statute: {s}

Fact pattern: {f}''',
           '''Statute: {s}

Fact pattern: {f}''', 
           '''Statute: {s}

Fact pattern: {f}''', 
           '''Statute: {s}

Fact pattern: {f}''', 
           '''Statute: {s1}
Fact pattern: {f1}
choice: True
choice: False
A: {a1}

Statute: {s}
Fact pattern: {f}
choice: True
choice: False
A: ''', 
           '''Statute: {s1}
Fact pattern: {f1}
choice: True
choice: False
A: {a1}

Statute: {s2}
Fact pattern: {f2}
choice: True
choice: False
A: {a2}

Statute: {s3}
Fact pattern: {f3}
choice: True
choice: False
A: {a3}

Statute: {s}
Fact pattern: {f}
choice: True
choice: False
A: ''', 
           '''Statute: {s1}
Fact pattern: {f1}
choice: True
choice: False
A: {a1}
Explanation: {e1}

Statute: {s}
Fact pattern: {f}
choice: True
choice: False
A: ''', 
           '''Statute: {s1}
Fact pattern: {f1}
choice: True
choice: False
A: {a1}
Explanation: {e1}

Statute: {s}
Fact pattern: {f}
choice: True
choice: False
A: 
Explanation: ''', 
           '''Statute: {s1}
Fact pattern: {f1}
choice: True
choice: False
A: {a1}

Statute: {s2}
Fact pattern: {f2}
choice: True
choice: False
A: {a2}

Statute: {s}
Fact pattern: {f}
choice: True
choice: False
A: ''']}

## GPT model setup

In [14]:
import openai
from openai import OpenAI

import os

In [15]:
MY_API_KEY = os.getenv('OPENAI_API_KEY')

In [16]:
def query_GPT(model, system_prompt, user_input):
    client = OpenAI(api_key=MY_API_KEY)
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_input},
    ]
    chat = client.chat.completions.create(model=model, 
                                          logprobs=False,
                                          messages=messages)
    reply = chat.choices[0].message.content
    
    message_text = 'system: ' + system_prompt + '\n\n' + 'user: ' + user_input
    return [message_text, reply]

## Pipeline

In [24]:
NUM_MODELS = 3
FINETUNED_MODEL_START = 2

MODEL_NAMES = ['gpt3.5', 'gpt4', 'finetuned']

MODELS = ['gpt-3.5-turbo-0125', 'gpt-4-0125-preview', 'ft:gpt-3.5-turbo-0125:ai4life:ambig-finetuned:8xj2T3xT']

In [25]:
import pathlib
import csv
import re
import random
from tqdm import tqdm
import pickle

In [27]:
# programmatic results processing

def clean_model_reply(reply):
    # take out punctuation, casing
    return re.sub(r'[^a-zA-Z0-9]', ' ', reply).lower()

def read_model_reply_for_bool(clean_reply):
    reply_words = clean_reply.split()
    
    # use the final "true/false" offered by the reply
    # particularly relevant when model has longer response e.g. bc asked for explanation
    ans = None
    for w in reply_words:
        if w == 'false':
            ans = False
        elif w == 'true':
            ans = True
            
    # returns bool if bool in model reply; else, returns None
    return ans

def followed_instr(prompt, clean_reply):
    num_words = len(clean_reply.split()) 
    if prompt == 'thinkambiguityfirst' or prompt == 'oneshotrationaleexplain':
        return num_words > 1 and read_model_reply_for_bool(clean_reply) != None
    
    return num_words == 1 and read_model_reply_for_bool(clean_reply) != None

In [28]:
# to generate examples for most few shot prompts, randomly pick unique datapoints other than the current from dataset
def gen_few_shot(ds, exclude_i, num_shot):
    examples = []
    
    dataset_len = len(ds)
    valid_i = [i for i in range(dataset_len) if i != exclude_i]
    
    chosen_i = random.sample(valid_i, num_shot)
    
    # returns list of quadruples (statute, fact_pattern, ambiguity_exists, reason_for_ambiguity) 
    for new_i in chosen_i:
        examples.append((ds['statute'].iloc[new_i], 
                         ds['fact_pattern'].iloc[new_i], 
                         ds['ambiguity_exists'].iloc[new_i], 
                         ds['reason_for_ambiguity'].iloc[new_i]))
    return examples

In [26]:
# to generate examples for smarttwoshot, use previously made smarttwoshot dictionary
with open('smarttwoshot.pickle', 'rb') as handle:
    TWO_SHOT_DICS = pickle.load(handle)

In [29]:
def pipeline(prompt_vers, run, ds_i_start=0, ds_i_end=NUM_DATASETS, m_i_start=0, m_i_end=NUM_MODELS, p_i_start=0, p_i_end=NUM_PROMPTS):
    
    pathlib.Path('EXPERIMENTAL_RESULTS/{v}_{r}'.format(v = prompt_vers, r = run)).mkdir(parents=True, exist_ok=True) 
    
    for ds_i in range(max(0, ds_i_start), min(NUM_DATASETS, ds_i_end)): 
        dataset = DATASETS[ds_i]
        dataset_len = len(dataset)
                
        for m_i in range(max(0, m_i_start), min(NUM_MODELS, m_i_end)):
            model = MODELS[m_i]
            
            for p_i in range(max(0, p_i_start), min(NUM_PROMPTS, p_i_end)): 
                results = []
                
                # calculate preliminary performance stats
                correct_counter = 0
                follow_counter = 0
                
                for d_i in tqdm(range(dataset_len)):
                    if m_i >= FINETUNED_MODEL_START and d_i in TRAIN_INDS[DATASET_NAMES[ds_i]]:
                        results.append(['training data point; skipped', '', 'NONE', 'NONE'])
                    
                    else: 
                        system_prompt = SYSTEM_PROMPTS[prompt_vers][p_i]

                        if p_i < IN_CONTEXT_PROMPT_BEGINNING:
                            user_prompt = USER_PROMPTS[prompt_vers][p_i].format(s = dataset['statute'].iloc[d_i], 
                                                                                f = dataset['fact_pattern'].iloc[d_i])
                        elif p_i == THREE_SHOT_PROMPT:
                            three_shot_examples = gen_few_shot(dataset, d_i, 3)
                            user_prompt = USER_PROMPTS[prompt_vers][p_i].format(s1 = three_shot_examples[0][0], 
                                                                                f1 = three_shot_examples[0][1],
                                                                                a1 = three_shot_examples[0][2],
                                                                                e1 = three_shot_examples[0][3],
                                                                                s2 = three_shot_examples[1][0], 
                                                                                f2 = three_shot_examples[1][1],
                                                                                a2 = three_shot_examples[1][2],
                                                                                e2 = three_shot_examples[1][3],
                                                                                s3 = three_shot_examples[2][0], 
                                                                                f3 = three_shot_examples[2][1],
                                                                                a3 = three_shot_examples[2][2],
                                                                                e3 = three_shot_examples[2][3],
                                                                                s = dataset['statute'].iloc[d_i], 
                                                                                f = dataset['fact_pattern'].iloc[d_i])
                        elif p_i == TWO_SHOT_PROMPT:
                            user_prompt = USER_PROMPTS[prompt_vers][p_i].format(s1 = TWO_SHOT_DICS[ds_i][d_i][0][0], 
                                                                                f1 = TWO_SHOT_DICS[ds_i][d_i][0][1],
                                                                                a1 = TWO_SHOT_DICS[ds_i][d_i][0][2],
                                                                                e1 = TWO_SHOT_DICS[ds_i][d_i][0][3],
                                                                                s2 = TWO_SHOT_DICS[ds_i][d_i][1][0], 
                                                                                f2 = TWO_SHOT_DICS[ds_i][d_i][1][1],
                                                                                a2 = TWO_SHOT_DICS[ds_i][d_i][1][2],
                                                                                e2 = TWO_SHOT_DICS[ds_i][d_i][1][3],
                                                                                s = dataset['statute'].iloc[d_i], 
                                                                                f = dataset['fact_pattern'].iloc[d_i])
                        else:
                            one_shot_example = gen_few_shot(dataset, d_i, 1)[0]
                            user_prompt = USER_PROMPTS[prompt_vers][p_i].format(s1 = one_shot_example[0],
                                                                                f1 = one_shot_example[1],
                                                                                a1 = one_shot_example[2],
                                                                                e1 = one_shot_example[3],
                                                                                s = dataset['statute'].iloc[d_i], 
                                                                                f = dataset['fact_pattern'].iloc[d_i])

                        res = query_GPT(model, system_prompt, user_prompt)
                        model_res_clean = clean_model_reply(res[1])
                        res_bool = read_model_reply_for_bool(model_res_clean)
                        
                        if res_bool != None:
                            res.append(res_bool)
                            if res_bool == dataset['ambiguity_exists'].iloc[d_i]:
                                correct_counter += 1
                        
                        res_fi = followed_instr(PROMPT_NAMES[p_i], model_res_clean)
                        res.append(res_fi)
                        if res_fi:
                            follow_counter += 1
                                                
                        results.append(res)
                
                csv_file_name = 'EXPERIMENTAL_RESULTS/{v}_{r}/{ds}_{m}_{p}.csv'.format(v = prompt_vers,
                                                                                       r = run,
                                                                                       ds = DATASET_NAMES[ds_i],
                                                                                       m = MODEL_NAMES[m_i],
                                                                                       p = PROMPT_NAMES[p_i])
                with open(csv_file_name, 'w') as csv_file:
                    csv_writer = csv.writer(csv_file)
                    csv_writer.writerow(['model_input', 'model_output', 'model_output_processed', 'followed_instr'])
                    csv_writer.writerows(results)
                    
                num_datapts = dataset_len
                if m_i >= FINETUNED_MODEL_START:
                    num_datapts -= len(TRAIN_INDS[DATASET_NAMES[ds_i]])
                    
                print('{ds} {m} {p} acc: {acc}\n'.format(ds = DATASET_NAMES[ds_i],
                                                         m = MODEL_NAMES[m_i],
                                                         p = PROMPT_NAMES[p_i],
                                                         acc = correct_counter / num_datapts))

                print('{ds} {m} {p} fi: {fi}\n'.format(ds = DATASET_NAMES[ds_i],
                                                       m = MODEL_NAMES[m_i],
                                                       p = PROMPT_NAMES[p_i],
                                                       fi = follow_counter / num_datapts))

In [31]:
pipeline('v2', 'r4')

100%|█████████████████████████████████████████| 310/310 [00:57<00:00,  5.39it/s]


sara finetunedexp simpleinstruction acc: 0.6563706563706564

sara finetunedexp simpleinstruction fi: 1.0



100%|█████████████████████████████████████████| 310/310 [03:21<00:00,  1.54it/s]


sara finetunedexp thinkambiguityfirst acc: 0.1583011583011583

sara finetunedexp thinkambiguityfirst fi: 0.2471042471042471



100%|█████████████████████████████████████████| 310/310 [01:02<00:00,  4.98it/s]


sara finetunedexp impersonateexpert acc: 0.6525096525096525

sara finetunedexp impersonateexpert fi: 0.9922779922779923



100%|█████████████████████████████████████████| 310/310 [00:58<00:00,  5.32it/s]


sara finetunedexp definedambiguity acc: 0.6254826254826255

sara finetunedexp definedambiguity fi: 1.0



100%|█████████████████████████████████████████| 310/310 [01:01<00:00,  5.08it/s]


sara finetunedexp oneshot acc: 0.6254826254826255

sara finetunedexp oneshot fi: 0.9806949806949807



100%|█████████████████████████████████████████| 310/310 [01:03<00:00,  4.88it/s]


sara finetunedexp threeshot acc: 0.6293436293436293

sara finetunedexp threeshot fi: 0.9922779922779923



100%|█████████████████████████████████████████| 310/310 [01:07<00:00,  4.58it/s]


sara finetunedexp oneshotrationale acc: 0.5598455598455598

sara finetunedexp oneshotrationale fi: 0.9845559845559846



100%|█████████████████████████████████████████| 310/310 [01:58<00:00,  2.62it/s]


sara finetunedexp oneshotrationaleexplain acc: 0.7258687258687259

sara finetunedexp oneshotrationaleexplain fi: 0.915057915057915



100%|█████████████████████████████████████████| 310/310 [00:59<00:00,  5.18it/s]


sara finetunedexp smarttwoshot acc: 0.6254826254826255

sara finetunedexp smarttwoshot fi: 1.0



100%|█████████████████████████████████████████| 500/500 [01:54<00:00,  4.37it/s]


echr finetunedexp simpleinstruction acc: 0.4222222222222222

echr finetunedexp simpleinstruction fi: 1.0



100%|█████████████████████████████████████████| 500/500 [06:33<00:00,  1.27it/s]


echr finetunedexp thinkambiguityfirst acc: 0.08888888888888889

echr finetunedexp thinkambiguityfirst fi: 0.22



100%|█████████████████████████████████████████| 500/500 [01:58<00:00,  4.24it/s]


echr finetunedexp impersonateexpert acc: 0.42

echr finetunedexp impersonateexpert fi: 0.9977777777777778



100%|█████████████████████████████████████████| 500/500 [01:57<00:00,  4.25it/s]


echr finetunedexp definedambiguity acc: 0.44666666666666666

echr finetunedexp definedambiguity fi: 0.9977777777777778



100%|█████████████████████████████████████████| 500/500 [02:51<00:00,  2.92it/s]


echr finetunedexp oneshot acc: 0.4111111111111111

echr finetunedexp oneshot fi: 0.9333333333333333



100%|█████████████████████████████████████████| 500/500 [03:00<00:00,  2.78it/s]


echr finetunedexp threeshot acc: 0.36666666666666664

echr finetunedexp threeshot fi: 0.9733333333333334



100%|█████████████████████████████████████████| 500/500 [02:33<00:00,  3.27it/s]


echr finetunedexp oneshotrationale acc: 0.4111111111111111

echr finetunedexp oneshotrationale fi: 0.98



100%|█████████████████████████████████████████| 500/500 [04:50<00:00,  1.72it/s]


echr finetunedexp oneshotrationaleexplain acc: 0.31777777777777777

echr finetunedexp oneshotrationaleexplain fi: 0.9888888888888889



100%|█████████████████████████████████████████| 500/500 [02:42<00:00,  3.07it/s]

echr finetunedexp smarttwoshot acc: 0.34444444444444444

echr finetunedexp smarttwoshot fi: 0.9622222222222222




