In [508]:
# Process datasets into prompts

# Datasets:
# Curation
# Amazon prod descs: It's the McAuley data (the metadata part)
# Wikihow
# FactualNLG
# DialogSum


import jsonlines, json
import pandas as pd
import gcsfs
import numpy as np

np.random.seed(123)

SUMMARIZE_PROMPT = "{:}\n\nGenerate a summary:\n"

FACTUALNLG_PROMPT = """Given the document below, determine if the summary is factually consistent with the document. 
Document: {:}
Summary: {:}
Is the summary consistent with the document? First provide detailed reasoning, then answer "Yes" or "No".
Reasoning:"""

DIALOGSUM_PROMPT = "Conversation:\n{:}\n\nNow write a summary of the conversation:\n"

WIKIHOW_PROMPT = "Can you tell me {:}?"



list_of_names = ["James", "Robert", "Michael", "William", "David", "Richard", "Adam", "Charles", "Jonathan", "Daniel", "Alexander", "Mary", "Susan", "Jennifer", "Margaret", "Ruth", "Sharon", "Barbara", "Elizabeth", "Paula", "Karol", "Melissa", "Louise", "Rachel", "Zero", "Nova", "Cohere", "Deenah", "Anna", "Olivia", "Woodrow", "Jennette", "Terence", "Skylar", "Karen", "Marie", "Blake", "Resnick", "Sherlock", "Sean", "Matt", "Jessica", "Hannah", "Grace", "Catherine", "Frida", "Sebastian", "Simon", "Alice", "Jose", "Marcos", "Barbara", "Jennifer", "Mark", "Traci", "Sam", "Ben", "Jack", "Jane", "Joseph", "Max", "Jennifer", "Lisa", "Barbara", "Jose", "Marcos", "Charles", "Marion"] 

def process_amazon(path, limit, cfg):

    gs = gcsfs.GCSFileSystem(
    )

    data = pd.read_parquet(path)
    
    return list(data.prompt[:limit]), list(data.completion[:limit])

def process_macsum(path, limit, cfg):
    with open(path) as f:
        macdoc = json.load(f)[:limit]
        
    return [SUMMARIZE_PROMPT.format("\n".join(row['source'])) for row in macdoc]

def process_curation(path, limit, cfg):
    exclude_lens = [898, 792]
    exclude_strs = ['TRY LAW360 FREE FOR SEVEN DAYS']
    def filter_art(art, summ):
        if len(art) in exclude_lens:
            return False
        for exclude_str in exclude_strs:
            if exclude_str in art:
                return False
        return True
    
    curation = pd.read_csv(path)
    
    articles = curation.prompt
    summs = curation.completion
    
    samples = [(SUMMARIZE_PROMPT.format(art), summ) for art, summ in zip(articles, summs) if len(art) > cfg['min_chars'] and len(art) < cfg['max_chars'] and filter_art(art, summ)]
    np.random.shuffle(samples)
    
    return list(zip(*samples[:limit]))

def process_wikihow(path, limit, cfg):
    wikihow = pd.read_csv(path)
    
    titles = [title for title in wikihow.title.unique() if isinstance(title, str) and not title[-1].isnumeric()]
    np.random.shuffle(titles)
    titles = titles[:limit]
    
    refs = ["".join(wikihow.where(wikihow.title == title).dropna().headline.unique()) for title in titles]
    
    
    prompts =  [WIKIHOW_PROMPT.format(title.lower()) for title in titles ]
    return prompts, refs

def process_dialogsum(path, limit, cfg):
    samples = []
    with jsonlines.open(path) as reader:
        for row in reader:
            names = np.random.choice(list_of_names, size=3)
            if "When she told me that she would marry no man but Dick" in row['dialogue']:
                continue
            prompt = DIALOGSUM_PROMPT.format(row['dialogue']).replace('#Person1#', names[0]).replace('#Person2#', names[1]).replace('#Person3#', names[2])
            summ = row['summary1'].replace('#Person1#', names[0]).replace('#Person2#', names[1]).replace('#Person3#', names[2])
            samples.append((prompt, summ))
            if len(samples) >= limit:
                return list(zip(*samples))
            
def process_factualnlg(path, limit, cfg):
    with open(path) as f:
        data = json.load(f)
    prompts = np.random.choice([FACTUALNLG_PROMPT.format(row['doc'], row['summary']) for row in data if row['label'] == 1 or 'entity_modification' in row['edit_types']], limit)
    refs = [None] * len(prompts)
    return list(prompts), refs

configs = {
    'amazon': {
        'input_data': AMAZON_DIR,
        'processor': process_amazon,
        'limit': 150
    },
    'curation': {
#         'input_data': '../granular-eval/data/curation-corpus/curation-corpus-articles.csv',
        'input_data': CURATION_DIR,
        'processor': process_curation,
        'limit': 150,
        "min_chars": 700,
        "max_chars": 2500
    },
    'wikihow': {
        'input_data': '../granular-eval/data/wikihow/wikihowSep.csv',
        'processor': process_wikihow,
        'limit': 150
    },
#     "dialogsum": {
#         'input_data': '../granular-eval/data/dialogsum/DialogSum_Data/dialogsum.test.jsonl',
#         'processor': process_dialogsum,
#         'limit': 150
#     },
#     "factualnlg": {
#         'input_data': '../granular-eval/data/factualnlg/data/summedits/summedits_news.json',
#         'processor': process_factualnlg,
#         'limit': 150
#     }
}


