Skip to content

Commit

Permalink
🐛 Fix special tokens for PLL MLM scores in TSE
Browse files Browse the repository at this point in the history
  • Loading branch information
jumelet committed Sep 28, 2022
1 parent 4a47a52 commit 3e8bf30
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 25 deletions.
15 changes: 15 additions & 0 deletions diagnnose/activations/selection_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Iterable, List

from torchtext.data import Example
from transformers import PreTrainedTokenizer

from diagnnose.typedefs.activations import SelectionFunc

Expand Down Expand Up @@ -42,6 +43,20 @@ def selection_func(w_idx: int, item: Example) -> bool:
return selection_func


def no_special_tokens(
tokenizer: PreTrainedTokenizer, sen_column: str = "sen"
) -> SelectionFunc:
def selection_func(w_idx: int, item: Example) -> bool:
sen = getattr(item, sen_column)

try:
return sen[w_idx] not in tokenizer.all_special_tokens
except IndexError:
raise

return selection_func


def first_n(n: int) -> SelectionFunc:
"""Wrapper that creates a selection_func that only returns True for
the first `n` items of a corpus.
Expand Down
7 changes: 6 additions & 1 deletion diagnnose/corpus/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,16 @@ def preprocess(text: Union[str, List[str]]) -> List[str]:
This allows us to still have access to the original tokens,
including those that will be mapped to <unk> later.
We cast the encoded text back to tokens for debugging purposes,
making it easier to inspect an example at a later stage.
"""
if isinstance(text, list):
text = " ".join(text)

return tokenizer.tokenize(text)
return tokenizer.convert_ids_to_tokens(
tokenizer.encode(text, add_special_tokens=True)
)

field.preprocessing = preprocess
field.pad_token = tokenizer.pad_token
Expand Down
20 changes: 15 additions & 5 deletions diagnnose/models/transformer_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,22 +195,32 @@ def _forward_pseudo_ll(
) -> Tensor:
mask_embedding = self.embeddings(torch.tensor(mask_idx, device=self.device))

sen_len = inputs_embeds.shape[1]
max_sen_len = inputs_embeds.shape[1]

pseudo_ll_logits = torch.zeros(
*inputs_embeds.shape[:2], self.nhid(activation_name), device=self.device
)

for w_idx in range(sen_len):
sen_column = batch.dataset.sen_column
sen_lens = getattr(batch, sen_column)[1]
for w_idx in range(max_sen_len):
if selection_func is not None:
sen_ids = []
for batch_idx, sen_idx in enumerate(batch.sen_idx):
if selection_func(w_idx, batch.dataset.examples[sen_idx]):
for batch_idx, (sen_idx, sen_len) in enumerate(
zip(batch.sen_idx, sen_lens)
):
if (w_idx < sen_len) and selection_func(
w_idx, batch.dataset.examples[sen_idx]
):
sen_ids.append(batch_idx)
if len(sen_ids) == 0:
continue
else:
sen_ids = slice(0, None)
sen_ids = [
batch_idx
for batch_idx, sen_len in enumerate(sen_lens)
if w_idx < sen_len
]

masked_inputs_embeds = inputs_embeds[sen_ids].clone()
masked_inputs_embeds[:, w_idx] = mask_embedding
Expand Down
49 changes: 35 additions & 14 deletions diagnnose/syntax/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from diagnnose.activations.selection_funcs import (
final_token,
no_special_tokens,
only_mask_token,
return_all,
)
Expand Down Expand Up @@ -123,7 +124,7 @@ def run(self) -> Tuple[AccuracyDict, ScoresDict]:

def _run_corpus(self, corpus: Corpus) -> pd.DataFrame:
if self.compare_full_sen:
selection_func = return_all
selection_func = no_special_tokens(self.tokenizer)
elif self.model.is_causal:
selection_func = final_token("sen")
else:
Expand All @@ -139,16 +140,20 @@ def _run_corpus(self, corpus: Corpus) -> pd.DataFrame:

