# Connections

Brief description of the problem, link to the website, etc...

In [1]:
# load a jsonl file
import json

def load_jsonl(file_path):
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            data.append(json.loads(line))
    return data

ds = load_jsonl('connections_prompts.jsonl')

In [2]:
print(ds[0])

{'words': ['schmaltz', 'knuckles', 'corn', 'sap', 'loose', 'smile', 'chump', 'egg', 'duct', 'pipe', 'climate', 'sea', 'cheese', 'window', 'drain', 'sewer'], 'categories': {'conduits for water removal': ['drain', 'duct', 'pipe', 'sewer'], 'food products associated with sentimentality': ['cheese', 'corn', 'sap', 'schmaltz'], 'things to crack': ['egg', 'knuckles', 'smile', 'window'], '___ change': ['chump', 'climate', 'loose', 'sea']}}


In [3]:
print(ds[0]["words"])

['schmaltz', 'knuckles', 'corn', 'sap', 'loose', 'smile', 'chump', 'egg', 'duct', 'pipe', 'climate', 'sea', 'cheese', 'window', 'drain', 'sewer']


## Naive approach

In [4]:
from openai import OpenAI

client = OpenAI()

In [5]:
system_prompt = (
    """The game "Connections" is a word game where you start with 16 words and need to group """
    """them into 4 groups of 4. Each grouping has a category that unambiguously groups the four words together."""
    """Each puzzle has exactly one solution. Watch out for words that seem to belong to multiple categories."""
    """You will be given 16 words. Output 4 groups of 4 words and the categories to which they belong"""
    """The results should be in JSON format as following:
    {"category1": ["word1", "word2", "word3", "word4"], "category2": ["word1", "word2", "word3", "word4"]}]}
    """
    """For instance, ['eggplant', 'cucumber', 'tomato', 'pepper'] would belong to 'vegetables' category'"""
    """You have to think outside of the box, sometimes the categories will not be evident and reference popular culture, art, science or other things."""
)

user_prompt = "Here are the 16 words: {words}"

In [6]:
import weave

weave.init("connections")

Logged in as Weights & Biases user: capecape.
View Weave data at https://wandb.ai/capecape/connections/weave




In [7]:
@weave.op()
def call_openai(messages, model="gpt-4o"):
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=0.7,
        response_format={ "type": "json_object" }
        )
    extracted = response.choices[0].message.content
    return extracted

In [8]:
messages=[
    {
        "role": "system",
        "content": system_prompt
    },
    {
        "role": "user",
        "content": user_prompt.format(words=ds[0]["words"])
    }
    ]

res = call_openai(messages)
generation = json.loads(res)

🍩 https://wandb.ai/capecape/connections/r/call/00cf0293-a879-4cc5-882d-ca96313d34c1


In [9]:
for group in generation.items():
    print(group)

('Animal By-products', ['schmaltz', 'knuckles', 'sap', 'cheese'])
('Body Parts', ['knuckles', 'loose', 'smile', 'chump'])
('Plumbing', ['duct', 'pipe', 'drain', 'sewer'])
('Natural Elements', ['corn', 'egg', 'climate', 'sea'])


Let's create a function to check if the groups are valid

In [10]:
flat_generation = list(generation.values())
flat_generation

[['schmaltz', 'knuckles', 'sap', 'cheese'],
 ['knuckles', 'loose', 'smile', 'chump'],
 ['duct', 'pipe', 'drain', 'sewer'],
 ['corn', 'egg', 'climate', 'sea']]

In [11]:
flat_solution = list(ds[0]["categories"].values())
flat_solution

[['drain', 'duct', 'pipe', 'sewer'],
 ['cheese', 'corn', 'sap', 'schmaltz'],
 ['egg', 'knuckles', 'smile', 'window'],
 ['chump', 'climate', 'loose', 'sea']]

In [12]:
@weave.op()
def check_solution(categories, model_output):
    "Check that all group of words match the solution"    
    accuracy = 0.
    correct = {}
    try: # this is ugly, weave shouldn't crash even if chatGPT failed
        for sol_cat, sol_group in categories.items():
            for gen_cat, gen_group in model_output.items():
                if set(gen_group) == set(sol_group):
                    print(f"{gen_cat} ~ {sol_cat}: {gen_group} == {sol_group}")
                    accuracy += 1
                    correct[gen_cat] = gen_group
        return {"match": True if accuracy == 4 else False, "accuracy": accuracy/4, "correct": correct}
    except:
        return {"match": False, "accuracy": 0., "correct": {}} 