#     'macsum': {
#         'input_data': '../granular-eval/data/MACSum/dataset/macdoc/test.json',
#         'processor': process_macsum,
#         'limit': 10
#     },



In [509]:
prompts = []
refs = []
for slug, cfg in configs.items():
    curr_prompts, curr_refs = cfg['processor'](cfg['input_data'], cfg['limit'], cfg)
    for ix, (prompt, ref) in enumerate(zip(curr_prompts, curr_refs)):
        sample_obj = {
            'dataset': slug,
            'prompt': prompt,
            'sample_ix': ix
        }
        prompts.append(sample_obj)
        refs.append({**sample_obj, 'model': 'refs', 'response': ref})

In [511]:
with jsonlines.open('../results/batch_v1/prompts_v2.jsonl','w') as writer:
    writer.write_all(prompts)
with jsonlines.open('../results/batch_v1/output_v2_refs.jsonl','w') as writer:
    writer.write_all(refs)

In [101]:
# Generate Cohere continuations from given input file

import jsonlines

with jsonlines.open('../results/batch_controlled/prompts_augmented_v2.jsonl') as reader:
    prompts = list(reader)

import cohere
from cohere import CohereAPIError, CohereConnectionError

API_KEY = os.environ.get('COHERE_API_KEY')
  
co = cohere.Client(API_KEY)

responses_command = []



responses_command = co.batch_generate(
    model = 'command-light-nightly',
#         model = 'command-nightly',
    prompts=[p['prompt'].strip() for p in prompts],
    return_exceptions=True,
    max_tokens=1000,
    temperature=0.7,
    num_generations=1,
)


if len([resp for resp in responses_command if isinstance(resp, CohereError) or isinstance(resp, CohereConnectionError)]) > 0:
    print('Some samples failed! You may want to try again')
    
    


# output_command = [{**sample_obj, 'model': 'command_52B_20230607', 'response': (resp[0].text if not isinstance(resp, CohereAPIError) else None)} for sample_obj, resp in zip(prompts, responses_command)]
output_command = [{**sample_obj, 'model': 'command_6Bpref_v14.7_20230817', 'response': (resp[0].text if not isinstance(resp, CohereError) else None)} for sample_obj, resp in zip(prompts, responses_command)]
      
with jsonlines.open('../results/batch_controlled/output_augmented_v2_command6Bpref.jsonl','w') as writer:
    writer.write_all(output_command)
    
print(len(prompts), len(output_command))

In [37]:
# Check dataset count

from collections import Counter

dataset_counter = Counter()

for row in prompts:
    dataset_counter[row['dataset']] += 1
    
dataset_counter



Counter({'curation': 150, 'wikihow': 150})

In [95]:
# Check for duped inputs (FactualNLG has some)

with jsonlines.open('../results/batch_v1/prompts_merged.jsonl') as reader:
    prompts = list(reader)
    
prompt_count = Counter()

for row in prompts:
#     if row['dataset'] == 'factualnlg':
    prompt_count[row['prompt']] +=1
    
duped_prompts = [prompt for prompt, count in prompt_count.items() if count > 1]


11

In [100]:
# Merge outputs from multiple prompt batches
# all_outputs_merged = []

for system in systems:
    sys_output_merged = []
    with jsonlines.open(f'../results/batch_v1/archive/output_{system}.jsonl') as reader:
        rows = list(reader)
    rows = [row if 'model' in row else {'model': system, **row} for row in rows]
    
    for row in rows:
        if row['dataset'] in datasets_to_skip_v1:
            continue
        if row['response'] is not None:
            sys_output_merged.append(row)
            
    with jsonlines.open(f'../results/batch_v1/archive/output_v2_{system}.jsonl') as reader:
        rows = list(reader)
    rows = [row if 'model' in row else {'model': system, **row} for row in rows]
    
    for row in rows:
        if row['dataset'] not in datasets_to_skip_v1:
            continue
        if row['response'] is not None:
            sys_output_merged.append(row)
            
    with jsonlines.open(f'../results/batch_v1/output_merged_{system}.jsonl', 'w') as writer:
        writer.write_all(sys_output_merged)
    print(system, len(sys_output_merged))
        
        
sys_output_merged = []

with jsonlines.open(f'../results/batch_v1/prompts.jsonl') as reader:
    rows = list(reader)
# rows = [row if 'model' in row else {'model': system, **row} for row in rows]

for row in rows:
    if row['dataset'] in datasets_to_skip_v1:
        continue
#     if row['response'] is not None:
    sys_output_merged.append(row)

