In [1]:
import torch
from toolformer_pytorch import Toolformer, PaLM

from transformers import GPTJForCausalLM
from transformers import GPT2Tokenizer, GPT2LMHeadModel

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def Calendar():
    import datetime
    from calendar import day_name, month_name
    now = datetime.datetime.now()
    return f'Today is {day_name[now.weekday()]}, {month_name[now.month]} {now.day}, {now.year}.'

# prompt for teaching it to use the Calendar function from above

prompt = f"""
Your task is to add calls to a Calendar API to a piece of text.
The API calls should help you get information required to complete the text.
You can call the API by writing "[Calendar()]"
Here are some examples of API calls:
Input: Today is the first Friday of the year.
Output: Today is the first [Calendar()] Friday of the year.
Input: The president of the United States is Joe Biden.
Output: The president of the United States is [Calendar()] Joe Biden.
Input: [input]
Output: 
"""

data = [
    "The store is never open on the weekend, so today it is closed.",
    "The number of days from now until Christmas is 30",
    "The current day of the week is Wednesday.",
]

# model - here using PaLM, but any nn.Module that returns logits in the shape (batch, seq, num_tokens) is fine

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2").cuda()

# model = PaLM(
#     dim = 512,
#     depth = 2,
#     heads = 8,
#     dim_head = 64
# ).cuda()

# toolformer

toolformer = Toolformer(
    model = model,
    model_seq_len = 256,
    teach_tool_prompt = prompt,
    tool_id = 'Calendar',
    tool = Calendar,
    finetune = True,
    tokenizer_encode=tokenizer.encode,
    tokenizer_decode=tokenizer.decode,
)

In [6]:
tokenizer.decode(tokenizer.encode(data[0]))

'The store is never open on the weekend, so today it is closed.'

In [7]:
# logits shape: torch.Size([3, 257, 49408])

In [8]:
# invoking this will
# (1) prompt the model with your inputs (data), inserted into [input] tag
# (2) with the sampled outputs, filter out the ones that made proper API calls
# (3) execute the API calls with the `tool` given
# (4) filter with the specialized filter function (which can be used independently as shown in the next section)
# (5) fine-tune on the filtered results

filtered_stats = toolformer(data)

# then, once you see the 'finetune complete' message

response = toolformer.sample_model_with_api_calls("How many days until the next new years?")

# hopefully you see it invoke the calendar and utilize the response of the api call...

100%|██████████| 119/119 [00:08<00:00, 13.45it/s]


> [0;32m/home/genesis/fun/toolformer/toolformer-pytorch/toolformer_pytorch/toolformer_pytorch.py[0m(888)[0;36mforward[0;34m()[0m
[0;32m    887 [0;31m        [0;32mimport[0m [0mipdb[0m[0;34m;[0m [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 888 [0;31m        [0;32massert[0m [0mlen[0m[0;34m([0m[0mfiltered_data_with_api_calls[0m[0;34m)[0m [0;34m>[0m [0;36m0[0m[0;34m,[0m [0;34m'your model failed to follow instructions and make API calls. please try a better model or do some better prompt engineering'[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    889 [0;31m[0;34m[0m[0m
[0m
ipdb> c


AssertionError: your model failed to follow instructions and make API calls. please try a better model or do some better prompt engineering