diff --git a/inseq/data/viz.py b/inseq/data/viz.py index 3f07bcda..78e3561e 100644 --- a/inseq/data/viz.py +++ b/inseq/data/viz.py @@ -102,14 +102,17 @@ def show_attributions( display(HTML(curr_html)) html_out += curr_html if not isnotebook(): - curr_color = colors[idx] + curr_color = None if attribution.source_attributions is not None: + curr_color = colors[idx] if display: print("\n\n") rprint(get_heatmap_type(attribution, curr_color, "Source", use_html=False)) if attribution.target_attributions is not None: curr_color = colors[idx + 1] if attribution.target_attributions is not None and display: + if curr_color is None and colors: + curr_color = colors[idx] print("\n\n") rprint(get_heatmap_type(attribution, curr_color, "Target", use_html=False)) if any(x is None for x in [attribution.source_attributions, attribution.target_attributions]): diff --git a/inseq/models/huggingface_model.py b/inseq/models/huggingface_model.py index dcf98160..49740436 100644 --- a/inseq/models/huggingface_model.py +++ b/inseq/models/huggingface_model.py @@ -191,6 +191,7 @@ def generate( self, inputs: Union[TextInput, BatchEncoding], return_generation_output: bool = False, + skip_special_tokens: bool = True, **kwargs, ) -> Union[List[str], Tuple[List[str], ModelOutput]]: """Wrapper of model.generate to handle tokenization and decoding. @@ -216,7 +217,7 @@ def generate( **kwargs, ) sequences = generation_out.sequences - texts = self.decode(ids=sequences, skip_special_tokens=True) + texts = self.decode(ids=sequences, skip_special_tokens=skip_special_tokens) if return_generation_output: return texts, generation_out return texts