-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attribution
Classes refactoring, Aggregator
for postponed score aggregation
#130
Conversation
Final touches missing, save/load on disk broken due to |
Adding EDIT: thanks to |
Everything should be in place at this point, only some minimal testing for the two aggregators is still missing before the merge. |
Attribution
classes refactoring, Aggregator
for postponed score aggregation
Finished tests, merging. Summary of the changes:
|
Attribution
classes refactoring, Aggregator
for postponed score aggregationAttribution
Classes refactoring, Aggregator
for postponed score aggregation
Description
Currently, attributions are aggregated at the token level during the step-by-step computation, providing only the aggregated per-token scores as output. Ideally, we want to make the aggregation step happen as late as possible (TBD) to enable various analyses (e.g. per-neuron attributions, different aggregation strategies) that are currently not supported.
The changes performed are the following:
batch.py
to use a sharedTensorWrapper
backbone class for common operations.FeatureAttributionRawStepOutput
andFeatureAttributionStepOutput
and the transition methodFeatureAttribution.make_attribution_output
to account for preserving the attribution tensors.FeatureAttributionSequenceOutput.from_step_attributions
to takemax_target_seq_len
step-attributions of shapebatch_size, max_source_seq_len, hidden_size
and produce source-target attributions of sizessource_seq_len, target_seq_len, hidden_size
andtarget_seq_len, target_seq_len, hidden_size
respectively, where the sequence lengths are variable for every generated object (i.e. remove end padding by truncating on token seq length).Aggregator
abstract class and aSumNormAggregator
(current strategy, will be used as default). Add anaggregator
field inFeatureAttributionSequenceOutput
that doesn't get picked up when saving and usesSumNormAggregator
as default. The aggregator attached to the class is used as default behavior to aggregate attributions when printing the object, using.show()
, usingmaximum
,minimum
, etc.Idea for the
Aggregator
design: Aggregator contains a mapping fromFeatureAttributionSequenceOutput
field names to functions that need to be applied to them. Some checks need to be performed after the full aggregation process to ensure that theshow
method will work:source_attributions
have a shape corresponding tosource_tokens
xtarget_tokens
target_attributions
have a shape corresponding totarget_tokens
xtarget_tokens
probabilities
, if present, have a shape oftarget_tokens
target_tokens
(e.g. deltas)