with jsonlines.open(f'../results/batch_v1/prompts_v2.jsonl') as reader:
    rows = list(reader)
# rows = [row if 'model' in row else {'model': system, **row} for row in rows]

for row in rows:
    if row['dataset'] not in datasets_to_skip_v1:
        continue
#     if row['response'] is not None:

    sys_output_merged.append(row)

with jsonlines.open(f'../results/batch_v1/prompts_merged.jsonl', 'w') as writer:
    writer.write_all(sys_output_merged)
    
print(len(sys_output_merged))

command52B 750
command6B 750
refs 600
falcon40 750


FileNotFoundError: [Errno 2] No such file or directory: '../granular-eval/batch_v1/archive/output_mpt30instruct.jsonl'

In [299]:
# Generate batches to be annotated

# 

import jsonlines
from collections import defaultdict
from itertools import combinations, product
from math import ceil
import random
import csv
import numpy as np

random.seed(12)
np.random.seed(12)


#### Settings begins

systems = [
#     'command52B',
#     'command6B',
    'command52Bpref',
    'command6Bpref',
    #     'refs',
    'falcon40',
    'mpt30instruct',
#     'mpt7',
#     'vicuna30uncensored',
]

# Part I - non augmented, used refs for distractors
# Main batches started at offset=3:
# pilot v5: offset 3, n=1 (inc 2x repeats)
# pilot v6: offset 4, n=1 (inc 4x repeats) -> 70
# main 1:   offset 5, n=19(inc 4x repeats) -> 1330
# main 2:   offset 24, n=20 (inc 0 repeats) -> 1000
# main 3:   offset 44, n=20 (inc 0 repeats) -> 1000
# main 4:   offset 64, n=20 (inc 0 repeats) -> 1000

# Part II - augmented, used all models for distractors
# controlled pilot 1: offset 100, n=1, inc 0 repeats # bug: repeated pairing
# controlled main 1: offset 100, n=25, inc 4 repeats -> 1125 = (10 combos+1distractor + 4repeats) * 3 datasets * 25
# controlled main 2: offset 125, n=25, 0 repeats -> (10+1) * 3 * 25 = 825
# controlled main 3: offset 75, n=25, 0 repeats -> (10+1) * 3 * 25 = 825
# controlled main 3: offset 50, n=25, 0 repeats -> (10+1) * 3 * 25 = 825

NUM_SAMPLES = 25 #NOTE: use multiples of 5 to ensure each model/preamble type is balanced
OFFSET = 50
NUM_PAIRS_PER_INSTANCE=5
NUM_ANNOTATOR_AGREEMENT_SAMPLES = 0 # How many times to duplicate one of the rows, for agreement checking?
BATCH_NAME = 'confound_main_4'

USE_REFS_MODEL_FOR_DISTRACTORS = False # Set to True to use the same model for distractors each time
PROMPT_KEY = 'original_prompt' # set to 'prompt' for non-augmented data
REFS_MODEL = 'command_52Bpref_v14.7_20230817' # set to 'refs' for non augmented data

# datasets_to_skip_v1 = ['wikihow','curation']

#### Settings Ends

def normalise_response(resp):
    return resp.replace('<s>','').replace('</s>','').replace('"','\"').strip()


samples_by_dataset_by_sampleix = defaultdict(lambda: defaultdict(list))
all_outputs_merged = []

# for system in systems:
#     prompts_so_far = set()
#     with jsonlines.open(f'../results/batch_v1/output_merged_{system}.jsonl') as reader:

with jsonlines.open(f'../results/batch_controlled/output_augmented_merged.jsonl') as reader:
    rows = list(reader)
rows = [row if 'model' in row else {'model': system, **row} for row in rows]

for row in rows:
    if row['response'] is not None:
        samples_by_dataset_by_sampleix[row['dataset']][row['sample_ix']].append(row)
        all_outputs_merged.append(row)

unique_preamble_types = set()
unique_models = set()

sampleixs_to_skip = {}
for dataset, samples_by_ix in samples_by_dataset_by_sampleix.items():
    print(dataset, ' has ', len(samples_by_ix), ' samples')
    sampleixs_to_skip[dataset] = []
    prompts_so_far = set()
    for sample_ix,row in samples_by_ix.items():
        if row[0][PROMPT_KEY] in prompts_so_far:
            sampleixs_to_skip[dataset].append(sample_ix)
        prompts_so_far.add(row[0][PROMPT_KEY])
        unique_preamble_types.update([x['preamble_type'] for x in row])
        unique_models.update([x['model'] for x in row])
    print(' .. of which ', len(sampleixs_to_skip[dataset]), ' are duplicate')
    
# Construct the model/preamble pairings to ensure uniform distribution
num_outputs = NUM_SAMPLES * len(samples_by_dataset_by_sampleix) * NUM_PAIRS_PER_INSTANCE


possible_combos = list(product(list(unique_preamble_types), list(unique_models)))
combos_all = []
for _ in range(ceil(num_outputs / len(possible_combos))):
    np.random.shuffle(possible_combos)
    combos_all.extend(possible_combos)


