In [2]:
%load_ext autoreload
%autoreload 2

from transformers import AutoTokenizer, AutoModelForCausalLM
import re
import torch
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")

DEVICE = 'cuda' if torch.cuda.is_available() else "cpu"
model = model.to(DEVICE)

  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
config.json: 100%|██████████| 930/930 [00:00<00:00, 3.83MB/s]
pytorch_model.bin: 100%|██████████| 24.2G/24.2G [20:37<00:00, 19.6MB/s]  


In [3]:
prompt = f"""
Your task is to add calls to a Calendar API to a piece of text.
The API calls should help you get information required to complete the text.
You can call the API by writing "[Calendar()]"
Here are some examples of API calls:
Input: Today is the first Friday of the year.
Output: Today is the first [Calendar()] Friday of the year.
Input: The president of the United States is Joe Biden.
Output: The president of the United States is [Calendar()] Joe Biden.
Input: [input]
Output:
"""

data = [
    "The store is never open on the weekend, so today it is closed.",
    "The number of days from now until Christmas is 30",
    "The current day of the week is Wednesday."
]

In [4]:
idx = 0
DEFAULT_PROMPT_INPUT_TAG = '[input]'
prompt_input_tag_regex = re.escape(DEFAULT_PROMPT_INPUT_TAG)
data_string = data[idx]
data_with_prompt = re.sub(prompt_input_tag_regex, data_string, prompt)
token_ids = tokenizer(data_with_prompt, return_tensors = 'pt')

batch_size, len_prev = token_ids['input_ids'].size()

In [50]:
import torch
with torch.no_grad():
  output = model.generate(**token_ids, max_new_tokens=20)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [59]:
tokenizer.decode(output[0][len_prev:])

'The store is never open on the [Calendar()] weekend, so today it is closed.'

In [64]:
from tools import Calendar
print(f"API call with argument: no argument,\nresult: {Calendar()}")

API call with argument: no argument,
result: Today is Thursday, November 30, 2023.


In [65]:
print(f"Text with the API call\n'The store is never open on the [Calendar() -> {Calendar()}] weekend, so today it is closed.'")

Text with the API call
'The store is never open on the [Calendar() -> Today is Thursday, November 30, 2023.] weekend, so today it is closed.'


In [9]:
# 문맥상 api 결과가 주말이여야 더 그럴듯하므로 변경
including_API_with_result ='The store is never open on the [Calendar() -> Today is Saturday, November 25, 2023.]' 
including_API_without_result = 'The store is never open on the [Calendar()]'
plain_text = 'The store is never open on the'

next_words = 'weekend, so today it is closed.'

In [10]:
inputs = tokenizer(including_API_with_result + next_words, return_tensors = 'pt')
mask_tokens = tokenizer(including_API_with_result, return_tensors = 'pt')
batch_size, len_mask_tokens = mask_tokens['input_ids'].size()

labels = tokenizer(including_API_with_result + next_words, return_tensors = "pt")['input_ids']
labels[:, :len_mask_tokens] = -100

inputs = inputs.to(DEVICE)
labels = labels.to(DEVICE)
with torch.no_grad():
  api_with_result_output = model(**inputs, labels=labels)

In [13]:
inputs = tokenizer(including_API_without_result + next_words, return_tensors = 'pt')
mask_tokens = tokenizer(including_API_without_result, return_tensors = 'pt')
batch_size, len_mask_tokens = mask_tokens['input_ids'].size()

labels = tokenizer(including_API_without_result + next_words, return_tensors = "pt")['input_ids']
labels[:, :len_mask_tokens] = -100

inputs = inputs.to(DEVICE)
labels = labels.to(DEVICE)
with torch.no_grad():
  api_without_result_output = model(**inputs, labels=labels)

In [15]:
inputs = tokenizer(plain_text + next_words, return_tensors = 'pt')
mask_tokens = tokenizer(plain_text, return_tensors = 'pt')
batch_size, len_mask_tokens = mask_tokens['input_ids'].size()

labels = tokenizer(plain_text + next_words, return_tensors = "pt")['input_ids']
labels[:, :len_mask_tokens] = -100

inputs = inputs.to(DEVICE)
labels = labels.to(DEVICE)
with torch.no_grad():
  plain_output = model(**inputs, labels=labels)

In [16]:
print(f"api + result loss: {api_with_result_output.loss}")
print(f"api without result loss: {api_without_result_output.loss}")
print(f"plain text loss: {plain_output.loss}")

api + result loss: 2.8421099185943604
api without result loss: 2.9888060092926025
plain text loss: 3.8347482681274414


In [None]:
filtering_threshold = 1.0
if api_with_result_output.loss > min(api_without_result_output, plain_output) + filtering_threshold:
  finetune_dataset = including_API_without_result + next_words
else:
  finetune_dataset = plain_text + next_words