Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fixes #233

Merged
merged 2 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Inseq is a Pytorch-based hackable toolkit to democratize the access to common po

## Installation

Inseq is available on PyPI and can be installed with `pip`:
Inseq is available on PyPI and can be installed with `pip` for Python >= 3.9, <= 3.11:

```bash
# Install latest stable version
Expand Down Expand Up @@ -270,6 +270,7 @@ Inseq has been used in various research projects. A list of known publications t
<li> <a href="https://aclanthology.org/2023.nlp4convai-1.1/">Response Generation in Longitudinal Dialogues: Which Knowledge Representation Helps?</a> (Mousavi et al., 2023) </li>
<li> <a href="https://arxiv.org/abs/2310.01188">Quantifying the Plausibility of Context Reliance in Neural Machine Translation</a> (Sarti et al., 2023)</li>
<li> <a href="https://arxiv.org/abs/2310.12127">A Tale of Pronouns: Interpretability Informs Gender Bias Mitigation for Fairer Instruction-Tuned Machine Translation</a> (Attanasio et al., 2023)</li>
<li> <a href="https://arxiv.org/abs/2310.09820">Assessing the Reliability of Large Language Model Knowledge</a> (Wang et al., 2023)</li>
</ol>

</details>
4 changes: 1 addition & 3 deletions inseq/attr/feat/internals_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def attribute(
else:
target_attributions = None
sequence_scores["decoder_self_attentions"] = decoder_self_attentions
sequence_scores["encoder_self_attentions"] = (
encoder_self_attentions[..., -1, :].clone().permute(0, 3, 1, 2)
)
sequence_scores["encoder_self_attentions"] = encoder_self_attentions.clone().permute(0, 3, 4, 1, 2)
return MultiDimensionalFeatureAttributionStepOutput(
source_attributions=cross_attentions[..., -1, :].clone().permute(0, 3, 1, 2),
target_attributions=target_attributions,
Expand Down
1 change: 1 addition & 0 deletions inseq/attr/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def kl_divergence_fn(
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
return_contrastive_target_ids=False,
return_contrastive_batch=True,
)
c_forward_output = args.attribution_model.get_forward_output(
contrast_inputs.batch, use_embeddings=args.attribution_model.is_encoder_decoder
Expand Down
9 changes: 6 additions & 3 deletions inseq/data/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,12 @@ def from_step_attributions(
# that are not source-to-target (default for encoder-decoder) or target-to-target
# (default for decoder only).
remove_pad_fn = cls.get_remove_pad_fn(attr, seq_score_name)
out_seq_scores = get_sequences_from_batched_steps(
[att.sequence_scores[seq_score_name] for att in attributions]
)
if seq_score_name.startswith("encoder"):
out_seq_scores = [attr.sequence_scores[seq_score_name][i, ...] for i in range(num_sequences)]
else:
out_seq_scores = get_sequences_from_batched_steps(
[att.sequence_scores[seq_score_name] for att in attributions]
)
for seq_id in range(num_sequences):
seq_scores[seq_id][seq_score_name] = remove_pad_fn(out_seq_scores, sources, targets, seq_id)
for seq_id in range(num_sequences):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ known-first-party = ["inseq"]
order-by-type = true

[tool.ruff.pylint]
max-branches = 20
max-branches = 22

[tool.ruff.pyupgrade]
keep-runtime-typing = true
Expand Down