In [1]:
import json
import requests
import tqdm
from collections import Counter
from petals import DistributedBloomForCausalLM
from transformers import BloomTokenizerFast
from data_processing import parse_example_file, get_ans, generate_prompt, generate_question

Переведем данные из датасета GSM8K в формат из статьи (question/thought/answer). Еще я перевел в этот формат prompt из Appendix-а, с которым был получен лучший результат на этом датасете в первой статье про CoT:

In [2]:
train_data = parse_example_file("data/train.jsonl")
test_data = parse_example_file("data/test.jsonl")
prompt_data = parse_example_file("data/article_prompt.jsonl")

main_prompt = generate_prompt(prompt_data)

Решим задачки с помощью обычного Chain-of-Thoughts и распределенной версии bloom. К сожалению, у большой модельки `bloom-petals` очень часто не все блоки доступны и на нее положиться нельзя. Кроме того, Google Colab, постоянно падает (и приходится перезапускать, выполняя весь предыдущий код) при работе с этими моделями, видимо из-за большого числа внешних запросов к другим блокам. А локально у меня ресурсов хватает только на `bloom-7b1-petals`, поэтому проверка работы будет с ней.

In [3]:
model_name = "bigscience/bloom-7b1-petals"
tokenizer = BloomTokenizerFast.from_pretrained(model_name)
model = DistributedBloomForCausalLM.from_pretrained(model_name, tuning_mode="ptune", pre_seq_len=16).cuda()

q1 = generate_question(test_data[0])
prefix = tokenizer(main_prompt + q1, return_tensors="pt")["input_ids"].cuda()
outputs = model.generate(prefix, max_new_tokens=128)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: We start with 15 trees.
Later we have 21 trees.
The difference must be the number of trees they planted.
So, they must have planted 21 - 15 = 6 trees
The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are 3 cars in the parking lot already.
2 more arrive.
Now there are 3 + 2 = 5 cars.
The answer is 5.

Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?
A: Leah had 32 chocolates and Leah’s sister had 42.
That means there were originally 32 + 42 = 74 chocolates.
35 have been eaten. So in total they still have 74 - 35 = 39 chocolates.
The answer is 39.

Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to

Рассуждение и ответ есть, в целом формат выдержан верно (картину портят только токены, сгенерированные после ответа). Проверим другой prompt на основе примеров из train части GSM8k и задачку (возьмем попроще).

In [4]:
prompt_gsm = generate_prompt(train_data[:8])
q2 = generate_question(test_data[33])
prefix = tokenizer(prompt_gsm + q2, return_tensors="pt")["input_ids"].cuda()
outputs = model.generate(prefix, max_new_tokens=128)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
A: Natalia sold 48/2 = 24 clips in May.
Natalia sold 48+24 = 72 clips altogether in April and May.
The answer is 72.

Q: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?
A: Weng earns 12/60 = $0.2 per minute.
Working 50 minutes, she earned 0.2 x 50 = $10.
The answer is 10.

Q: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?
A: In the beginning, Betty has only 100 / 2 = $50.
Betty's grandparents gave her 15 * 2 = $30.
This means, Betty needs 100 - 50 - 30 - 15 = $5 more.
The answer is 5.

Q: Julie is reading a 120-page book. Yesterday, she was able to read 1

Генерация идет долго и на правильные расчеты от такой маленькой модельки рассчитывать не приходится, поэтому будем использовать оригинальную `bigscience/bloom` через HuggingFace Inference API. Кроме того, можно будет сразу получить результат, из которого легче получить ответ, если задать условие остановки генерации.

In [22]:
API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom"
headers = {"Authorization": "Bearer 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"} #токен не скрываю, чтобы можно было запустить


def query(payload):
    response = requests.post(API_URL, headers=headers, json=payload)
    return response.json()


params = {
    "max_new_tokens": 128,
    "temperature": 1.0,
    "stop": ["\n\n"]
}
print(query({
    "inputs": (main_prompt + q1), "parameters": params
})[0]['generated_text'])

Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: We start with 15 trees.
Later we have 21 trees.
The difference must be the number of trees they planted.
So, they must have planted 21 - 15 = 6 trees
The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are 3 cars in the parking lot already.
2 more arrive.
Now there are 3 + 2 = 5 cars.
The answer is 5.

Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?
A: Leah had 32 chocolates and Leah’s sister had 42.
That means there were originally 32 + 42 = 74 chocolates.
35 have been eaten. So in total they still have 74 - 35 = 39 chocolates.
The answer is 39.

Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to

Приятно, что получили правильный ответ.
Теперь можно запустить эксперимент на большем числе задач, в котором будут сохраняться предсказанные рассуждение и ответ, чтобы потом посчитать метрики.

In [None]:
with open("results/results.jsonl", "w") as f:
    for q in tqdm.tqdm(test_data[0:200]):
        inp = (main_prompt + generate_question(q))
        solution = query({
            "inputs": inp, "parameters": params
        })[0]['generated_text'][len(inp):]
        answer = get_ans(solution)
        print(json.dumps({"solution": solution, "answer": answer}), file=f)

Теперь реализуем ансамблированный CoT. Будем брать 40 предсказаний с параметрами из Supplementary Code второй статьи.
Из ответов возьмем самый частый, а если таких несколько - то первый из них. Еще сохраним сколько предсказаний имело такой ответ.
К API в час получается сделать чуть меньше 300 вопросов, поэтому будем останавливать процесс и ждать.
На решение 100 заданий уходит больше 16 часов, поэтому возьмем первые 200 задач (и столько же соответственно в предыдущем пункте).

In [None]:
from time import sleep

params_ensemble = {
    "max_new_tokens": 128,
    "do_sample": True,
    "temperature": 0.7,
    "use_cache": False,
    "stop": ["\n\n", "Q:", "A:"]
}
with open("results/results_ensemble.jsonl", "w") as f:
    for q in tqdm.tqdm(test_data[:200]):
        inp = main_prompt + generate_question(q)
        answers = []
        pairs = []
        for _ in range(40):
            solution = ""
            while True:
                try:
                    solution = query({
                        "inputs": inp, "parameters": params_ensemble}
                    )[0]["generated_text"][len(inp):]
                    break
                except Exception:
                    sleep(180)
            answer = get_ans(solution)
            if answer != "":
                answers.append(answer)
                pairs.append((solution, answer))
        c = Counter(answers)
        answer = c.most_common()[0][0]
        num = c.get(answer)
        for s, a in pairs:
            if a == answer:
                solution = s
                break
        print(json.dumps({"solution": solution, "answer": answer, "num": num}), file=f)