diff --git a/projects/director/director_agent.py b/projects/director/director_agent.py index 98268468df..a216eabd3e 100644 --- a/projects/director/director_agent.py +++ b/projects/director/director_agent.py @@ -26,19 +26,37 @@ import parlai.utils.logging as logging +class ScalarLayer(nn.Module): + def __init__(self): + super().__init__() + self.params = nn.Parameter(torch.Tensor([1.0, 0.0])) + + def forward(self, input: torch.Tensor): + return input * self.params[0].expand_as(input) + self.params[1].expand_as(input) + + class DirectorModel(TransformerGeneratorModel): """ Director model that extends TransformerGeneratorModel and adds |V| binary classifier heads. """ - def __init__(self, opt: Opt, dictionary: DictionaryAgent, **kwargs): + def __init__( + self, + opt: Opt, + dictionary: DictionaryAgent, + **kwargs, + ): super().__init__(opt, dictionary, **kwargs) vocabulary_size = len(dictionary) decoder_output_dim = self.decoder.out_dim - self.classifier_heads = nn.Linear(decoder_output_dim, vocabulary_size) + self.use_shared_embedding = opt.get('director_use_shared_embedding', False) + if self.use_shared_embedding: + self.classifier_heads = ScalarLayer() + else: + self.classifier_heads = nn.Linear(decoder_output_dim, vocabulary_size) self.infer_gamma = opt['train_gamma'] if opt.get('infer_gamma') is not None: @@ -56,6 +74,9 @@ def classifier_output(self, input: torch.Tensor): if self.freeze_decoder: input = input.detach() + if self.use_shared_embedding: + input = self.generator_output(input) + return self.classifier_heads(input) def output(self, latent: torch.Tensor): @@ -182,6 +203,12 @@ def add_cmdline_args( default=False, help='Train the generation head with the positive examples from the feedback data.', ) + group.add_argument( + '--director-use-shared-embedding', + type=bool, + default=False, + help='Use a shared final embedding for the generator and classifier head of the director.', + ) return parser def __init__(self, opt: Opt, shared=None):