In [13]:
check_solution(ds[0]["categories"], generation)

Plumbing ~ conduits for water removal: ['duct', 'pipe', 'drain', 'sewer'] == ['drain', 'duct', 'pipe', 'sewer']
🍩 https://wandb.ai/capecape/connections/r/call/c7f5660a-4391-4014-a250-51f0e36c3eae


{'match': False,
 'accuracy': 0.25,
 'correct': {'Plumbing': ['duct', 'pipe', 'drain', 'sewer']}}

## Refactor into V1

In [14]:
@weave.op()
def generate_solution(system_prompt, user_prompt):
    messages=[
        {
            "role": "system",
            "content": system_prompt
        },
        {
            "role": "user",
            "content": user_prompt
        }
    ]
    res = call_openai(messages)
    generation = json.loads(res)
    return generation

In [15]:
class Model1(weave.Model):
    system_prompt: str
    user_prompt: str

    @weave.op()
    def predict(self, words):
        generation = generate_solution(
            self.system_prompt, 
            self.user_prompt.format(words=words))
        return generation

In [18]:
weave_eval = weave.Evaluation(dataset=ds[0:10], scorers=[check_solution])

In [123]:
await weave_eval.evaluate(Model1(system_prompt=system_prompt, user_prompt=user_prompt))

Vegetables ~ vegetables that are also fruits: ['eggplant', 'cucumber', 'tomato', 'pepper'] == TraceList(['cucumber', 'eggplant', 'pepper', 'tomato'])
Shapes ~ 3-d shapes: ['cube', 'cone', 'pyramid', 'sphere'] == TraceList(['cone', 'cube', 'pyramid', 'sphere'])


Plumbing ~ conduits for water removal: ['pipe', 'duct', 'drain', 'sewer'] == TraceList(['drain', 'duct', 'pipe', 'sewer'])


Soft Drinks ~ soda brands: ['crush', 'sprite', 'mug', 'squirt'] == TraceList(['crush', 'mug', 'sprite', 'squirt'])


Dances ~ dance fads: ['twist', 'mashed potato', 'dougie', 'macarena'] == TraceList(['dougie', 'macarena', 'mashed potato', 'twist'])


measurement terms ~ breadth: ['extent', 'scope', 'reach', 'range'] == TraceList(['extent', 'range', 'reach', 'scope'])


Monopoly game pieces ~ original monopoly tokens: ['thimble', 'top hat', 'boot', 'iron'] == TraceList(['boot', 'iron', 'thimble', 'top hat'])


Awards ~ awards: ['cup', 'trophy', 'ribbon', 'medal'] == TraceList(['cup', 'medal', 'ribbon', 'trophy'])


🍩 https://wandb.ai/capecape/connections/r/call/924601f4-c88a-425a-9cae-ae81bdab33b1


{'check_solution': {'match': {'true_count': 0, 'true_fraction': 0.0},
  'accuracy': {'mean': 0.2}},
 'model_latency': {'mean': 3.131709885597229}}

## V1.5

In [24]:
system_prompt1_5 = (
    "I want you to solve a daily word puzzle that finds commonalities between words. "
    "There are 16 words, which form 4 groups of 4 words. Each group has some common theme that links the words. "
    "You must use each of the 16 words, and use each word only once. Each group of 4 words are linked together in some way. "
    "The connection between words can be simple. An example of a simple connection would be 'types of fish': Bass, Flounder, Salmon, Trout. "
    "Categories can also be more complex, and require abstract or lateral thinking. An example of this type of connection would be 'things that start with FIRE': Ant, Drill, Island, Opal. "
    "The results should be in JSON format as following: {'category1': ['word1', 'word2', 'word3', 'word4'], 'category2': ['word1', 'word2', 'word3', 'word4']}. "
    "Replace each GROUP NAME with a name for the group you create. Some rules: "
    "- Give your final answers in the format described above. Put each group on a separate line. "
    "Do not add any additional text to your final answer, just the group name and the 4 words.")


In [125]:
model1_5 = Model1(system_prompt=system_prompt1_5, user_prompt=user_prompt)
await weave_eval.evaluate(model1_5)

Synonyms for Range ~ breadth: ['extent', 'scope', 'reach', 'range'] == TraceList(['extent', 'range', 'reach', 'scope'])


