In [1]:
%load_ext autoreload
%autoreload 2

from typing import Literal, Union
from pathlib import Path
import csv
from datetime import datetime
import random


import pandas as pd
import dspy
from dspy.evaluate import Evaluate
from dspy.teleprompt import MIPROv2
from sklearn.model_selection import train_test_split

from programs import WrapperEnglishSPT, WrapperSpanishSPT, evaluate_answer
from custom_evaluation import custom_evaluate

In [2]:
lm = dspy.LM(
    "ollama_chat/deepseek-r1:14b",
    api_base="http://localhost:11434",
)
dspy.settings.configure(lm=lm)

In [3]:
lm("What is your name")

["<think>\n\n</think>\n\nGreetings! I'm DeepSeek-R1, an artificial intelligence assistant created by DeepSeek. I'm at your service and would be delighted to assist you with any inquiries or tasks you may have."]

In [4]:
data = pd.read_csv("dev_dwug_es.csv")
display(data.shape)

(8704, 8)

In [5]:
training_set = []

for _, row in data.iterrows():
    training_set.append(
        dspy.Example(
            sentence1=row["context_x"],
            sentence2=row["context_y"],
            target_word=row["lemma"],
            answer=int(row["judgment"]),
        ).with_inputs("sentence1", "sentence2", "target_word")
    )

In [6]:
classes_1_es = [item for item in training_set if item.answer == 1]
classes_2_es = [item for item in training_set if item.answer == 2]
classes_3_es = [item for item in training_set if item.answer == 3]
classes_4_es = [item for item in training_set if item.answer == 4]

print(len(classes_1_es))
print(len(classes_2_es))
print(len(classes_3_es))
print(len(classes_4_es))

classes_1_train, classes_1_dev = train_test_split(
    classes_1_es,
    test_size=0.2,
    random_state=42,
)

classes_1_train, classes_1_test = train_test_split(
    classes_1_train, test_size=0.2, random_state=42
)


classes_2_train, classes_2_dev = train_test_split(
    classes_2_es,
    test_size=0.2,
    random_state=42,
)
classes_2_train, classes_2_test = train_test_split(
    classes_2_train, test_size=0.2, random_state=42
)


classes_3_train, classes_3_dev = train_test_split(
    classes_3_es,
    test_size=0.2,
    random_state=42,
)
classes_3_train, classes_3_test = train_test_split(
    classes_3_train, test_size=0.2, random_state=42
)


classes_4_train, classes_4_dev = train_test_split(
    classes_4_es,
    test_size=0.2,
    random_state=42,
)
classes_4_train, classes_4_test = train_test_split(
    classes_4_train,
    test_size=0.2,
    random_state=42,
)

print(len(classes_1_train), len(classes_1_dev), len(classes_1_test))
print(len(classes_2_train), len(classes_2_dev), len(classes_2_test))
print(len(classes_3_train), len(classes_3_dev), len(classes_3_test))
print(len(classes_4_train), len(classes_4_dev), len(classes_4_test))

1406
1522
2343
3433
899 282 225
973 305 244
1499 469 375
2196 687 550


In [7]:
program_spt_prompt_es_assertions = WrapperSpanishSPT().activate_assertions()

In [None]:
# custom_evaluate(
#     random.choices(classes_1_test, k=225)
#     + random.choices(classes_2_test, k=225)
#     + random.choices(classes_3_test, k=225)
#     + random.choices(classes_4_test, k=225),
#     evaluate_answer,
#     program_spt_prompt_es_assertions,
#     debug=True,
# )

Evaluating: 900 examples
Prediction:  1
Prediction:  2
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  3
Prediction:  1
Prediction:  3
Prediction:  3
Prediction:  1
Prediction:  1
Prediction:  2
Prediction:  2
Prediction:  2
Prediction:  2
Prediction:  2
Prediction:  1
Prediction:  1
Prediction:  3
Prediction:  1
Prediction:  3
Prediction:  1
Prediction:  1
Prediction:  4
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  2
Prediction:  1
Prediction:  4
Prediction:  1
Prediction:  2
Prediction:  3
Prediction:  1
Prediction:  3


2025/04/17 11:22:23 INFO dspy.primitives.assertions: SuggestionFailed: La salida deberia ser 1 o 2 o 3 o 4. Por favor, revise en consecuencia.


Prediction:  2
Prediction:  4
Prediction:  1
Prediction:  1
Prediction:  4
Prediction:  2
Prediction:  1
Prediction:  4
Prediction:  1
Prediction:  1
Prediction:  4
Prediction:  2
Prediction:  1
Prediction:  3
Prediction:  4
Prediction:  2
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  2
Prediction:  1
Prediction:  2
Prediction:  1
Prediction:  2
Prediction:  1
Prediction:  2
Prediction:  2
Prediction:  3
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  2
Prediction:  1
Prediction:  4
Prediction:  1
Prediction:  3
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  3
Prediction:  2
Prediction:  1
Prediction:  3
Prediction:  2
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  2
Prediction:  3
Prediction:  3
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  3
Prediction:  3
Prediction:  2
Prediction:  2
Prediction