combo_groups = [combos_all[i*NUM_PAIRS_PER_INSTANCE:(i+1)*NUM_PAIRS_PER_INSTANCE] for i in range(NUM_SAMPLES * len(samples_by_dataset_by_sampleix)) ]
min_set_size = min([len(set(group)) for group in combo_groups])
meets_conditions = min_set_size == NUM_PAIRS_PER_INSTANCE
if not meets_conditions:
    print('Grouping check failed ({:})!'.format(min_set_size))

samples_for_prolific = []
    
batch_ix = 0
for dataset, samples_by_ix in samples_by_dataset_by_sampleix.items():
    
    valid_samples = list(range(0,150))
    for invalid_ix in sampleixs_to_skip[dataset]:
        valid_samples.remove(invalid_ix)
    for ix in valid_samples[OFFSET : OFFSET+NUM_SAMPLES]:
        samples = samples_by_ix[ix]
        
        samples_by_key = {(sample['preamble_type'], sample['model']): sample for sample in samples}
        
        np.random.shuffle(samples)
        
#         preambles_to_use = combos_all[batch_ix*NUM_PAIRS_PER_INSTANCE:(batch_ix+1)*NUM_PAIRS_PER_INSTANCE]
#         models_to_use = repeated_models[batch_ix*NUM_PAIRS_PER_INSTANCE:(batch_ix+1)*NUM_PAIRS_PER_INSTANCE]
#         combo_keys = combos_all[batch_ix*NUM_PAIRS_PER_INSTANCE:(batch_ix+1)*NUM_PAIRS_PER_INSTANCE]
        combo_keys = combo_groups[batch_ix]
        combos_to_use = [samples_by_key[k] for k in combo_keys]
        
#         combos = list(combinations(samples[:NUM_PAIRS_PER_INSTANCE], 2))
        
        combos = list(combinations(combos_to_use, 2))
        if NUM_ANNOTATOR_AGREEMENT_SAMPLES > 0:
#             dupe_combos = list(combinations(samples[:NUM_PAIRS_PER_INSTANCE], 2))
            dupe_combos = list(combinations(combos_to_use, 2))
            np.random.shuffle(dupe_combos)
            annotator_checks = [dupe_combos[0]] * NUM_ANNOTATOR_AGREEMENT_SAMPLES
        else: 
            annotator_checks = []
        for sample_1, sample_2 in combos + annotator_checks:
            # switch the order at random
            if sample_1[PROMPT_KEY] != sample_2[PROMPT_KEY]:
                print('Found mismatched prompts!!')
            sample_a, sample_b = (sample_1, sample_2) if random.random() >= 0.5 else (sample_2, sample_1)
            assert sample_a[PROMPT_KEY] == sample_b[PROMPT_KEY]
#             assert sample_a['preamble'] == sample_b['preamble']
            sample_main = [
                {'type': 'text','content': '## Prompt\n-------\n{:}'.format(normalise_response(sample_a[PROMPT_KEY]))},
                 {'type': 'text', 'content': '## Response A\n-------\n{:}'.format(normalise_response(sample_a['response']))},
                 {'type': 'text', 'content': '## Response B\n-------\n{:}'.format(normalise_response(sample_b['response']))}]
            
            sample_meta = {
                'batch_id': BATCH_NAME,
                'dataset': dataset,
                'model_a': sample_a['model'],
                'model_b': sample_b['model'],
                'sample_ix': ix,
                'preamble_a': sample_a['preamble'],
                'preamble_type_a': sample_a['preamble_type'],
                'preamble_b': sample_b['preamble'],
                'preamble_type_b': sample_b['preamble_type'],
            }
            
            

            samples_for_prolific.append(
                {
                 'text':{
                     'Prompt': normalise_response(sample_a[PROMPT_KEY]), 
                     'Response A': normalise_response(sample_a['response']),
                     'Response B': normalise_response(sample_b['response']),
                 },
                 'meta': sample_meta}
            )
            
        batch_ix += 1
            
            
        # Construct an attention check example

        curr_refs = [sample for sample in samples_by_ix[ix] if sample['model'] == REFS_MODEL or not USE_REFS_MODEL_FOR_DISTRACTORS]
        np.random.shuffle(curr_refs)
        if len(curr_refs) == 0:
            continue
        sample_1 = curr_refs[0]
        distractor_ix = np.random.choice([i for i in samples_by_ix.keys() if i != ix])
        sample_2_options = [sample for sample in samples_by_ix[distractor_ix] if sample['model'] == sample_1['model']]
        np.random.shuffle(sample_2_options)
        sample_2 = sample_2_options[0]
        sample_2 = {**sample_2, 'model':'distractor'}
        sample_a, sample_b = (sample_1, sample_2) if random.random() >= 0.5 else (sample_2, sample_1)

        sample_main = [
            {'type': 'text','content': '## Prompt\n-------\n{:}'.format(normalise_response(sample_1[PROMPT_KEY]))},
             {'type': 'text', 'content': '## Response A\n-------\n{:}'.format(normalise_response(sample_a['response']))},
             {'type': 'text', 'content': '## Response B\n-------\n{:}'.format(normalise_response(sample_b['response']))}]

        sample_meta = {
            'batch_id': BATCH_NAME,
            'dataset': dataset,
            'model_a': sample_a['model'],
            'model_b': sample_b['model'],
            'sample_ix': 'distractor_' + str(ix),
            'preamble_a': sample_a['preamble'],
            'preamble_type_a': sample_a['preamble_type'],
            'preamble_b': sample_b['preamble'],
            'preamble_type_b': sample_b['preamble_type'],
        }


        samples_for_prolific.append(
            {
             'text':{
                 'Prompt': normalise_response(sample_1[PROMPT_KEY]), 
                 'Response A': normalise_response(sample_a['response']),
                 'Response B': normalise_response(sample_b['response']),
             },
             'meta': sample_meta}
        )

            