awards ~ awards: ['cup', 'trophy', 'ribbon', 'medal'] == TraceList(['cup', 'medal', 'ribbon', 'trophy'])


Dances ~ dance fads: ['twist', 'mashed potato', 'dougie', 'macarena'] == TraceList(['dougie', 'macarena', 'mashed potato', 'twist'])


Musical Instruments ~ musical sections: ['brass', 'string', 'wind', 'rhythm'] == TraceList(['brass', 'rhythm', 'string', 'wind'])


Plumbing components ~ conduits for water removal: ['duct', 'pipe', 'drain', 'sewer'] == TraceList(['drain', 'duct', 'pipe', 'sewer'])


States ~ u.s. mountain states: ['utah', 'arizona', 'nevada', 'colorado'] == TraceList(['arizona', 'colorado', 'nevada', 'utah'])
Soft Drinks ~ soda brands: ['sprite', 'crush', 'mug', 'squirt'] == TraceList(['crush', 'mug', 'sprite', 'squirt'])


Fruits/Vegetables ~ vegetables that are also fruits: ['eggplant', 'cucumber', 'tomato', 'pepper'] == TraceList(['cucumber', 'eggplant', 'pepper', 'tomato'])
Shapes ~ 3-d shapes: ['cube', 'cone', 'pyramid', 'sphere'] == TraceList(['cone', 'cube', 'pyramid', 'sphere'])


🍩 https://wandb.ai/capecape/connections/r/call/f832268c-547a-444b-9f13-5753ae53806f


{'check_solution': {'match': {'true_count': 0, 'true_fraction': 0.0},
  'accuracy': {'mean': 0.225}},
 'model_latency': {'mean': 1.8168670415878296}}

## V2

Let's call the model twice

In [16]:
class Model2(weave.Model):
    system_prompt: str
    user_prompt: str
    max_retries: int = 4

    @weave.op()
    def create_incorrect_prompt(self, words, solutions, accuracy, correct):
        incorrect_prompt = (
            f"You recently tried solving the puzzle with the following words: \n {words}\n"
            "You produced the following solutions that are incorrect: \n"
            "\n".join(str(solutions))
        )
        if accuracy>0.:
            incorrect_prompt += (
                f"\nYou had {4*accuracy} categories correct"
                f"\nThe correct guesses where: \n {correct}"
            )

        incorrect_prompt+=(
            "\nTake this into account, and try to generate a correct solution this time. "
            "\nMake sure you don't repeat any previous guesses"
        )
        return incorrect_prompt

    @weave.op()
    def predict(self, words, categories):
        retries = 0
        previous_generations = []
        generation = generate_solution(self.system_prompt, self.user_prompt.format(words=words))
        scores = check_solution(categories, generation)
        if scores["match"]:
            return generation
        else:
            while (retries < self.max_retries and not scores["match"]):
                previous_generations.append(generation)
                retries+=1
                print(f"Retry {retries}")
                generation = generate_solution(
                    self.system_prompt, 
                    self.create_incorrect_prompt(words, previous_generations, scores["accuracy"], scores["correct"]))
                scores = check_solution(categories, generation)
        return generation


In [19]:
weave_model2 = Model2(system_prompt=system_prompt, user_prompt=user_prompt)

await weave_eval.evaluate(weave_model2)

Decay or Spoilage ~ go bad: ['sour', 'rot', 'spoil', 'turn'] == TraceList(['rot', 'sour', 'spoil', 'turn'])
Retry 1
Rewards ~ awards: ['trophy', 'medal', 'ribbon', 'cup'] == TraceList(['cup', 'medal', 'ribbon', 'trophy'])
Retry 1
category3 ~ conduits for water removal: ['duct', 'pipe', 'drain', 'sewer'] == TraceList(['drain', 'duct', 'pipe', 'sewer'])
Retry 1
Monopoly game pieces ~ original monopoly tokens: ['thimble', 'top hat', 'boot', 'iron'] == TraceList(['boot', 'iron', 'thimble', 'top hat'])
Retry 1
Vegetables ~ vegetables that are also fruits: ['eggplant', 'cucumber', 'tomato', 'pepper'] == TraceList(['cucumber', 'eggplant', 'pepper', 'tomato'])
Shapes ~ 3-d shapes: ['cube', 'cone', 'pyramid', 'sphere'] == TraceList(['cone', 'cube', 'pyramid', 'sphere'])
Retry 1
Retry 1
measurement_terms ~ breadth: ['extent', 'range', 'scope', 'reach'] == TraceList(['extent', 'range', 'reach', 'scope'])
Retry 1
Retry 1
US States ~ u.s. mountain states: ['utah', 'arizona', 'nevada', 'colorado'] =

