Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Aug 25, 2023
1 parent b6b8b13 commit a8ec7b6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
5 changes: 4 additions & 1 deletion inseq/data/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,17 @@ def show_attributions(
display(HTML(curr_html))
html_out += curr_html
if not isnotebook():
curr_color = colors[idx]
curr_color = None
if attribution.source_attributions is not None:
curr_color = colors[idx]
if display:
print("\n\n")
rprint(get_heatmap_type(attribution, curr_color, "Source", use_html=False))
if attribution.target_attributions is not None:
curr_color = colors[idx + 1]
if attribution.target_attributions is not None and display:
if curr_color is None and colors:
curr_color = colors[idx]
print("\n\n")
rprint(get_heatmap_type(attribution, curr_color, "Target", use_html=False))
if any(x is None for x in [attribution.source_attributions, attribution.target_attributions]):
Expand Down
3 changes: 2 additions & 1 deletion inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def generate(
self,
inputs: Union[TextInput, BatchEncoding],
return_generation_output: bool = False,
skip_special_tokens: bool = True,
**kwargs,
) -> Union[List[str], Tuple[List[str], ModelOutput]]:
"""Wrapper of model.generate to handle tokenization and decoding.
Expand All @@ -216,7 +217,7 @@ def generate(
**kwargs,
)
sequences = generation_out.sequences
texts = self.decode(ids=sequences, skip_special_tokens=True)
texts = self.decode(ids=sequences, skip_special_tokens=skip_special_tokens)
if return_generation_output:
return texts, generation_out
return texts
Expand Down

0 comments on commit a8ec7b6

Please sign in to comment.