In [1]:
%load_ext autoreload
%autoreload 2

# %set_env CUDA_VISIBLE_DEVICES=7
# import sys; sys.path.append('/future/u/okhattab/repos/public/stanfordnlp/dspy')

import dspy
from dspy.evaluate import Evaluate
from dspy.datasets.hotpotqa import HotPotQA
from dspy.teleprompt import BootstrapFewShotWithRandomSearch, BootstrapFinetune

  from .autonotebook import tqdm as notebook_tqdm


### 1) Configure the default LM and retriever

In [4]:
x = dspy.HFModel(model="meta-llama/Llama-2-13b-chat-hf")

Loading checkpoint shards: 100%|██████████| 3/3 [00:58<00:00, 19.38s/it]


In [6]:
x("hello", temperature=0.1)

: 

In [2]:
ports = [7140, 7141, 7142, 7143, 7144, 7145]
llamaChat = dspy.HFClientTGI(model="meta-llama/Llama-2-13b-chat-hf", port=ports, max_tokens=150)
colbertv2 = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')

dspy.settings.configure(rm=colbertv2, lm=llamaChat)

### 2) Load a small sample of HotPotQA data

In [3]:
dataset = HotPotQA(train_seed=1, train_size=200, eval_seed=2023, dev_size=1000, test_size=0)
trainset = [x.with_inputs('question') for x in dataset.train]
devset = [x.with_inputs('question') for x in dataset.dev]
testset = [x.with_inputs('question') for x in dataset.test]

len(trainset), len(devset), len(testset)

Downloading data: 100%|██████████| 301M/301M [00:39<00:00, 7.65MB/s] 
Downloading data: 100%|██████████| 31.1M/31.1M [00:03<00:00, 8.22MB/s]
Downloading data: 100%|██████████| 28.1M/28.1M [00:03<00:00, 7.49MB/s]
Downloading data: 100%|██████████| 27.6M/27.6M [00:03<00:00, 7.26MB/s]
Generating train split: 100%|██████████| 90447/90447 [00:07<00:00, 12151.43 examples/s]
Generating validation split: 100%|██████████| 7405/7405 [00:00<00:00, 11511.88 examples/s]
Generating test split: 100%|██████████| 7405/7405 [00:00<00:00, 18930.17 examples/s]


(200, 1000, 0)

In [4]:
trainset[0]

Example({'question': 'At My Window was released by which American singer-songwriter?', 'answer': 'John Townes Van Zandt'}) (input_keys={'question'})

### 3) Define a simple multi-hop program

In [5]:
from dsp.utils.utils import deduplicate

class BasicMH(dspy.Module):
    def __init__(self, passages_per_hop=3):
        super().__init__()

        self.retrieve = dspy.Retrieve(k=passages_per_hop)
        self.generate_query = [dspy.ChainOfThought("context, question -> search_query") for _ in range(2)]
        self.generate_answer = dspy.ChainOfThought("context, question -> answer")
    
    def forward(self, question):
        context = []
        
        for hop in range(2):
            search_query = self.generate_query[hop](context=context, question=question).search_query
            passages = self.retrieve(search_query).passages
            context = deduplicate(context + passages)

        return self.generate_answer(context=context, question=question).copy(context=context)

### 4) Compile the program with `Llama2-13b-chat`

In [6]:
RECOMPILE_INTO_LLAMA_FROM_SCRATCH = False
NUM_THREADS = 24

metric_EM = dspy.evaluate.answer_exact_match

In [7]:
if RECOMPILE_INTO_LLAMA_FROM_SCRATCH:
    tp = BootstrapFewShotWithRandomSearch(metric=metric_EM, max_bootstrapped_demos=2, num_threads=NUM_THREADS)
    basicmh_bs = tp.compile(BasicMH(), trainset=trainset[:50], valset=trainset[50:200])

    ensemble = [prog for *_, prog in basicmh_bs.candidate_programs[:4]]

    for idx, prog in enumerate(ensemble):
        # prog.save(f'multihop_llama213b_{idx}.json')
        pass

In [8]:
if not RECOMPILE_INTO_LLAMA_FROM_SCRATCH:
    ensemble = []

    for idx in range(4):
        prog = BasicMH()
        prog.load(f'multihop_llama213b_{idx}.json')
        ensemble.append(prog)

In [9]:
llama_program = ensemble[0]

evaluate_hotpot = Evaluate(devset=devset[:1000], metric=metric_EM, num_threads=NUM_THREADS, display_progress=True, display_table=0)
evaluate_hotpot(llama_program)

  return self._cached_call(args, kwargs)[0]


