In [24]:
import numpy as np

stimuli = np.array([
    [.312, -.241],
    [.918, -.264],
    [1.405, -.187],
    [2.062, -.227],
    [.228, .640],
    [.844, .662],
    [1.324, .687],
    [1.885, .623],
    [.374, 1.555],
    [.916, 1.501],
    [1.473, 1.544],
    [2.128, 1.520],
    [.135, 2.352],
    [.889, 2.412],
    [1.451, 2.493],
    [2.061, 2.382],
])

size = [[2, 5, 7, 8], [11, 12, 14]]
angle = [[2, 6, 9], [3, 7, 12, 15]]
criss_cross = [[4, 7, 10, 13], [1, 2, 15, 16]]
diagonal = [[2, 3, 7, 12], [5, 10, 14, 15]]

conditions = [size, angle, criss_cross, diagonal]

In [55]:
def stringify(stimulus):
    return f"({stimulus[0]: .1f}, {stimulus[1]: .1f})"

def stringify_list(stimuli_list):
    return '\n'.join([stringify(stimulus) for stimulus in stimuli_list])

def gen_user_prompt(cond, stim):
    user_prompt = f"""\
You are trying to classify stimuli into category A or category B.
The stimuli vary along 2 parameters.

You know that the following stimuli belong to category A:
{stringify_list([stimuli[stim_id-1] for stim_id in cond[0]])}

You know that the following stimuli belong to category B:
{stringify_list([stimuli[stim_id-1] for stim_id in cond[1]])}

What do you think the category of the stimulus {stringify(stim)} would be?

First decide how similar this stimulus is to each of the examples from both categories.
Then use these similarities to guess which category the stimulus is more likely to belong to.

Format your response as a JSON object where the keys are "similaritiesA", "similaritiesB", and "category". 
"similaritiesA" should be formatted as a list of integers, where each entry is an integer from 1-9, indicating
how similar the stimulus is to each example of category A, in the same order as above. Use 1 to indicate most
disimilar and 9 to indicate most similar. 
"similaritiesB" should be formatted the same way as "similaritiesA", but should indicate how similar the stimulus
is to the examples of category B.
"category" should be either "A" or "B" to indicate which category you think the stimulus belongs to.

Only respond with the JSON object. Do not include anything else in your response.
    """
    return user_prompt

In [49]:
print(gen_user_prompt(size, stimuli[0]))

You are trying to classify stimuli into category A or category B.
The stimuli vary along 2 parameters.

You know that the following stimuli belong to category A:
( 0.9, -0.3)
( 0.2,  0.6)
( 1.3,  0.7)
( 1.9,  0.6)

You know that the following stimuli belong to category B:
( 1.5,  1.5)
( 2.1,  1.5)
( 0.9,  2.4)

What do you think the category of the stimulus ( 0.3, -0.2) would be?
First decide how similar this stimulus is to each of the examples from both categories.
Then use these similarities to guess which category the stimulus is more likely to belong to.
Format your response as a JSON object where the keys are "similaritiesA", "similaritiesB", and "category". 
"similaritiesA" should be formatted as a list of integers, where each entry is an integer from 1-9, indicating
how similar the stimulus is to each example of category A, in the same order as above. Use 1 to indicate most
disimilar and 9 to indicate most similar. 
"similaritiesB" should be formatted the same way as "similariti

In [88]:
import openai
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)

MODEL = "gpt-3.5-turbo"
with open("openai_key.txt") as openai_key:
    openai.api_key = openai_key.read()

@retry(wait=wait_random_exponential(min=1, max=120), stop=stop_after_attempt(6))
def completion_with_backoff(**kwargs):
    return openai.ChatCompletion.create(**kwargs)


def get_categorization(cond, stim):
    system_prompt = """\
You are a participant in a category learning experiment. 
You try your best to categorize stimuli correctly based on the information given.
"""
    user_prompt = gen_user_prompt(cond, stim)
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]

    response = completion_with_backoff(
            model=MODEL,
            messages=messages,
            temperature=1,
            n=50,
    )

    return response

In [72]:
response = get_categorization(size, stimuli[0])

In [83]:
import json
from collections import Counter

def get_probability(response):
    counter = Counter()
    for choice in response.choices:
        try:
            content = json.loads(choice.message.content)
        except:
            continue
        counter[content['category']] += 1
    probA = counter['A'] / (counter['A'] + counter['B'])
    return probA

In [89]:
responses = []
for cond in conditions:
    cond_responses = []
    for stim in stimuli:
        response = get_categorization(cond, stim)
        cond_responses.append(response)
    responses.append(cond_responses)


In [97]:
cat_responses = [[get_probability(response) for response in cond] for cond in responses]

In [94]:
print(cat_probs)

[[0.74, 1.0, 0.94, 0.46938775510204084, 1.0, 0.92, 1.0, 1.0, 0.04, 0.9, 0.38, 0.2916666666666667, 0.02, 0.2, 0.14, 0.12], [0.02, 1.0, 0.0, 0.0, 0.16, 0.98, 0.76, 0.08, 0.8979591836734694, 0.02, 0.02, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.02, 0.64, 1.0, 0.92, 1.0, 1.0, 0.9791666666666666, 0.7959183673469388, 1.0, 0.54, 0.52, 1.0, 0.36, 0.02040816326530612, 0.1], [0.02040816326530612, 1.0, 1.0, 0.9583333333333334, 0.0, 0.7, 0.9183673469387755, 0.88, 0.061224489795918366, 0.5918367346938775, 0.6938775510204082, 0.98, 0.0, 0.0, 0.0, 0.20833333333333334]]


In [101]:
def pprint_responses(resp):
    out = f"""\
{resp[12]:.2f} {resp[13]:.2f} {resp[14]:.2f} {resp[15]:.2f}
{resp[8]:.2f} {resp[9]:.2f} {resp[10]:.2f} {resp[11]:.2f}
{resp[4]:.2f} {resp[5]:.2f} {resp[6]:.2f} {resp[7]:.2f}
{resp[0]:.2f} {resp[1]:.2f} {resp[2]:.2f} {resp[3]:.2f}
    """
    print(out)

In [102]:
print('size')
pprint_responses(cat_responses[0])

size
0.02 0.20 0.14 0.12
0.04 0.90 0.38 0.29
1.00 0.92 1.00 1.00
0.74 1.00 0.94 0.47
    


In [103]:
print('angle')
pprint_responses(cat_responses[1])

angle
0.00 0.00 0.00 0.00
0.90 0.02 0.02 0.00
0.16 0.98 0.76 0.08
0.02 1.00 0.00 0.00
    


In [104]:
print('criss cross')
pprint_responses(cat_responses[2])

criss cross
1.00 0.36 0.02 0.10
0.80 1.00 0.54 0.52
0.92 1.00 1.00 0.98
0.00 0.02 0.64 1.00
    


In [105]:
print('diagonal')
pprint_responses(cat_responses[3])

diagonal
0.00 0.00 0.00 0.21
0.06 0.59 0.69 0.98
0.00 0.70 0.92 0.88
0.02 1.00 1.00 0.96
    
