In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import dspy
from dspy.datasets import DataLoader
from dspy.evaluate.metrics import answer_exact_match
from typing import List
from dspy.evaluate import Evaluate

import dotenv

dotenv.load_dotenv()

def debug_exact_match(example, pred, trace=None, frac=1.0):
    print(example.inputs())
    print(example.answer)
    print(pred)
    # print(trace)
    # print(frac)
    return answer_exact_match(example, pred, trace, frac)

In [4]:
# lm = dspy.LM(model="openai/Qwen/Qwen2-VL-7B-Instruct", api_base="http://localhost:8000/v1", api_key="sk-fake-key", max_tokens=5000)
lm = dspy.LM(model="openai/gpt-4o-mini")

dspy.settings.configure(lm=lm)

In [5]:
%%capture
from concurrent.futures import ThreadPoolExecutor

input_keys = tuple([f"image_{i}" for i in range(1, 3)] + ["question", "options"])
subsets = ['Accounting', 'Agriculture', 'Architecture_and_Engineering', 'Art', 'Art_Theory', 'Basic_Medical_Science', 'Biology', 'Chemistry', 'Clinical_Medicine', 'Computer_Science', 'Design', 'Diagnostics_and_Laboratory_Medicine', 'Economics', 'Electronics', 'Energy_and_Power', 'Finance', 'Geography', 'History', 'Literature', 'Manage', 'Marketing', 'Materials', 'Math', 'Mechanical_Engineering', 'Music', 'Pharmacy', 'Physics', 'Psychology', 'Public_Health', 'Sociology']

devset = []
valset = []
with ThreadPoolExecutor(max_workers=len(subsets)) as executor:
    def load_dataset(subset_index_subset):
        subset_index, subset = subset_index_subset
        dataset = DataLoader().from_huggingface("MMMU/MMMU", subset, split=["dev", "validation"], input_keys=input_keys)
        return subset_index, dataset["dev"], dataset["validation"]
    
    results = list(executor.map(load_dataset, enumerate(subsets)))
    
    results.sort(key=lambda x: x[0])
    
    for _, dev, val in results:
        devset.extend(dev)
        valset.extend(val)

In [6]:
import ast

def count_images(dataset):
    image_counts = {i: 0 for i in range(6)}  # Initialize counts for 0 to 2 images
    for example in dataset:
        count = sum(1 for key in example.inputs().keys() if key.startswith('image_') and example.inputs()[key] is not None)
        image_counts[count] += 1
    return image_counts

def count_multiple_choice_questions(dataset):
    return sum(1 for example in dataset if example["question_type"] == "multiple-choice")
max_images = 5

num_images = 2

devset_filtered = [example for example in devset if sum(1 for key in example.inputs().keys() if key.startswith('image_') and example.inputs()[key] is not None) == num_images]
valset_filtered = [example for example in valset if sum(1 for key in example.inputs().keys() if key.startswith('image_') and example.inputs()[key] is not None) == num_images]

devset_image_counts = count_images(devset_filtered)
valset_image_counts = count_images(valset_filtered)

devset_multiple_choice_questions = count_multiple_choice_questions(devset_filtered)
valset_multiple_choice_questions = count_multiple_choice_questions(valset_filtered)

print("Image counts in devset:")
for count, num_examples in devset_image_counts.items():
    print(f"{count} image(s): {num_examples} examples")

print("\nImage counts in valset:")
for count, num_examples in valset_image_counts.items():
    print(f"{count} image(s): {num_examples} examples")

print("\nMultiple choice questions in devset:")
print(devset_multiple_choice_questions, "out of", len(devset_filtered))
print("\nMultiple choice questions in valset:")
print(valset_multiple_choice_questions, "out of", len(valset_filtered))

