-
Notifications
You must be signed in to change notification settings - Fork 164
/
output.py
856 lines (691 loc) · 36.8 KB
/
output.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
import os
import json
import ecco
from IPython import display as d
from ecco import util, lm_plots
import random
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn import functional as F
from sklearn import decomposition
from typing import Dict, Optional, List, Tuple, Union
from ecco.util import strip_tokenizer_prefix, is_partial_token
class OutputSeq:
"""An OutputSeq object is the result of running a language model on some input data. It contains not only the output
sequence of words generated by the model, but also other data collecting during the generation process
that is useful to analyze the model.
In addition to the data, the object has methods to create plots
and visualizations of that collected data. These include:
- [layer_predictions()](./#ecco.output.OutputSeq.layer_predictions) <br/>
Which tokens did the model consider as the best outputs for a specific position in the sequence?
- [rankings()](./#ecco.output.OutputSeq.rankings) <br/>
After the model chooses an output token for a specific position, this visual looks back at the ranking
of this token at each layer of the model when it was generated (layers assign scores to candidate output tokens,
the higher the "probability" score, the higher the ranking of the token).
- [rankings_watch()](./#ecco.output.OutputSeq.rankings_watch) <br />
Shows the rankings of multiple tokens as the model scored them for a single position. For example, if the input is
"The cat \_\_\_", we use this method to observe how the model ranked the words "is", "are", "was" as candidates
to fill in the blank.
- [primary_attributions()](./#ecco.output.OutputSeq.primary_attributions) <br />
How important was each input token in the selection of calculating the output token?
To process neuron activations, OutputSeq has methods to reduce the dimensionality and reveal underlying patterns in
neuron firings. These are:
- [run_nmf()](./#ecco.output.OutputSeq.run_nmf)
"""
def __init__(self,
token_ids=None,
n_input_tokens=None,
tokenizer=None,
output_text=None,
tokens=None,
encoder_hidden_states=None,
decoder_hidden_states=None,
embedding_states=None,
attribution=None,
activations=None,
collect_activations_layer_nums=None,
attention=None,
model_type: str= 'mlm',
lm_head=None,
device='cpu',
config=None):
"""
Args:
token_ids: The input token ids. Dimensions: (batch, position)
n_input_tokens: Int. The number of input tokens in the sequence.
tokenizer: huggingface tokenizer associated with the model generating this output
output_text: The output text generated by the model (if processed with generate())
tokens: A list of token text. Shorthand to passing the token ids by the tokenizer.
dimensions are (batch, position)
hidden_states: A tensor of dimensions (layer, position, hidden_dimension).
In layer, index 0 is for embedding hidden_state.
attribution: A list of attributions. One element per generated token.
Each element is a list giving a value for tokens from 0 to right before the generated token.
activations: The activations collected from model processing.
Shape is (batch, layer, neurons, position)
collect_activations_layer_nums:
attention: The attention tensor retrieved from the language model
model_outputs: Raw return object returned by the model
lm_head: The trained language model head from a language model projecting a
hidden state to an output vocabulary associated with teh tokenizer.
device: "cuda" or "cpu"
config: The configuration dict of the language model
"""
self.token_ids = token_ids
self.tokenizer = tokenizer
self.n_input_tokens = n_input_tokens
self.output_text = output_text
self.tokens = tokens
self.encoder_hidden_states = encoder_hidden_states
self.decoder_hidden_states = decoder_hidden_states
self.embedding_states = embedding_states
self.attribution = attribution
self.activations = activations
self.collect_activations_layer_nums = collect_activations_layer_nums
self.attention_values = attention
self.lm_head = lm_head
self.device = device
self.config = config
self.model_type = model_type
self._path = os.path.dirname(ecco.__file__)
def _get_encoder_hidden_states(self):
return self.encoder_hidden_states if self.encoder_hidden_states is not None else self.decoder_hidden_states
def _get_hidden_states(self) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]:
"""
Returns a tuple with (encoder hidden states, decoder hidden states)
"""
return (self.encoder_hidden_states, self.decoder_hidden_states)
def __str__(self):
return "<LMOutput '{}' # of lm outputs: {}>".format(self.output_text, len(self._get_hidden_states()[1][-1]))
def to(self, tensor: torch.Tensor):
if self.device == 'cuda':
return tensor.to('cuda')
return tensor
def explorable(self, printJson: Optional[bool] = False):
tokens = []
for idx, token in enumerate(self.tokens[0]):
type = "input" if idx < self.n_input_tokens else 'output'
tokens.append({'token': token,
'token_id': int(self.token_ids[0][idx]),
'type': type
})
data = {
'tokens': tokens
}
d.display(d.HTML(filename=os.path.join(self._path, "html", "setup.html")))
js = f"""
requirejs(['basic', 'ecco'], function(basic, ecco){{
const viz_id = basic.init()
ecco.renderOutputSequence({{
parentDiv: viz_id,
data: {data},
tokenization_config: {json.dumps(self.config['tokenizer_config'])}
}})
}}, function (err) {{
console.log(err);
}})"""
d.display(d.Javascript(js))
if printJson:
print(data)
def __call__(self, position=None, **kwargs):
if position is not None:
self.position(position, **kwargs)
else:
self.primary_attributions(**kwargs)
def position(self, position, attr_method='grad_x_input'):
if (position < self.n_input_tokens) or (position > len(self.tokens) - 1):
raise ValueError("'position' should indicate a position of a generated token. "
"Accepted values for this sequence are between {} and {}."
.format(self.n_input_tokens, len(self.tokens) - 1))
importance_id = position - self.n_input_tokens
tokens = []
assert attr_method in self.attribution, \
f"attr_method={attr_method} not found. Choose one of the following: {list(self.attribution.keys())}"
attribution = self.attribution[attr_method]
for idx, token in enumerate(self.tokens):
type = "input" if idx < self.n_input_tokens else 'output'
if idx < len(attribution[importance_id]):
imp = attribution[importance_id][idx]
else:
imp = -1
tokens.append({'token': token,
'token_id': int(self.token_ids[idx]),
'type': type,
'value': str(imp) # because json complains of floats
})
data = {
'tokens': tokens
}
d.display(d.HTML(filename=os.path.join(self._path, "html", "setup.html")))
# d.display(d.HTML(filename=os.path.join(self._path, "html", "basic.html")))
viz_id = 'viz_{}'.format(round(random.random() * 1000000))
js = """
requirejs(['basic', 'ecco'], function(basic, ecco){{
const viz_id = basic.init()
ecco.renderSeqHighlightPosition(viz_id, {}, {})
}}, function (err) {{
console.log(err);
}})""".format(position, data)
d.display(d.Javascript(js))
def primary_attributions(self,
attr_method: Optional[str] = 'grad_x_input',
style="minimal",
ignore_tokens: Optional[List[int]] = [],
**kwargs):
"""
Explorable showing primary attributions of each token generation step.
Hovering-over or tapping an output token imposes a saliency map on other tokens
showing their importance as features to that prediction.
Examples:
```python
import ecco
lm = ecco.from_pretrained('distilgpt2')
text= "The countries of the European Union are:\n1. Austria\n2. Belgium\n3. Bulgaria\n4."
output = lm.generate(text, generate=20, do_sample=True)
# Show primary attributions explorable
output.primary_attributions()
```
Which creates the following interactive explorable:
![input saliency example 1](../../img/saliency_ex_1.png)
If we want more details on the saliency values, we can use the detailed view:
```python
# Show detailed explorable
output.primary_attributions(style="detailed")
```
Which creates the following interactive explorable:
![input saliency example 2 - detailed](../../img/saliency_ex_2.png)
Details:
This view shows the Gradient * Inputs method of input saliency. The attribution values are calculated across the
embedding dimensions, then we use the L2 norm to calculate a score for each token (from the values of its embeddings dimension)
To get a percentage value, we normalize the scores by dividing by the sum of the attribution scores for all
the tokens in the sequence.
"""
position = self.n_input_tokens
importance_id = position - self.n_input_tokens
tokens = []
assert attr_method in self.attribution, \
f"attr_method={attr_method} not found. Choose one of the following: {list(self.attribution.keys())}"
attribution = self.attribution[attr_method]
for idx, token in enumerate(self.tokens[0]):
token_id = self.token_ids[0][idx]
raw_token = self.tokenizer.convert_ids_to_tokens([token_id])[0]
clean_token = self.tokenizer.decode(token_id)
# Strip prefixes because bert decode still has ## for partials even after decode()
clean_token = strip_tokenizer_prefix(self.config, clean_token)
type = "input" if idx < self.n_input_tokens else 'output'
if idx < len(attribution[importance_id]):
imp = attribution[importance_id][idx]
else:
imp = 0
tokens.append({'token': clean_token,
'token_id': int(self.token_ids[0][idx]),
'is_partial': is_partial_token(self.config, raw_token),
'type': type,
'value': str(imp), # because json complains of floats. Probably not used?
'position': idx
})
if len(ignore_tokens) > 0:
for output_token_index, _ in enumerate(attribution):
for idx in ignore_tokens:
attribution[output_token_index][idx] = 0
data = {
'tokens': tokens,
'attributions': [att.tolist() for att in attribution]
}
d.display(d.HTML(filename=os.path.join(self._path, "html", "setup.html")))
if (style == "minimal"):
js = f"""
requirejs(['basic', 'ecco'], function(basic, ecco){{
const viz_id = basic.init()
console.log(viz_id)
// ecco.interactiveTokens(viz_id, {{}})
window.ecco[viz_id] = new ecco.MinimalHighlighter({{
parentDiv: viz_id,
data: {json.dumps(data)},
preset: 'viridis',
tokenization_config: {json.dumps(self.config['tokenizer_config'])}
}})
window.ecco[viz_id].init();
window.ecco[viz_id].selectFirstToken();
}}, function (err) {{
console.log(err);
}})"""
elif (style == "detailed"):
js = f"""
requirejs(['basic', 'ecco'], function(basic, ecco){{
const viz_id = basic.init()
console.log(viz_id)
window.ecco[viz_id] = ecco.interactiveTokens({{
parentDiv: viz_id,
data: {json.dumps(data)},
tokenization_config: {json.dumps(self.config['tokenizer_config'])}
}})
}}, function (err) {{
console.log(err);
}})"""
d.display(d.Javascript(js))
if 'printJson' in kwargs and kwargs['printJson']:
print(data)
return data
def _repr_html_(self, **kwargs):
# if util.type_of_script() == "jupyter":
self.explorable(**kwargs)
return '<OutputSeq>'
def layer_predictions(self, position: int = 1, topk: Optional[int] = 10, layer: Optional[int] = None, **kwargs):
"""
Visualization plotting the topk predicted tokens after each layer (using its hidden state).
Example:
![prediction scores](../../img/layer_predictions_ex_london.png)
Args:
position: The index of the output token to trace
topk: Number of tokens to show for each layer
layer: None shows all layers. Can also pass an int with the layer id to show only that layer
"""
assert self.model_type != 'mlm', "method not supported for Masked-LMs"
_, dec_hidden_states = self._get_hidden_states()
assert dec_hidden_states is not None, "decoder hidden states not found"
if position == 0:
raise ValueError(f"'position' is set to 0. There is never a hidden state associated with this position."
f"Possible values are 1 and above -- the position of the token of interest in the sequence")
if self.model_type in ['enc-dec', 'causal']:
# The position is relative. By that means, position self.n_input_tokens + 1 is the first generated token
offset = 1 if self.model_type == 'enc-dec' else 0
new_position = position - offset - self.n_input_tokens
assert new_position >= 0, f"position={position} not supported, minimum is " \
f"position={self.n_input_tokens + offset} for the first generated token"
assert new_position < len(dec_hidden_states), f"position={position} not supported, maximum is " \
f"position={len(dec_hidden_states) - 1 + self.n_input_tokens + offset} " \
f"for the last generated token."
position = new_position
else:
raise NotImplemented(f"model_type={self.model_type} not supported")
dec_hidden_states = dec_hidden_states[position][:, -1, :] # only focus on the hidden states for that particular position
if layer is not None:
# If a layer is specified, choose it only.
assert dec_hidden_states is not None
dec_hidden_states = dec_hidden_states[layer].unsqueeze(0)
k = topk
top_tokens = []
probs = []
data = []
# loop through layer levels
for layer_no, h in enumerate(dec_hidden_states):
# Use lm_head to project the layer's hidden state to output vocabulary
logits = self.lm_head(self.to(h))
softmax = F.softmax(logits, dim=-1)
# softmax dims are (number of words in vocab) - 50257 in GPT2
sorted_softmax = self.to(torch.argsort(softmax))
# Not currently used. If we're "watching" a specific token, this gets its ranking
# idx = sorted_softmax.shape[0] - torch.nonzero((sorted_softmax == watch)).flatten()
layer_top_tokens = [self.tokenizer.decode(t) for t in sorted_softmax[-k:]][::-1]
top_tokens.append(layer_top_tokens)
layer_probs = softmax[sorted_softmax[-k:]].cpu().detach().numpy()[::-1]
probs.append(layer_probs.tolist())
# Package in output format
layer_data = []
for idx, (token, prob) in enumerate(zip(layer_top_tokens, layer_probs)):
layer_num = layer if layer is not None else layer_no
layer_data.append({'token': token,
'prob': str(prob),
'ranking': idx + 1,
'layer': layer_num
})
data.append(layer_data)
d.display(d.HTML(filename=os.path.join(self._path, "html", "setup.html")))
# d.display(d.HTML(filename=os.path.join(self._path, "html", "basic.html")))
js = f"""
requirejs(['basic', 'ecco'], function(basic, ecco){{
const viz_id = basic.init()
let pred = new ecco.LayerPredictions({{
parentDiv: viz_id,
data:{json.dumps(data)}
}})
pred.init()
}}, function (err) {{
console.log(viz_id, err);
}})"""
d.display(d.Javascript(js))
if 'printJson' in kwargs and kwargs['printJson']:
print(data)
return data
def rankings(self, **kwargs):
"""
Plots the rankings (across layers) of the tokens the model selected.
Each column is a position in the sequence. Each row is a layer.
![Rankings watch](../../img/rankings_ex_eu_1.png)
"""
assert self.model_type != 'mlm', "method not supported for Masked-LMs"
_, dec_hidden_states = self._get_hidden_states()
assert dec_hidden_states is not None, "decoder hidden states not found"
n_layers_dec = dec_hidden_states[0].shape[0]
position = len(dec_hidden_states)
rankings = np.zeros((n_layers_dec, position), dtype=np.int32)
predicted_tokens = np.empty((n_layers_dec, position), dtype='U25')
token_found_mask = np.ones((n_layers_dec, position))
# loop through tokens hidden states
for j, token_hidden_states in enumerate(dec_hidden_states):
# Loop through generated/output positions
for i, hidden_state in enumerate(token_hidden_states[:, -1, :]):
# Project hidden state to vocabulary
# (after debugging pain: ensure input is on GPU, if appropriate)
logits = self.lm_head(self.to(hidden_state))
# Sort by score (ascending)
sorted = torch.argsort(logits)
# What token was sampled in this position?
offset = self.n_input_tokens + 1 if self.model_type == 'enc-dec' else self.n_input_tokens
token_id = torch.tensor(self.token_ids[0][offset + j])
# token_id = self.token_ids.clone().detach()[self.n_input_tokens + j]
# What's the index of the sampled token in the sorted list?
r = torch.nonzero((sorted == token_id)).flatten()
# subtract to get ranking (where 1 is the top scoring, because sorting was in ascending order)
ranking = sorted.shape[0] - r
token = self.tokenizer.decode([token_id])
predicted_tokens[i, j] = token
rankings[i, j] = int(ranking)
if token_id == self.token_ids[0][j + 1]:
token_found_mask[i, j] = 0
input_tokens = [repr(strip_tokenizer_prefix(self.config, t)) for t in self.tokens[0][self.n_input_tokens - 1:-1]]
offset = self.n_input_tokens + 1 if self.model_type == 'enc-dec' else self.n_input_tokens
output_tokens = [repr(strip_tokenizer_prefix(self.config, t)) for t in self.tokens[0][offset:]]
lm_plots.plot_inner_token_rankings(input_tokens,
output_tokens,
rankings,
**kwargs)
if 'printJson' in kwargs and kwargs['printJson']:
data = {
'input_tokens': input_tokens,
'output_tokens': output_tokens,
'rankings': rankings,
'predicted_tokens': predicted_tokens,
'token_found_mask': token_found_mask
}
print(data)
return data
def rankings_watch(self, watch: List[int] = None, position: int = -1, **kwargs):
"""
Plots the rankings of the tokens whose ids are supplied in the watch list.
Only considers one position.
![Rankings plot](../../img/ranking_watch_ex_is_are_1.png)
"""
assert self.model_type != 'mlm', "method not supported for Masked-LMs"
_, dec_hidden_states = self._get_hidden_states()
assert dec_hidden_states is not None, "decoder hidden states not found"
if position != -1:
if self.model_type in ['enc-dec', 'causal']:
# The position is relative. By that means, position self.n_input_tokens + 1 is the first generated token
offset = 1 if self.model_type == 'enc-dec' else 0
new_position = position - offset - self.n_input_tokens
assert new_position >= 0, f"position={position} not supported, minimum is " \
f"position={self.n_input_tokens + offset} for the first generated token"
assert new_position < len(dec_hidden_states), f"position={position} not supported, maximum is " \
f"position={len(dec_hidden_states) - 1 + self.n_input_tokens + offset} " \
f"for the last generated token."
position = new_position
else:
raise NotImplemented(f"model_type={self.model_type} not supported")
dec_hidden_states = dec_hidden_states[position][:, -1, :]
n_layers_dec = len(dec_hidden_states) if dec_hidden_states is not None else 0
n_tokens_to_watch = len(watch)
rankings = np.zeros((n_layers_dec, n_tokens_to_watch), dtype=np.int32)
# loop through layer levels
for i, level in enumerate(dec_hidden_states):
# Loop through generated/output positions
for j, token_id in enumerate(watch):
# Project hidden state to vocabulary
# (after debugging pain: ensure input is on GPU, if appropriate)
logits = self.lm_head(self.to(level))
# Sort by score (ascending)
sorted = torch.argsort(logits)
# What token was sampled in this position?
token_id = torch.tensor(token_id)
# What's the index of the sampled token in the sorted list?
r = torch.nonzero((sorted == token_id)).flatten()
# subtract to get ranking (where 1 is the top scoring, because sorting was in ascending order)
ranking = sorted.shape[0] - r
rankings[i, j] = int(ranking)
input_tokens = [strip_tokenizer_prefix(self.config,t) for t in self.tokens[0]]
output_tokens = [repr(self.tokenizer.decode(t)) for t in watch]
lm_plots.plot_inner_token_rankings_watch(input_tokens, output_tokens, rankings,
position + self.n_input_tokens if self.model_type == 'enc-dec' else position)
if 'printJson' in kwargs and kwargs['printJson']:
data = {'input_tokens': input_tokens,
'output_tokens': output_tokens,
'rankings': rankings}
print(data)
return data
def attention(self, attention_values=None, layer=0, **kwargs):
position = self.n_input_tokens
# importance_id = position - self.n_input_tokens
importance_id = self.n_input_tokens - 1 # Sete first values to first output token
tokens = []
if attention_values:
attn = attention_values
else:
attn = self.attention_values[layer]
# normalize attention heads
attn = attn.sum(axis=1) / attn.shape[1]
for idx, token in enumerate(self.tokens):
# print(idx, attn.shape)
type = "input" if idx < self.n_input_tokens else 'output'
if idx < len(attn[0][importance_id]):
attention_value = attn[0][importance_id][idx].cpu().detach().numpy()
else:
attention_value = 0
tokens.append({'token': token,
'token_id': int(self.token_ids[idx]),
'type': type,
'value': str(attention_value), # because json complains of floats
'position': idx
})
data = {
'tokens': tokens,
'attributions': [att.tolist() for att in attn[0].cpu().detach().numpy()]
}
d.display(d.HTML(filename=os.path.join(self._path, "html", "setup.html")))
# d.display(d.HTML(filename=os.path.join(self._path, "html", "basic.html")))
viz_id = 'viz_{}'.format(round(random.random() * 1000000))
js = """
requirejs(['basic', 'ecco'], function(basic, ecco){{
const viz_id = basic.init()
ecco.interactiveTokens(viz_id, {})
}}, function (err) {{
console.log(err);
}})""".format(data)
d.display(d.Javascript(js))
if 'printJson' in kwargs and kwargs['printJson']:
print(data)
def run_nmf(self, **kwargs):
"""
Run Non-negative Matrix Factorization on network activations of FFNN. Returns an [NMF]() object which holds
the factorization model and data and methods to visualize them.
"""
return NMF(self.activations,
n_input_tokens=self.n_input_tokens,
token_ids=self.token_ids,
_path=self._path,
tokens=self.tokens,
config=self.config,
collect_activations_layer_nums=self.collect_activations_layer_nums,
**kwargs)
class NMF:
""" Conducts NMF and holds the models and components """
def __init__(self, activations: Dict[str, np.ndarray],
n_input_tokens: int = 0,
token_ids: torch.Tensor = torch.Tensor(0),
_path: str = '',
n_components: int = 10,
from_layer: Optional[int] = None,
to_layer: Optional[int] = None,
tokens: Optional[List[str]] = None,
collect_activations_layer_nums: Optional[List[int]] = None,
config=None,
**kwargs):
"""
Receives a neuron activations tensor from OutputSeq and decomposes it using NMF into the number
of components specified by `n_components`. For example, a model like `distilgpt2` has 18,000+
neurons. Using NMF to reduce them to 32 components can reveal interesting underlying firing
patterns.
Args:
activations: Activations tensor. Dimensions: (batch, layer, neuron, position)
n_input_tokens: Number of input tokens.
token_ids: List of tokens ids.
_path: Disk path to find javascript that create interactive explorables
n_components: Number of components/factors to reduce the neuron factors to.
tokens: The text of each token.
collect_activations_layer_nums: The list of layer ids whose activtions were collected. If
None, then all layers were collected.
"""
if activations == []:
raise ValueError(f"No activation data found. Make sure 'activations=True' was passed to "
f"ecco.from_pretrained().")
self._path = _path
self.token_ids = token_ids
self.n_input_tokens = n_input_tokens
self.config = config
# Joining Encoder and Decoder (if exists) together
activations = np.concatenate(list(activations.values()), axis=-1)
merged_act = self.reshape_activations(activations,
from_layer,
to_layer,
collect_activations_layer_nums)
# 'merged_act' is now ( neuron (and layer), position (and batch) )
activations = merged_act
self.tokens = tokens
# Run NMF. 'activations' is neuron activations shaped (neurons (and layers), positions (and batches))
n_output_tokens = activations.shape[-1]
n_layers = activations.shape[0]
n_components = min([n_components, n_output_tokens])
components = np.zeros((n_layers, n_components, n_output_tokens))
models = []
# Get rid of negative activation values
# (There are some, because GPT2 uses GELU, which allow small negative values)
self.activations = np.maximum(activations, 0).T
self.model = decomposition.NMF(n_components=n_components,
init='random',
random_state=0,
max_iter=500)
self.components = self.model.fit_transform(self.activations).T
@staticmethod
def reshape_activations(activations,
from_layer: Optional[int] = None,
to_layer: Optional[int] = None,
collect_activations_layer_nums: Optional[List[int]] = None):
"""Prepares the activations tensor for NMF by reshaping it from four dimensions
(batch, layer, neuron, position) down to two:
( neuron (and layer), position (and batch) ).
Args:
activations (tensor): activations tensors of shape (batch, layers, neurons, positions) and float values
from_layer (int or None): Start value. Used to indicate a range of layers whose activations are to
be processed
to_layer (int or None): End value. Used to indicate a range of layers
collect_activations_layer_nums (list of ints or None): A list of layer IDs. Used to indicate specific
layers whose activations are to be processed
"""
if len(activations.shape) != 4:
raise ValueError(f"The 'activations' parameter should have four dimensions: "
f"(batch, layers, neurons, positions). "
f"Supplied dimensions: {activations.shape}", 'activations')
if collect_activations_layer_nums is None:
collect_activations_layer_nums = list(range(activations.shape[1]))
layer_nums_to_row_ixs = {layer_num: i
for i, layer_num in enumerate(collect_activations_layer_nums)}
if from_layer is not None or to_layer is not None:
from_layer = from_layer if from_layer is not None else 0
to_layer = to_layer if to_layer is not None else activations.shape[0]
if from_layer == to_layer:
raise ValueError(f"from_layer ({from_layer}) and to_layer ({to_layer}) cannot be the same value. "
"They must be apart by at least one to allow for a layer of activations.")
if from_layer > to_layer:
raise ValueError(f"from_layer ({from_layer}) cannot be larger than to_layer ({to_layer}).")
layer_nums = list(range(from_layer, to_layer))
else:
layer_nums = sorted(layer_nums_to_row_ixs.keys())
if any([num not in layer_nums_to_row_ixs for num in layer_nums]):
available = sorted(layer_nums_to_row_ixs.keys())
raise ValueError(f"Not all layers between from_layer ({from_layer}) and to_layer ({to_layer}) "
f"have recorded activations. Layers with recorded activations are: {available}")
row_ixs = [layer_nums_to_row_ixs[layer_num] for layer_num in layer_nums]
activation_rows = [activations[:, row_ix] for row_ix in row_ixs]
# Merge 'layers' and 'neuron' dimensions. Sending activations down from
# (batch, layer, neuron, position) to (batch, neuron, position)
merged_act = np.concatenate(activation_rows, axis=1)
# merged_act = np.stack(activation_rows, axis=1)
# 'merged_act' is now (batch, neuron (and layer), position)
merged_act = merged_act.swapaxes(0, 1)
# 'merged_act' is now (neuron (and layer), batch, position)
merged_act = merged_act.reshape(merged_act.shape[0], -1)
return merged_act
def explore(self, input_sequence: int = 0, **kwargs):
"""
Show interactive explorable for a single sequence with sparklines to isolate factors.
Example:
![NMF Example](../../img/nmf_ex_1.png)
Args:
input_sequence: Which sequence in the batch to show.
"""
tokens = []
for idx, token in enumerate(self.tokens[input_sequence]): # self.tokens[:-1]
type = "input" if idx < self.n_input_tokens else 'output'
tokens.append({'token': token,
'token_id': int(self.token_ids[input_sequence][idx]),
# 'token_id': int(self.token_ids[idx]),
'type': type,
# 'value': str(components[0][comp_num][idx]), # because json complains of floats
'position': idx
})
# If the sequence contains both input and generated tokens:
# Duplicate the factor at index 'n_input_tokens'. THis way
# each token has an activation value (instead of having one activation less than tokens)
# But with different meanings: For inputs, the activation is a response
# For outputs, the activation is a cause
if len(self.token_ids[input_sequence]) != self.n_input_tokens:
# Case: Generation. Duplicate value of last input token.
factors = np.array(
[np.concatenate([comp[:self.n_input_tokens], comp[self.n_input_tokens - 1:]]) for comp in
self.components])
factors = [comp.tolist() for comp in factors] # the json conversion needs this
else:
# Case: no generation
factors = [comp.tolist() for comp in self.components] # the json conversion needs this
data = {
# A list of dicts. Each in the shape {
# Example: [{'token': 'by', 'token_id': 2011, 'type': 'input', 'position': 235}]
'tokens': tokens,
# Three-dimensional list. Shape: (1, factors, sequence length)
'factors': [factors]
}
d.display(d.HTML(filename=os.path.join(self._path, "html", "setup.html")))
js = f"""
requirejs(['basic', 'ecco'], function(basic, ecco){{
const viz_id = basic.init()
ecco.interactiveTokensAndFactorSparklines(viz_id, {data},
{{
'hltrCFG': {{'tokenization_config': {json.dumps(self.config['tokenizer_config'])}
}}
}})
}}, function (err) {{
console.log(err);
}})"""
d.display(d.Javascript(js))
if 'printJson' in kwargs and kwargs['printJson']:
print(data)
return data
def plot(self, n_components=3):
for idx, comp in enumerate(self.components):
# print('Layer {} components'.format(idx), 'Variance: {}'.format(lm.variances[idx][:n_components]))
print('Layer {} components'.format(idx))
comp = comp[:n_components, :].T
# plt.figure(figsize=(16,2))
fig, ax1 = plt.subplots(1)
plt.subplots_adjust(wspace=.4)
fig.set_figheight(2)
fig.set_figwidth(17)
# fig.tight_layout()
# PCA Line plot
ax1.plot(comp)
ax1.set_xticks(range(len(self.tokens)))
ax1.set_xticklabels(self.tokens, rotation=-90)
ax1.legend(['Component {}'.format(i + 1) for i in range(n_components)], loc='center left',
bbox_to_anchor=(1.01, 0.5))
plt.show()