# Shuffle, then assign IDs, so that potato displays them in shuffled order
np.random.shuffle(samples_for_prolific)
samples_for_prolific = [{'id': str(i), **x} for i, x in enumerate(samples_for_prolific)]


with jsonlines.open(f'../results/batch_controlled/prolific_{BATCH_NAME}.jsonl','w') as writer:
    writer.write_all(samples_for_prolific)

len(samples_for_prolific)

amazon  has  100  samples
 .. of which  0  are duplicate
curation  has  100  samples
 .. of which  0  are duplicate
wikihow  has  100  samples
 .. of which  0  are duplicate


825

In [260]:
# Check distractor counts

distractor_preambles = Counter()
distractor_models = Counter()
for row in samples_for_prolific:
    if 'distractor' in str(row['meta']['sample_ix']):
#         print(row['meta'])
        distractor_preambles[row['meta']['preamble_type_a']] += 1
        distractor_preambles[row['meta']['preamble_type_b']] += 1
        distractor_models[row['meta']['model_a']] += 1
        distractor_models[row['meta']['model_b']] += 1
        
distractor_preambles, distractor_models

(Counter({'complexity_low': 36,
          'confidence_low': 30,
          'normal': 29,
          'complexity_high': 28,
          'confidence_high': 27}),
 Counter({'distractor': 75,
          'command_6Bpref_v14.7_20230817': 21,
          'llama13chat': 18,
          'falcon40': 16,
          'command_52Bpref_v14.7_20230817': 11,
          'mpt30instruct': 9}))

In [261]:
# Check other counts are as expected

from collections import Counter

dataset_counter = Counter()

for row in samples_for_prolific:
    dataset_counter[row['meta']['dataset']] += 1
    
# Expected: 8n for factual nlg, 13n for others (4C2+2, 5C2+1 + 2) (combos + distractor + agreement checks)
print(dataset_counter)

model_counter = Counter()

for row in samples_for_prolific:
    if 'distractor' not in [row['meta']['model_a'],row['meta']['model_b']]:
        model_counter[row['meta']['model_a']] += 1
        model_counter[row['meta']['model_b']] += 1
    
# Expected: 8n for factual nlg, 13n for others (4C2+2, 5C2+1 + 2) (combos + distractor + agreement checks)
print(model_counter)

preamble_type_counter = Counter()
preamble_counter = Counter()

for row in samples_for_prolific:
    if 'distractor' not in [row['meta']['model_a'],row['meta']['model_b']]:
        preamble_type_counter[row['meta']['preamble_type_a']] += 1
        preamble_type_counter[row['meta']['preamble_type_b']] += 1
        preamble_counter[row['meta']['preamble_a']] += 1
        preamble_counter[row['meta']['preamble_b']] += 1
    
# Expected: 8n for factual nlg, 13n for others (4C2+2, 5C2+1 + 2) (combos + distractor + agreement checks)
print(preamble_counter)
print(preamble_type_counter)


prompt_counter = Counter()

for row in samples_for_prolific:
    prompt_counter[row['text']['Prompt']] += 1
    
# Expected: (6+agreement)n for factual nlg, (11+agreement)n for others (4C2, 5C2+1) (combos + distractor + agreement checks)
print('Max prompt count: ', prompt_counter.most_common(1)[0][1])

Counter({'wikihow': 275, 'amazon': 275, 'curation': 275})
Counter({'falcon40': 300, 'mpt30instruct': 300, 'command_52Bpref_v14.7_20230817': 300, 'command_6Bpref_v14.7_20230817': 300, 'llama13chat': 300})
Counter({'Respond in a cautious, defensive and uncertain way, as if you are unfamilar with the topic.': 300, 'Respond using complex language, long words and technical terms, as if you are an expert.': 300, 'Respond authoritatively, assertively and persuasively, as if you are very knowledgeable about the topic.': 300, '': 300, 'Respond using only short words and simple language, as if you were talking to a child.': 300})
Counter({'confidence_low': 300, 'complexity_high': 300, 'confidence_high': 300, 'normal': 300, 'complexity_low': 300})
Max prompt count:  11


