# Gradient Based Constrained Decoding Demo

Licensed under the Apache License, Version 2.0.

This method is based upon [Gradient-based Inference for Networks
with Output Constraints](https://arxiv.org/pdf/1707.08608.pdf) by Lee et al.

In [None]:
import json
import numpy as np
import random
import tensorflow as tf

import constrained_evaluation as eval_model
  # local file import from experimental.language_structure.pslimport data
  # local file import from experimental.language_structure.pslimport psl_model_multiwoz as model  # local file import from experimental.language_structure.psl

# Dataset and Task

We study constrained decoding through the task of dialog structure prediction. Dialog structure is the high level representation of the flow of a dialog, where nodes represent abstract topics or dialog acts that statements would fit into and edges represent topic changes.

To verify our method we ideally would like to test it over a multi-goal oriented dialog corpus such as [MultiWoZ 2.0](https://arxiv.org/pdf/1907.01669.pdf), created by Mihail Eric et. al. Unfortunately, this corpus does not have a ground truth dialog structure, therefore, we use a [Synthetic Multi-WoZ](https://almond-static.stanford.edu/papers/multiwoz-acl2020.pdf) dataset created by Giovanni Campagna et. al.

In [None]:
# ========================================================================
# Constants
# ========================================================================
DATA_PATH = ''

RULE_WEIGHTS = np.array([1.0, 20.0, 5.0, 5.0, 5.0, 10.0, 5.0, 20.0, 5.0, 5.0, 5.0, 10.0])
RULE_NAMES = ('rule_1', 'rule_2', 'rule_3', 'rule_4', 'rule_5', 'rule_6', 'rule_7', 'rule_8', 'rule_9', 'rule_10', 'rule_11', 'rule_12')

ALPHAS = [0.1]
GRAD_STEPS = [10, 50, 100, 500]
LEARNING_RATES = [0.0001, 0.0005, 0.001, 0.01]

# ========================================================================
# Seed Data
# ========================================================================
SEED = random.randint(-10000000, 10000000)
print("Seed: %d" % SEED)
tf.random.set_seed(SEED)

# ========================================================================
# Load Data
# ========================================================================
DATA = data.load_json(DATA_PATH)

In [None]:
#@title Config
config = {
    'default_seed': 4,
    'batch_size': 128,
    'max_dialog_size': 10,
    'max_utterance_size': 40,
    'class_map': {
        'accept': 0,
        'cancel': 1,
        'end': 2,
        'greet': 3,
        'info_question': 4,
        'init_request': 5,
        'insist': 6,
        'second_request': 7,
        'slot_question': 8,
    },
    'accept_words': ['yes', 'great'],
    'cancel_words': ['no'],
    'end_words': ['thank', 'thanks'],
    'greet_words': ['hello', 'hi'],
    'info_question_words': ['address', 'phone'],
    'insist_words': ['sure', 'no'],
    'slot_question_words': ['what', '?'],
    'includes_word': -1,
    'excludes_word': -2,
    'mask_index': 0,
    'accept_index': 1,
    'cancel_index': 2,
    'end_index': 3,
    'greet_index': 4,
    'info_question_index': 5,
    'insist_index': 6,
    'slot_question_index': 7,
    'utterance_mask': -1,
    'last_utterance_mask': -2,
    'pad_utterance_mask': -3,
    'shuffle_train': True,
    'shuffle_test': False,
    'train_epochs': 1,
}

In [None]:
#@title Prepare Dataset
train_dialogs = data.add_features(
    DATA['train_data'],
    vocab_mapping=DATA['vocab_mapping'],
    accept_words=config['accept_words'],
    cancel_words=config['cancel_words'],
    end_words=config['end_words'],
    greet_words=config['greet_words'],
    info_question_words=config['info_question_words'],
    insist_words=config['insist_words'],
    slot_question_words=config['slot_question_words'],
    includes_word=config['includes_word'],
    excludes_word=config['excludes_word'],
    accept_index=config['accept_index'],
    cancel_index=config['cancel_index'],
    end_index=config['end_index'],
    greet_index=config['greet_index'],
    info_question_index=config['info_question_index'],
    insist_index=config['insist_index'],
    slot_question_index=config['slot_question_index'],
    utterance_mask=config['utterance_mask'],
    pad_utterance_mask=config['pad_utterance_mask'],
    last_utterance_mask=config['last_utterance_mask'],
    mask_index=config['mask_index'])
train_data = data.pad_dialogs(train_dialogs, config['max_dialog_size'],
                              config['max_utterance_size'])
raw_train_labels = data.one_hot_string_encoding(DATA['train_truth_dialog'],
                                                config['class_map'])
train_labels = data.pad_one_hot_labels(raw_train_labels,
                                       config['max_dialog_size'],
                                       config['class_map'])
train_ds = data.list_to_dataset(train_data[0], train_labels[0],
                                config['shuffle_train'],
                                config['batch_size'])

test_dialogs = data.add_features(
    DATA['test_data'],
    vocab_mapping=DATA['vocab_mapping'],
    accept_words=config['accept_words'],
    cancel_words=config['cancel_words'],
    end_words=config['end_words'],
    greet_words=config['greet_words'],
    info_question_words=config['info_question_words'],
    insist_words=config['insist_words'],
    slot_question_words=config['slot_question_words'],
    includes_word=config['includes_word'],
    excludes_word=config['excludes_word'],
    accept_index=config['accept_index'],
    cancel_index=config['cancel_index'],
    end_index=config['end_index'],
    greet_index=config['greet_index'],
    info_question_index=config['info_question_index'],
    insist_index=config['insist_index'],
    slot_question_index=config['slot_question_index'],
    utterance_mask=config['utterance_mask'],
    pad_utterance_mask=config['pad_utterance_mask'],
    last_utterance_mask=config['last_utterance_mask'],
    mask_index=config['mask_index'])
test_data = data.pad_dialogs(test_dialogs, config['max_dialog_size'],
                             config['max_utterance_size'])
raw_test_labels = data.one_hot_string_encoding(DATA['test_truth_dialog'],
                                               config['class_map'])
test_labels = data.pad_one_hot_labels(raw_test_labels,
                                      config['max_dialog_size'],
                                      config['class_map'])
test_ds = data.list_to_dataset(test_data[0], test_labels[0],
                               config['shuffle_test'],
                               config['batch_size'])

In [None]:
#@title Helper Functions
def class_confusion_matrix(preds, labels, config):
  correct = 0
  incorrect = 0

  class_map = config['class_map']
  reverse_class_map = {v: k for k, v in class_map.items()}
  class_confusion_matrix_dict = {key: {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0} for key, _ in class_map.items()}
  class_confusion_matrix_dict['total'] = {'correct': 0, 'incorrect': 0}

  for pred_list, label_list in zip(preds, labels):
    for pred, label in zip(pred_list, label_list):
      if class_map[label] == pred:
        class_confusion_matrix_dict['total']['correct'] += 1
        class_confusion_matrix_dict[label]['tp'] += 1
      else:
        class_confusion_matrix_dict['total']['incorrect'] += 1
        class_confusion_matrix_dict[label]['fp'] += 1
        class_confusion_matrix_dict[reverse_class_map[pred.numpy()]]['fn'] += 1

      for key in class_map:
        if key == label or key == reverse_class_map[pred.numpy()]:
          continue
        class_confusion_matrix_dict[reverse_class_map[pred.numpy()]]['tn'] += 1

  return class_confusion_matrix_dict

def precision_recall_f1(confusion_matrix):
  if (confusion_matrix['tp'] + confusion_matrix['fp']) == 0:
    precision = 0.0
  else:
    precision = confusion_matrix['tp'] / (confusion_matrix['tp'] + confusion_matrix['fp'])

  if (confusion_matrix['tp'] + confusion_matrix['fn']) == 0:
    recall = 0.0
  else:
    recall = confusion_matrix['tp'] / (confusion_matrix['tp'] + confusion_matrix['fn'])

  if (precision + recall) == 0:
    f1 = 0.0
  else:
    f1 = 2.0 * (precision * recall / (precision + recall))

  return precision, recall, f1

def print_metrics(confusion_matrix):
  cat_accuracy = confusion_matrix['total']['correct'] / (confusion_matrix['total']['incorrect'] + confusion_matrix['total']['correct'])
  print("Categorical Accuracy: %0.4f" % (cat_accuracy,))
  values = []
  for key, value in confusion_matrix.items():
    if key == 'total':
      continue
    precision, recall, f1 = precision_recall_f1(value)

    print("Class: %s Precision: %0.4f  Recall: %0.4f  F1: %0.4f" % (key.ljust(15), precision, recall, f1))
    values.append(str(precision) + "," + str(recall) + "," + str(f1))
  return values, cat_accuracy

# Neural Model

Below is a simple neural model for supervised structure prediction.

In [None]:
#@title Create Neural Model
def build_model(input_size, learning_rate=0.001):
  """Build simple neural model for class prediction."""
  input_layer = tf.keras.layers.Input(input_size)
  hidden_layer_1 = tf.keras.layers.Dense(1024)(input_layer)
  hidden_layer_2 = tf.keras.layers.Dense(
      512, activation='sigmoid')(
          hidden_layer_1)
  output = tf.keras.layers.Dense(
      9, activation='softmax',
      kernel_regularizer=tf.keras.regularizers.l2(1.0))(
          hidden_layer_2)

  model = tf.keras.Model(input_layer, output)

  model.compile(
      optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
      loss='categorical_crossentropy',
      metrics=['accuracy'])

  return model

In [None]:
def run_non_constrained(train_ds, test_ds, test_labels, config, learning_rate):
  test_model = build_model([config['max_dialog_size'], config['max_utterance_size']], learning_rate=learning_rate)
  test_model.fit(train_ds, epochs=config['train_epochs'])

  logits = test_model.predict(test_ds)
  predictions = tf.math.argmax(logits, axis=-1)

  confusion_matrix = class_confusion_matrix(predictions, test_labels, config)
  metrics, cat_accuracy = print_metrics(confusion_matrix)

  return test_model, metrics, cat_accuracy

test_model, metrics, cat_accuracy = run_non_constrained(train_ds, test_ds, DATA['train_truth_dialog'], config, 0.0001)

# Gradient Based Constraint Decoding

Rules:

1. !FirstStatement(S) -> !State(S, 'greet')
2. FirstStatement(S) & HasGreetWord(S) -> State(S, 'greet')
3. FirstStatement(S) & !HasGreetWord(S) -> State(S, 'init_request')
4. PreviousStatement(S1, S2) & State(S2, 'init_request') -> State(S1, 'second_request')
5. PreviousStatement(S1, S2) & !State(S2, 'greet') -> !State(S1, 'init_request')
6. PreviousStatement(S1, S2) & State(S2, 'greet') -> State(S1, 'init_request')
7. LastStatement(S) & HasEndWord(S) -> State(S, 'end')
8. LastStatement(S) & HasAcceptWord(S) -> State(S, 'accept')
9. NextStatement(S1, S2) & State(S2, 'end') & HasCancelWord(S1) -> State(S1, 'cancel')
10. PreviousStatement(S1, S2) & State(S2, 'second_request') & HasInfoQuestionWord(S1) -> State(S1, 'info_question')
11. LastStatement(S) & HasInsistWord(S) -> State(S, 'insist')
12. PreviousStatement(S1, S2) & State(S2, 'second_request') & HasSlotQuestionWord(S1) -> State(S1, 'slot_question')

In [None]:
def run_constrained(test_model, rule_weights, rule_names, test_ds, test_labels, config, alpha, grad_step):
  psl_constraints = model.PSLModelMultiWoZ(rule_weights, rule_names, config=config)
  logits = eval_model.evaluate_constrained_model(test_model, test_ds, psl_constraints, grad_steps=grad_step, alpha=alpha)
  predictions = tf.math.argmax(tf.concat(logits, axis=0), axis=-1)

  confusion_matrix = class_confusion_matrix(predictions, test_labels, config)
  metrics, cat_accuracy = print_metrics(confusion_matrix)

  return predictions, metrics, cat_accuracy

predictions, metrics, cat_accuracy = run_constrained(test_model, RULE_WEIGHTS, RULE_NAMES, test_ds, DATA['test_truth_dialog'], config, 0.1, 500)

# Qualitative Analysis

In [None]:
def recover_utterances(dialog, vocab_map):
  sentences = []
  for utterance in dialog:
    sentence = ''
    for word in utterance:
      if word in [0, -1, -2, -3]:
        continue
      sentence += ' ' + vocab_map[word]
    if sentence != '':
      sentences.append(sentence)
  return sentences

def print_dialog(dialog_index, vocab_map, class_map, data, predictions):
  vocab_map = {v: k for k, v in vocab_map.items()}
  class_map = {v: k for k, v in class_map.items()}
  utterances = recover_utterances(test_data[0][dialog_index], vocab_map)

  for utterance_index in range(len(utterances)):
    key = predictions[dialog_index][utterance_index]
    print("Prediction: %s Utterance: %s" % (class_map[int(key)].ljust(15), utterances[utterance_index]))

print("\nDialog Greet")
print('-' * 50)
print_dialog(27, DATA['vocab_mapping'], config['class_map'], test_data, predictions)
print("\nDialog End")
print('-' * 50)
print_dialog(6, DATA['vocab_mapping'], config['class_map'], test_data, predictions)

# Run Hyperparameter Grid

In [None]:
def run_grid(train_ds, test_ds, test_data, test_labels, rule_weights, rule_names, vocab_mapping, config, alphas, grad_steps, learning_rates):
  character_size = 80

  constrained_metrics = []
  non_constrained_metrics = []
  constrained_cat_accuracies = []
  non_constrained_cat_accuracies = []

  for alpha in alphas:
    for grad_step in grad_steps:
      for learning_rate in learning_rates:
        print('\n' + '=' * character_size)
        print("Running: Alpha - %0.5f   Gradient Steps - %d   Learning Rate - %0.5f" % (alpha, grad_step, learning_rate))
        print('=' * character_size)

        print('\nNon-Constrained')
        print('-' * character_size)
        test_model, metrics, cat_accuracy = run_non_constrained(train_ds, test_ds, DATA['test_truth_dialog'], config, learning_rate=learning_rate)
        non_constrained_metrics.append(metrics)
        non_constrained_cat_accuracies.append(cat_accuracy)

        print('\nConstrained')
        print('-' * character_size)
        predictions, metrics, cat_accuracy = run_constrained(test_model, rule_weights, rule_names, test_ds, DATA['test_truth_dialog'], config, alpha=alpha, grad_step=grad_step)
        constrained_metrics.append(metrics)
        constrained_cat_accuracies.append(cat_accuracy)

        print("\nDialog Greet")
        print('-' * 50)
        print_dialog(11, DATA['vocab_mapping'], config['class_map'], test_data, predictions)
        print("\nDialog End")
        print('-' * 50)
        print_dialog(6, DATA['vocab_mapping'], config['class_map'], test_data, predictions)

  return non_constrained_metrics, constrained_metrics, non_constrained_cat_accuracies, constrained_cat_accuracies