-
Notifications
You must be signed in to change notification settings - Fork 210
/
generation_utils.rs
2372 lines (2239 loc) · 97.6 KB
/
generation_utils.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors.
// Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! # Natural Language Generation utilities
//! Set of text generation utilities, serving as a basis for TextGenerationModel, SummarizationModels and TranslationModels.
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
//! Supports batch generation of sentences from several prompts. Sequences will be left-padded with the model's padding token if present, the unknown token otherwise.
//! This may impact the results and it is recommended to submit prompts of similar length for best results.
//!
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! use rust_bert::gpt2::GPT2Generator;
//! use rust_bert::pipelines::generation_utils::{
//! GenerateConfig, GenerateOptions, LanguageGenerator,
//! };
//!
//! let generate_config = GenerateConfig {
//! do_sample: true,
//! num_beams: 5,
//! temperature: 1.1,
//! num_return_sequences: 3,
//! ..Default::default()
//! };
//! let mut gpt2_generator = GPT2Generator::new(generate_config)?;
//!
//! let input_context = "The dog";
//! let second_input_context = "The cat was";
//!
//! let generate_options = GenerateOptions {
//! min_length: Some(32),
//! max_length: Some(128),
//! output_scores: true,
//! ..Default::default()
//! };
//!
//! let output = gpt2_generator.generate(
//! Some(&[input_context, second_input_context]),
//! Some(generate_options),
//! );
//! # Ok(())
//! # }
//! ```
//!
//! Example output: \
//! ```no_run
//! # let output =
//! [
//! "The dog's owners, however, did not want to be named. According to the lawsuit, the animal's owner, a 29-year",
//! "The dog has always been part of the family. \"He was always going to be my dog and he was always looking out for me",
//! "The dog has been able to stay in the home for more than three months now. \"It's a very good dog. She's",
//! "The cat was discovered earlier this month in the home of a relative of the deceased. The cat\'s owner, who wished to remain anonymous,",
//! "The cat was pulled from the street by two-year-old Jazmine.\"I didn't know what to do,\" she said",
//! "The cat was attacked by two stray dogs and was taken to a hospital. Two other cats were also injured in the attack and are being treated."
//! ]
//! # ;
//! ```
use tch::kind::Kind::Int64;
use tch::{no_grad, Device, Kind, Tensor};
use crate::bart::LayerState as BartLayerState;
use crate::common::resources::ResourceProvider;
use crate::gpt_j::LayerState as GPTJLayerState;
use crate::gpt_neo::LayerState as GPTNeoLayerState;
use crate::pipelines::generation_utils::private_generation_utils::{
InternalGenerateOptions, PrivateLanguageGenerator,
};
use crate::prophetnet::LayerState as ProphetNetLayerState;
use crate::reformer::LayerState as ReformerLayerState;
use crate::t5::LayerState as T5LayerState;
use crate::xlnet::LayerState as XLNetLayerState;
use self::ordered_float::OrderedFloat;
use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
extern crate ordered_float;
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::ONNXLayerCache;
use crate::RustBertError;
#[cfg(feature = "remote")]
use crate::{
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
resources::RemoteResource,
};
/// # Configuration for text generation
pub struct GenerateConfig {
/// Model type used for generation
pub model_type: ModelType,
/// Model weights resource (default: pretrained GPT2 model)
pub model_resource: ModelResource,
/// Config resource (default: pretrained GPT2 model)
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained GPT2 model)
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained GPT2 model)
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
pub max_length: Option<i64>,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
pub early_stopping: bool,
/// Number of beams for beam search (default: 5)
pub num_beams: i64,
/// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
pub temperature: f64,
/// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
pub top_k: i64,
/// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
pub top_p: f64,
/// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
pub repetition_penalty: f64,
/// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
pub length_penalty: f64,
/// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
pub no_repeat_ngram_size: i64,
/// Number of sequences to return for each prompt text (default: 1)
pub num_return_sequences: i64,
/// Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation.
pub num_beam_groups: Option<i64>,
/// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)
pub diversity_penalty: Option<f64>,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
pub kind: Option<Kind>,
}
#[cfg(feature = "remote")]
impl Default for GenerateConfig {
fn default() -> GenerateConfig {
GenerateConfig {
model_type: ModelType::GPT2,
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
Gpt2ModelResources::GPT2,
))),
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2,
))),
min_length: 0,
max_length: Some(56),
do_sample: true,
early_stopping: true,
num_beams: 5,
temperature: 1.0,
top_k: 0,
top_p: 0.9,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 3,
num_return_sequences: 1,
num_beam_groups: None,
diversity_penalty: None,
device: Device::cuda_if_available(),
kind: None,
}
}
}
impl GenerateConfig {
pub(crate) fn validate(&self) {
assert!(self.temperature > 0f64, "temperature must positive");
assert!(
(self.top_p >= 0f64) & (self.top_p <= 1f64),
"top_p must be 0 and 1"
);
assert!(
self.repetition_penalty >= 1f64,
"repetition_penalty must be greater than 1"
);
assert!(
self.length_penalty > 0f64,
"length_penalty must be strictly greater than 0"
);
assert!(
self.num_return_sequences > 0i64,
"num_return_sequences must be strictly greater than 0"
);
assert!(
self.num_beams > 0i64,
"num_beams must be strictly greater than 0"
);
if !self.do_sample {
if self.num_beams == 1 {
assert_eq!(
self.num_return_sequences, 1,
"num_return_sequences must be set to 1 for greedy decoding"
)
} else {
assert!(
self.num_beams >= self.num_return_sequences,
"num_return_sequences must be lower than the number of beams"
)
}
}
if let Some(num_beam_groups_value) = self.num_beam_groups {
if num_beam_groups_value > 1 {
assert_eq!(
self.num_beams % num_beam_groups_value,
0,
"num_beam_groups must be a multiple of num_beam_groups"
)
}
}
}
}
#[derive(Debug)]
pub enum Cache {
GPT2Cache(Option<Vec<Tensor>>),
BARTCache(Option<Vec<(Option<BartLayerState>, Option<BartLayerState>)>>),
T5Cache(Option<Vec<(Option<T5LayerState>, Option<T5LayerState>)>>),
LongT5Cache(Option<Vec<(Option<T5LayerState>, Option<T5LayerState>)>>),
XLNetCache(Option<Vec<Option<XLNetLayerState>>>),
ReformerCache(Option<Vec<Option<ReformerLayerState>>>),
ProphetNetCache(Option<Vec<(Option<ProphetNetLayerState>, Option<ProphetNetLayerState>)>>),
GPTNeoCache(Option<Vec<Option<GPTNeoLayerState>>>),
GPTJCache(Option<Vec<Option<GPTJLayerState>>>),
#[cfg(feature = "onnx")]
ONNXCache(ONNXLayerCache),
None,
}
pub(crate) mod private_generation_utils {
use rust_tokenizers::TokenIdsWithOffsets;
use std::cmp::{max, min};
use std::collections::HashMap;
use std::convert::TryFrom;
use std::mem;
use rust_tokenizers::tokenizer::{truncate_sequences, TruncationStrategy};
use tch::{nn, Device, Kind, Tensor};
use crate::pipelines::common::TokenizerOption;
use crate::pipelines::generation_utils::{
BeamHypotheses, Cache, GenerateConfig, LMModelOutput, PrefixAllowedFunction,
};
use super::ordered_float::OrderedFloat;
use crate::common::kind::{get_negative_infinity, get_positive_infinity};
use crate::RustBertError;
pub struct InternalGenerateOptions<'a> {
pub min_length: i64,
pub max_length: Option<i64>,
pub do_sample: bool,
pub temperature: f64,
pub top_k: i64,
pub top_p: f64,
pub repetition_penalty: f64,
pub no_repeat_ngram_size: i64,
pub pad_token_id: Option<i64>,
pub eos_token_ids: Option<Vec<i64>>,
pub num_return_sequences: i64,
pub early_stopping: bool,
pub num_beams: i64,
pub length_penalty: f64,
pub num_beam_groups: Option<i64>,
pub diversity_penalty: Option<f64>,
pub forced_bos_token_id: Option<i64>,
pub bad_word_ids: Option<&'a Vec<Vec<i64>>>,
}
pub struct PreparedInput<'a> {
pub prepared_input: Option<Tensor>,
pub prepared_attention_mask: Option<Tensor>,
pub prepared_encoder_output: Option<&'a Tensor>,
pub prepared_decoder_input: Option<Tensor>,
pub prepared_position_ids: Option<Tensor>,
pub prepared_past: Cache,
}
pub struct GeneratedOutputWithScores {
pub indices: Tensor,
pub scores: Option<Vec<f64>>,
pub token_scores: Option<Vec<Vec<f64>>>,
}
pub trait PrivateLanguageGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption;
fn get_device(&self) -> Device;
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError>;
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption;
fn get_config(&self) -> &GenerateConfig;
fn get_bos_id(&self) -> Option<i64>;
fn get_eos_ids(&self) -> Option<&Vec<i64>>;
fn get_forced_bos_token_id(&self) -> Option<i64> {
None
}
fn get_forced_eos_token_id(&self) -> Option<i64> {
None
}
fn get_pad_id(&self) -> Option<i64>;
fn is_encoder_decoder(&self) -> bool;
fn get_vocab_size(&self) -> i64;
fn get_decoder_start_id(&self) -> Option<i64>;
fn get_max_positions_embeddings(&self) -> Option<i64>;
fn forward_t(
&self,
input_ids: Option<&Tensor>,
layer_past: Cache,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError>;
fn prepare_scores_for_generation(
&self,
scores: &mut Tensor,
current_length: i64,
max_length: Option<i64>,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
if let Some(forced_bos_token_id) =
forced_bos_token_id.or(self.get_forced_bos_token_id())
{
force_token_id_generation(
scores,
&[forced_bos_token_id],
self.get_vocab_size(),
);
}
} else if let Some(max_length) = max_length {
if let Some(forced_eos_token_id) = self.get_forced_eos_token_id() {
if current_length == max_length - 1 {
force_token_id_generation(
scores,
&[forced_eos_token_id],
self.get_vocab_size(),
);
}
}
}
}
fn encode(&self, _input_ids: &Tensor, _attention_mask: Option<&Tensor>) -> Option<Tensor> {
None
}
fn prepare_inputs_for_generation<'a>(
&self,
input_ids: Tensor,
_encoder_outputs: Option<&'a Tensor>,
past: Cache,
attention_mask: Tensor,
) -> PreparedInput<'a> {
PreparedInput {
prepared_input: Some(input_ids),
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: None,
prepared_decoder_input: None,
prepared_position_ids: None,
prepared_past: past,
}
}
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Send + Sync,
{
let token_ids = if self.is_encoder_decoder() {
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>()
} else {
// Special tokens (e.g. BOS) are not added at the end of the prompt for causal generation
let tokens = self._get_tokenizer().tokenize_list(prompt_text);
let token_ids = tokens
.into_iter()
.map(|prompt_tokens| {
self._get_tokenizer().convert_tokens_to_ids(&prompt_tokens)
})
.collect::<Vec<Vec<i64>>>();
let num_truncated_tokens = token_ids
.iter()
.map(|token_ids| {
max_len
.map(|max_len| {
if token_ids.len() > max_len as usize {
token_ids.len() - max_len as usize
} else {
0
}
})
.unwrap_or(0)
})
.collect::<Vec<usize>>();
token_ids
.into_iter()
.zip(num_truncated_tokens)
.map(|(tokens, num_truncated_tokens)| {
truncate_sequences(
TokenIdsWithOffsets {
ids: tokens,
offsets: vec![],
reference_offsets: vec![],
masks: vec![],
},
None,
num_truncated_tokens,
&TruncationStrategy::LongestFirst,
0,
)
.unwrap()
.0
.ids
})
.collect::<Vec<Vec<i64>>>()
};
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let mut temp = vec![pad_token; max_len - input.len()];
if self.is_encoder_decoder() {
input.extend(temp);
input
} else {
// Pad left for causal generation
temp.extend(input);
temp
}
})
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn enforce_repetition_penalty(
&self,
next_token_logits: &mut Tensor,
batch_size: i64,
num_beams: i64,
prev_output_tokens: &Tensor,
repetition_penalty: f64,
) {
for i in 0..(batch_size * num_beams) {
for token_position in 0..prev_output_tokens.get(i).size()[0] {
let token = prev_output_tokens.get(i).int64_value(&[token_position]);
let updated_value = &next_token_logits.double_value(&[i, token]);
if updated_value < &0f64 {
let _ = next_token_logits.get(i).index_fill_(
0,
&Tensor::from_slice(&[token])
.to_kind(Kind::Int64)
.to_device(next_token_logits.device()),
updated_value * repetition_penalty,
);
} else {
let _ = next_token_logits.get(i).index_fill_(
0,
&Tensor::from_slice(&[token])
.to_kind(Kind::Int64)
.to_device(next_token_logits.device()),
updated_value / repetition_penalty,
);
}
}
}
}
fn get_banned_tokens(
&self,
input_ids: &Tensor,
no_repeat_ngram_size: i64,
cur_len: i64,
) -> Vec<Vec<i64>> {
// Ported from hugging face's transformers and fairseq (https://github.com/pytorch/fairseq/blob/master/fairseq/sequence_generator.py)
if cur_len + 1 < no_repeat_ngram_size {
vec![vec![]]
} else {
let input_ids = input_ids.to(Device::Cpu);
let num_hypothesis = *input_ids.size().first().unwrap();
let mut banned_tokens: Vec<Vec<i64>> = Vec::with_capacity(num_hypothesis as usize);
for hypothesis_index in 0..num_hypothesis {
let hypothesis_input_ids = input_ids.get(hypothesis_index);
let mut generated_ngram: HashMap<Vec<i64>, Vec<i64>> = HashMap::new();
let input: Vec<i64> = (0..hypothesis_input_ids.size1().unwrap()).collect();
let hypothesis_input_ids = hypothesis_input_ids
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>();
let query = &hypothesis_input_ids
[cur_len as usize + 1 - no_repeat_ngram_size as usize..]
.to_vec();
for ngram in input
.windows(no_repeat_ngram_size as usize)
.map(|win| (*win.first().unwrap(), *win.last().unwrap()))
{
let ngram = &hypothesis_input_ids[ngram.0 as usize..ngram.1 as usize + 1];
let key = ngram[..no_repeat_ngram_size as usize - 1].to_vec();
let value = *ngram.last().unwrap();
generated_ngram
.entry(key)
.or_insert_with(|| vec![value])
.push(value);
}
let hypothesis_banned_tokens = match generated_ngram.get(query) {
Some(banned_tokens) => banned_tokens.clone(),
None => vec![],
};
banned_tokens.push(hypothesis_banned_tokens);
}
banned_tokens
}
}
fn top_k_top_p_filtering(
&self,
logits: &mut Tensor,
top_k: i64,
top_p: f64,
min_tokens_to_keep: i64,
) {
// Nucleus and top-k filtering introduced by Holtzman et al. (http://arxiv.org/abs/1904.09751)
// Ported from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
let vocab_size = *logits.size().last().unwrap();
if top_k > 0 {
let top_k = vocab_size - min(max(top_k, min_tokens_to_keep), vocab_size);
let (_, indices_to_remove) = logits.topk(top_k, -1, false, false);
for index in 0..*logits.size().first().unwrap() {
let _ = logits.get(index).index_fill_(
0,
&indices_to_remove.get(index),
f64::NEG_INFINITY,
);
}
}
if top_p < 1f64 {
let (sorted_logits, sorted_indices) = logits.sort(-1, true);
let cumulative_probabilities = sorted_logits
.softmax(-1, sorted_logits.kind())
.cumsum(-1, sorted_logits.kind());
let mut sorted_indices_to_remove =
cumulative_probabilities.ge(top_p).to_kind(Kind::Int64);
if min_tokens_to_keep > 1 {
let _ = sorted_indices_to_remove.index_fill_(
1,
&Tensor::arange_start(
0,
min_tokens_to_keep + 1,
(Kind::Int64, logits.device()),
),
0,
);
}
let _ = sorted_indices_to_remove.index_copy_(
1,
&Tensor::arange_start(1, vocab_size, (Kind::Int64, logits.device())),
&sorted_indices_to_remove
.slice(1, 0, vocab_size - 1, 1)
.copy(),
);
let _ = sorted_indices_to_remove.index_fill_(
1,
&Tensor::from_slice(&[0])
.to_kind(Kind::Int64)
.to_device(sorted_indices_to_remove.device()),
0,
);
let indices_to_remove = sorted_indices_to_remove
.scatter(1, &sorted_indices, &sorted_indices_to_remove)
.to_kind(Kind::Bool);
let _ = logits.masked_fill_(&indices_to_remove, f64::NEG_INFINITY);
}
}
fn run_hamming_diversity_penalty(
&self,
scores: &mut Tensor,
current_tokens: &Tensor,
diversity_penalty: f64,
num_beams: i64,
batch_size: i64,
group_size: i64,
group_start_index: i64,
) {
if group_start_index > 0 {
let vocab_size = *scores.size().last().unwrap();
for batch_index in 0..batch_size {
let previous_group_tokens = current_tokens.slice(
0,
batch_index * num_beams,
batch_index * num_beams + group_start_index,
1,
);
let diversity_penalty = previous_group_tokens
.bincount::<Tensor>(None, vocab_size)
* diversity_penalty;
let _ = scores
.slice(
0,
batch_index * group_size,
(batch_index + 1) * group_size,
1,
)
.subtract_(&diversity_penalty);
}
}
}
fn apply_prefix_allowed_tokens_function(
&self,
prefix_allowed_tokens_fn: &dyn Fn(i64, &Tensor) -> Vec<i64>,
num_beams: i64,
input_ids: &Tensor,
scores: &mut Tensor,
) {
let mask = scores.new_full(
scores.size().as_slice(),
get_positive_infinity(scores.kind()).unwrap(),
(scores.kind(), scores.device()),
);
for idx in 0..scores.size()[0] {
let batch_id = idx / num_beams;
let allowed_tokens: Vec<i64> =
prefix_allowed_tokens_fn(batch_id, &input_ids.get(idx));
let _ = mask.get(idx).index_fill_(
0,
&Tensor::from_slice(allowed_tokens.as_slice()).to(scores.device()),
0,
);
}
let _ = scores.subtract_(&mask);
}
fn split_bad_word_ids<'a>(
&self,
bad_word_ids: Option<&'a Vec<Vec<i64>>>,
) -> (Option<Vec<i64>>, Option<Vec<&'a Vec<i64>>>) {
if let Some(bad_word_ids) = bad_word_ids {
let mut bad_word_ids_length_1 = vec![];
let mut bad_word_ids_length_greater_than_1 = vec![];
for bad_word in bad_word_ids {
if bad_word.len() == 1 {
bad_word_ids_length_1.push(bad_word[0]);
} else {
bad_word_ids_length_greater_than_1.push(bad_word);
}
}
let bad_word_ids_length_1 = if !bad_word_ids_length_1.is_empty() {
Some(bad_word_ids_length_1)
} else {
None
};
let bad_word_ids_length_greater_than_1 =
if !bad_word_ids_length_greater_than_1.is_empty() {
Some(bad_word_ids_length_greater_than_1)
} else {
None
};
(bad_word_ids_length_1, bad_word_ids_length_greater_than_1)
} else {
(None, None)
}
}
fn tokens_match(&self, prev_tokens: &[i64], tokens: &[i64]) -> bool {
if tokens.is_empty() {
true
} else if tokens.len() > prev_tokens.len() {
false
} else {
&prev_tokens[prev_tokens.len() - tokens.len()..] == tokens
}
}
fn calc_static_bad_word_mask(
&self,
scores: &Tensor,
bad_words_id_length_1: &[i64],
) -> Tensor {
let mut static_bad_words_mask =
Tensor::zeros([scores.size()[1]], (Kind::Int8, scores.device()));
let _ = static_bad_words_mask.index_fill_(
0,
&Tensor::from_slice(bad_words_id_length_1).to_device(scores.device()),
1,
);
static_bad_words_mask.unsqueeze(0).totype(Kind::Bool)
}
fn get_dynamic_bad_word_ids(
&self,
prev_tokens: &[Vec<i64>],
bad_word_ids_length_greater_than_1: &[&Vec<i64>],
) -> Vec<Vec<i64>> {
let mut banned_tokens = Vec::new();
for prev_token_sequence in prev_tokens {
let mut sequence_banned_tokens = Vec::new();
for bad_word_ids in bad_word_ids_length_greater_than_1 {
if self
.tokens_match(prev_token_sequence, &bad_word_ids[..bad_word_ids.len() - 1])
{
sequence_banned_tokens.push(*bad_word_ids.last().unwrap());
}
}
banned_tokens.push(sequence_banned_tokens);
}
banned_tokens
}
fn ban_bad_words(
&self,
dynamic_bad_words: Option<&Vec<&Vec<i64>>>,
static_bad_words_mask: Option<&Tensor>,
token_ids: &Tensor,
scores: &mut Tensor,
) {
let longest_bad_word = dynamic_bad_words
.iter()
.map(|bad_word| bad_word.len())
.max()
.unwrap() as i64;
let last_token_ids = token_ids.slice(1, -longest_bad_word, None, 1);
let mut prev_tokens = Vec::new();
for sequence_idx in 0..token_ids.size()[0] {
prev_tokens.push(
last_token_ids
.get(sequence_idx)
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>(),
)
}
let dynamic_bad_words_mask = if let Some(dynamic_bad_words) = dynamic_bad_words {
let dynamic_banned_tokens =
self.get_dynamic_bad_word_ids(&prev_tokens, dynamic_bad_words);
let dynamic_banned_mask =
Tensor::zeros(scores.size().as_slice(), (Kind::Int, scores.device()));
for (sequence_index, sequence_ban_tokens) in
dynamic_banned_tokens.iter().enumerate()
{
if !sequence_ban_tokens.is_empty() {
let _ = dynamic_banned_mask.get(sequence_index as i64).index_fill_(
0,
&Tensor::from_slice(sequence_ban_tokens).to_device(scores.device()),
1,
);
}
}
Some(dynamic_banned_mask.to_kind(Kind::Bool))
} else {
None
};
let combined_bad_word_mask = {
if let (Some(static_mask), Some(dynamic_mask)) =
(static_bad_words_mask, &dynamic_bad_words_mask)
{
Some(static_mask.bitwise_or_tensor(dynamic_mask))
} else {
None
}
};
let bad_word_mask = if combined_bad_word_mask.is_some() {
combined_bad_word_mask.as_ref()
} else if static_bad_words_mask.is_some() {
static_bad_words_mask
} else if dynamic_bad_words_mask.is_some() {
dynamic_bad_words_mask.as_ref()
} else {
None
};
if let Some(bad_word_mask) = bad_word_mask {
let _ = scores.masked_fill_(bad_word_mask, f64::NEG_INFINITY);
}
}
fn generate_no_beam_search(
&self,
input_ids: Tensor,
encoder_outputs: Option<Tensor>,
cur_len: i64,
batch_size: i64,
attention_mask: Tensor,
gen_opt: InternalGenerateOptions,
prefix_allowed_tokens_fn: Option<PrefixAllowedFunction>,
output_scores: bool,
) -> GeneratedOutputWithScores {
let mut unfinished_sentences =
Tensor::ones([batch_size], (Kind::Int64, self.get_device()));
let mut sentence_lengths: Tensor =
Tensor::ones([batch_size], (Kind::Int64, self.get_device()));
let (bad_word_ids_length_1, bad_word_ids_length_greater_than_1) =
self.split_bad_word_ids(gen_opt.bad_word_ids);
let mut static_bad_words_mask: Option<Tensor> = None;
let mut attention_mask = attention_mask.copy();
let mut input_ids = input_ids.copy();
let mut past: Cache = Cache::None;
let mut outputs: Tensor;
let mut current_length = cur_len;
let mut token_scores_output: Option<Vec<Tensor>> =
if output_scores { Some(vec![]) } else { None };
loop {
let prepared_input = self.prepare_inputs_for_generation(
input_ids.copy(),
encoder_outputs.as_ref(),
past,
attention_mask.copy(),
);
let temp = self
.forward_t(
prepared_input.prepared_input.as_ref(),
prepared_input.prepared_past,
prepared_input.prepared_attention_mask.as_ref(),
None,
prepared_input.prepared_position_ids.as_ref(),
None,
prepared_input.prepared_encoder_output,
prepared_input.prepared_decoder_input.as_ref(),
false,
)
.unwrap();
outputs = temp.lm_logits;
past = temp.cache;
let mut next_token_logits = outputs.select(1, -1);
// Reduce probability for repeated inputs
if gen_opt.repetition_penalty > 1f64 {
self.enforce_repetition_penalty(
&mut next_token_logits,
batch_size,
1,
&input_ids,
gen_opt.repetition_penalty,
)
}
// Get bad word_ids and set their probability to 0
if gen_opt.bad_word_ids.is_some() {
// Calculate static bad words masks if not set yet
if let Some(bad_word_ids_length_1) = &bad_word_ids_length_1 {
if static_bad_words_mask.is_none() {
static_bad_words_mask = Some(self.calc_static_bad_word_mask(
&next_token_logits,
bad_word_ids_length_1,
));
}
}
self.ban_bad_words(
bad_word_ids_length_greater_than_1.as_ref(),
static_bad_words_mask.as_ref(),
&input_ids,
&mut next_token_logits,
);
}
// Get banned tokens and set their probability to 0
if gen_opt.no_repeat_ngram_size > 0 {
let banned_tokens = self.get_banned_tokens(
&input_ids,
gen_opt.no_repeat_ngram_size,
current_length,
);
for (batch_index, index_banned_token) in
(0..banned_tokens.len() as i64).zip(banned_tokens)
{
let _ = next_token_logits.get(batch_index).index_fill_(
0,
&Tensor::from_slice(&index_banned_token)
.to_device(next_token_logits.device()),
f64::NEG_INFINITY,
);
}
}
// Apply custom prefix constraint function
if let Some(prefix_allowed_tokens_function) = prefix_allowed_tokens_fn {
self.apply_prefix_allowed_tokens_function(
prefix_allowed_tokens_function,
1,
&input_ids,
&mut next_token_logits,
)
}
// Do not allow eos token if min length is not reached
if (gen_opt.eos_token_ids.is_some()) & (current_length < gen_opt.min_length) {
let _ = next_token_logits.index_fill_(
1,
&Tensor::from_slice(gen_opt.eos_token_ids.as_ref().unwrap())
.to(next_token_logits.device()),
f64::NEG_INFINITY,
);
}
self.prepare_scores_for_generation(
&mut next_token_logits,
current_length,
gen_opt.max_length,
gen_opt.forced_bos_token_id,
);
// Top-k and top-p sampling
let next_token = if gen_opt.do_sample {
if gen_opt.temperature > 1f64 {
next_token_logits /= gen_opt.temperature;
}
self.top_k_top_p_filtering(
&mut next_token_logits,
gen_opt.top_k,
gen_opt.top_p,
1,
);
let probabilities = next_token_logits.softmax(-1, next_token_logits.kind());
probabilities.multinomial(1, false).squeeze_dim(1)
} else {
next_token_logits.argmax(-1, false)
};
if let Some(prev_scores) = token_scores_output.as_mut() {
let finished_mask = unfinished_sentences.eq(0);
prev_scores.push(
next_token_logits
.log_softmax(-1, next_token_logits.kind())
.gather(1, &next_token.reshape([-1, 1]), false)
.squeeze()
.masked_fill(&finished_mask, 0),
);
};
// Add tokens to unfinished sentences
let tokens_to_add = match &gen_opt.eos_token_ids {
Some(_) => {
next_token * &unfinished_sentences
- gen_opt.pad_token_id.unwrap() * (&unfinished_sentences - 1)
}
None => next_token,
};
input_ids = Tensor::cat(&[input_ids, tokens_to_add.unsqueeze(-1)], -1);
if gen_opt.eos_token_ids.is_some() {
for eos_token_id in gen_opt.eos_token_ids.as_ref().unwrap() {
let sentence_with_eos =
tokens_to_add.eq(*eos_token_id).to_kind(Kind::Int64);
let sentence_with_eos: Tensor = sentence_with_eos * &unfinished_sentences;
let _ = sentence_lengths.masked_fill_(
&sentence_with_eos
.to_kind(Kind::Bool)
.to_device(sentence_lengths.device()),
current_length + 1,
);
unfinished_sentences = -unfinished_sentences * (sentence_with_eos - 1);
}
if i64::try_from(unfinished_sentences.max()).unwrap() == 0 {
break;