From 9e88a9ebf3788043ec6cc0c6fa334f594a29b154 Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Thu, 22 Dec 2022 14:26:33 +0100 Subject: [PATCH] Fix conflicting generation args (#155) --- docs/source/_static/inseq.js | 16 +++++++--------- inseq/models/huggingface_model.py | 12 +----------- 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/docs/source/_static/inseq.js b/docs/source/_static/inseq.js index 3233cd78..e550c1fd 100644 --- a/docs/source/_static/inseq.js +++ b/docs/source/_static/inseq.js @@ -61,15 +61,11 @@ function resizeHtmlExamples() { for (const ex of examples) { const iframe = ex.firstElementChild; const zoom = iframe.getAttribute("scale") - const origHeight = iframe.contentWindow.document.body.scrollHeight - const origWidth = iframe.contentWindow.document.body.scrollWidth - ex.style.height = ((origHeight * zoom) + 50) + "px"; - const frameHeight = origHeight / zoom - const frameWidth = origWidth / zoom + ex.style.height = ((iframe.contentWindow.document.body.scrollHeight * zoom) + 50) + "px"; // add extra 50 pixels - in reality need just a bit more - iframe.style.height = frameHeight + "px" + iframe.style.height = (iframe.contentWindow.document.body.scrollHeight / zoom) + "px" // set the width of the iframe as the width of the iframe content - iframe.style.width = frameWidth + 'px'; + iframe.style.width = (iframe.contentWindow.document.body.scrollWidth / zoom) + 'px'; iframe.style.zoom = zoom; iframe.style.MozTransform = `scale(${zoom})`; iframe.style.WebkitTransform = `scale(${zoom})`; @@ -83,12 +79,14 @@ function resizeHtmlExamples() { function onLoad() { addIcon(); addCustomFooter(); + resizeHtmlExamples(); } window.addEventListener("load", onLoad); window.onresize = function() { var wwidth = $(window).width(); - if(curr_width!==wwidth){ - resizeHtmlExamples(); + if( curr_width !== wwidth ){ + window.location.reload(); + curr_width = wwidth; } } diff --git a/inseq/models/huggingface_model.py b/inseq/models/huggingface_model.py index 8e9ce5e8..1db165d0 100644 --- a/inseq/models/huggingface_model.py +++ b/inseq/models/huggingface_model.py @@ -65,7 +65,6 @@ def __init__( attribution_method: Optional[str] = None, tokenizer: Union[str, PreTrainedTokenizer, None] = None, device: Optional[str] = None, - model_max_length: Optional[int] = 512, **kwargs, ) -> None: """ @@ -81,8 +80,6 @@ def __init__( attribution_method (str, optional): The attribution method to use. Passing it here reduces overhead on attribute call, since it is already initialized. - model_max_length (int, optional): The maximum length of the model. If not provided, will be inferred from - the model config. **kwargs: additional arguments for the model and the tokenizer. """ super().__init__(**kwargs) @@ -110,9 +107,7 @@ def __init__( if isinstance(tokenizer, PreTrainedTokenizer): self.tokenizer = tokenizer else: - self.tokenizer = AutoTokenizer.from_pretrained( - tokenizer, *tokenizer_inputs, model_max_length=model_max_length, **tokenizer_kwargs - ) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, *tokenizer_inputs, **tokenizer_kwargs) if self.model.config.pad_token_id is not None: self.pad_token = self.tokenizer.convert_ids_to_tokens(self.model.config.pad_token_id) self.tokenizer.pad_token = self.pad_token @@ -122,7 +117,6 @@ def __init__( if self.tokenizer.unk_token_id is None: self.tokenizer.unk_token_id = self.tokenizer.pad_token_id self.embed_scale = 1.0 - self.model_max_length = model_max_length self.encoder_int_embeds = None self.decoder_int_embeds = None self.is_encoder_decoder = self.model.config.is_encoder_decoder @@ -167,7 +161,6 @@ def generate( self, inputs: Union[TextInput, BatchEncoding], return_generation_output: bool = False, - max_new_tokens: Optional[int] = None, **kwargs, ) -> Union[List[str], Tuple[List[str], ModelOutput]]: """Wrapper of model.generate to handle tokenization and decoding. @@ -186,14 +179,11 @@ def generate( isinstance(inputs, list) and len(inputs) > 0 and all([isinstance(x, str) for x in inputs]) ): inputs = self.encode(inputs) - if max_new_tokens is None: - max_new_tokens = self.model_max_length - inputs.input_ids.shape[-1] inputs = inputs.to(self.device) generation_out = self.model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, return_dict_in_generate=True, - max_new_tokens=max_new_tokens, **kwargs, ) texts = self.tokenizer.batch_decode(