# Toolformer: Language Models Can Teach Themselves to Use Tools
- https://arxiv.org/pdf/2302.04761.pdf

Model trained to decide which APIs to call
- when to call them
- what arguments to pass
- how to best incorporate the results into future token prediction

Self-supervised
- requiring nothing more than a handful of
- demonstrations for each API

Architecture
- based on a pretrained GPT-J model with 6.7 Billion parameters
- outperforming a GPT-3 model
- https://github.com/kingoflolz/mesh-transformer-jax
    
    

## Approach

Each api call as a tuple
- ```c = (a_c, i_c)```
- a_c -> name of the API
- i_C -> corresponding input
- corresponding result ```r```

Special tokens
- ```<API>``` and ```</API>```
- ```-->```

Dataset
- Given a dataset of plain texts
- Converted into a dataset augmented with API calls
- Done in three steps
    - 1. Exploit the in-context learning ability of the model ```M``` to sample a large numer of potential API calls
    - 2. Execture these API calls and check whether the obtained repsonses are helpful for predicting future tokens 
        - Filtering criterion
    - 3. Merge API calls for different tools, resulting in the augmented dataset ```C*```
    - 4. Finetune ```M``` itself on this dataset
    
Sampling API Calls
- write a prompt ```P(x)``` for each API
    - encourage LM to annotate an example ```x=[x_1,...,x_n]``` with API calls
  
Executing API Calls
- How the execution is done depends entirely on the API itself
- Can involve
    - Calling another neural network
    - Executing a python script
    - Using retrieval system to perform search or a large corpus
- There is a response for each API call: ```c_i```
    - needs to be a single text sequence ```r_i```
    
Filtering API Calls
- Weighted cross entropy loss for ```M``` over the token x_i, ..., x_n if the model is prefixed with ```z```

Model Finetuning
- Construct new sequence ```x*=x_{1:i-1},e(c_i,r_c),x_{i:n}```
- Doing this for all ```x``` in ```C``` results in the new dataset ```C*``` augmented with API calls
- use this new dataset to finetune M

Inference
- Regular decoding until M produces ```-->```
    - Indicates that is expects the response for an API call.
- Interrupt the decoding process
    - Call the appropriate API to get a response
    - Continue the decoding process after inserting both the response and the ```</API>``` token

# Tools

Uses:
- Question answering
- calculator
- wikipedia search
- machine translation
- calendar

Contraints:
- Both their inputs and outputs have to be represented as text sequences
- We can obtain a few demonstrations of their intended use

# Experiments

GPT-J
- This paper uses it as the language model ```M```
    - https://github.com/kingoflolz/mesh-transformer-jax
- batch size of 128
- learning rate of 1e-5
- linear warmup for the first 10% of training
    
Datasets:
- subset of CCNet as the language modeling dataset ```C```
    - https://aclanthology.org/2020.lrec-1.494/
    - To reduce the computational cost of
annotating C with API calls, we define heuristics
for some APIs to get a subset of C for which API
calls are more likely to be helpful than for an average text. For example, we only consider texts
for the calculator tool if they contain at least three
numbers

Benchmarking
- SQuAD, Google-RE, and T-REx subsets of the LAMA benchmark
    - https://aclanthology.org/D19-1250/
    -  filter out examples where the mask token is not the final token, so that the remaining examples can be processed in a left-to-right fashion
- Math
    - ASDiv, SVAMP, MAWPS
- Question Answering
    - Web Questions, Natural Questions, TriviaQA
    - mostly relying on the **Wikipedia search API (99.3%) to find relevant information**
- Figure 4 shows that the ability to leverage the provided tools only emerges at around 775M parameters: smaller models achieve similar performance both with and without tools

# Training

- 25k examples per api
- max sequence length 1024
- batch size of 128
- trained using DeepSpeed's ZeRO-3 (Rasley et al., 2020)
- 8 NVIDIA A100 40GB GPU's with BF16. 
- Training up to 2k steps
- evaluate PPL on small dev set from CCNet containing 1k example every 500 steps.
    - Pick the checkpoint that performs the best

# Conclusion

- a language model that learns in a self-supervised way how to use different tools such as search engines, calculators, and translation systems via simple API calls.
- this is done by finetuning on a large number of sampled API calss taht are filtered based on whether they reduce perplexity on future tokens.
- considerebly imporoves zero shot performance of a 6.7B parameter GPT-J model
    - outperforms GPT-3