# Connections

Brief description of the problem, link to the website, etc...

In [1]:
# load a jsonl file
import json

def load_jsonl(file_path):
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            data.append(json.loads(line))
    return data

ds = load_jsonl('connections_prompts.jsonl')

In [40]:
print(ds[0])

{'words': ['schmaltz', 'knuckles', 'corn', 'sap', 'loose', 'smile', 'chump', 'egg', 'duct', 'pipe', 'climate', 'sea', 'cheese', 'window', 'drain', 'sewer'], 'categories': {'conduits for water removal': ['drain', 'duct', 'pipe', 'sewer'], 'food products associated with sentimentality': ['cheese', 'corn', 'sap', 'schmaltz'], 'things to crack': ['egg', 'knuckles', 'smile', 'window'], '___ change': ['chump', 'climate', 'loose', 'sea']}}


In [2]:
print(ds[0]["words"])



['schmaltz', 'knuckles', 'corn', 'sap', 'loose', 'smile', 'chump', 'egg', 'duct', 'pipe', 'climate', 'sea', 'cheese', 'window', 'drain', 'sewer']


## Naive approach

In [3]:
from openai import OpenAI

client = OpenAI()

In [31]:
system_prompt = (
    """The game "Connections" is a word game where you start with 16 words and need to group """
    """them into 4 groups of 4. Each grouping has a category that unambiguously groups the four words together."""
    """Each puzzle has exactly one solution. Watch out for words that seem to belong to multiple categories."""
    """You will be given 16 words. Output 4 groups of 4 words and the categories to which they belong"""
    """The results should be in JSON format as following:
    {"category1": ["word1", "word2", "word3", "word4"], "category2": ["word1", "word2", "word3", "word4"]}]}
    """
)

user_prompt = "Here are the 16 words: {words}"

In [5]:
import weave

weave.init("connections")

Logged in as Weights & Biases user: capecape.
View Weave data at https://wandb.ai/capecape/connections/weave




In [32]:
@weave.op()
def call_openai(system_prompt, user_prompt, model="gpt-4o"):
    response = client.chat.completions.create(
        model=model,
        messages=[
            {
                "role": "system",
                "content": system_prompt
            },
            {
                "role": "user",
                "content": user_prompt
            }
            ],
            temperature=0.7,
            response_format={ "type": "json_object" }
        )
    extracted = response.choices[0].message.content
    return extracted

In [33]:
res = call_openai(system_prompt, user_prompt.format(words=ds[0]["words"]))
generation = json.loads(res)

🍩 https://wandb.ai/capecape/connections/r/call/a38c6826-695f-4718-b9d7-b14bc6c2e59b


In [34]:
for group in generation.items():
    print(group)

('category1', ['duct', 'pipe', 'drain', 'sewer'])
('category2', ['schmaltz', 'cheese', 'corn', 'sap'])
('category3', ['smile', 'loose', 'climate', 'sea'])
('category4', ['knuckles', 'chump', 'egg', 'window'])


Let's create a function to check if the groups are valid

In [9]:
flat_generation = list(generation.values())
flat_generation

[['schmaltz', 'loose', 'smile', 'chump'],
 ['knuckles', 'duct', 'pipe', 'drain'],
 ['corn', 'egg', 'cheese', 'sap'],
 ['climate', 'sea', 'window', 'sewer']]

In [10]:
flat_solution = list(ds[0]["categories"].values())
flat_solution

[['drain', 'duct', 'pipe', 'sewer'],
 ['cheese', 'corn', 'sap', 'schmaltz'],
 ['egg', 'knuckles', 'smile', 'window'],
 ['chump', 'climate', 'loose', 'sea']]

In [11]:
@weave.op()
def check_solution(categories, model_output):
    "Check that all group of words match the solution"    
    accuracy = 0.
    try: # this is ugly, weave shouldn't crash even if chatGPT failed
        for sol_cat, sol_group in categories.items():
            for gen_cat, gen_group in model_output.items():
                if set(gen_group) == set(sol_group):
                    print(f"{gen_cat} ~ {sol_cat}: {gen_group} == {sol_group}")
                    accuracy += 1
        return {"match": True if accuracy == 4 else False, "accuracy": accuracy/4}
    except:
        return {"match": False, "accuracy": 0.} 

In [12]:
check_solution(ds[0]["categories"], generation)

🍩 https://wandb.ai/capecape/connections/r/call/ced0d5db-b87d-4fbc-976a-85c6118b0a12


{'match': False, 'accuracy': 0.0}

## Refactor into V1