2025/04/17 11:40:01 INFO dspy.primitives.assertions: SuggestionFailed: La salida deberia ser 1 o 2 o 3 o 4. Por favor, revise en consecuencia.
2025/04/17 11:40:09 INFO dspy.primitives.assertions: SuggestionFailed: La salida deberia ser 1 o 2 o 3 o 4. Por favor, revise en consecuencia.
2025/04/17 11:40:09 INFO dspy.primitives.assertions: SuggestionFailed: La salida deberia ser 1 o 2 o 3 o 4. Por favor, revise en consecuencia.


Prediction:  0
Prediction:  2
Prediction:  1
Prediction:  3
Prediction:  4
Prediction:  1
Prediction:  1
Prediction:  2
Prediction:  4
Prediction:  2
Prediction:  1
Prediction:  2
Prediction:  4
Prediction:  2
Prediction:  4
Prediction:  3
Prediction:  1
Prediction:  3
Prediction:  3
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  2
Prediction:  1
Prediction:  4
Prediction:  3
Prediction:  1
Prediction:  3
Prediction:  2
Prediction:  2
Prediction:  3
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  1


2025/04/17 11:44:47 INFO dspy.primitives.assertions: SuggestionFailed: La salida deberia ser 1 o 2 o 3 o 4. Por favor, revise en consecuencia.


Prediction:  4
Prediction:  1
Prediction:  2
Prediction:  1
Prediction:  3
Prediction:  2
Prediction:  2
Prediction:  1
Prediction:  3
Prediction:  1
Prediction:  3
Prediction:  1
Prediction:  3
Prediction:  1
Prediction:  2
Prediction:  3
Prediction:  1
Prediction:  1
Prediction:  3
Prediction:  2
Prediction:  3
Prediction:  4
Prediction:  2
Prediction:  1
Prediction:  2
Prediction:  1
Prediction:  1
Prediction:  4
Prediction:  1
Prediction:  1
Prediction:  1
Prediction:  2
Prediction:  3
Prediction:  2
Prediction:  4
Prediction:  2
Prediction:  3
Prediction:  1
Prediction:  1
Prediction:  3
Prediction:  3
Prediction:  4
Prediction:  2
Prediction:  1
Prediction:  1
Prediction:  4
Prediction:  3
Prediction:  2
Prediction:  2
Prediction:  3
Prediction:  2
Prediction:  3
Prediction:  4
Prediction:  1
Prediction:  4
Prediction:  2
Prediction:  1
Prediction:  3
Prediction:  1
Prediction:  1
Prediction:  4
Prediction:  1
Prediction:  2
Prediction:  3
Prediction:  4
Prediction:  1
Prediction

2025/04/17 12:31:59 INFO dspy.primitives.assertions: SuggestionFailed: La salida deberia ser 1 o 2 o 3 o 4. Por favor, revise en consecuencia.


Prediction:  3
Prediction:  4
Prediction:  3
Prediction:  3
Prediction:  4
Prediction:  4
Prediction:  4
Prediction:  2
Prediction:  4
Prediction:  4
Prediction:  4
Prediction:  1
Prediction:  3
Prediction:  3
Prediction:  1
Prediction:  4
Prediction:  3
Prediction:  3
Prediction:  4
Prediction:  3
Prediction:  4
Prediction:  3
Prediction:  3
Prediction:  4
Prediction:  4
Prediction:  4
Prediction:  4
Prediction:  4
Prediction:  4
Prediction:  3
Prediction:  3
Prediction:  3
Prediction:  4
Prediction:  4
Prediction:  2
Prediction:  3
Prediction:  1
Prediction:  4
Prediction:  3
Prediction:  2
Prediction:  4
Prediction:  4
Prediction:  3
Prediction:  4
Prediction:  3
Prediction:  4
Prediction:  1
Prediction:  4
Prediction:  3
Prediction:  2
Prediction:  4
Prediction:  4
Prediction:  1
Prediction:  4
Prediction:  3
Prediction:  4
Prediction:  3
Prediction:  3
Prediction:  1
Prediction:  4
Prediction:  1
Prediction:  4
Prediction:  4
Prediction:  3
Prediction:  4
Prediction:  3
Prediction

In [9]:
%reload_ext autoreload

start_time = datetime.now()

teleprompter = MIPROv2(
    metric=evaluate_answer,
    task_model=lm,
    num_candidates=10,
    init_temperature=0.7,
    max_bootstrapped_demos=3,
    max_labeled_demos=4,
    verbose=False,
)

