In [7]:
import dspy
from dspy.evaluate import Evaluate
from dspy.datasets.hotpotqa import HotPotQA
from dspy.teleprompt import BootstrapFewShotWithRandomSearch
from compiler.utils import load_api_key

load_api_key('/mnt/ssd4/lm_compiler/secrets.toml')

gpt4o_mini = dspy.OpenAI('gpt-4o-mini', max_tokens=1000)
colbert = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')
dspy.configure(lm=gpt4o_mini, rm=colbert)

In [2]:
dataset = HotPotQA(train_seed=1, train_size=200, eval_seed=2023, dev_size=300, test_size=0)
trainset = [x.with_inputs('question') for x in dataset.train[0:150]]
valset = [x.with_inputs('question') for x in dataset.train[150:200]]
devset = [x.with_inputs('question') for x in dataset.dev]

# show an example datapoint; it's just a question-answer pair
trainset[0]

Example({'question': 'At My Window was released by which American singer-songwriter?', 'answer': 'John Townes Van Zandt'}) (input_keys={'question'})

In [3]:
from dsp.utils.utils import deduplicate

class BasicMH(dspy.Module):
    def __init__(self, passages_per_hop=3):
        super().__init__()

        self.retrieve = dspy.Retrieve(k=passages_per_hop)
        self.generate_query = [dspy.ChainOfThought("context, question -> search_query") for _ in range(2)]
        self.generate_answer = dspy.ChainOfThought("context, question -> answer")

    def forward(self, question):
        context = []

        for hop in range(2):
            search_query = self.generate_query[hop](context=context, question=question).search_query
            passages = self.retrieve(search_query).passages
            context = deduplicate(context + passages)

        return self.generate_answer(context=context, question=question).copy(context=context)

In [8]:
agent = BasicMH(passages_per_hop=2)

In [9]:
# Set up an evaluator on the first 300 examples of the devset.
config = dict(num_threads=8, display_progress=True, display_table=5)
evaluate = Evaluate(devset=devset, metric=dspy.evaluate.answer_exact_match, **config)

evaluate(agent)

  0%|          | 0/300 [00:00<?, ?it/s]ERROR:dspy.evaluate.evaluate:[2m2024-09-25T03:56:12.603898Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 HTTPConnectionPool(host='20.102.90.50', port=2017): Max retries exceeded with url: /wiki17_abstracts?query=operator+of+Newark+Airport&k=2 (Caused by ConnectTimeoutError(<urllib3.connection.HTTPConnection object at 0x7f9138400d10>, 'Connection to 20.102.90.50 timed out. (connect timeout=10)')). Set `provide_traceback=True` to see the stack trace.[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m203[0m
Average Metric: 0.0 / 1  (0.0):   0%|          | 1/300 [00:10<53:36, 10.76s/it]ERROR:dspy.evaluate.evaluate:[2m2024-09-25T03:56:12.723661Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 HTTPConnectionPool(host='20.102.90.50', port=2017): Max retries exceeded with url: /wiki17_abstracts?query=river+near+Crichton+Collegiate+Church&k=2 (Caused by Co