In [457]:
# Check total prompt counts are as expected

with jsonlines.open('../results/batch_v1/prolific_main_batch_1_repeats.jsonl') as reader:
    batch_1 = list(reader)
with jsonlines.open('../results/batch_v1/prolific_main_batch_2.jsonl') as reader:
    batch_2 = list(reader)
with jsonlines.open('../results/batch_v1/prolific_main_batch_3.jsonl') as reader:
    batch_3 = list(reader)
with jsonlines.open('../results/batch_v1/prolific_main_batch_4.jsonl') as reader:
    batch_4 = list(reader)
    
batch_1_prompts = set([row['text']['Prompt'] for row in batch_1])
batch_2_prompts = set([row['text']['Prompt'] for row in batch_2])
batch_3_prompts = set([row['text']['Prompt'] for row in batch_3])
batch_4_prompts = set([row['text']['Prompt'] for row in batch_4])

batch_1_prompts & batch_2_prompts & batch_3_prompts & batch_4_prompts, len(batch_1_prompts), len(batch_2_prompts), len(batch_3_prompts), len(batch_4_prompts)

(set(), 95, 100, 100, 100)

In [None]:
# Generate augmented responses from Cohere models

import cohere
from cohere import CohereAPIError, CohereConnectionError, CohereError

API_KEY = os.environ.get('COHERE_API_KEY')

  
co = cohere.Client(API_KEY)

# preambles_complexity = [
#     ('complexity_high', 'Respond using jargon, long words and technical language appropriate for an expert.'),
#     ('normal', ''),
#     ('complexity_low', 'Respond using only short words and simple language appropriate for a child.')
    
# ]
preambles_complexity = [
    ('complexity_high', 'Respond using additional jargon, long words and technical terms, as if you are an expert addressing another expert.'),
    ('normal', ''),
    ('complexity_low', 'Explain to me like I\'m five, and respond using only short words and simple language, as if you were talking to a child.')
    
]


# preambles_confidence = [
#     ('confidence_high','Respond in an persuasive, assertive and confident way.'),
#     ('normal',''),
#     ('confidence_low','Respond in an cautious and balanced way, including both sides of the argument.'),
# ]
preambles_confidence = [
    ('confidence_high','Respond aggressively, assertively and confidently, as if you were a genius talking to a stupid person.'),
    ('normal',''),
    ('confidence_low','Respond in a very cautious and uncertain way, as if you were an idiot talking to a genius.'),
]

all_responses = []


for _, preamble in preambles_confidence:


    responses_command = co.batch_generate(
        model='command-nightly',
#         prompts=['User: ' + p['prompt'] +'\n' + preamble + '\nChatbot:'.strip() for p in wikihow_prompts],
        prompts=[p['prompt'] +'\n' + preamble for p in wikihow_prompts],
        return_exceptions=True,
        max_tokens=500,
        temperature=0.7,
        num_generations=1,
    )

    all_responses.append([[(r.text if not isinstance(resp, CohereError) else None) for r in resp] for resp in responses_command])

    

    



In [15]:
import jsonlines

with jsonlines.open('../results/batch_v1/prompts_merged.jsonl') as reader:
    prompts = list(reader)

len(prompts)
confound_prompts = [p for p in prompts if p['dataset'] in ['wikihow','curation','amazon']]

prompts_by_dataset = defaultdict(list)

for prompt in confound_prompts:
    prompts_by_dataset[prompt['dataset']].append(prompt)

confound_prompts = []
for prompts in prompts_by_dataset.values():
    confound_prompts.extend(prompts[-100:])

In [42]:
confound_prompts_augmented = []

# preambles_complexity = [
#     ('complexity_high', 'Respond using jargon, long words and technical language appropriate for an expert.'),
#     ('normal', ''),
#     ('complexity_low', 'Respond using only short words and simple language appropriate for a child.')
    
# ]
# preambles_complexity = [
#     ('complexity_high', 'Respond using additional jargon, long words and technical terms, as if you are an expert addressing another expert.'),
#     ('normal', ''),
#     ('complexity_low', 'Explain to me like I\'m five, and respond using only short words and simple language, as if you were talking to a child.')
    
# ]
preambles_complexity = [
    ('complexity_high', 'Respond using complex language, long words and technical terms, as if you are an expert.'),
    ('normal', ''),
    ('complexity_low', 'Respond using only short words and simple language, as if you were talking to a child.')
    
]


# preambles_confidence = [
#     ('confidence_high','Respond in an persuasive, assertive and confident way.'),
#     ('normal',''),
#     ('confidence_low','Respond in an cautious and balanced way, including both sides of the argument.'),
# ]
# preambles_confidence = [
#     ('confidence_high','Respond aggressively, assertively and confidently, as if you were a genius talking to a stupid person.'),
#     ('normal',''),
#     ('confidence_low','Respond in a very cautious and uncertain way, as if you were an idiot talking to a genius.'),
# ]

