Skip to content

Commit

Permalink
♻️ Improve integration Fairseq LMs
Browse files Browse the repository at this point in the history
  • Loading branch information
jumelet committed Jun 21, 2022
1 parent 4533347 commit a4dbb77
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
4 changes: 4 additions & 0 deletions diagnnose/models/fairseq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,7 @@ def embeddings(self) -> Callable[[Tensor], Tensor]:
@property
def decoder(self) -> nn.Module:
return self.pretrained_model.decoder.output_projection

@property
def decoder_w(self) -> nn.Module:
return self.decoder.weight.data
6 changes: 2 additions & 4 deletions diagnnose/models/import_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

from typing import Type

from .language_model import LanguageModel
Expand Down Expand Up @@ -30,7 +28,7 @@ def import_model(*args, **kwargs) -> LanguageModel:
return model


def _import_transformer_lm(*args, **kwargs) -> "TransformerLM":
def _import_transformer_lm(*args, **kwargs):
""" Imports a Transformer LM. """
from .fairseq_lm import FairseqLM
from .huggingface_lm import HuggingfaceLM
Expand All @@ -41,7 +39,7 @@ def _import_transformer_lm(*args, **kwargs) -> "TransformerLM":
return HuggingfaceLM(*args, **kwargs)


def _import_recurrent_lm(*args, **kwargs) -> "RecurrentLM":
def _import_recurrent_lm(*args, **kwargs):
""" Imports a recurrent LM and sets the initial states. """
from .recurrent_lm import RecurrentLM

Expand Down
3 changes: 3 additions & 0 deletions diagnnose/tokenizer/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def create_tokenizer(
}
tokenizer.vocab = vocab
tokenizer.ids_to_tokens = {idx: w for w, idx in vocab.items()}
elif hasattr(tokenizer, "sym2idx"):
tokenizer.vocab = tokenizer.sym2idx
tokenizer.ids_to_tokens = tokenizer.idx2sym

if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.unk_token
Expand Down

0 comments on commit a4dbb77

Please sign in to comment.