In [27]:
@weave.op()
def generate_solution(words, system_prompt, user_prompt):
    res = call_openai(system_prompt, user_prompt.format(words=words))
    generation = json.loads(res)
    return generation

In [14]:
class Model1(weave.Model):
    system_prompt: str
    user_prompt: str

    @weave.op()
    def predict(self, words):
        generation = generate_solution(words, self.system_prompt, self.user_prompt)
        return generation

In [15]:
weave_eval = weave.Evaluation(dataset=ds[0:10], scorers=[check_solution])

In [16]:
await weave_eval.evaluate(Model1(system_prompt=system_prompt, user_prompt=user_prompt))

category3 ~ breadth: ['reach', 'extent', 'range', 'scope'] == TraceList(['extent', 'range', 'reach', 'scope'])


category2 ~ musical sections: ['brass', 'wind', 'string', 'rhythm'] == TraceList(['brass', 'rhythm', 'string', 'wind'])


Awards ~ awards: ['cup', 'trophy', 'ribbon', 'medal'] == TraceList(['cup', 'medal', 'ribbon', 'trophy'])


Plumbing ~ conduits for water removal: ['pipe', 'duct', 'drain', 'sewer'] == TraceList(['drain', 'duct', 'pipe', 'sewer'])


Sodas ~ soda brands: ['crush', 'sprite', 'mug', 'squirt'] == TraceList(['crush', 'mug', 'sprite', 'squirt'])


Dance Moves ~ dance fads: ['twist', 'mashed potato', 'dougie', 'macarena'] == TraceList(['dougie', 'macarena', 'mashed potato', 'twist'])
Things with Links ~ things with links: ['chain', 'golf course', 'website', 'sausage'] == TraceList(['chain', 'golf course', 'sausage', 'website'])


🍩 https://wandb.ai/capecape/connections/r/call/85df6ab9-5c74-44be-aabc-5452d8f88bc3


{'check_solution': {'match': {'true_count': 0, 'true_fraction': 0.0},
  'accuracy': {'mean': 0.175}},
 'model_latency': {'mean': 8.111122131347656}}

## V2

In [17]:
extra_system_prompt = """
Check your solution before submitting it. Be sure about:
- that you have 4 groups of 4 words each
- that the words are not in the same category
- that the words are not in the same group
- that the words are not in the same category
"""

In [18]:
await weave_eval.evaluate(Model1(system_prompt=system_prompt+extra_system_prompt, user_prompt=user_prompt))

awards ~ awards: ['cup', 'trophy', 'ribbon', 'medal'] == TraceList(['cup', 'medal', 'ribbon', 'trophy'])


Vegetables ~ vegetables that are also fruits: ['eggplant', 'cucumber', 'tomato', 'pepper'] == TraceList(['cucumber', 'eggplant', 'pepper', 'tomato'])
3D Shapes ~ 3-d shapes: ['cube', 'cone', 'pyramid', 'sphere'] == TraceList(['cone', 'cube', 'pyramid', 'sphere'])


dances ~ dance fads: ['twist', 'mashed potato', 'macarena', 'dougie'] == TraceList(['dougie', 'macarena', 'mashed potato', 'twist'])
things with links ~ things with links: ['golf course', 'sausage', 'chain', 'website'] == TraceList(['chain', 'golf course', 'sausage', 'website'])


Measurement or Range ~ breadth: ['extent', 'scope', 'range', 'reach'] == TraceList(['extent', 'range', 'reach', 'scope'])
Armor or Protection ~ heraldry terms: ['shield', 'coat', 'crest', 'arms'] == TraceList(['arms', 'coat', 'crest', 'shield'])


Baseball Equipment ~ baseball equipment: ['base', 'bat', 'glove', 'ball'] == TraceList(['ball', 'base', 'bat', 'glove'])
Monopoly Pieces ~ original monopoly tokens: ['thimble', 'iron', 'top hat', 'boot'] == TraceList(['boot', 'iron', 'thimble', 'top hat'])


🍩 https://wandb.ai/capecape/connections/r/call/97a2fe64-8f77-40e3-930b-cb208796be68


{'check_solution': {'match': {'true_count': 0, 'true_fraction': 0.0},
  'accuracy': {'mean': 0.225}},
 'model_latency': {'mean': 1.8481802463531494}}

## V3

Let's call the model twice

In [19]:
ds[0]