preambles_confidence = [
    ('confidence_high','Respond authoritatively, assertively and persuasively, as if you are very knowledgeable about the topic.'),
    ('normal',''),
    ('confidence_low','Respond in a cautious, defensive and uncertain way, as if you are unfamilar with the topic.'),
]



# Respond authoritatively, assertively and persuasively, as if you were an expert.
# Respond in a very cautious and uncertain way, as if you are unfamilar with the topic.

# Respond using complex language, long words and technical terms, as if you are an expert.
# Explain to me like I\'m five, and respond using only short words and simple language, as if you were talking to a child.

preambles = {x[0]: x[1] for x in preambles_complexity+preambles_confidence}

def prompt_to_augmented(orig, augment, dataset):
    if dataset == 'curation':
        if '\nGenerate a summary:' in orig:
            return orig.replace('\nGenerate a summary:', '\n' + augment + ' Generate a summary:')
        else:
            return orig + '\n' + augment
    elif dataset == 'amazon':
        if '\nProduct description:' in orig:
            return orig.replace('\nProduct description:', '\n' + augment + '\nProduct description:')
        else:
            return orig + '\n' + augment
    else:
        return orig + '\n' + augment

for preamble_type, preamble_text in preambles.items():
    
    if 'complexity' in preamble_type:
        for prompt in confound_prompts * 0:
            confound_prompts_augmented.append({**prompt, 'original_prompt': prompt['prompt'], 'preamble': preamble_text, 'prompt': prompt_to_augmented(prompt['prompt'], preamble_text,prompt['dataset']), 'preamble_type': preamble_type })
    elif 'confidence' in preamble_type:
        for prompt in confound_prompts * 2:
            confound_prompts_augmented.append({**prompt, 'original_prompt': prompt['prompt'], 'preamble': preamble_text, 'prompt': prompt_to_augmented(prompt['prompt'], preamble_text,prompt['dataset']), 'preamble_type': preamble_type })
    
    else:
        for prompt in confound_prompts * 0:
            confound_prompts_augmented.append({**prompt, 'original_prompt': prompt['prompt'], 'preamble': preamble_text, 'prompt': prompt_to_augmented(prompt['prompt'], preamble_text,prompt['dataset']), 'preamble_type': preamble_type })
        
        
len(confound_prompts_augmented)

1200

In [43]:
confound_prompts_augmented[1]

