Skip to content

Commit

Permalink
Set correct nhid for HF TF models
Browse files Browse the repository at this point in the history
  • Loading branch information
jumelet committed Apr 13, 2023
1 parent f5d5f07 commit cb09d3e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
5 changes: 2 additions & 3 deletions diagnnose/models/import_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@ def import_model(*args, **kwargs) -> LanguageModel:

def _import_transformer_lm(*args, **kwargs):
""" Imports a Transformer LM. """
from .fairseq_lm import FairseqLM
from .huggingface_lm import HuggingfaceLM

if kwargs["transformer_type"] == "fairseq":
from .fairseq_lm import FairseqLM
return FairseqLM(*args, **kwargs)

from .huggingface_lm import HuggingfaceLM
return HuggingfaceLM(*args, **kwargs)


Expand Down
10 changes: 7 additions & 3 deletions diagnnose/models/transformer_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,14 @@ def top_layer(self) -> int:
return -1

def nhid(self, activation_name: ActivationName) -> int:
if activation_name[1] == "out":
return self.pretrained_model.config.vocab_size
model_config = self.pretrained_model.config

return self.pretrained_model.config.hidden_size
if activation_name[1] == "out":
return model_config.vocab_size
elif hasattr(model_config, "word_embed_proj_dim"):
return model_config.word_embed_proj_dim
else:
return self.pretrained_model.config.hidden_size

@staticmethod
def activation_names() -> ActivationNames:
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ transformers>=4.0.0
tqdm
dill
unidecode
fairseq

0 comments on commit cb09d3e

Please sign in to comment.