Error for example in dev set: 		 HTTPConnectionPool(host='future-hgx-1', port=7144): Max retries exceeded with url: /generate (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0xfffeed5d5580>: Failed to resolve 'future-hgx-1' ([Errno -2] Name or service not known)"))
Error for example in dev set: 		 HTTPConnectionPool(host='future-hgx-1', port=7141): Max retries exceeded with url: /generate (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0xfffeec5a99d0>: Failed to resolve 'future-hgx-1' ([Errno -2] Name or service not known)"))
Error for example in dev set: 		 HTTPConnectionPool(host='future-hgx-1', port=7144): Max retries exceeded with url: /generate (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0xfffeec4c5490>: Failed to resolve 'future-hgx-1' ([Errno -2] Name or service not known)"))
Error for example in dev set: 		 HTTPConnectionPool(host='future-hgx-1', port=7141): Max retries exceeded with url:

  0%|          | 0/1000 [00:00<?, ?it/s]

ConnectionError: HTTPConnectionPool(host='future-hgx-1', port=7144): Max retries exceeded with url: /generate (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0xfffed0511e50>: Failed to resolve 'future-hgx-1' ([Errno -2] Name or service not known)"))

In [10]:
llama_program(question="How many storeys are in the castle that David Gregory inherited?")

llamaChat.inspect_history(n=3)

ConnectionError: HTTPConnectionPool(host='future-hgx-1', port=7143): Max retries exceeded with url: /generate (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0xfffeeeb70520>: Failed to resolve 'future-hgx-1' ([Errno -2] Name or service not known)"))

### 6) Compile into `T5-Large` (770M parameters)

In [11]:
unlabeled_train = HotPotQA(train_seed=1, train_size=3000, eval_seed=2023, dev_size=0, test_size=0).train
unlabeled_train = [dspy.Example(question=x.question).with_inputs('question') for x in unlabeled_train]
len(unlabeled_train)

  0%|          | 0/1000 [00:40<?, ?it/s]


3000

Optional step: pre-compute the ensemble on the unlabeled training set

In [12]:
always_true = lambda g, p, trace=None: True

for prog_ in ensemble:
    evaluate_hotpot(prog_, devset=unlabeled_train[:3000], metric=always_true)

  0%|          | 0/3000 [00:00<?, ?it/s]

ConnectionError: HTTPConnectionPool(host='future-hgx-1', port=7143): Max retries exceeded with url: /generate (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0xfffec48e5a00>: Failed to resolve 'future-hgx-1' ([Errno -2] Name or service not known)"))

Now compile into T5!

In [13]:
RECOMPILE_INTO_T5_FROM_SCRATCH = False

if RECOMPILE_INTO_T5_FROM_SCRATCH:
    config = dict(target='t5-large', epochs=2, bf16=True, bsize=6, accumsteps=2, lr=5e-5)

    tp = BootstrapFinetune(metric=None)
    t5_program = tp.compile(BasicMH(), teacher=ensemble, trainset=unlabeled_train[:3000], **config)

    # Deactivate chain of thought prompting. Let's use T5 to directly predict outputs. (Faster and similar quality.)
    for p in t5_program.predictors(): p.activated = False

In [14]:
if not RECOMPILE_INTO_T5_FROM_SCRATCH:
    t5_program = BasicMH()

    # ckpt_path = '../finetuning_ckpts/LMWEP0WZ5IKWM.all/checkpoint-5400'
    ckpt_path = "colbert-ir/dspy-Oct11-T5-Large-MH-3k-v1"
    LM = dspy.HFModel(checkpoint=ckpt_path, model='t5-large')

    for p in t5_program.predictors():
        p.lm = LM
        p.activated = False

config.json: 100%|██████████| 1.21k/1.21k [00:00<00:00, 73.4kB/s]
tokenizer_config.json: 100%|██████████| 20.7k/20.7k [00:00<00:00, 846kB/s]
tokenizer.json: 100%|██████████| 2.42M/2.42M [00:00<00:00, 3.67MB/s]
special_tokens_map.json: 100%|██████████| 2.20k/2.20k [00:00<00:00, 5.72MB/s]
config.json: 100%|██████████| 1.51k/1.51k [00:00<00:00, 135kB/s]
pytorch_model.bin: 100%|██████████| 2.95G/2.95G [04:11<00:00, 11.7MB/s]
generation_config.json: 100%|██████████| 112/112 [00:00<00:00, 10.2kB/s]


### 7) Evaluate the T5-Large `multihop` program

In [15]:
score = evaluate_hotpot(t5_program, num_threads=1)

Token indices sequence length is longer than the specified maximum sequence length for this model (548 > 512). Running this sequence through the model will result in indexing errors


ConnectionError: HTTPConnectionPool(host='20.102.90.50', port=2017): Max retries exceeded with url: /wiki17_abstracts?query=The+Gaslight+Anthem+Seaweed+genre&k=3 (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0xfffda5c9b280>: Failed to establish a new connection: [Errno 111] Connection refused'))

In [None]:
t5_program.predictors()[0].lm.inspect_history(n=3)