print("Optimizing program with MIPRO...")
optimized_program = teleprompter.compile(
    program_spt_prompt_es_assertions.deepcopy(),
    trainset=random.choices(classes_1_train, k=500)
    + random.choices(classes_2_train, k=500)
    + random.choices(classes_3_train, k=500)
    + random.choices(classes_4_train, k=500),
    valset=random.choices(classes_1_dev, k=200)
    + random.choices(classes_2_dev, k=200)
    + random.choices(classes_3_dev, k=200)
    + random.choices(classes_4_dev, k=200),
    num_trials=15,
    minibatch_size=25,
    minibatch_full_eval_steps=10,
    minibatch=True,
    requires_permission_to_run=False,
)

optimized_program.save(f"compile-models/sp/es_spt_mipro_optimized_prompt_es_deepseek-q4")

print(f"Elapsed time: {datetime.now() - start_time}")

2025/04/17 12:59:12 INFO dspy.teleprompt.mipro_optimizer_v2: 
==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==
2025/04/17 12:59:12 INFO dspy.teleprompt.mipro_optimizer_v2: These will be used as few-shot example candidates for our program and for creating instructions.

2025/04/17 12:59:12 INFO dspy.teleprompt.mipro_optimizer_v2: Bootstrapping N=10 sets of demonstrations...


Optimizing program with MIPRO...
Bootstrapping set 1/10
Bootstrapping set 2/10
Bootstrapping set 3/10


  0%|          | 7/2000 [01:32<7:20:55, 13.27s/it]


Bootstrapped 3 full traces after 7 examples for up to 1 rounds, amounting to 7 attempts.
Bootstrapping set 4/10


  0%|          | 1/2000 [00:14<7:47:05, 14.02s/it]


Bootstrapped 1 full traces after 1 examples for up to 1 rounds, amounting to 1 attempts.
Bootstrapping set 5/10


  0%|          | 5/2000 [00:54<6:02:47, 10.91s/it]


Bootstrapped 1 full traces after 5 examples for up to 1 rounds, amounting to 5 attempts.
Bootstrapping set 6/10


  0%|          | 9/2000 [01:40<6:09:04, 11.12s/it]


Bootstrapped 2 full traces after 9 examples for up to 1 rounds, amounting to 9 attempts.
Bootstrapping set 7/10


  0%|          | 1/2000 [00:16<9:24:23, 16.94s/it]


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt


trial_logs = optimized_program.trial_logs

trial_numbers = list(trial_logs.keys())
scores = [trial_logs[trial]["score"] for trial in trial_numbers]

full_eval = [trial_logs[trial]["full_eval"] for trial in trial_numbers]

for trial_number, score, pruned in zip(trial_numbers, scores, full_eval):
    if pruned is False:
        plt.scatter(
            trial_number,
            score,
            color="grey",
            label=(
                "Pruned Batch"
                if "Pruned Batch" not in plt.gca().get_legend_handles_labels()[1]
                else ""
            ),
        )
    else:
        plt.scatter(
            trial_number,
            score,
            color="green",
            label=(
                "Successful Batch"
                if "Successful Batch" not in plt.gca().get_legend_handles_labels()[1]
                else ""
            ),
        )

plt.xlabel("Batch Number")
plt.ylabel("Score")
plt.title("Batch Scores")
plt.grid(True)
plt.legend()
plt.show()

In [None]:
best_score = 0
best_program_so_far = None


def get_signature(predictor):
    if hasattr(predictor, "extended_signature"):
        return predictor.extended_signature
    elif hasattr(predictor, "signature"):
        return predictor.signature


# print(f"Baseline program | Score: {best_score}:")
# for i, predictor in enumerate(WrapperEnglishSPT().predictors()):
#     print(f"Prompt {i+1} Instruction: {get_signature(predictor).instructions}")
# print()

print("----------------")

for trial_num in optimized_program.trial_logs:
    program_score = optimized_program.trial_logs[trial_num]["score"]
    program_pruned = optimized_program.trial_logs[trial_num]["full_eval"]
    # if (
    #     program_score > best_score
    #     and program_pruned is True
    #     # and optimized_program.trial_logs[trial_num]["full_eval"]
    # ):
    if program_pruned is True:
        best_score = program_score
        best_program_so_far = optimized_program.trial_logs[trial_num]["program"]
    # if trial_num % 5 == 0:
    #     print(f"Best program after {trial_num} batches | Score: {best_score}:")
    #     for i, predictor in enumerate(best_program_so_far.predictors()):
    #         print(f"Prompt {i+1} Instruction: {get_signature(predictor).instructions}")
    #     print()
    
        # print(f"Best program with best score: {best_score}")
        for i, predictor in enumerate(best_program_so_far.predictors()):
            print(f"Prompt {trial_num} Instruction: {get_signature(predictor).instructions}")
            print(best_score)
        print()

In [None]:
program_spt_prompt_es_assertions.load(
    "compile-models/sp/es_spt_mipro_optimized_prompt_es_deepseek-q4"
)

In [None]:

custom_evaluate(
    random.choices(classes_1_test, k=225)
    + random.choices(classes_2_test, k=225)
    + random.choices(classes_3_test, k=225)
    + random.choices(classes_4_test, k=225),
    evaluate_answer,
    program_spt_prompt_es_assertions,
    debug=True,
)