-
Notifications
You must be signed in to change notification settings - Fork 121
/
toolformer_pytorch.py
900 lines (678 loc) · 29.2 KB
/
toolformer_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
import re
from functools import partial, wraps
from collections import namedtuple
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from einops import rearrange, reduce
from toolformer_pytorch.palm import PaLM
from toolformer_pytorch.optimizer import get_optimizer
from toolformer_pytorch.prompts import DEFAULT_PROMPT_INPUT_TAG
from beartype import beartype
from beartype.typing import Callable, Optional, Union, List, Tuple
from tqdm import tqdm
from x_clip.tokenizer import tokenizer
pad_sequence = partial(pad_sequence, batch_first = True)
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def identity(t):
return t
def always(val):
def inner(*args, **kwargs):
return val
return inner
def try_except(fn, callback = identity):
@wraps(fn)
def inner(*args):
try:
return fn(*args)
except Exception as e:
return callback(e)
return inner
# tensor helpers
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature = 1., dim = -1, eps = 1e-10):
if temperature == 0:
return t.argmax(dim = dim)
return ((t / max(temperature, eps)) + gumbel_noise(t)).argmax(dim = dim)
def top_k(logits, thres = 0.9):
k = math.ceil((1 - thres) * logits.shape[-1])
val, indices = torch.topk(logits, k)
probs = torch.full_like(logits, -torch.finfo(logits.dtype).max)
probs.scatter_(1, indices, val)
return probs
def all_contains_id(t: torch.Tensor, token_id: int):
mask = t == token_id
return mask.any(dim = -1).all()
def find_indices_of(t: torch.Tensor, token_id: int, occurrence = 1):
assert occurrence > 0
mask = (t == token_id)
has_occurred = mask.cumsum(dim = -1)
has_occurred = F.pad(has_occurred, (1, 0), value = 0.)
return (has_occurred < occurrence).sum(dim = -1).long()
# invoking api call functions
def is_valid_string(s):
return exists(re.fullmatch(r"'[^']*'|\"[^\"]*\"", s))
def is_valid_integer(s):
return exists(re.fullmatch(r"[+-]?\d+", s))
def is_valid_float(s):
return exists(re.fullmatch(r"[+-]?\d+(\.\d+)?", s))
def parse_param(s: str) -> Optional[Union[int, float, str]]:
if is_valid_string(s):
return str(s)
elif is_valid_integer(s):
return int(s)
elif is_valid_float(s):
return float(s)
return None
@beartype
def replace_fn(
registry: dict[str, Callable],
matches,
delimiter = '→'
):
orig_text = matches.group(0)
text_without_end_api_token = matches.group(1)
end_api_token = matches.group(4)
function_name = matches.group(2)
# unable to find function in registry
if function_name not in registry:
return orig_text
fn = registry[function_name]
params = matches.group(3).split(',')
params = list(map(lambda s: s.strip(), params))
params = list(filter(len, params))
params = list(map(parse_param, params))
# if any of the parameters are not parseable, return
if any([(not exists(p)) for p in params]):
return orig_text
# just return original text if there is some error with the function
out = try_except(fn, always(None))(*params)
# the api calling function can also arrest the process, by returning None
if not exists(out):
return orig_text
# return original text with the output delimiter and the stringified output
return f'{text_without_end_api_token} {delimiter} {str(out)} {end_api_token}'
# main function, which takes a registry of functions, the text in question, and makes all the appropriate api calls and append the output
def create_function_regex(
api_start = ' [',
api_stop = ']'
):
api_start_regex, api_stop_regex = map(re.escape, (api_start, api_stop))
return rf'({api_start_regex}(\w+)\(([^)]*)\))({api_stop_regex})'
def num_matches(substr: str, text: str):
return len(re.findall(re.escape(substr), text))
def has_api_calls(
text,
api_start = ' [',
api_stop = ']'
):
regex = create_function_regex(api_start, api_stop)
matches = re.findall(regex, text)
return len(matches) > 0
def replace_all_but_first(
text: str,
api_start = ' [',
api_stop = ']'
) -> str:
regex = create_function_regex(api_start, api_stop)
count = 0
def replace_(matches):
orig_text = matches.group(0)
nonlocal count
count += 1
if count > 1:
return ''
return orig_text
return re.sub(regex, replace_, text)
def invoke_tools(
registry: dict[str, Callable],
text: str,
delimiter: str = '→',
api_start = ' [',
api_stop = ' ]'
) -> str:
regex = create_function_regex(api_start, api_stop)
replace_ = partial(replace_fn, registry, delimiter = delimiter)
return re.sub(regex, replace_, text)
def invoke_tools_on_batch_sequences(
registry: dict[str, Callable],
token_ids: torch.Tensor,
*,
encode: Callable,
decode: Callable,
delimiter: str = '→',
api_start = ' [',
api_stop = ']'
) -> torch.Tensor:
regex = create_function_regex(api_start_regex, api_stop_regex)
all_texts = [decode(one_seq_token_ids) for one_seq_token_ids in token_ids]
invoke_tools_ = partial(invoke_tools, api_start = api_start, api_stop = api_stop)
all_texts_with_api_calls = [invoke_tools_(registry, text, delimiter) for text in all_texts]
return encode(all_texts_with_api_calls)
# sampling api related functions
# they do greedy sampling, but encourage sampling api calls by auto-selecting <api> when that token is in the top k = 10
@beartype
@torch.no_grad()
def sample(
model: nn.Module,
*,
seq_len,
prime: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
batch_size = 1,
eos_token_id = None,
sos_token_id = 1,
temperature = 0.,
pad_id = 0,
call_api_only_once = False,
api_start_token_id = None,
auto_select_api_start_token_when_topk = False,
select_api_start_id_top_k = 10,
):
device = next(model.parameters()).device
max_seq_len = seq_len + 1
# validate
if call_api_only_once:
assert exists(api_start_token_id)
# prime
if exists(prime):
batch_size, prime_length = prime.shape
else:
prime_length = 1
prime = torch.full((batch_size, 1), sos_token_id, device = device, dtype = torch.long)
prime = prime.to(device)
# sampling positions - different sequences have different cursors
if exists(positions):
positions = positions.clone()
else:
positions = torch.zeros((batch_size,), device = device, dtype = torch.long)
assert (positions <= (prime_length + 1)).all() and (positions <= max_seq_len).all(), 'all positions must be less then initial prime length as well as the total sequence length + 1 (plus one for noop if one sequence finished sampling before the other)'
# eval model
model.eval()
# lengthen the prime to the entire sequence length
remain_iterations = seq_len - prime_length
output = F.pad(prime, (0, max_seq_len - prime_length), value = 0.)
batch_indices = torch.arange(batch_size, device = device)
batch_indices = rearrange(batch_indices, 'b -> b 1')
position_indices = rearrange(positions, 'b -> b 1')
# determine the <api> token mask, for making sure api is called only once, masking out logit to prevent it from being selected for those rows which already contains an <api> token
api_token_mask = None # lazily created, since do not know logit dimensions
def create_api_token_mask(num_tokens, api_start_token_id):
mask = torch.zeros((1, 1, num_tokens), dtype = torch.bool)
assert api_start_token_id < num_tokens
mask[..., api_start_token_id] = True
return mask
# start iterating
for iteration in tqdm(range(remain_iterations)):
logits = model(output)
last_logits = logits[batch_indices, position_indices]
# this will ensure that each batch token sequence will have at most one <api> token
if call_api_only_once:
if not exists(api_token_mask):
num_tokens = last_logits.shape[-1]
api_token_mask = create_api_token_mask(num_tokens, api_start_token_id)
api_token_mask = api_token_mask.to(device)
api_called = (output == api_start_token_id).any(dim = -1)
logit_mask = api_token_mask & rearrange(api_called, 'b -> b 1 1')
last_logits = last_logits.masked_fill(logit_mask, -torch.finfo(last_logits.dtype).max)
# greedy sample (but could be made non-greedy)
sampled = gumbel_sample(last_logits, temperature = temperature)
# for those sequences without an api call, if the api_start_token_id is within top k (set to 10 in paper) of logits, just auto-select
# seems to be an important hack in the paper
# it seems like this paper will take a lot more follow up research to be viable
if auto_select_api_start_token_when_topk:
top_token_ids = last_logits.topk(select_api_start_id_top_k, dim = -1).indices
has_api_token_in_topk = (top_token_ids == api_start_token_id).any(dim = -1)
should_auto_select_api_token = has_api_token_in_topk & ~rearrange(api_called, 'b -> b 1')
sampled = sampled.masked_fill(should_auto_select_api_token, api_start_token_id)
# set the sampled tokens at the right curosr positions
output[batch_indices, position_indices] = sampled
# increment positions
position_indices += 1
position_indices.clamp_(max = seq_len) # noop if one sequence is further along and near the end
# if using <eos> tokens, look for all sequences having it and terminate, also anything after <eos> will be padded
if exists(eos_token_id):
eos_mask = (output == eos_token_id)
all_rows_have_eos = eos_mask.any(dim = -1).all()
if all_rows_have_eos:
keep_mask = eos_mask.cumsum(dim = -1) == 0
keep_mask = F.pad(keep_mask, (1, 0), value = True)
output = output.masked_fill(~keep_mask, pad_id)
break
# remove the last token in output (use as noop placeholder)
output = output[:, :-1]
return output
@beartype
@torch.no_grad()
def sample_with_api_call(
model: nn.Module,
*,
seq_len,
call_apis: Callable,
prime: torch.Tensor,
api_end_token_id: int,
occurrence = 1,
**kwargs
):
sampled = sample(
model = model,
prime = prime,
seq_len = seq_len,
**kwargs
)
sampled = call_apis(sampled)
sampled_seq_len = sampled.shape[-1]
null_positions = sampled_seq_len # handle sequences that do not have api calls
pos_starting_at_end_of_api = find_indices_of(
sampled,
api_end_token_id,
occurrence = occurrence
)
resample_after_api_calls = sample(
model = model,
prime = sampled,
seq_len = sampled_seq_len,
positions = (pos_starting_at_end_of_api + 1).clamp(max = null_positions), # start at the position right after the </api>
**kwargs
)
return resample_after_api_calls
# the main contribution of the paper is simply the filtering equations presented in section 2
def default_weight_fn(t):
# following the formula in section 4.1 - however, not sure what w_s is in the denominator
# if t stands for each timestep, this would also mean within 5 tokens it would diminish to 0?
return (1. - t * 0.2).clamp(min = 0.)
def get_pred_prob(token_ids, logits):
logits = logits[:, :-1] # logits of each token... (omit last logit)
token_ids = token_ids[:, 1:] # predicts the next token id (omit first token id)
token_ids = rearrange(token_ids, 'b n -> b n 1')
probs = logits.softmax(dim = -1)
correct_token_id_pred_prob = probs.gather(-1, token_ids)
return rearrange(correct_token_id_pred_prob, 'b n 1 -> b n')
def get_arange_start_at_token_id(
token_ids: torch.Tensor,
token_id: int,
pad_id = -1
):
is_token_id_mask = token_ids == token_id
arange = (is_token_id_mask.cumsum(dim = -1) > 0).cumsum(dim = -1)
before_token_mask = arange == 0
arange = arange - 1
arange = arange.masked_fill(before_token_mask, pad_id)
return arange
def weight_and_mask(
token_ids: torch.Tensor,
token_id: int,
pad_id = -1,
weighting_fn: Callable = default_weight_fn
):
t = get_arange_start_at_token_id(token_ids, token_id, pad_id)
weights = weighting_fn(t)
return weights.masked_fill(t == pad_id, 0.)
FilteredResults = namedtuple('FilteredResults', [
'num_passed',
'num_failed',
'selected_indices',
'selected_mask',
'filtered_tokens',
'filtered_tokens_without_api_response',
'filtered_tokens_with_api_response'
])
@beartype
def filter_tokens_with_api_response(
model: nn.Module, # the language model should accept the token ids below and return the logits in shape (batch, seq, num tokens)
*,
tokens: torch.Tensor, # token ids (batch, seq) of the original passage, without api calls
tokens_without_api_response: torch.Tensor, # token ids (batch, seq) of the passage, but with the api call (but without a response filled in) - <api>tool1(x, y)</api>
tokens_with_api_response: torch.Tensor, # token ids (batch, seq) of the passage with api call and the response - <api>tool1(x, y) → {response}</api>
api_start_token_id: int, # token id of the <api> tag
api_end_token_id: int, # token id of the </api> tag
filter_threshold: float = 1., # the threshold at which to accept the sampled api call (tokens_with_api_response) for fine-tuning
weighting_fn: Callable = default_weight_fn # weighting function
) -> FilteredResults:
# validations
assert all([*map(lambda t: t.dtype == torch.long, (tokens, tokens_with_api_response, tokens_without_api_response))])
assert all_contains_id(tokens_without_api_response, api_start_token_id)
assert all_contains_id(tokens_without_api_response, api_end_token_id)
assert all_contains_id(tokens_with_api_response, api_start_token_id)
assert all_contains_id(tokens_with_api_response, api_end_token_id)
# auto set devices
device = next(model.parameters()).device
tokens, tokens_without_api_response, tokens_with_api_response = map(lambda t: t.to(device), (tokens, tokens_without_api_response, tokens_with_api_response))
# get all the logits
with torch.no_grad():
model.eval()
logits, logits_without_api_response, logits_with_api_response = map(model, (tokens, tokens_without_api_response, tokens_with_api_response))
# derive all predicted prob of the actual next token id in sequence
probs = get_pred_prob(tokens, logits)
probs_without_api_response = get_pred_prob(tokens_without_api_response, logits_without_api_response)
probs_with_api_response = get_pred_prob(tokens_with_api_response, logits_with_api_response)
weight_and_mask_fn = partial(weight_and_mask, weighting_fn = weighting_fn)
# derive the weighting
weight_without_api_response = weight_and_mask_fn(tokens_without_api_response[:, :-1], api_end_token_id)
weight_with_api_response = weight_and_mask_fn(tokens_with_api_response[:, :-1], api_end_token_id)
# deriving the weighting for the original passage is more tricky
# would need to start counting up from <api> start token location
# this would also assume that the language model perfectly copied the passage over and that both token ids are aligned except for the inserted API call - but this can be done with the custom filtering functions eventually
weight = weight_and_mask_fn(tokens_without_api_response[:, 1:], api_start_token_id) # shift to the left by one since <api> does not exist in the original sequence
weight = weight[:, :probs.shape[-1]]
# get the loss L for all three types of sequences
def loss_fn(weight, probs):
return (weight * -log(probs)).sum(dim = -1)
loss = loss_fn(weight, probs)
loss_without_api_response = loss_fn(weight_without_api_response, probs_without_api_response)
loss_with_api_response = loss_fn(weight_with_api_response, probs_with_api_response)
# calculate the main formula in the paper
# loss+ = loss with api response
# loss- = min(loss without api response, loss without api at all)
loss_plus = loss_with_api_response
loss_minus = torch.minimum(loss_without_api_response, loss)
selected_mask = (loss_minus - loss_plus) >= filter_threshold
# now we can select and return the entries that survived the filtering stage
# also returning the selected indices of the batch being processed
# for finetuning the model into toolformer
batch = tokens.shape[0]
indices = torch.arange(batch, device = tokens.device)
selected_indices = indices[selected_mask]
ret = FilteredResults(
selected_mask.sum().item(),
(~selected_mask).sum().item(),
selected_indices,
selected_mask,
tokens[selected_mask],
tokens_without_api_response[selected_mask],
tokens_with_api_response[selected_mask]
)
return ret
# datasets and dataloaders
# for bootstrapping the initial datasets with api calls
# as well as for the final finetuning
@beartype
class PromptDataset(Dataset):
def __init__(
self,
prompt: str,
prompt_input_tag: str,
data: List[str],
tokenizer_encode: Callable
):
self.data = data
self.prompt = prompt
self.prompt_input_tag_regex = re.escape(prompt_input_tag)
self.tokenizer_encode = tokenizer_encode
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data_string = self.data[idx]
data_with_prompt = re.sub(self.prompt_input_tag_regex, data_string, self.prompt)
token_ids = self.tokenizer_encode(data_with_prompt)
return torch.tensor(token_ids).long(), torch.tensor(len(token_ids)).long()
def prompt_collate_fn(data, padding_value = 0):
prompts, prompt_lengths = zip(*data)
prompts = pad_sequence(prompts, padding_value = padding_value)
return prompts, torch.stack(prompt_lengths)
def PromptDataloader(ds: Dataset, *args, padding_value = 0, **kwargs):
collate_fn = partial(prompt_collate_fn, padding_value = padding_value)
return DataLoader(ds, *args, collate_fn = collate_fn, **kwargs)
class FinetuneDataset(Dataset):
def __init__(
self,
tokens: torch.Tensor
):
self.tokens = tokens
def __len__(self):
return len(self.tokens)
def __getitem__(self, idx):
return self.tokens[idx]
def FinetuneDataloader(ds: Dataset, *args, padding_value = 0, **kwargs):
return DataLoader(ds, *args, collate_fn = partial(pad_sequence, padding_value = padding_value), **kwargs)
# classes
@beartype
class Toolformer(nn.Module):
def __init__(
self,
model: nn.Module,
*,
tool_id: str,
tool: Callable,
api_start_str = ' [',
api_stop_str = ']',
api_response_delimiter = '→',
api_start_id = None,
api_stop_id = None,
teach_tool_prompt: str,
filter_threshold = 1.,
pad_id = 0,
prompt_batch_size = 4,
model_seq_len = 2048,
tokenizer_encode: Callable = tokenizer.encode,
tokenizer_decode: Callable = tokenizer.decode,
post_prompt_callback: Callable = identity,
prompt_input_tag: str = DEFAULT_PROMPT_INPUT_TAG,
exclude_filters: dict[str, Callable[[str], bool]] = dict(),
finetune = False,
finetune_lr = 1e-4,
finetune_wd = 1e-2,
finetune_betas = (0.9, 0.99),
finetune_eps = 1e-8,
finetune_epochs = 3,
finetune_batch_size = 16
):
super().__init__()
self.model = model
self.model_seq_len = model_seq_len
self.teach_tool_prompt = teach_tool_prompt
self.prompt_batch_size = prompt_batch_size
self.prompt_input_tag = prompt_input_tag
self.post_prompt_callback = post_prompt_callback # for easy mocking
self.tokenizer_encode = tokenizer_encode
self.tokenizer_decode = tokenizer_decode
self.tokenizer_encode_to_tensor = lambda s: torch.tensor(tokenizer_encode(s)).long()
self.filter_threshold = filter_threshold
self.api_start_str = api_start_str
self.api_stop_str = api_stop_str
self.api_response_delimiter = api_response_delimiter
if not exists(api_start_id):
api_start_id = tokenizer_encode(api_start_str)
assert len(api_start_id) == 1
api_start_id = api_start_id[0]
self.api_start_id = api_start_id
if not exists(api_stop_id):
api_stop_id = tokenizer_encode(api_stop_str)
assert len(api_stop_id) == 1
api_stop_id = api_stop_id[0]
self.api_stop_id = api_stop_id
self.pad_id = pad_id
self.tool_id = tool_id
self.tool = tool
self.registry = {tool_id: tool}
assert num_matches(prompt_input_tag, teach_tool_prompt) == 1, f'there must be exactly one prompt input tag `{prompt_input_tag}` in your prompt to encourage the language model to use the designated tool'
self.teach_tool_prompt = teach_tool_prompt
self.exclude_filters = exclude_filters
self.should_finetune = finetune
if not finetune:
return
self.finetune_batch_size = finetune_batch_size
self.finetune_epochs = finetune_epochs
self.optimizer = get_optimizer(
model.parameters(),
lr = finetune_lr,
wd = finetune_wd,
betas = finetune_betas,
eps = finetune_eps
)
def generate_data_with_api_calls(
self,
data: List[str],
temperature: float = 0.9
) -> List[str]:
dataset = PromptDataset(
data = data,
prompt_input_tag = self.prompt_input_tag,
prompt = self.teach_tool_prompt,
tokenizer_encode = self.tokenizer_encode
)
dl = PromptDataloader(
dataset,
batch_size = self.prompt_batch_size
)
prompted_outputs = []
for prime, positions in dl:
sampled_outputs = sample(
model = self.model,
prime = prime,
positions = positions,
seq_len = self.model_seq_len,
pad_id = self.pad_id,
temperature = temperature
)
for sample_output, position in zip(sampled_outputs, positions):
start_position = position.item()
prompted_output = self.tokenizer_decode(sample_output[start_position:])
prompted_outputs.append(prompted_output)
return self.post_prompt_callback(prompted_outputs)
def filter_and_keep_only_first_api_call(
self,
data,
data_with_api_calls: List[str],
return_excluded = False
):
included_data = []
included_data_with_api_calls = []
included = (included_data, included_data_with_api_calls)
excluded_data = []
excluded_data_with_api_calls = []
excluded = (excluded_data, excluded_data_with_api_calls)
api_start_stop_kwargs = dict(api_start = self.api_start_str, api_stop = self.api_stop_str)
has_api_calls_ = partial(has_api_calls, **api_start_stop_kwargs)
replace_all_but_first_ = partial(replace_all_but_first, **api_start_stop_kwargs)
for datum, data_with_api_call in zip(data, data_with_api_calls):
if has_api_calls_(data_with_api_call):
data_with_api_call = replace_all_but_first_(data_with_api_call)
included_data.append(datum)
included_data_with_api_calls.append(data_with_api_call)
else:
excluded_data.append(datum)
excluded_data_with_api_calls.append(data_with_api_call)
if not return_excluded:
return included
return included, excluded
@torch.no_grad()
def sample_model_with_api_calls(
self,
prime: Union[torch.Tensor, str],
occurrence = 1,
**kwargs
):
self.model.eval()
prime_is_str = isinstance(prime, str)
if prime_is_str:
prime = self.tokenizer_encode(prime)
prime = torch.tensor(prime).long()
prime = rearrange(prime, 'n -> 1 n')
assert prime.shape[0] == 1, 'only one at a time for now'
invoke_tools_ = partial(invoke_tools, self.registry)
def call_apis(t: torch.Tensor):
t = self.tokenizer_decode(t[0])
t = invoke_tools_(t)
t = self.tokenizer_encode_to_tensor(t)
return rearrange(t, 'n -> 1 n')
output = sample_with_api_call(
model = self.model,
prime = prime,
seq_len = self.model_seq_len,
call_apis = call_apis,
api_end_token_id = self.api_stop_id,
occurrence = occurrence,
**kwargs
)
if not prime_is_str:
return output
return self.tokenizer_decode(output[0])
def make_api_calls(
self,
filtered_data_with_api_calls: List[str]
):
invoke_tools_ = partial(
invoke_tools,
self.registry,
api_start = self.api_start_str,
api_stop = self.api_stop_str, delimiter = self.api_response_delimiter
)
data_with_api_responses = []
for data in filtered_data_with_api_calls:
output = invoke_tools_(data)
data_with_api_responses.append(output)
return data_with_api_responses
def filter_by_api_responses(
self,
data: List[str],
data_with_api_calls: List[str],
data_with_api_responses: List[str]
) -> FilteredResults:
to_token_ids = lambda l: pad_sequence([*map(self.tokenizer_encode_to_tensor, l)], padding_value = self.pad_id)
tokens, tokens_without_api_response, tokens_with_api_response = map(to_token_ids, (data, data_with_api_calls, data_with_api_responses))
filtered_results = filter_tokens_with_api_response(
model = self.model,
tokens = tokens,
tokens_with_api_response = tokens_with_api_response,
tokens_without_api_response = tokens_without_api_response,
filter_threshold = self.filter_threshold,
api_start_token_id = self.api_start_id,
api_end_token_id = self.api_stop_id
)
return filtered_results
def finetune(
self,
filtered_results: Union[FilteredResults, torch.Tensor]
):
self.model.train()
if isinstance(filtered_results, FilteredResults):
filtered_results = filtered_results.filtered_tokens_without_api_response
dataset = FinetuneDataset(tokens = filtered_results)
dl = FinetuneDataloader(dataset, batch_size = self.finetune_batch_size, shuffle = True)
for epoch in tqdm(range(self.finetune_epochs), desc = 'finetune epochs'):
for batch in dl:
inp, labels = batch[:, :-1], batch[:, 1:]
logits = self.model(inp)
logits = rearrange(logits, 'b n c -> b c n')
loss = F.cross_entropy(logits, labels, ignore_index = self.pad_id)
loss.backward()
print(f'loss: {loss.item()}')
self.optimizer.step()
self.optimizer.zero_grad()
print(f'finished finetuning on {len(dataset)} filtered samples')
def forward(
self,
data: List[str],
return_after_generating_api_calls = False,
return_after_making_api_calls = False,
return_after_filtering_api_calls = False,
return_after_filtering_by_api_response = False
):
data_with_api_calls = self.generate_data_with_api_calls(data)
if return_after_generating_api_calls:
return data_with_api_calls
filtered_data, filtered_data_with_api_calls = self.filter_and_keep_only_first_api_call(data, data_with_api_calls)
if return_after_filtering_api_calls:
return filtered_data, filtered_data_with_api_calls
assert len(filtered_data_with_api_calls) > 0, 'your model failed to follow instructions and make API calls. please try a better model or do some better prompt engineering'
data_with_responses = self.make_api_calls(filtered_data_with_api_calls)
if return_after_making_api_calls:
return filtered_data, filtered_data_with_api_calls, data_with_responses
filtered_results = self.filter_by_api_responses(filtered_data, filtered_data_with_api_calls, data_with_responses)
if return_after_filtering_by_api_response:
return filtered_results
if self.should_finetune:
assert filtered_results.num_passed > 0, f'none of the sequences with API calls passed the filtering criteria with threshold {self.filter_threshold}'
self.finetune(filtered_results)
return filtered_results