Dances ~ dance fads: ['twist', 'mashed potato', 'dougie', 'macarena'] == TraceList(['dougie', 'macarena', 'mashed potato', 'twist'])
Retry 1
Rewards ~ awards: ['trophy', 'medal', 'ribbon', 'cup'] == TraceList(['cup', 'medal', 'ribbon', 'trophy'])
Retry 2
category3 ~ conduits for water removal: ['duct', 'pipe', 'drain', 'sewer'] == TraceList(['drain', 'duct', 'pipe', 'sewer'])
Retry 2
Vegetables ~ vegetables that are also fruits: ['eggplant', 'cucumber', 'tomato', 'pepper'] == TraceList(['cucumber', 'eggplant', 'pepper', 'tomato'])
Shapes ~ 3-d shapes: ['cube', 'cone', 'pyramid', 'sphere'] == TraceList(['cone', 'cube', 'pyramid', 'sphere'])
Retry 2
Monopoly game pieces ~ original monopoly tokens: ['thimble', 'top hat', 'boot', 'iron'] == TraceList(['boot', 'iron', 'thimble', 'top hat'])
Retry 2
Retry 2
measurement_terms ~ breadth: ['extent', 'range', 'scope', 'reach'] == TraceList(['extent', 'range', 'reach', 'scope'])
Retry 2
Retry 2
Dances ~ dance fads: ['twist', 'mashed potato', 'dou

category3 ~ conduits for water removal: ['duct', 'pipe', 'drain', 'sewer'] == TraceList(['drain', 'duct', 'pipe', 'sewer'])
Retry 3
Retry 4
Monopoly game pieces ~ original monopoly tokens: ['thimble', 'top hat', 'boot', 'iron'] == TraceList(['boot', 'iron', 'thimble', 'top hat'])
Retry 4
category1 ~ breadth: ['extent', 'range', 'scope', 'reach'] == TraceList(['extent', 'range', 'reach', 'scope'])
Retry 4
music ~ musical sections: ['brass', 'string', 'wind', 'rhythm'] == TraceList(['brass', 'rhythm', 'string', 'wind'])
Retry 4
Body Parts ~ muscles, informally: ['pec', 'lat', 'quad', 'tri'] == TraceList(['lat', 'pec', 'quad', 'tri'])
Rewards ~ awards: ['trophy', 'medal', 'ribbon', 'cup'] == TraceList(['cup', 'medal', 'ribbon', 'trophy'])
Body Parts ~ muscles, informally: ['pec', 'lat', 'quad', 'tri'] == TraceList(['lat', 'pec', 'quad', 'tri'])
Rewards ~ awards: ['trophy', 'medal', 'ribbon', 'cup'] == TraceList(['cup', 'medal', 'ribbon', 'trophy'])


Retry 4
Decay or Spoilage ~ go bad: ['sour', 'rot', 'spoil', 'turn'] == TraceList(['rot', 'sour', 'spoil', 'turn'])
Retry 4
category3 ~ conduits for water removal: ['duct', 'pipe', 'drain', 'sewer'] == TraceList(['drain', 'duct', 'pipe', 'sewer'])
Retry 4
Vegetables ~ vegetables that are also fruits: ['eggplant', 'cucumber', 'tomato', 'pepper'] == TraceList(['cucumber', 'eggplant', 'pepper', 'tomato'])
Shapes ~ 3-d shapes: ['cube', 'cone', 'pyramid', 'sphere'] == TraceList(['cone', 'cube', 'pyramid', 'sphere'])
Vegetables ~ vegetables that are also fruits: ['eggplant', 'cucumber', 'tomato', 'pepper'] == TraceList(['cucumber', 'eggplant', 'pepper', 'tomato'])
Shapes ~ 3-d shapes: ['cube', 'cone', 'pyramid', 'sphere'] == TraceList(['cone', 'cube', 'pyramid', 'sphere'])


Monopoly game pieces ~ original monopoly tokens: ['thimble', 'top hat', 'boot', 'iron'] == TraceList(['boot', 'iron', 'thimble', 'top hat'])
Monopoly game pieces ~ original monopoly tokens: ['thimble', 'top hat', 'boot', 'iron'] == TraceList(['boot', 'iron', 'thimble', 'top hat'])


category1 ~ breadth: ['extent', 'range', 'scope', 'reach'] == TraceList(['extent', 'range', 'reach', 'scope'])
category3 ~ heraldry terms: ['shield', 'coat', 'arms', 'crest'] == TraceList(['arms', 'coat', 'crest', 'shield'])
category1 ~ breadth: ['extent', 'range', 'scope', 'reach'] == TraceList(['extent', 'range', 'reach', 'scope'])
category3 ~ heraldry terms: ['shield', 'coat', 'arms', 'crest'] == TraceList(['arms', 'coat', 'crest', 'shield'])


music ~ musical sections: ['brass', 'string', 'wind', 'rhythm'] == TraceList(['brass', 'rhythm', 'string', 'wind'])
music ~ musical sections: ['brass', 'string', 'wind', 'rhythm'] == TraceList(['brass', 'rhythm', 'string', 'wind'])


Decay or Spoilage ~ go bad: ['sour', 'rot', 'spoil', 'turn'] == TraceList(['rot', 'sour', 'spoil', 'turn'])
Decay or Spoilage ~ go bad: ['sour', 'rot', 'spoil', 'turn'] == TraceList(['rot', 'sour', 'spoil', 'turn'])


category3 ~ conduits for water removal: ['duct', 'pipe', 'drain', 'sewer'] == TraceList(['drain', 'duct', 'pipe', 'sewer'])
category3 ~ conduits for water removal: ['duct', 'pipe', 'drain', 'sewer'] == TraceList(['drain', 'duct', 'pipe', 'sewer'])


🍩 https://wandb.ai/capecape/connections/r/call/1e011f37-61c2-4e79-b0e5-acbae1937380


{'check_solution': {'match': {'true_count': 2, 'true_fraction': 0.2},
  'accuracy': {'mean': 0.45}},
 'model_latency': {'mean': 18.96248745918274}}

In [22]:
weave_model2 = Model2(system_prompt=system_prompt1_5, user_prompt=user_prompt)

await weave_eval.evaluate(weave_model2)

Retry 1
Vegetables ~ vegetables that are also fruits: ['eggplant', 'cucumber', 'tomato', 'pepper'] == TraceList(['cucumber', 'eggplant', 'pepper', 'tomato'])
Shapes ~ 3-d shapes: ['cube', 'cone', 'pyramid', 'sphere'] == TraceList(['cone', 'cube', 'pyramid', 'sphere'])
Retry 1
Retry 1
Dances ~ dance fads: ['twist', 'mashed potato', 'dougie', 'macarena'] == TraceList(['dougie', 'macarena', 'mashed potato', 'twist'])
Magazines ~ magazines: ['wired', 'vogue', 'fortune', 'rolling stone'] == TraceList(['fortune', 'rolling stone', 'vogue', 'wired'])
Chains ~ things with links: ['chain', 'sausage', 'website', 'golf course'] == TraceList(['chain', 'golf course', 'sausage', 'website'])
Retry 1
Types of Awards ~ awards: ['trophy', 'cup', 'ribbon', 'medal'] == TraceList(['cup', 'medal', 'ribbon', 'trophy'])
Retry 1
Soda Brands ~ soda brands: ['crush', 'sprite', 'mug', 'squirt'] == TraceList(['crush', 'mug', 'sprite', 'squirt'])
Retry 1
Retry 1
Retry 1
Measurements ~ breadth: ['extent', 'scope', 'r

Retry 2
Retry 2
Objects that Make Sound ~ hit hard: ['bang', 'slam', 'hammer', 'pound'] == TraceList(['bang', 'hammer', 'pound', 'slam'])
Retry 2
Bands ~ 60’s band members: ['monkee', 'beach boy', 'byrd', 'beatle'] == TraceList(['beach boy', 'beatle', 'byrd', 'monkee'])
Dances ~ dance fads: ['twist', 'mashed potato', 'dougie', 'macarena'] == TraceList(['dougie', 'macarena', 'mashed potato', 'twist'])
Magazines ~ magazines: ['wired', 'vogue', 'fortune', 'rolling stone'] == TraceList(['fortune', 'rolling stone', 'vogue', 'wired'])
Chains ~ things with links: ['chain', 'sausage', 'website', 'golf course'] == TraceList(['chain', 'golf course', 'sausage', 'website'])
Bands ~ 60’s band members: ['monkee', 'beach boy', 'byrd', 'beatle'] == TraceList(['beach boy', 'beatle', 'byrd', 'monkee'])
Dances ~ dance fads: ['twist', 'mashed potato', 'dougie', 'macarena'] == TraceList(['dougie', 'macarena', 'mashed potato', 'twist'])
Magazines ~ magazines: ['wired', 'vogue', 'fortune', 'rolling stone'] =

Retry 2
Soda Brands ~ soda brands: ['crush', 'sprite', 'mug', 'squirt'] == TraceList(['crush', 'mug', 'sprite', 'squirt'])
Retry 2
Retry 3
Monopoly pieces ~ original monopoly tokens: ['thimble', 'top hat', 'boot', 'iron'] == TraceList(['boot', 'iron', 'thimble', 'top hat'])
Retry 3
Retry 3
Soda Brands ~ soda brands: ['crush', 'sprite', 'mug', 'squirt'] == TraceList(['crush', 'mug', 'sprite', 'squirt'])
Retry 3
Objects that Make Sound ~ hit hard: ['bang', 'slam', 'hammer', 'pound'] == TraceList(['bang', 'hammer', 'pound', 'slam'])
Retry 3
Retry 3
Types of Awards ~ awards: ['trophy', 'cup', 'ribbon', 'medal'] == TraceList(['cup', 'medal', 'ribbon', 'trophy'])
Retry 3
Measurements ~ breadth: ['extent', 'scope', 'reach', 'range'] == TraceList(['extent', 'range', 'reach', 'scope'])
Armor ~ heraldry terms: ['shield', 'coat', 'crest', 'arms'] == TraceList(['arms', 'coat', 'crest', 'shield'])
Retry 3
Retry 4
Measurements ~ breadth: ['extent', 'scope', 'reach', 'range'] == TraceList(['extent', 

Types of Awards ~ awards: ['trophy', 'cup', 'ribbon', 'medal'] == TraceList(['cup', 'medal', 'ribbon', 'trophy'])
Types of Awards ~ awards: ['trophy', 'cup', 'ribbon', 'medal'] == TraceList(['cup', 'medal', 'ribbon', 'trophy'])


Plumbing ~ conduits for water removal: ['duct', 'pipe', 'drain', 'sewer'] == TraceList(['drain', 'duct', 'pipe', 'sewer'])
Plumbing ~ conduits for water removal: ['duct', 'pipe', 'drain', 'sewer'] == TraceList(['drain', 'duct', 'pipe', 'sewer'])


Objects that Make Sound ~ hit hard: ['bang', 'slam', 'hammer', 'pound'] == TraceList(['bang', 'hammer', 'pound', 'slam'])
Objects that Make Sound ~ hit hard: ['bang', 'slam', 'hammer', 'pound'] == TraceList(['bang', 'hammer', 'pound', 'slam'])


Monopoly pieces ~ original monopoly tokens: ['thimble', 'top hat', 'boot', 'iron'] == TraceList(['boot', 'iron', 'thimble', 'top hat'])
Monopoly pieces ~ original monopoly tokens: ['thimble', 'top hat', 'boot', 'iron'] == TraceList(['boot', 'iron', 'thimble', 'top hat'])


Soda Brands ~ soda brands: ['crush', 'sprite', 'mug', 'squirt'] == TraceList(['crush', 'mug', 'sprite', 'squirt'])
Soda Brands ~ soda brands: ['crush', 'sprite', 'mug', 'squirt'] == TraceList(['crush', 'mug', 'sprite', 'squirt'])


🍩 https://wandb.ai/capecape/connections/r/call/ed7be433-4dc9-40c8-bb05-fa2f2e573dd5


{'check_solution': {'match': {'true_count': 2, 'true_fraction': 0.2},
  'accuracy': {'mean': 0.375}},
 'model_latency': {'mean': 18.09070508480072}}

## V3 COT

In [39]:

system_prompt_cot = (
    "You are an expert puzzle solver. I want you to solve a daily word puzzle that finds commonalities between words. "
    "There are 16 words, which form 4 groups of 4 words. Each group has some common theme that links the words. "
    "You must use each of the 16 words, and use each word only once. Each group of 4 words are linked together in some way. "
    "The connection between words can be simple. An example of a simple connection would be 'types of fish': Bass, Flounder, Salmon, Trout. "
    "Categories can also be more complex, and require abstract or lateral thinking. An example of this type of connection would be 'things that start with FIRE': Ant, Drill, Island, Opal... ")

cot_prompt = """- First, briefly summarize the rules and objective of the puzzle (in no more than 50
words)
- Next, come up with a category to which four of the words belong and briefly explain why you think they belong to that category:"""

user_prompt_cot = (
    "Provide the one group you are most sure of as your final answer. I will enter this into the puzzle and give you feedback: I will tell you whether it is correct, incorrect, or nearly correct (3/4 words). "
    "Then we will continue until the puzzle is solved, or you lose. "
    "The results should be in JSON format as following: category: [\"word1\", \"word2\", \"word3\", \"word4\"]"
    f"Some rules: {cot_prompt} - Give your final answer in the format described above. "
    "Do not add any additional text to your final answer, just the group name and the 4 words. "
    "Here are the starting 16 words: {words} "
)


In [40]:
user_prompt_cot.format(words="hola")

'Provide the one group you are most sure of as your final answer. I will enter this into the puzzle and give you feedback: I will tell you whether it is correct, incorrect, or nearly correct (3/4 words). Then we will continue until the puzzle is solved, or you lose. The results should be in JSON format as following: category: ["word1", "word2", "word3", "word4"]Some rules: - First, briefly summarize the rules and objective of the puzzle (in no more than 50\nwords)\n- Next, come up with a category to which four of the words belong and briefly explain why you think they belong to that category: - Give your final answer in the format described above. Do not add any additional text to your final answer, just the group name and the 4 words. Here are the starting 16 words: hola '

In [46]:
generation = generate_solution(system_prompt_cot, user_prompt_cot.format(words=ds[0]["words"]))

🍩 https://wandb.ai/capecape/connections/r/call/6b5d757c-4dec-47f7-8a89-0cb9c5b5655e


In [49]:
generation

{'category': ['duct', 'pipe', 'drain', 'sewer']}

In [47]:
@weave.op()
def check_one_solution(categories, model_output):
    "Check that all group of words match the solution"    

    try: # this is ugly, weave shouldn't crash even if chatGPT failed
        for sol_cat, sol_group in categories.items():
            for gen_cat, gen_group in model_output.items():
                if set(gen_group) == set(sol_group):
                    print(f"{gen_cat} ~ {sol_cat}: {gen_group} == {sol_group}")
                    return {"match": True }
        else: 
            return {"match": False} 
    except:
        return {"match": False} 

In [48]:
check_one_solution(ds[0]["categories"], generation)

category ~ conduits for water removal: ['duct', 'pipe', 'drain', 'sewer'] == ['drain', 'duct', 'pipe', 'sewer']
🍩 https://wandb.ai/capecape/connections/r/call/2171d7bf-9406-46f3-851a-906af82b35b4


{'match': True}

In [None]:
class Model3(weave.Model):
    system_prompt: str
    user_prompt: str
    max_retries: int = 4

    @weave.op()
    def create_incorrect_prompt(self, words, solutions, accuracy, correct):
        incorrect_prompt = (
            f"You recently tried solving the puzzle with the following words: \n {words}\n"
            "You produced the following solutions that are incorrect: \n"
            "\n".join(str(solutions))
        )
        if accuracy>0.:
            incorrect_prompt += (
                f"\nYou had {4*accuracy} categories correct"
                f"\nThe correct guesses where: \n {correct}"
            )

        incorrect_prompt+=(
            "\nTake this into account, and try to generate a correct solution this time. "
            "\nMake sure you don't repeat any previous guesses"
        )
        return incorrect_prompt

    @weave.op()
    def predict(self, words, categories):
        retries = 0
        previous_generations = []
        generation = generate_solution(self.system_prompt, self.user_prompt.format(words=words))
        scores = check_solution(categories, generation)
        if scores["match"]:
            return generation
        else:
            while (retries < self.max_retries and not scores["match"]):
                previous_generations.append(generation)
                retries+=1
                print(f"Retry {retries}")
                generation = generate_solution(
                    self.system_prompt, 
                    self.create_incorrect_prompt(words, previous_generations, scores["accuracy"], scores["correct"]))
                scores = check_solution(categories, generation)
        return generation
