Skip to content

Commit

Permalink
feat: Updated dspy/teleprompt/random_search.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-ai[bot] committed Dec 20, 2023
1 parent 6a14181 commit 930f948
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions dspy/teleprompt/random_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, metric, teacher_settings={}, max_bootstrapped_demos=4, max_la
# print("Going to sample", self.max_num_traces, "traces in total.")
print("Will attempt to train", self.num_candidate_sets, "candidate sets.")

def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None):
async def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None):
self.trainset = trainset
self.valset = valset or trainset # TODO: FIXME: Note this choice.

Expand All @@ -70,14 +70,14 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None
elif seed == -2:
# labels only
teleprompter = LabeledFewShot(k=self.max_labeled_demos)
program2 = teleprompter.compile(student, trainset=trainset2)
program2 = await teleprompter.compile(student, trainset=trainset2)

elif seed == -1:
# unshuffled few-shot
program = BootstrapFewShot(metric=self.metric, max_bootstrapped_demos=self.max_num_samples,
max_labeled_demos=self.max_labeled_demos,
teacher_settings=self.teacher_settings, max_rounds=self.max_rounds)
program2 = program.compile(student, teacher=teacher, trainset=trainset2)
program2 = await program.compile(student, teacher=teacher, trainset=trainset2)

else:
assert seed >= 0, seed
Expand All @@ -95,7 +95,7 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None
evaluate = Evaluate(devset=self.valset, metric=self.metric, num_threads=self.num_threads,
display_table=False, display_progress=True)

score, subscores = evaluate(program2, return_all_scores=True)
score, subscores = await evaluate(program2, return_all_scores=True)

all_subscores.append(subscores)

Expand Down

0 comments on commit 930f948

Please sign in to comment.