{'dataset': 'amazon',
 'prompt': 'Product name:\nWayne SWS50-12P 1/2 hp Shallow Well Jet Pump Conventional 12 gallon Tank System\n\nProduct keywords:\n1/2 hp cast iron shallow well jet pump conventional tank system with 12 gallon horizontal tank\n1/2 hp; max. Flow rate is 420 gallons per hour; (50 psi)\n3/4 in. Npt discharge with 1-1/4 in. Suction\nDual voltage (120V/240V) high efficiency square flange motor (factory set for 240V)\nPressure switch is pre-set at 30-50 psi for automatic operation\nAssembled in US with foreign and domestic parts\nDurable cast iron volute with dedicated priming port for volute filling. For wells 0 ft. - 25 ft. in depth.\n\nRespond authoritatively, assertively and persuasively, as if you are very knowledgeable about the topic.\nProduct description:',
 'sample_ix': 51,
 'original_prompt': 'Product name:\nWayne SWS50-12P 1/2 hp Shallow Well Jet Pump Conventional 12 gallon Tank System\n\nProduct keywords:\n1/2 hp cast iron shallow well jet pump conventional ta

In [44]:
with jsonlines.open('../results/batch_controlled/prompts_augmented_v2_part2.jsonl','w') as writer:
    writer.write_all(confound_prompts_augmented)

In [106]:
# Now we need to rerank the complexity samples

from nltk.tokenize import word_tokenize
import numpy as np

from readability import Readability

systems = [
#     'command52B',
#     'command6B',
    
    
#     'refs',
#     'vicuna30uncensored',
#     'mpt7',
    
    'command52Bpref',
    'command6Bpref',
    'falcon40',
    'llama13chat',
    'mpt30instruct',
]

reranked_outputs = []

for system in systems:
    with jsonlines.open(f'../results/batch_controlled/output_augmented_v2_{system}.jsonl') as reader:
        outputs = list(reader)
    with jsonlines.open(f'../results/batch_controlled/output_augmented_v2_part2_{system}.jsonl') as reader:
        outputs += list(reader)
        
    print(system, len(outputs))
        
    outputs = [{**row, 'model': system if 'model' not in row else row['model']} for row in outputs]
        
    outputs_by_preamble_by_sample = defaultdict(lambda: defaultdict(list))
    
    for output in outputs:
        outputs_by_preamble_by_sample[output['preamble_type']][(output['dataset'], output['sample_ix'] )].append(output)
        
    
    for preamble_type in ['complexity_high', 'complexity_low']:
        for responses in outputs_by_preamble_by_sample[preamble_type].values():
            scores = []
            for response in responses:
                if response['response'] is None:
                    scores.append(-100)
                    continue
        #         print(response)
                r = Readability(response['response'])
                if len(word_tokenize(response['response'])) < 10 or r.statistics()['num_words'] < 100:
                    scores.append(-100)
                    continue
            
            
                
        #         score = r.flesch_kincaid().grade_level
                score = r.dale_chall().score
                scores.append(score)

            if len(scores) > 0:
                if preamble_type[-5:] == '_high':
                    reranked_outputs.append(responses[np.argmax(scores)])
                if preamble_type[-4:] == '_low':
                    reranked_outputs.append(responses[np.argmin(scores)])

                    
    for preamble_type in ['confidence_high', 'confidence_low']:
        
        for responses in outputs_by_preamble_by_sample[preamble_type].values():
#             print(len(responses))
            candidates = []
            
            nosorry_responses = [r for r in responses if (r['response'] is not None and "I'm sorry" not in r['response'])]
            if len(nosorry_responses) > 0:
                reranked_outputs.append(nosorry_responses[0]) 
            else:
                print('All candidates included sorry!', responses[0]['model'])
                reranked_outputs.append(responses[0]) 
                    
    for responses in outputs_by_preamble_by_sample['normal'].values():
        if len(responses) > 1:
            print('Duplicate responses for normal?!')
        
        reranked_outputs.append(responses[0])
        
                   
# Length should be n * 5 models * 3 tasks * 5 groups
print(len(reranked_outputs))

command52Bpref 3900
All candidates included sorry! command_52Bpref_v14.7_20230817
All candidates included sorry! command_52Bpref_v14.7_20230817
All candidates included sorry! command_52Bpref_v14.7_20230817
All candidates included sorry! command_52Bpref_v14.7_20230817
All candidates included sorry! command_52Bpref_v14.7_20230817
All candidates included sorry! command_52Bpref_v14.7_20230817
All candidates included sorry! command_52Bpref_v14.7_20230817
All candidates included sorry! command_52Bpref_v14.7_20230817
All candidates included sorry! command_52Bpref_v14.7_20230817
All candidates included sorry! command_52Bpref_v14.7_20230817
All candidates included sorry! command_52Bpref_v14.7_20230817
All candidates included sorry! command_52Bpref_v14.7_20230817
All candidates included sorry! command_52Bpref_v14.7_20230817
All candidates included sorry! command_52Bpref_v14.7_20230817
All candidates included sorry! command_52Bpref_v14.7_20230817
All candidates included sorry! command_52Bpref_v14

In [107]:
with jsonlines.open('../results/batch_controlled/output_augmented_merged.jsonl','w') as writer:
    writer.write_all(reranked_outputs)

In [297]:
# Generate the batches for expert annotation

import jsonlines

BATCH_NAME = 'confound_validation_1'

with jsonlines.open('../results/batch_controlled/prolific_confound_main_1_agreement.jsonl') as reader:
    orig_inputs = list(reader)
    
orig_inputs_by_key = {}
for row in orig_inputs:
    key_a = (row['meta']['sample_ix'], row['meta']['dataset'],row['meta']['model_a'],row['meta']['preamble_type_a'],)
    if key_a not in orig_inputs_by_key and 'distractor' not in str(row['meta']['sample_ix']):
        orig_inputs_by_key[key_a] = {
            'text': {'Prompt': row['text']['Prompt'], 'Response A': row['text']['Response A']},
            'meta': {
                'batch_id': BATCH_NAME,
                'dataset': row['meta']['dataset'],
                'sample_ix': row['meta']['sample_ix'],
                'model_a': row['meta']['model_a'],
                'preamble_type_a': row['meta']['preamble_type_a'],
                'preamble_a': row['meta']['preamble_a'],
            },
        }
    
    key_b = (row['meta']['sample_ix'], row['meta']['dataset'], row['meta']['model_b'],row['meta']['preamble_type_b'],)
    if key_b not in orig_inputs_by_key and 'distractor' not in str(row['meta']['sample_ix']):
        orig_inputs_by_key[key_b] = {
            'text': {'Prompt': row['text']['Prompt'], 'Response A': row['text']['Response B']},
            'meta': {
                'batch_id': BATCH_NAME,
                'dataset': row['meta']['dataset'],
                'sample_ix': row['meta']['sample_ix'],
                'model_a': row['meta']['model_b'],
                'preamble_type_a': row['meta']['preamble_type_b'],
                'preamble_a': row['meta']['preamble_b'],
            },
        }
        
validation_batch = [{'id': str(i), **row[1]} for i, row in enumerate(sorted(orig_inputs_by_key.items(), key=lambda x: x[0]))]

with jsonlines.open('../results/batch_controlled/prolific_confound_validation_1.jsonl' ,'w') as writer:
    writer.write_all(validation_batch[:150])
with jsonlines.open('../results/batch_controlled/prolific_confound_validation_2.jsonl' ,'w') as writer:
    writer.write_all(validation_batch[150:300])
with jsonlines.open('../results/batch_controlled/prolific_confound_validation_3.jsonl' ,'w') as writer:
    writer.write_all(validation_batch[300:])

In [298]:
len(orig_inputs_by_key)

375