-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils_fns.py
188 lines (135 loc) · 6.15 KB
/
utils_fns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from jaxtyping import Float, Int
from typing import Tuple, Optional, List, Union, Dict
import re
import torch
from torch import Tensor
from transformer_lens import HookedTransformer
import einops
import numpy as np
from eindex import eindex
Arr = np.ndarray
def k_largest_indices(
x: Float[Tensor, "rows cols"],
k: int,
largest: bool = True,
buffer: Tuple[int, int] = (5, 5),
) -> Int[Tensor, "k 2"]:
'''w
Given a 2D array, returns the indices of the top or bottom `k` elements.
Also has a `buffer` argument, which makes sure we don't pick too close to the left/right of sequence. If `buffer`
is (5, 5), that means we shouldn't be allowed to pick the first or last 5 sequence positions, because we'll need
to append them to the left/right of the sequence. We should only be allowed from [5:-5] in this case.
'''
x = x[:, buffer[0]:-buffer[1]]
indices = x.flatten().topk(k=k, largest=largest).indices
rows = indices // x.size(1)
cols = indices % x.size(1) + buffer[0]
return torch.stack((rows, cols), dim=1)
def sample_unique_indices(large_number, small_number):
'''Samples a small number of unique indices from a large number of indices.'''
weights = torch.ones(large_number) # Equal weights for all indices
sampled_indices = torch.multinomial(weights, small_number, replacement=False)
return sampled_indices
def random_range_indices(
x: Float[Tensor, "batch seq"],
bounds: Tuple[float, float],
k: int,
buffer: Tuple[int, int] = (5, 5),
) -> Int[Tensor, "k 2"]:
'''
Given a 2D array, returns the indices of `k` elements whose values are in the range `bounds`.
Will return fewer than `k` values if there aren't enough values in the range.
Also has a `buffer` argument, which makes sure we don't pick too close to the left/right of sequence.
'''
# Limit x, because our indices (bolded words) shouldn't be too close to the left/right of sequence
x = x[:, buffer[0]:-buffer[1]]
# Creat a mask for where x is in range, and get the indices as a tensor of shape (k, 2)
mask = (bounds[0] <= x) & (x <= bounds[1])
indices = torch.stack(torch.where(mask), dim=-1)
# If we have more indices than we need, randomly select k of them
if len(indices) > k:
indices = indices[sample_unique_indices(len(indices), k)]
# Adjust indices to account for the buffer
return indices + torch.tensor([0, buffer[0]]).to(indices.device)
# # Example, where it'll pick the elements from the end of this 2D tensor, working backwards
# x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# k = 3
# print(k_largest_indices(x, k)) # Output: tensor([[1, 2], [1, 1], [1, 0]])
# # Example, where it'll pick one of (0.2, 0.3) cause they're the only ones within range
# x = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]])
# bounds = (0.15, 0.35)
# k = 1
# print(random_range_indices(x, bounds, k))
def to_str_tokens(vocab_dict: Dict[int, str], tokens: Union[int, torch.Tensor]):
'''
If tokens is 1D, does the same thing as model.to_str_tokens.
If tokens is 2D or 3D, it flattens, does this thing, then reshapes.
Also, makes sure that line breaks are replaced with their repr.
'''
if isinstance(tokens, int):
return vocab_dict[tokens]
assert tokens.ndim <= 3
# Get flattened list of tokens
str_tokens = [vocab_dict[t] for t in tokens.flatten().tolist()]
# Replace line breaks with things that will appear as the literal '\n' in HTML
str_tokens = [s.replace("\n", "\n") for s in str_tokens]
# str_tokens = [s.replace(" ", " ") for s in str_tokens]
# Reshape
return reshape(str_tokens, tokens.shape)
def reshape(my_list, shape):
assert np.prod(shape) == len(my_list), "Shape is not compatible with list size"
assert len(shape) in [1, 2, 3], "Only shapes of length 1, 2, or 3 are supported"
if len(shape) == 1:
return my_list
it = iter(my_list)
if len(shape) == 2:
return [[next(it) for _ in range(shape[1])] for _ in range(shape[0])]
return [[[next(it) for _ in range(shape[2])] for _ in range(shape[1])] for _ in range(shape[0])]
class TopK:
'''
Wrapper around the object returned by torch.topk, which has the following 3 advantages:
> friendlier to type annotation
> easy device moving, without having to do it separately for values & indices
> easy indexing, without having to do it separately for values & indices
> other classic tensor operations, like .ndim, .shape, etc. work as expected
We initialise with a topk object, which is treated as a tuple of (values, indices).
'''
def __init__(self, obj: Optional[Tuple[Arr, Arr]] = None):
self.values: Arr = obj[0] if isinstance(obj[0], Arr) else obj[0].detach().cpu().numpy()
self.indices: Arr = obj[1] if isinstance(obj[1], Arr) else obj[1].detach().cpu().numpy()
def __getitem__(self, item):
return TopK((self.values[item], self.indices[item]))
def concat(self, other: "TopK"):
'''If self is empty, returns the other (so we can start w/ empty & concatenate consistently).'''
if self.numel() == 0:
return other
else:
return TopK((
np.concatenate((self.values, other.values)),
np.concatenate((self.indices, other.indices))
))
@property
def ndim(self):
return self.values.ndim
@property
def shape(self):
return self.values.shape
@property
def size(self):
return self.values.size()
def numel(self):
return self.values.size
class Output:
'''So I can type annotate the output of transformer.'''
loss: Tensor
logits: Tensor
def merge_lists(*lists):
return [item for sublist in lists for item in sublist]
def extract_and_remove_scripts(html_content) -> Tuple[str, str]:
# Pattern to find <script>...</script> tags
pattern = r'<script[^>]*>.*?</script>'
# Find all script tags
scripts = re.findall(pattern, html_content, re.DOTALL)
# Remove script tags from the original content
html_without_scripts = re.sub(pattern, '', html_content, flags=re.DOTALL)
return "\n".join(scripts), html_without_scripts