if "counter_sen" in corpus.fields:
if self.compare_full_sen:
selection_func = return_all
counter_selection_func = no_special_tokens(
self.tokenizer, sen_column="counter_sen"
)
elif self.model.is_causal:
selection_func = final_token("counter_sen")
counter_selection_func = final_token("counter_sen")
else:
selection_func = only_mask_token(
counter_selection_func = only_mask_token(
self.tokenizer.mask_token, "counter_sen"
)

corpus.sen_column = "counter_sen"
counter_activations = self._calc_final_hidden(corpus, selection_func)
counter_activations = self._calc_final_hidden(
corpus, counter_selection_func
)
else:
counter_activations = None

Expand All @@ -157,6 +162,8 @@ def _run_corpus(self, corpus: Corpus) -> pd.DataFrame:
corpus,
activations,
counter_activations,
selection_func,
counter_selection_func,
)
else:
scores_df = self._calc_scores(
Expand Down Expand Up @@ -221,6 +228,8 @@ def _calc_full_sen_scores(
corpus: Corpus,
activations: Tensor,
counter_activations: Tensor,
selection_func: SelectionFunc,
counter_selection_func: SelectionFunc,
) -> pd.DataFrame:
scores_df = pd.DataFrame(
{
Expand All @@ -237,18 +246,30 @@ def _calc_full_sen_scores(
corpus, batch_size=1, device=self.model.device
)

for idx, (activation, counter_activation, item) in enumerate(
zip(activations, counter_activations, corpus_iterator)
for idx, (activation, counter_activation, batch_item, corpus_item) in enumerate(
zip(activations, counter_activations, corpus_iterator, corpus.examples)
):
sen_ids = item.sen[0][0]
sen = batch_item.sen[0].squeeze()
token_ids = [
token_idx
for w_idx, token_idx in enumerate(sen)
if selection_func(w_idx, corpus_item)
]
all_logits = self._decode(activation).log_softmax(-1)
logits = all_logits[range(len(sen_ids)), sen_ids]
scores[idx] = logits.mean()

counter_sen_ids = item.counter_sen[0][0]
logits = all_logits[range(len(token_ids)), token_ids]
scores[idx] = logits.sum()

counter_sen = batch_item.counter_sen[0].squeeze()
counter_token_ids = [
token_idx
for w_idx, token_idx in enumerate(counter_sen)
if counter_selection_func(w_idx, corpus_item)
]
all_logits = self._decode(counter_activation).log_softmax(-1)
counter_logits = all_logits[range(len(counter_sen_ids)), counter_sen_ids]
counter_scores[idx] = counter_logits.mean()
counter_logits = all_logits[
range(len(counter_token_ids)), counter_token_ids
]
counter_scores[idx] = counter_logits.sum()

scores_df["scores"] = scores
scores_df["counter_scores"] = counter_scores
Expand Down
16 changes: 11 additions & 5 deletions scripts/config/syntax_transformer.json
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
{
"model": {
"transformer_type": "roberta-base",
"cache_dir": "/media/jaap/81b6ce8a-28e5-4eda-9c68-b13e0637cc4f/transformers",
"mode": "masked_lm"
"transformer_type": "bert-base-uncased",
"mode": "masked_lm",
"compute_pseudo_ll": true
},
"downstream": {
"tasks": ["lakretz_transformer"],
"tasks": ["blimp"],
"ignore_unk": false,
"config": {
"lakretz_transformer": {
"path": "../lm_data/corpora/downstream/transformers/distilroberta-base/lakretz"
"path": "../lm_data/corpora/downstream/transformers/distilroberta-base/lakretz",
"subtasks": ["simple"]
},
"blimp": {
"path": "../lm_data/corpora/downstream/blimp",
"compare_full_sen": true,
"subtasks": ["principle_A_domain_1_sample"]
}
}
}
Expand Down

0 comments on commit 3e8bf30

Please sign in to comment.