{'words': ['schmaltz',
  'knuckles',
  'corn',
  'sap',
  'loose',
  'smile',
  'chump',
  'egg',
  'duct',
  'pipe',
  'climate',
  'sea',
  'cheese',
  'window',
  'drain',
  'sewer'],
 'categories': {'conduits for water removal': ['drain',
   'duct',
   'pipe',
   'sewer'],
  'food products associated with sentimentality': ['cheese',
   'corn',
   'sap',
   'schmaltz'],
  'things to crack': ['egg', 'knuckles', 'smile', 'window'],
  '___ change': ['chump', 'climate', 'loose', 'sea']}}

In [35]:
@weave.op()
def check_solution(categories, model_output):
    "Check that all group of words match the solution"    
    accuracy = 0.
    correct = {}
    try: # this is ugly, weave shouldn't crash even if chatGPT failed
        for sol_cat, sol_group in categories.items():
            for gen_cat, gen_group in model_output.items():
                if set(gen_group) == set(sol_group):
                    print(f"{gen_cat} ~ {sol_cat}: {gen_group} == {sol_group}")
                    accuracy += 1
                    correct[gen_cat] = gen_group
        return {"match": True if accuracy == 4 else False, "accuracy": accuracy/4, "correct": correct}
    except:
        return {"match": False, "accuracy": 0., "correct": {}} 

In [36]:
sol1 = generate_solution(ds[0]["words"], system_prompt, user_prompt)

🍩 https://wandb.ai/capecape/connections/r/call/5d854a5d-3300-47f0-a471-a959f2551d15


In [37]:
sol1

{'category1': ['knuckles', 'smile', 'corn', 'loose'],
 'category2': ['pipe', 'duct', 'drain', 'sewer'],
 'category3': ['cheese', 'sap', 'egg', 'chump'],
 'category4': ['schmaltz', 'window', 'climate', 'sea']}

In [38]:
check_solution(ds[0]["categories"], sol1)

category2 ~ conduits for water removal: ['pipe', 'duct', 'drain', 'sewer'] == ['drain', 'duct', 'pipe', 'sewer']
🍩 https://wandb.ai/capecape/connections/r/call/bd3aa53b-bb2c-4774-b3f5-c7e7eb027c9f


{'match': False,
 'accuracy': 0.25,
 'correct': {'category2': ['pipe', 'duct', 'drain', 'sewer']}}

In [48]:
user_prompt2 = """
You recently got the following words: {words}
You produced the following solution: 
{solution}
This solution has a {accuracy} accuracy.
You got correct {correct}.
Take this into account, and try to generate a correct solution this time
"""

In [52]:
class Model2(weave.Model):
    system_prompt1: str
    user_prompt1: str

    system_prompt2: str
    user_prompt2: str

    @weave.op()
    def predict(self, words, categories):
        generation = generate_solution(words, self.system_prompt1, self.user_prompt1)
        scores = check_solution(categories, generation)
        if scores["match"]:
            return generation
        final_gen = call_openai(self.system_prompt2, 
                                self.user_prompt2.format(
                                    words=words, 
                                    solution=generation,
                                    accuracy=scores["accuracy"],
                                    correct=scores["correct"]))
        return final_gen


In [53]:
weave_model2 = Model2(system_prompt1=system_prompt, user_prompt1=user_prompt,
                      system_prompt2=system_prompt, user_prompt2=user_prompt2)

await weave_eval.evaluate(weave_model2)

Dances ~ dance fads: ['twist', 'mashed potato', 'dougie', 'macarena'] == TraceList(['dougie', 'macarena', 'mashed potato', 'twist'])
Things with Links ~ things with links: ['chain', 'sausage', 'golf course', 'website'] == TraceList(['chain', 'golf course', 'sausage', 'website'])
Musical Instruments ~ musical sections: ['brass', 'wind', 'string', 'rhythm'] == TraceList(['brass', 'rhythm', 'string', 'wind'])
Vegetables ~ vegetables that are also fruits: ['eggplant', 'cucumber', 'tomato', 'pepper'] == TraceList(['cucumber', 'eggplant', 'pepper', 'tomato'])
Shapes ~ 3-d shapes: ['cube', 'cone', 'pyramid', 'sphere'] == TraceList(['cone', 'cube', 'pyramid', 'sphere'])
Awards ~ awards: ['trophy', 'medal', 'ribbon', 'cup'] == TraceList(['cup', 'medal', 'ribbon', 'trophy'])
Measurement and Distance ~ breadth: ['extent', 'range', 'scope', 'reach'] == TraceList(['extent', 'range', 'reach', 'scope'])
Armor and Protection ~ heraldry terms: ['shield', 'arms', 'coat', 'crest'] == TraceList(['arms', '