diff --git a/projects/director/README.md b/projects/director/README.md index cb357fcb7a7..b26969bf0ba 100644 --- a/projects/director/README.md +++ b/projects/director/README.md @@ -8,6 +8,23 @@ Kushal Arora, Kurt Shuster, Sainbayar Sukhbaatar, Jason Weston ## Abstract -Current language models achieve low perplexity but their resulting generations still suffer from toxic responses, repetitiveness and contradictions. The standard language modeling setup fails to address these issues. In this paper, we introduce a new architecture, {\sc Director}, that consists of a unified generator-classifier with both a language modeling and a classification head for each output token. Training is conducted jointly using both standard language modeling data, and data labeled with desirable and undesirable sequences. Experiments in several settings show that the model has competitive training and decoding speed compared to standard language models while yielding superior results, alleviating known issues while maintaining generation quality. It also outperforms existing model guiding approaches in terms of both accuracy and efficiency. +Current language models achieve low perplexity but their resulting generations still suffer from toxic responses, repetitiveness and contradictions. The standard language modeling setup fails to address these issues. In this paper, we introduce a new architecture, DIRECTOR, that consists of a unified generator-classifier with both a language modeling and a classification head for each output token. Training is conducted jointly using both standard language modeling data, and data labeled with desirable and undesirable sequences. Experiments in several settings show that the model has competitive training and decoding speed compared to standard language models while yielding superior results, alleviating known issues while maintaining generation quality. It also outperforms existing model guiding approaches in terms of both accuracy and efficiency. + +## Safety Experiments Commands: + +### Train the evaluation classifier: +```python + parlai train -t projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+neg_only -et projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+pos_only,projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+neg_only,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+neg_only -vtim 120 --model transformer/classifier --load-from-pretrained-ranker True --init-model zoo:pretrained_transformers/bi_model_huge_reddit/model --dict-file zoo:pretrained_transformers/bi_model_huge_reddit/model.dict --history-size 20 --label-truncate 72 --text-truncate 360 --dict-tokenizer bpe --dict-lower True --optimizer adamax --output-scaling 0.06 --variant xlm --reduction-type mean --share-encoders False --learn-positional-embeddings True --n-layers 12 --n-heads 12 --ffn-size 3072 --attention-dropout 0.1 --relu-dropout 0.0 --dropout 0.1 --n-positions 1024 --embedding-size 768 --activation gelu --embeddings-scale False --n-segments 2 --learn-embeddings True --share-word-embeddings False --dict-endtoken __start__ -vp 30 -stim 60 --lr-scheduler fixed --lr-scheduler-patience 3 --lr-scheduler-decay 0.9 --warmup_updates 1000 --fp16 true -lr 5e-05 --classes pos neg -bs 20 --validation-metric f1 --validation-metric-mode max --validation-max-exs 3000 --validation-patience 200 --log-every-n-secs 10 -ttim 34200 --load-from-checkpoint true --save-after-valid true --tensorboard-log true --aggregate-micro True --model-file ./models/safety/eval_model +``` +### Train the DIRECTOR Model: +``` python +parlai train -vtim 300 -bs 6 --gradient-clip 10.0 --fp16 True -lr 1e-05 --validation-metric unweighted_loss --validation-metric-mode min --validation-max-exs 10000 --validation-patience 50 --log-every-n-secs 10 --load-from-checkpoint True --save-after-valid True --tensorboard-log True --skip-generation False --aggregate-micro True --model projects.director.director_agent:DirectorAgent --validation-cutoff 0 --multitask-weights 5,1,1,1,1,1 --embedding-size 2560 --ffn-size 10240 --n-decoder-layers 24 --n-encoder-layers 2 --n-heads 32 --n-positions 128 --variant prelayernorm --text-truncate 128 --truncate 128 --dict-tokenizer bytelevelbpe --fp16-impl mem_efficient --optimizer adam --history-add-global-end-token end --lr-scheduler-patience 3 --warmup-updates 100 --init-model zoo:blender/reddit_3B/model --dict-file zoo:blender/reddit_3B/model.dict --model-parallel True -t blended_skill_talk:mutators=flatten,projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY -et blended_skill_talk:mutators=flatten,projects.director.tasks.safety:SafeBADTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeAdvTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeStdTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeMultiTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY,projects.director.tasks.safety:SafeWikiToxicTeacher:mutators=flatten+safety_relabel_classes+DIRECTOR_LTR_COPY --train-gamma 3.0 --model-file ./models/safety/director_model +``` + + +### Evaluate DIRECTOR model on the toxic prompts from WikiToxicComments dataset: +```python +python -m parlai.scripts.eval_model --datatype test --model-file./models/safety/director_model --num-examples 1000 --batchsize 16 --log-every-n-secs 30 --fp16 True --metrics all --inference beam --beam-size 10 --beam-min-length 20 --beam-block-ngram 3 --beam-context-block-ngram 3 --beam-block-full-context True --skip-generation False --task projects.director.tasks.safety:SafeWikiToxicEvalTeacher:mutators=flatten+safety_relabel_classes+neg_only --eval-classifier-model-file ./models/safety/eval_model --include-label-cand-only True -bs 8 --infer-gamma 1 +``` \ No newline at end of file diff --git a/projects/director/director_agent.py b/projects/director/director_agent.py new file mode 100644 index 00000000000..ac07f913792 --- /dev/null +++ b/projects/director/director_agent.py @@ -0,0 +1,584 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +DirectorAgent for Supervised Language Modeling. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Any, Dict, Tuple, Union + +from parlai.agents.transformer.modules import TransformerGeneratorModel +from parlai.agents.transformer.transformer import TransformerGeneratorAgent + +from parlai.core.dict import DictionaryAgent +from parlai.core.message import Message +from parlai.core.metrics import AverageMetric +from parlai.core.opt import Opt +from parlai.core.params import ParlaiParser +from parlai.core.torch_generator_agent import PPLMetric + +import parlai.utils.logging as logging + + +class DirectorModel(TransformerGeneratorModel): + """ + Director model that extends TransformerGeneratorModel and adds |V| binary classifier + heads. + """ + + def __init__(self, opt: Opt, dictionary: DictionaryAgent, **kwargs): + super().__init__(opt, dictionary, **kwargs) + + vocabulary_size = len(dictionary) + + decoder_output_dim = self.decoder.out_dim + self.classifier_heads = nn.Linear(decoder_output_dim, vocabulary_size) + + self.infer_gamma = opt['train_gamma'] + if opt.get('infer_gamma') is not None: + self.infer_gamma = opt['infer_gamma'] + + self.freeze_decoder = opt['freeze_decoder'] + + def generator_output(self, input: torch.Tensor): + if self.freeze_decoder: + input = input.detach() + + return super().output(input) + + def classifier_output(self, input: torch.Tensor): + if self.freeze_decoder: + input = input.detach() + + return self.classifier_heads(input) + + def output(self, latent: torch.Tensor): + """Overriding output method to use |V| classifier heads to modify the generator logprobs. + This modification allows model to incorporate attribute information from classifier for selecting the next tokens. + Args: + latent (torch.Tensor): decoder outputs + + Returns: + Modified logprobs. + """ + classifier_outputs = F.logsigmoid(self.classifier_output(latent)) + log_predictor_scores = F.log_softmax(self.generator_output(latent), dim=-1) + + scores = log_predictor_scores + self.infer_gamma * classifier_outputs + + return F.log_softmax(scores, dim=-1) + + def load_state_dict(self, state_dict): + """ + Overrided to load only the generator weights from the state dict and leaving the + classifier head weights untouched. + """ + for k, v in self.state_dict().items(): + if k not in state_dict: + state_dict[k] = v + + super().load_state_dict(state_dict) + + def forward( + self, *xs, ys=None, prev_enc=None, maxlen=None, bsz=None + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.LongTensor, + torch.Tensor, + torch.BoolTensor, + Any, + ]: + """ + Nearly copied verbatim, except for return type to return the latent state and + the classifier scores. + """ + assert ys is not None, "Greedy decoding in TGModel.forward no longer supported." + self.longest_label = max(self.longest_label, ys.size(1)) + + # use cached encoding if available + encoder_states = prev_enc if prev_enc is not None else self.encoder(*xs) + + # use teacher forcing + scores, preds, latent, mask = self.decode_forced(encoder_states, ys) + + classifer_score = self.classifier_output(latent) + return scores, preds, classifer_score, latent, mask, encoder_states + + def decode_forced( + self, encoder_states: Tuple[Any], ys: torch.LongTensor + ) -> Tuple[torch.Tensor, torch.LongTensor, torch.Tensor, torch.BoolTensor]: + """ + Override TGM.decode_forced to return latent states and using generator_output + method to generate the decoder output. + """ + bsz = ys.size(0) + seqlen = ys.size(1) + inputs = ys.narrow(1, 0, seqlen - 1) + if (ys[:, 0] == self.START_IDX).any(): + raise AssertionError( + "The Beginning of Sentence token is automatically added to the " + "label in decode_forced, but you included it in the label. This means " + "your model will have a double BOS token, which is probably not what " + "you intended." + ) + inputs = self._get_initial_forced_decoder_input(bsz, inputs) + latent, mask = self.decoder(inputs, encoder_states) + logits = self.generator_output(latent) + _, preds = logits.max(dim=2) + return logits, preds, latent, mask + + +class DirectorAgent(TransformerGeneratorAgent): + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + TransformerGeneratorAgent.add_cmdline_args(parser, partial_opt=partial_opt) + # This method will add arguments specific to fused classifier architecture + # like num_layers classifier, etc. + group = parser.add_argument_group('Director Group') + group.add_argument( + '--explicit-classifier-norm', + type=bool, + default=False, + help='If we should explictly try to set non-target tokens to 0.5 during training.', + ) + group.add_argument( + '--explicit-classifier-norm-coeff', + type=float, + default=1, + help='If we should explictly try to set non-target tokens to 0.5 during training.', + ) + group.add_argument( + '--freeze-decoder', + type=bool, + default=False, + help='Freeze decoder for training only classifier head.', + ) + group.add_argument( + '--train-gamma', + type=float, + default=0.5, + help="Implementing Sainaa's suggestion of keeping generator weight fixed (to 1) and using \alpha (hopefully <1) to weight classifier.", + ) + group.add_argument( + '--infer-gamma', + type=float, + default=None, + help="Implementing Sainaa's suggestion of keeping generator weight fixed (to 1) and using \alpha (hopefully <1) to weight classifier.", + ) + return parser + + def __init__(self, opt: Opt, shared=None): + super().__init__(opt, shared) + self.explicit_classifier_norm = opt['explicit_classifier_norm'] + self.explicit_classifier_norm_coeff = opt['explicit_classifier_norm_coeff'] + self.train_gamma = opt['train_gamma'] + self.infer_gamma = opt['infer_gamma'] + + assert opt[ + 'beam_block_full_context' + ], 'must set --beam-block-full-context True to use PACER' + + def load_state_dict(self, state_dict): + """ + Load the state dict into model. + + This copies the classifier specific params to init_model and then calls the + load_state_dict method of TorchAgent. + """ + for k, v in self.model.state_dict().items(): + if k not in state_dict: + state_dict[k] = v + super().load_state_dict(state_dict) + + def _get_batch_context(self, batch): + """ + Override to always provide full context. + """ + if 'full_text_vec' not in batch: + logging.warn('Batch does not have full text vec, resorting to text vec') + return batch.text_vec + return batch.full_text_vec + + def build_model(self, states=None): + """ + Build and return model. + """ + model = DirectorModel(self.opt, self.dict) + if self.opt['embedding_type'] != 'random': + self._copy_embeddings( + model.encoder.embeddings.weight, self.opt['embedding_type'] + ) + return model + + def observe(self, observation: Union[Dict, Message]) -> Message: + observation = super().observe(observation) + if 'is_ltr' not in observation: + observation['is_ltr'] = False + observation['classifier_label'] = 'none' + observation['classifier_label_idx'] = -1 + return observation + + classifier_label = observation['classifier_label'] + if classifier_label == 'pos': + observation['classifier_label_idx'] = 1 + elif classifier_label == 'neg': + observation['classifier_label_idx'] = 0 + return observation + + def batchify(self, obs_batch, sort=False): + """ + This method calls the parent class's batchify method and then add + classifier_label and is_ltr property to the the batch. + """ + batch = super().batchify(obs_batch, sort=sort) + + if batch.valid_indices is None: + return batch + + batch.classifier_label = torch.tensor( + [ + [obs_batch[i].get('classifier_label_idx', -1)] + for i in batch.valid_indices + ] + ) + batch.is_ltr = torch.tensor( + [[obs_batch[i].get('is_ltr', False)] for i in batch.valid_indices] + ) + return batch + + def _reshape_to_record_metrics(self, batch, losses, num_target_tokens, indices): + """ + MultitaskAgent shuffles and combines examples from both classifier and the + generator tasks in a single batch. We compute losses only for those exs in the + batch resulting in losses and num_target_tokens vectors that are smaller than + the. + + This method reshapes the losses and num_target_tokens vectors back to the batch size. This is needed to record local metrics as the metrics need to be of batch size. + + Args: + batch: batch being processed in this iteration. + losses: classifier or generator loss vector (shape: b' X 1), where b' <= b. + num_target_tokens: number of tokens in each examples for classification or generation tasks. (shape: b' X 1), where b' <= b. + indices: indices of (either classification or generation) exs for which the loss was computed. + + Returns: + A tuple of reshaped losses and num_target_tokens, both of shape: b X 1. + """ + val_id_shape = batch.valid_indices.shape + reshaped_losses = torch.zeros( + val_id_shape, device=losses.device, dtype=losses.dtype + ) + reshaped_num_target_tokens = torch.zeros( + val_id_shape, device=num_target_tokens.device, dtype=num_target_tokens.dtype + ) + + reshaped_losses[indices] = losses + reshaped_num_target_tokens[indices] = num_target_tokens + + return (reshaped_losses, reshaped_num_target_tokens) + + def _v2t(self, vec): + """ + This method is copied from TFGA but wraps the vec2txt call in a try catch to + ensure that sequences with generation errors are ignored. + + We return a empty string instead in that scenario. + """ + new_vec = [] + if hasattr(vec, 'cpu'): + vec = vec.cpu() + for i in vec: + if i == self.END_IDX: + break + elif i != self.START_IDX: + new_vec.append(i) + + try: + txt = self.dict.vec2txt(new_vec) + except AssertionError: + txt = "" + return txt + + def compute_classifier_loss(self, classifier_scores, batch): + bsz = batch.batchsize + device = classifier_scores.device + + classifier_losses = torch.zeros((bsz,), device=device) + num_target_tokens = torch.zeros((bsz,), device=device, dtype=torch.long) + + # idxs of all the classification exs in the batch. + classification_idxs = batch.is_ltr[:, 0] + + self.record_local_metric( + 'pct_classifier_exs', AverageMetric.many(classification_idxs) + ) + + # if none of the exs in the batch are classifier examples, + # return zero classifier loss. + if not torch.any(classification_idxs): + return classifier_losses, num_target_tokens + + classifier_scores = classifier_scores[classification_idxs] + + # Select the classifier scores for next tokens + target_tokens = batch.label_vec[classification_idxs] + next_tokens = target_tokens + + next_token_scores = classifier_scores.gather( + -1, next_tokens.unsqueeze(-1) + ).squeeze(-1) + + classifier_labels = batch.classifier_label[classification_idxs] + + # 1/0 (pos/neg) labels for each next token given the context. + classifier_labels = classifier_labels.expand_as(next_token_scores).float() + + # Compute BCE loss based on classifier/attribute labels for the next tokens. + classifier_losses = F.binary_cross_entropy_with_logits( + next_token_scores, + classifier_labels, + reduction='none', + ) + + notnull = target_tokens.ne(self.NULL_IDX) + classifier_losses *= notnull + + num_target_tokens = notnull.long().sum(dim=-1) + + non_target_indices = torch.ones_like(classifier_scores, dtype=torch.bool) + non_target_indices.scatter_(-1, next_tokens.unsqueeze(-1), False) + + normalized_classifier_scores = ( + torch.sigmoid(classifier_scores) - 0.5 + ) * notnull.unsqueeze(dim=-1) + + normalized_non_target_classifier_scores = normalized_classifier_scores[ + non_target_indices + ].reshape(*classifier_scores.shape[:-1], -1) + + normalized_non_target_classifier_scores_squared = ( + normalized_non_target_classifier_scores**2 + ) + normalized_non_target_classifier_scores_mean = ( + normalized_non_target_classifier_scores.mean(dim=-1) + ) + normalized_non_target_classifier_var = ( + normalized_non_target_classifier_scores.var(dim=-1) + ) + + ( + normalized_non_target_classifier_scores_mean_reshaped, + num_target_tokens_reshaped, + ) = self._reshape_to_record_metrics( + batch, + normalized_non_target_classifier_scores_mean.mean(-1), + num_target_tokens, + classification_idxs, + ) + self.record_local_metric( + 'classifier_score_mean', + AverageMetric.many( + normalized_non_target_classifier_scores_mean_reshaped, + num_target_tokens_reshaped, + ), + ) + + ( + normalized_non_target_classifier_var_reshaped, + num_target_tokens_reshaped, + ) = self._reshape_to_record_metrics( + batch, + normalized_non_target_classifier_var.mean(-1), + num_target_tokens, + classification_idxs, + ) + self.record_local_metric( + 'classifier_score_var', + AverageMetric.many( + normalized_non_target_classifier_var_reshaped, + num_target_tokens_reshaped, + ), + ) + + # Explicitly force the score for non-target tokens to 0.5. + # This is done as << 0.5 indicates negative attributes and + # >> 0.5 indicates positive attributes. + if self.explicit_classifier_norm: + classifier_losses += ( + self.explicit_classifier_norm_coeff + * normalized_non_target_classifier_scores_squared.mean(dim=-1) + ) + + classifier_losses = classifier_losses.sum(dim=1) + + ( + classifier_losses_reshaped, + num_target_tokens_reshaped, + ) = self._reshape_to_record_metrics( + batch, classifier_losses, num_target_tokens, classification_idxs + ) + self.record_local_metric( + 'classifier_loss', + AverageMetric.many(classifier_losses_reshaped, num_target_tokens_reshaped), + ) + + classifier_predictions = (torch.sigmoid(next_token_scores) > 0.5).long() + + classifier_accuracy = classifier_labels == classifier_predictions + classifier_accuracy *= notnull + classifier_accuracy = classifier_accuracy.sum(-1) + ( + classifier_accuracy_reshaped, + num_target_tokens_reshaped, + ) = self._reshape_to_record_metrics( + batch, classifier_accuracy, num_target_tokens, classification_idxs + ) + self.record_local_metric( + 'classifier_accuracy', + AverageMetric.many( + classifier_accuracy_reshaped, num_target_tokens_reshaped + ), + ) + + f1s = {} + for class_name, positive_class in (('neg', 0), ('pos', 1)): + positives = classifier_labels == positive_class + negatives = classifier_labels != positive_class + trues = classifier_predictions == classifier_labels + falses = classifier_predictions != classifier_labels + + true_positives = ((positives & trues) * notnull).sum(-1) + false_positives = ((negatives & falses) * notnull).sum(-1) + false_negatives = ((positives & falses) * notnull).sum(-1) + + classifier_f1 = (2 * true_positives) / ( + 2 * true_positives + false_positives + false_negatives + ) + classifier_f1[true_positives == 0] = 0 + + (classifier_f1_reshaped, _) = self._reshape_to_record_metrics( + batch, classifier_f1, num_target_tokens, classification_idxs + ) + + f1s[class_name] = classifier_f1_reshaped + + batch_positives = batch.classifier_label[:, 0] == positive_class + # We use (classification_idxs & (batch_positives > 0) to indicate that we only consider the exs + # that are ltr classification examples and classifier_labels == positive_class. + self.record_local_metric( + f'{class_name}_classifier_f1', + AverageMetric.many( + classifier_f1_reshaped, + (classification_idxs & (batch_positives > 0)).int(), + ), + ) + + avg_classifier_f1_reshaped = sum(f1s.values()) + # We use classification_idxs.int() to indicate that we only consider the exs that are + # ltr classification examples. + self.record_local_metric( + f'classifier_f1', + AverageMetric.many(avg_classifier_f1_reshaped, classification_idxs.int()), + ) + return classifier_losses_reshaped, num_target_tokens_reshaped + + def compute_generator_loss(self, generator_scores, batch): + bsz = batch.batchsize + device = generator_scores.device + + generator_losses = torch.zeros((bsz,), device=device) + num_target_tokens = torch.zeros((bsz,), device=device, dtype=torch.long) + + generation_idxs = torch.logical_not(batch.is_ltr[:, 0]) + self.record_local_metric( + 'pct_generator_exs', AverageMetric.many(generation_idxs) + ) + + # If there are no generation exs in the batch, + # returrn zero generator loss. + if not torch.any(generation_idxs): + return generator_losses, num_target_tokens + + # Copied verbatim from TGA.compute_loss. + generator_scores = generator_scores[generation_idxs] + generator_label_vec = batch.label_vec[generation_idxs] + + generator_scores_view = generator_scores.reshape(-1, generator_scores.size(-1)) + generator_losses = self.criterion( + generator_scores_view, generator_label_vec.view(-1) + ) + + # cross entropy loss + generator_losses = generator_losses.view(generator_scores.shape[:-1]).sum(dim=1) + + notnull = generator_label_vec.ne(self.NULL_IDX) + num_target_tokens = notnull.long().sum(dim=-1) + + ( + reshaped_generator_losses, + num_target_tokens_reshaped, + ) = self._reshape_to_record_metrics( + batch, generator_losses, num_target_tokens, generation_idxs + ) + + # save loss to metrics + self.record_local_metric( + 'generator_loss', + AverageMetric.many(reshaped_generator_losses, num_target_tokens_reshaped), + ) + + # save perplexity to metrics + self.record_local_metric( + 'generator_ppl', + PPLMetric.many(reshaped_generator_losses, num_target_tokens_reshaped), + ) + + return reshaped_generator_losses, num_target_tokens_reshaped + + def compute_loss(self, batch, return_output=False): + """ + Overrides compute_loss for multi-objective Director loss computation. + """ + if batch.label_vec is None: + raise ValueError('Cannot compute loss without a label.') + + model_output = self.model(*self._model_input(batch), ys=batch.label_vec) + generator_scores, _, classifier_scores, *_ = model_output + + generator_losses, generator_num_target_tokens = self.compute_generator_loss( + generator_scores, batch + ) + classifier_losses, classifier_num_target_tokens = self.compute_classifier_loss( + classifier_scores, batch + ) + + losses = generator_losses + self.train_gamma * classifier_losses + + num_target_tokens = generator_num_target_tokens + classifier_num_target_tokens + + self.record_local_metric('loss', AverageMetric.many(losses, num_target_tokens)) + + # This unweighted_loss ignores mixing weights and weighs + # generator and classifier losses equally. This can be + # used to do validation across various mixing coeffs (train_gamma). + self.record_local_metric( + 'unweighted_loss', + AverageMetric.many( + (generator_losses + classifier_losses), num_target_tokens + ), + ) + + loss = (losses / num_target_tokens).sum() + + if return_output: + return (loss, model_output) + return loss diff --git a/projects/director/tasks/safety.py b/projects/director/tasks/safety.py new file mode 100644 index 00000000000..3d95baf5740 --- /dev/null +++ b/projects/director/tasks/safety.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List, Tuple +from typing import Optional +from parlai.core.agents import Agent, create_agent_from_model_file +from parlai.core.params import ParlaiParser +from parlai.core.message import Message +from parlai.core.opt import Opt + +from parlai.core.mutators import ( + register_mutator, + ManyEpisodeMutator, + EpisodeMutator, +) +import parlai.tasks.bot_adversarial_dialogue.agents as bad +import parlai.tasks.dialogue_safety.agents as bibifi + +from parlai.core.metrics import AverageMetric + + +@register_mutator('LTR') +class LeftToRightMutator(ManyEpisodeMutator): + """ + Mutator that breaks down episodes into all partial sequences (Left To Right). + """ + + def many_episode_mutation(self, episode: List[Message]) -> List[Message]: + new_episodes = [] + for message in episode: + label_words = message['labels'][0].split() + if len(label_words) < 2: + continue + for i in range(1, len(label_words) + 1): + new_message = message.copy() + label = ' '.join(label_words[:i]) + new_message.force_set('labels', [label]) + new_episodes.append([new_message]) + return new_episodes + + +@register_mutator('DIRECTOR_LTR') +class EDCLeftToRightMutator(ManyEpisodeMutator): + """ + EDCLeftToRightMutator prepares data for training left to right (LTR) classifier for + Encoder-Decoder Classifier (EDC) model. + + This limits to context to all but last utterance that is fed to the encoder. + The final utterance is considered as a label for the decoder and the attribute/classifier + labels are stored seperately marking the final utterance pos. or neg. + + This mutator also adds a is_ltr flag to differentiate classifier exs from the generator exs which are used to finetune the generator model. + """ + + def many_episode_mutation(self, episode: List[Message]) -> List[List[Message]]: + new_episodes = [] + for message in episode: + text = message['text'] + utterances = text.split('\n') + + if len(utterances) < 2: + continue + + new_message = message.copy() + new_message.force_set('is_ltr', True) + new_message.force_set('classifier_label', message['labels'][0]) + new_text = '\n'.join(utterances[:-1]) + new_message.force_set('text', new_text) + new_message.force_set('labels', [utterances[-1]]) + new_episodes.append([new_message]) + return new_episodes + + +@register_mutator('DIRECTOR_LTR_COPY') +class EDCLeftToRightMutator(ManyEpisodeMutator): + """ + EDCLeftToRightMutator prepares data for training left to right (LTR) classifier for + Encoder-Decoder Classifier (EDC) model. + + This limits to context to all but last utterance that is fed to the encoder. + The final utterance is considered as a label for the decoder and the attribute/classifier + labels are stored seperately marking the final utterance pos. or neg. + + This mutator also adds a is_ltr flag to differentiate classifier exs from the generator exs which are used to finetune the generator model. + """ + + def many_episode_mutation(self, episode: List[Message]) -> List[List[Message]]: + new_episodes = [] + for message in episode: + text = message['text'] + utterances = text.split('\n') + + if len(utterances) < 2: + utterances.insert(0, utterances[0]) + + new_message = message.copy() + new_message.force_set('is_ltr', True) + new_message.force_set('classifier_label', message['labels'][0]) + new_text = '\n'.join(utterances[:-1]) + new_message.force_set('text', new_text) + new_message.force_set('labels', [utterances[-1]]) + new_episodes.append([new_message]) + return new_episodes + + +@register_mutator('DIRECTOR_LTR_EMPTY') +class EDCLeftToRightMutator(ManyEpisodeMutator): + """ + EDCLeftToRightMutator prepares data for training left to right (LTR) classifier for + Encoder-Decoder Classifier (EDC) model. + + This limits to context to all but last utterance that is fed to the encoder. + The final utterance is considered as a label for the decoder and the attribute/classifier + labels are stored seperately marking the final utterance pos. or neg. + + This mutator also adds a is_ltr flag to differentiate classifier exs from the generator exs which are used to finetune the generator model. + """ + + def many_episode_mutation(self, episode: List[Message]) -> List[List[Message]]: + new_episodes = [] + for message in episode: + text = message['text'] + utterances = text.split('\n') + + if len(utterances) < 2: + utterances.insert(0, "") + + new_message = message.copy() + new_message.force_set('is_ltr', True) + new_message.force_set('classifier_label', message['labels'][0]) + new_text = '\n'.join(utterances[:-1]) + new_message.force_set('text', new_text) + new_message.force_set('labels', [utterances[-1]]) + new_episodes.append([new_message]) + return new_episodes + + +@register_mutator('neg_only') +class NegOnlyMutator(ManyEpisodeMutator): + """ + Mutator that filters to only the neg set. + """ + + def many_episode_mutation(self, episode: List[Message]) -> List[Message]: + new_episodes = [] + for message in episode: + if message['labels'][0] == 'neg' or message['labels'][0] == '__notok__': + new_episodes.append([message]) + return new_episodes + + +@register_mutator('pos_only') +class PosOnlyMutator(ManyEpisodeMutator): + """ + Mutator that filters to only the neg set. + """ + + def many_episode_mutation(self, episode: List[Message]) -> List[Message]: + new_episodes = [] + for message in episode: + if message['labels'][0] == 'pos': + new_episodes.append([message]) + return new_episodes + + +class ClassifierMetricTeacher: + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + super().add_cmdline_args(parser, partial_opt=partial_opt) + parser.add_argument( + '--eval-classifier-model-file', + required=False, + type=str, + help='Filepath for evaluation classifier to evaluate the model generation.', + ) + parser.add_argument( + '--include-label-cand-only', + type='bool', + default=False, + help='When passing inputs to the classifier, use only the label targets if set to True.', + ) + return parser + + def __init__(self, opt, shared=None): + self.include_label_cand_only = opt['include_label_cand_only'] + if opt.get('eval_classifier_model_file'): + from parlai.agents.reranker.classifier_reranker import ClassifierReranker + + if not shared: + self.classifier = create_agent_from_model_file( + opt['eval_classifier_model_file'] + ) + else: + self.classifier = shared['classifier'] + self.context = [] + DEFAULT_DELIM = '\n' + self.delimiter = opt.get('delimiter', DEFAULT_DELIM) + else: + self.classifier = None + super().__init__(opt, shared) + + def share(self): + shared = super().share() + shared['classifier'] = self.classifier + return shared + + def predict(self, context: str) -> Message: + """ + Use classifier to predict given the context. + + :param context: + The input context to classify. + + :return output: + return output from classifier act. + """ + assert isinstance(self.classifier, Agent) + obs = Message({'text': context, 'episode_done': True}) + self.classifier.observe(obs) + act = self.classifier.act() + assert isinstance(act, Message) + return act + + def custom_evaluation( + self, + teacher_action: Message, + labels: Optional[Tuple[str]], + model_response: Message, + ) -> None: + """ + Compute Classifier for a model response. + + :param teacher_action: + The message last sent from this teacher. + :param labels: + The previous correct labels + :param model_response: + The raw response from the model + """ + if self.classifier is None: + return + if not model_response or not model_response.get('text'): + return + self.context.append(teacher_action['text']) + correct_class = self.classifier.ref_class + model_text = model_response['text'] + if self.include_label_cand_only: + classifier_act = self.predict(model_text) + else: + context = self.delimiter.join(self.context) + classifier_act = self.predict(context + self.delimiter + model_text) + + predicted_class = classifier_act['text'] + correct_prediction = int(predicted_class == correct_class) + + self.metrics.add('classifier_accuracy', AverageMetric(correct_prediction)) + + if teacher_action['episode_done']: + self.context = [] + else: + assert labels + self.context.append(labels[0]) + + +class SafeWikiToxicEvalTeacher(ClassifierMetricTeacher, bibifi.DefaultTeacher): + pass + + +class SafeBADTeacher(bad.BotAdversarialDialogueTeacher): + pass + + +class SafeAdvTeacher(bibifi.AdversarialTeacher): + pass + + +class SafeStdTeacher(bibifi.StandardTeacher): + pass + + +class SafeMultiTeacher(bibifi.MultiturnTeacher): + pass + + +class SafeWikiToxicTeacher(bibifi.DefaultTeacher): + pass + + +@register_mutator('safety_to_LTR') +class SafetyLTRMutator(ManyEpisodeMutator): + """ + Mutator that takes safety data with __ok__ and __notok__ labels and converts to + "pos" and "neg" which we use elsewhere. + + It assumes the last line of 'text' is the last dialogue utterance, and splits that + by word for the left-to-right classifier. + """ + + def many_episode_mutation(self, episode: List[Message]) -> List[Message]: + new_episodes = [] + for message in episode: + new_message = message.copy() + label = message['labels'][0] + if label == '__notok__': + label = 'neg' + else: + label = 'pos' + new_message.force_set('labels', [label]) + text = '\n'.join(message['text'].split('\n')[:-1]) + label_words = message['text'].split('\n')[-1].split() + if len(label_words) < 2: + continue + for i in range(1, len(label_words) + 1): + new_message2 = new_message.copy() + label = ' '.join(label_words[:i]) + new_text = text + '\n' + label + new_message2.force_set('text', new_text) + new_episodes.append([new_message2]) + return new_episodes + + +@register_mutator('safety_relabel_classes') +class SafetyRelabelClassesMutator(EpisodeMutator): + def episode_mutation(self, episode: List[Message]) -> List[Message]: + new_episodes = [] + for message in episode: + new_message = message.copy() + label = message['labels'][0] + if label == '__notok__': + label = 'neg' + else: + label = 'pos' + new_message.force_set('labels', [label]) + new_episodes.append(new_message) + return new_episodes