def convert_multiple_choice_to_letter(dataset):
    new_dataset = []
    for example in dataset:
        if example["question_type"] == "multiple-choice":
            # print(example["options"])
            options = ast.literal_eval(example["options"])
            example["answer_choices"] = str([chr(65 + i) + ". " + option for i, option in enumerate(options)])
        else:
            example["answer_choices"] = str(ast.literal_eval(example["options"]))

        updated_example = example.with_inputs(*example.inputs().keys(), "answer_choices")
        new_dataset.append(updated_example)
    return new_dataset

print(devset_filtered[0])
updated_devset = convert_multiple_choice_to_letter(devset_filtered)
print(updated_devset[0])
updated_valset = convert_multiple_choice_to_letter(valset_filtered)

# print(devset[0])


Image counts in devset:
0 image(s): 0 examples
1 image(s): 0 examples
2 image(s): 4 examples
3 image(s): 0 examples
4 image(s): 0 examples
5 image(s): 0 examples

Image counts in valset:
0 image(s): 0 examples
1 image(s): 0 examples
2 image(s): 43 examples
3 image(s): 0 examples
4 image(s): 0 examples
5 image(s): 0 examples

Multiple choice questions in devset:
4 out of 4

Multiple choice questions in valset:
42 out of 43
Example({'id': 'dev_Art_Theory_3', 'question': 'Church interiors from this time period typically were covered with <image 1> <image 2>', 'options': "['timber roofs', 'quadripartite vaults', 'pendentive domes', 'masonry barrel vaults']", 'explanation': '', 'image_1': <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=933x609 at 0x7E259C230790>, 'image_2': <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=933x737 at 0x7E259FF6C490>, 'image_3': None, 'image_4': None, 'image_5': None, 'image_6': None, 'image_7': None, 'img_type': "['Sculpture']", 'answer': 'A', 'to

In [7]:
# print(len(devset))
# print(devset)

class MMMUSignature(dspy.Signature):
    """Output a rationale and the answer to a multiple choice question about an image with the letter of the correct answer, if present, otherwise the exact answer."""

    question: str = dspy.InputField(desc="A question about the image(s)")
    image_1: dspy.Image = dspy.InputField(desc="An image relating to the shown problem")
    image_2: dspy.Image = dspy.InputField(desc="An image relating to the shown problem")
    answer_choices: List[str] = dspy.InputField(desc="The answer options for the question")
    answer: str = dspy.OutputField(desc="The single letter of the correct answer. Do not include the entire answer or a period at the end.")

class MMMUModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.predictor = dspy.ChainOfThought(MMMUSignature)

    def __call__(self, **kwargs):
        return self.predictor(**kwargs)


In [None]:
rs_optimizer_no_labeled = dspy.BootstrapFewShotWithRandomSearch(
    metric=answer_exact_match,
    num_threads=150,
    num_candidate_programs=6,
    max_labeled_demos=0,
    max_bootstrapped_demos=3,
    max_errors=10000,
)

sample_input = updated_devset[0]
# print(sample_input.inputs())
# print(encode_image(sample_input.inputs()["image_1"]))
mmmu = MMMUModule()
print(sample_input.inputs())
print(mmmu(**sample_input.inputs()))
print(sample_input.answer)

evaluate_mmmu = Evaluate(metric=answer_exact_match, num_threads=300, devset=updated_valset)
lm.history[-1]

In [28]:
lm.inspect_history()





System message:

Your input fields are:
1. `question` (str): A question about the image(s)
2. `image_1` (Image): An image relating to the shown problem
3. `image_2` (Image): An image relating to the shown problem
4. `answer_choices` (list[str]): The answer options for the question

Your output fields are:
1. `reasoning` (str)
2. `answer` (str): The single letter of the correct answer. Do not include the entire answer or a period at the end.

All interactions will be structured in the following way, with the appropriate values filled in.

[[ ## question ## ]]
{question}

[[ ## image_1 ## ]]
{image_1}

[[ ## image_2 ## ]]
{image_2}

[[ ## answer_choices ## ]]
{answer_choices}

[[ ## reasoning ## ]]
{reasoning}

[[ ## answer ## ]]
{answer}

[[ ## completed ## ]]


In adhering to this structure, your objective is: 
        Output a rationale and the answer to a multiple choice question about an image with the letter of the correct answer, if present, otherwise the exact answer.


User 

# Make sure that multiple images work

## No examples

In [29]:
import PIL
def set_image_to_black_square(example, key):
    example_copy = example.copy()
    example_copy[key] = PIL.Image.open("black_image_300x300.png")
    return example_copy.with_inputs(*example.inputs().keys())

print(updated_devset[0]["image_1"])
print(updated_devset[0]["image_2"])
examples_no_image_1 = list(map(lambda x: set_image_to_black_square(x, "image_1"), updated_valset))
print(examples_no_image_1[0]["image_1"] == PIL.Image.open("black_image_300x300.png"))
print(examples_no_image_1[0]["image_2"] == PIL.Image.open("black_image_300x300.png"))
examples_no_image_2 = list(map(lambda x: set_image_to_black_square(x, "image_2"), updated_valset))
print(examples_no_image_2[0]["image_1"] == PIL.Image.open("black_image_300x300.png"))
print(examples_no_image_2[0]["image_2"] == PIL.Image.open("black_image_300x300.png"))

examples_no_actual_image = list(map(lambda x: set_image_to_black_square(x, "image_1"), updated_valset))
examples_no_actual_image = list(map(lambda x: set_image_to_black_square(x, "image_2"), examples_no_actual_image))
print(examples_no_actual_image[0]["image_1"] == PIL.Image.open("black_image_300x300.png"))
print(examples_no_actual_image[0]["image_2"] == PIL.Image.open("black_image_300x300.png"))


<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=933x609 at 0x7E259C230790>
<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=933x737 at 0x7E259FF6C490>
True
False
False
True
True
True


In [30]:
mmmu = MMMUModule()
print(examples_no_image_1[0].inputs())
print(mmmu(**examples_no_image_1[0].inputs()))

print(examples_no_image_2[0].inputs())
print(mmmu(**examples_no_image_2[0].inputs()))


Example({'question': "<image 1> What group of pathogens, often mistaken for regrowth following glyphosate treatment, can cause a growth habit in blackberry plants that is near-identical to the 'little leaf' symptoms commonly witnessed post-glyphosate treatment?", 'options': '["I don\'t know and I don\'t want to guess", \'Nematodes\', \'Fungi\', \'Phytoplasmas\', \'Bacteria\']', 'image_1': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=300x300 at 0x7E25782FD290>, 'image_2': <PIL.PngImagePlugin.PngImageFile image mode=P size=300x232 at 0x7E258579CF50>, 'answer_choices': '["A. I don\'t know and I don\'t want to guess", \'B. Nematodes\', \'C. Fungi\', \'D. Phytoplasmas\', \'E. Bacteria\']'}) (input_keys={'image_1', 'answer_choices', 'image_2', 'options', 'question'})


Prediction(
    reasoning='The question asks about a group of pathogens that can cause symptoms in blackberry plants similar to those seen after glyphosate treatment. Among the options provided, phytoplasmas are known to cause growth abnormalities in plants, including symptoms that can be confused with glyphosate damage. Nematodes, fungi, and bacteria do not typically produce the same growth habit as described. Therefore, the most appropriate answer is D. Phytoplasmas.',
    answer='D'
)
Example({'question': "<image 1> What group of pathogens, often mistaken for regrowth following glyphosate treatment, can cause a growth habit in blackberry plants that is near-identical to the 'little leaf' symptoms commonly witnessed post-glyphosate treatment?", 'options': '["I don\'t know and I don\'t want to guess", \'Nematodes\', \'Fungi\', \'Phytoplasmas\', \'Bacteria\']', 'image_1': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=414x365 at 0x7E259E1FAA50>, 'image_2': <PIL.PngImagePlugin.Png

In [31]:
normal = evaluate_mmmu(mmmu, devset=updated_valset)
no_image_1 = evaluate_mmmu(mmmu, devset=examples_no_image_1)
no_image_2 = evaluate_mmmu(mmmu, devset=examples_no_image_2)
no_actual_image = evaluate_mmmu(mmmu, devset=examples_no_actual_image)
print("Testing on MMMU validation set (N=", len(updated_valset), ")")
print("Score with both images:", normal)
print("Score with image_1 set to black square:", no_image_1)
print("Score with image_2 set to black square:", no_image_2)
print("Score with both images set to black squares:", no_actual_image)

Testing on MMMU validation set (N= 43 )
Score with both images: 58.14
Score with image_1 set to black square: 37.21
Score with image_2 set to black square: 48.84
Score with both images set to black squares: 44.19


## TODO: Test with bootstrapped examples


# Make sure that JPGs work

## Convert images to JPGs

In [32]:
import io
from PIL import Image

def convert_to_jpg(example):
    example_copy = example.copy()
    for key in ['image_1', 'image_2']:
        if key in example_copy and isinstance(example_copy[key], Image.Image):
            # Convert to RGB mode (in case it's not already)
            img = example[key].convert('RGB')
            
            # Save as JPG in memory
            buffer = io.BytesIO()
            img.save(buffer, format='JPEG')
            buffer.seek(0)
            
            # Load the JPG back as a PIL Image
            example_copy[key] = Image.open(buffer)
    
    return example_copy.with_inputs(*example.inputs().keys())

# Convert all images in the dataset to JPG
examples_jpg = list(map(convert_to_jpg, updated_valset))

# Verify conversion
print("Original image format:", updated_valset[0]['image_1'].format)
print("Converted image format:", examples_jpg[0]['image_1'].format)


Original image format: PNG
Converted image format: JPEG


In [33]:
examples_jpg = list(map(convert_to_jpg, updated_valset))
examples_no_image_1_jpg = list(map(lambda x: convert_to_jpg(x), examples_no_image_1))
examples_no_image_2_jpg = list(map(lambda x: convert_to_jpg(x), examples_no_image_2))
examples_no_actual_image_jpg = list(map(lambda x: convert_to_jpg(x), examples_no_actual_image))

mmmu = MMMUModule()
print(examples_no_image_1_jpg[0].inputs())
print(mmmu(**examples_no_image_1_jpg[0].inputs()))
print(examples_no_image_1_jpg[0]["image_1"].format)

Example({'question': "<image 1> What group of pathogens, often mistaken for regrowth following glyphosate treatment, can cause a growth habit in blackberry plants that is near-identical to the 'little leaf' symptoms commonly witnessed post-glyphosate treatment?", 'options': '["I don\'t know and I don\'t want to guess", \'Nematodes\', \'Fungi\', \'Phytoplasmas\', \'Bacteria\']', 'image_1': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x300 at 0x7E257B5A0650>, 'image_2': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x232 at 0x7E25921CAA10>, 'answer_choices': '["A. I don\'t know and I don\'t want to guess", \'B. Nematodes\', \'C. Fungi\', \'D. Phytoplasmas\', \'E. Bacteria\']'}) (input_keys={'image_1', 'answer_choices', 'image_2', 'options', 'question'})
Prediction(
    reasoning="The question asks about a group of pathogens that can cause symptoms in blackberry plants similar to those seen after glyphosate treatment. Among the options provided, phytoplasmas are kn

In [34]:
normal = evaluate_mmmu(mmmu, devset=examples_jpg)
no_image_1 = evaluate_mmmu(mmmu, devset=examples_no_image_1_jpg)
no_image_2 = evaluate_mmmu(mmmu, devset=examples_no_image_2_jpg)
no_actual_image = evaluate_mmmu(mmmu, devset=examples_no_actual_image_jpg)
print("Testing on MMMU validation set (N=", len(updated_valset), ")")
print("Score with both images:", normal)
print("Score with image_1 set to black square:", no_image_1)
print("Score with image_2 set to black square:", no_image_2)
print("Score with both images set to black squares:", no_actual_image)

Testing on MMMU validation set (N= 43 )
Score with both images: 58.14
Score with image_1 set to black square: 44.19
Score with image_2 set to black square: 46.51
Score with both images set to black squares: 44.19


In [35]:
lm.inspect_history()





System message:

Your input fields are:
1. `question` (str): A question about the image(s)
2. `image_1` (Image): An image relating to the shown problem
3. `image_2` (Image): An image relating to the shown problem
4. `answer_choices` (list[str]): The answer options for the question

Your output fields are:
1. `reasoning` (str)
2. `answer` (str): The single letter of the correct answer. Do not include the entire answer or a period at the end.

All interactions will be structured in the following way, with the appropriate values filled in.

[[ ## question ## ]]
{question}

[[ ## image_1 ## ]]
{image_1}

[[ ## image_2 ## ]]
{image_2}

[[ ## answer_choices ## ]]
{answer_choices}

[[ ## reasoning ## ]]
{reasoning}

[[ ## answer ## ]]
{answer}

[[ ## completed ## ]]


In adhering to this structure, your objective is: 
        Output a rationale and the answer to a multiple choice question about an image with the letter of the correct answer, if present, otherwise the exact answer.


User 

# Testing that URLs work

In [36]:

colors = {
    "White": "FFFFFF",
    "Red": "FF0000",
    "Green": "00FF00",
    "Blue": "0000FF",
    "Yellow": "FFFF00",
    "Cyan": "00FFFF",
    "Magenta": "FF00FF",
    "Gray": "808080",
    "Orange": "FFA500",
    "Purple": "800080"
}
def get_color_image_url(color, file_extension="png"):
    return f"https://placehold.co/300/{colors[color]}/{colors[color]}.{file_extension}"


In [37]:
import random

def generate_random_2_color_image_examples(n):
    examples = []
    for _ in range(n):
        color_1, color_2 = random.sample(list(colors.keys()), 2)
        chosen_color = color_1 if random.random() < 0.5 else color_2
        chosen_image = "image_1" if chosen_color == color_1 else "image_2"
        example_kwargs = {
            "image_1": get_color_image_url(color_1),
            "image_2": get_color_image_url(color_2),
            "question": f"What color is {chosen_image}?",
            "answer": chosen_color
        }
        examples.append(dspy.Example(**example_kwargs).with_inputs("image_1", "image_2", "question"))
    return examples

examples = generate_random_2_color_image_examples(100)
print(examples[0])


Example({'image_1': 'https://placehold.co/300/FFFF00/FFFF00.png', 'image_2': 'https://placehold.co/300/0000FF/0000FF.png', 'question': 'What color is image_2?', 'answer': 'Blue'}) (input_keys={'image_1', 'image_2', 'question'})


In [41]:
class ColorSignature(dspy.Signature):
    """Output the color of the designated image."""
    image_1: dspy.Image = dspy.InputField(desc="An image")
    image_2: dspy.Image = dspy.InputField(desc="An image")
    question: str = dspy.InputField(desc="A question about the image")
    answer: str = dspy.OutputField(desc="The color of the designated image")
color_program = dspy.Predict(ColorSignature)


In [55]:
print(examples[0])
print(color_program(**examples[0].inputs()))

Example({'image_1': 'https://placehold.co/300/FFFF00/FFFF00.png', 'image_2': 'https://placehold.co/300/0000FF/0000FF.png', 'question': 'What color is image_2?', 'answer': 'Blue'}) (input_keys={'image_1', 'image_2', 'question'})
Prediction(
    reasoning='The color of image_2 is a solid blue shade.',
    answer='Blue'
)


In [53]:
few_shot_optimizer = dspy.BootstrapFewShot(metric=answer_exact_match, max_bootstrapped_demos=3, max_labeled_demos=10)
smaller_few_shot_optimizer = dspy.BootstrapFewShot(metric=answer_exact_match, max_bootstrapped_demos=1, max_labeled_demos=1)
dataset = generate_random_2_color_image_examples(1000)
trainset = dataset[:200]
validationset = dataset[200:400]
evaluate_colors = Evaluate(metric=answer_exact_match, num_threads=300, devset=validationset)

In [54]:
compiled_color_program = few_shot_optimizer.compile(color_program, trainset=trainset)
compiled_smaller_color_program = smaller_few_shot_optimizer.compile(color_program, trainset=trainset)
print(evaluate_colors(color_program))
print(evaluate_colors(compiled_color_program))
print(evaluate_colors(compiled_smaller_color_program))

  0%|          | 0/200 [00:00<?, ?it/s]

  2%|▏         | 3/200 [00:18<20:27,  6.23s/it]


Bootstrapped 3 full traces after 4 examples in round 0.


  0%|          | 1/200 [00:02<08:09,  2.46s/it]


Bootstrapped 1 full traces after 2 examples in round 0.
99.0
100.0
96.5


In [57]:
print(compiled_color_program(**validationset[0].inputs()))
lm.inspect_history()

Prediction(
    reasoning='Not supplied for this particular example.',
    answer='White'
)




System message:

Your input fields are:
1. `image_1` (Image): An image
2. `image_2` (Image): An image
3. `question` (str): A question about the image

Your output fields are:
1. `reasoning` (str)
2. `answer` (str): The color of the designated image

All interactions will be structured in the following way, with the appropriate values filled in.

[[ ## image_1 ## ]]
{image_1}

[[ ## image_2 ## ]]
{image_2}

[[ ## question ## ]]
{question}

[[ ## reasoning ## ]]
{reasoning}

[[ ## answer ## ]]
{answer}

[[ ## completed ## ]]


In adhering to this structure, your objective is: 
        Output the color of the designated image.


User message:

This is an example of the task, though some input or output fields are not supplied.
[[ ## image_1 ## ]]
<image_url: https://placehold.co/300/00FF00/00FF00.png>

[[ ## image_2 ## ]]
<image_url: https://placehold.co/300/800080/800080.png>

[[ ## question #

# TODO(Isaac): Delete; Archive of old experiments

In [None]:
dataset = DataLoader().from_huggingface("Alanox/stanford-dogs", split="full", input_keys=("image",), trust_remote_code=True)

In [69]:
# rename the field from "image" to "image_1"
def rename_field(example, old_name, new_name):
    try:
        example[new_name] = example[old_name]
        del example[old_name]
    except Exception:
        pass
    return example
    
dog_dataset = list(map(rename_field, dataset, ["image"]*len(dataset), ["image_1"]*len(dataset)))
dog_dataset2 = list(map(rename_field, dog_dataset, ["target"]*len(dog_dataset), ["answer"]*len(dog_dataset)))
dog_dataset3 = list(map(lambda x: x.with_inputs("image_1"), dog_dataset2))
dog_dataset = dog_dataset3
random.shuffle(dog_dataset)

In [48]:
class DogPictureSignature(dspy.Signature):
    """Output the dog breed of the dog in the image."""
    image_1: dspy.Image = dspy.InputField(desc="An image of a dog")
    answer: str = dspy.OutputField(desc="The dog breed of the dog in the image")

class DogPicture(dspy.Module):
    def __init__(self) -> None:
        self.predictor = dspy.ChainOfThought(DogPictureSignature)
    
    def __call__(self, **kwargs):
        return self.predictor(**kwargs)

dog_picture = DogPicture()
print(dog_picture(**dog_dataset[0].inputs()))

Prediction(
    reasoning='The dog in the image has a curly, white coat and a distinctive blue collar, which are characteristic features of the Bedlington Terrier breed.',
    answer='Bedlington Terrier'
)


In [70]:
evaluate = Evaluate(metric=answer_exact_match, num_threads=100, devset= dog_dataset[-500:], display_progress=True, max_errors=10000)
