In [None]:
# only run on google runtime
!pip install tensorflow-text
!pip install tf-models-official
!pip install tensorflow-addons
!pip install scikit-learn
!pip install scikit-multilearn

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text
from official.nlp import optimization
import tensorflow_addons as tfa
import transformers
import sklearn as sk

In [3]:
node_list_task_two = {
    "Logos": 0,
    "Repetition": 1,
    "Obfuscation, Intentional vagueness, Confusion": 2,
    "Reasoning": 3,
    "Justification": 4,
    "Slogans": 5,
    "Bandwagon": 6,
    "Appeal to authority": 7,
    "Flag-waving": 8,
    "Appeal to fear/prejudice": 9,
    "Simplification": 10,
    "Causal Oversimplification": 11,
    "Black-and-white Fallacy/Dictatorship": 12,
    "Thought-terminating cliché": 13,
    "Distraction": 14,
    "Misrepresentation of Someone's Position (Straw Man)": 15,
    "Presenting Irrelevant Data (Red Herring)": 16,
    "Whataboutism": 17,
    "Ethos": 18,
    "Glittering generalities (Virtue)": 19,
    "Ad Hominem": 20,
    "Doubt": 21,
    "Name calling/Labeling": 22,
    "Smears": 23,
    "Reductio ad hitlerum": 24,
    "Pathos": 25,
    "Exaggeration/Minimisation": 26,
    "Loaded Language": 27,
    "Transfer": 28,
    "Appeal to (Strong) Emotions": 29
}

id_to_node = {
    0: "Logos",
    1: "Repetition",
    2: "Obfuscation, Intentional vagueness, Confusion",
    3: "Reasoning",
    4: "Justification",
    5: "Slogans",
    6: "Bandwagon",
    7: "Appeal to authority",
    8: "Flag-waving",
    9: "Appeal to fear/prejudice",
    10: "Simplification",
    11: "Causal Oversimplification",
    12: "Black-and-white Fallacy/Dictatorship",
    13: "Thought-terminating cliché",
    14: "Distraction",
    15: "Misrepresentation of Someone's Position (Straw Man)",
    16: "Presenting Irrelevant Data (Red Herring)",
    17: "Whataboutism",
    18: "Ethos",
    19: "Glittering generalities (Virtue)",
    20: "Ad Hominem",
    21: "Doubt",
    22: "Name calling/Labeling",
    23: "Smears",
    24: "Reductio ad hitlerum",
    25: "Pathos",
    26: "Exaggeration/Minimisation",
    27: "Loaded Language",
    28: "Transfer",
    29: "Appeal to (Strong) Emotions"
}

parent_child_dict = {
    "Logos": [],
    "Repetition": ["Logos"],
    "Obfuscation, Intentional vagueness, Confusion": ["Logos"],
    "Reasoning": ["Logos"],
    "Justification": ["Logos"],
    "Slogans": ["Justification", "Logos"],
    "Bandwagon": ["Justification", "Ethos", "Logos"],
    "Appeal to authority": ["Justification", "Ethos", "Logos"],
    "Flag-waving": ["Justification", "Pathos", "Logos"],
    "Appeal to fear/prejudice": ["Justification", "Pathos", "Logos"],
    "Simplification": ["Reasoning", "Logos"],
    "Causal Oversimplification": ["Simplification", "Reasoning", "Logos"],
    "Black-and-white Fallacy/Dictatorship": ["Simplification", "Reasoning", "Logos"],
    "Thought-terminating cliché": ["Simplification", "Reasoning", "Logos"],
    "Distraction": ["Reasoning", "Logos"],
    "Misrepresentation of Someone's Position (Straw Man)": ["Distraction", "Reasoning", "Logos"],
    "Presenting Irrelevant Data (Red Herring)": ["Distraction", "Reasoning", "Logos"],
    "Whataboutism": ["Distraction", "Reasoning", "Logos"],
    "Ethos": [],
    "Glittering generalities (Virtue)": ["Ethos"],
    "Ad Hominem": ["Ethos"],
    "Doubt": ["Ad Hominem", "Ethos"],
    "Name calling/Labeling": ["Ad Hominem", "Ethos"],
    "Smears": ["Ad Hominem", "Ethos"],
    "Reductio ad hitlerum": ["Ad Hominem", "Ethos"],
    "Pathos": [],
    "Exaggeration/Minimisation": ["Pathos"],
    "Loaded Language": ["Pathos"],
    "Transfer": ["Ethos", "Pathos"],
    "Appeal to (Strong) Emotions": ["Pathos"]
}

In [None]:
# only run on google runtime
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
input_directory = '/content/drive/MyDrive/2023-2024 School Year/Fall Semester/Neuro-Symbolic Approaches to NLP/ProjectData'
subtask_2_train = input_directory + '/subtask2a/train.json'
subtask_2_validation = input_directory + '/subtask2a/validation.json'
base_train_image_path = input_directory + '/train_images/'
base_dev_image_path = input_directory + '/validation_images/'


In [5]:
import json

def extract_data(filename, base_image_path):
  id_dict = {}
  ids = []
  text = []
  labels = []
  images = []
  with open(filename, 'r', encoding='utf-8') as f:
    data = json.load(f)
    for elem in data:
      id_dict[elem["id"]] = len(ids)
      ids.append(elem["id"])
      text.append(elem["text"])
      labels.append(elem["labels"])
      images.append(tf.keras.utils.load_img(base_image_path + elem["image"]).resize((224,224)))
  return id_dict, ids, text, images, labels

def get_leaf_encoded_labels(plaintext_labels_list, label_map):
  labels = []
  for plaintext_labels in plaintext_labels_list:
    example_labels = np.zeros(len(label_map))
    for plaintext_label in plaintext_labels:
      example_labels[label_map[plaintext_label]] = 1
    labels.append(example_labels)
  return labels

def get_hierarchy_encoded_labels(plaintext_labels_list, label_map, child_parent_map):
  labels = []
  for plaintext_labels in plaintext_labels_list:
    example_labels = np.zeros(len(label_map))
    for plaintext_label in plaintext_labels:
      example_labels[label_map[plaintext_label]] = 1
      for ancestor_plaintext in child_parent_map[plaintext_label]:
        example_labels[label_map[ancestor_plaintext]] = 1
    labels.append(example_labels)
  return labels

In [6]:
train_id_dict, raw_train_ids, raw_train_text, raw_train_images, raw_train_labels = extract_data(subtask_2_train, base_train_image_path)
dev_id_dict, raw_dev_ids, raw_dev_text, raw_dev_images, raw_dev_labels = extract_data(subtask_2_validation, base_dev_image_path)

encoded_train_labels = np.array(get_leaf_encoded_labels(raw_train_labels, node_list_task_two))
encoded_hierarchy_train_labels = np.array(get_hierarchy_encoded_labels(raw_train_labels, node_list_task_two, parent_child_dict))
encoded_dev_labels = np.array(get_leaf_encoded_labels(raw_dev_labels, node_list_task_two))
encoded_hierarchy_dev_labels = np.array(get_hierarchy_encoded_labels(raw_dev_labels, node_list_task_two, parent_child_dict))

In [7]:
encoded_train_labels = np.array(encoded_train_labels)
encoded_hierarchy_train_labels = np.array(encoded_hierarchy_train_labels)
encoded_dev_labels = np.array(encoded_dev_labels)
encoded_hierarchy_dev_labels = np.array(encoded_hierarchy_dev_labels)

In [None]:
from skmultilearn.model_selection import iterative_train_test_split
# If using train set for evaluation
final_train_ids, final_train_labels, final_eval_ids, final_eval_labels = iterative_train_test_split(np.expand_dims(np.array(raw_train_ids), axis=1), encoded_hierarchy_train_labels, test_size=0.07)
final_train_ids = final_train_ids.squeeze()
final_eval_ids = final_eval_ids.squeeze()

final_raw_train_text = []
final_raw_train_images = []
for train_id in final_train_ids:
  final_raw_train_text.append(raw_train_text[train_id_dict[train_id]])
  final_raw_train_images.append(raw_train_images[train_id_dict[train_id]])

final_raw_eval_text = []
final_raw_eval_images = []
for train_id in final_eval_ids:
  final_raw_eval_text.append(raw_train_text[train_id_dict[train_id]])
  final_raw_eval_images.append(raw_train_images[train_id_dict[train_id]])


In [8]:
# If using dev set for evaluation
final_raw_train_text, final_raw_train_images, final_train_labels = sk.utils.shuffle(raw_train_text, raw_train_images, encoded_hierarchy_train_labels)
final_raw_eval_text = raw_dev_text
final_raw_eval_images = raw_dev_images
final_eval_labels = encoded_hierarchy_dev_labels

In [None]:
# only run on google runtime
bert_tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/deberta-base")
bert_encoder = transformers.TFDebertaModel.from_pretrained("microsoft/deberta-base")

In [None]:
# CNN Based
image_processor = transformers.AutoImageProcessor.from_pretrained("microsoft/resnet-50")
image_model = transformers.TFResNetModel.from_pretrained("microsoft/resnet-50")

In [11]:
def hugging_face_bert_encode(text_data):
  input_ids = []
  attention_masks = []
  i = 0
  for text in text_data:
    i += 1
    tokenized_data = bert_tokenizer(text, padding='max_length', max_length=512, truncation=True)
    input_ids.append(tokenized_data['input_ids'])
    attention_masks.append(tokenized_data['attention_mask'])
  return [np.array(input_ids), np.array(attention_masks)]

def image_encode(image_data):
  pixel_values = []
  for image in image_data:
    pixel_values.append(image_processor(image)["pixel_values"])
  return np.array(pixel_values)

hugging_face_train_text_data = hugging_face_bert_encode(final_raw_train_text)
hugging_face_train_image_data = image_encode(final_raw_train_images).squeeze()

hugging_face_eval_text_data = hugging_face_bert_encode(final_raw_eval_text)
hugging_face_eval_image_data = image_encode(final_raw_eval_images).squeeze()

In [12]:
final_train_data = {}
final_train_data["input_ids"] = hugging_face_train_text_data[0]
final_train_data["attention_mask"] = hugging_face_train_text_data[1]
final_train_data["images"] = hugging_face_train_image_data
final_train_labels = final_train_labels
final_eval_data = {}
final_eval_data["input_ids"] = hugging_face_eval_text_data[0]
final_eval_data["attention_mask"] = hugging_face_eval_text_data[1]
final_eval_data["images"] = hugging_face_eval_image_data
final_eval_labels = final_eval_labels

In [13]:
class Hierarchy_Rule:
  def __init__(self, num_classes, parent_index, child_index):
     # p_student is 2D array of size n x k, where n is number of input examples and k is number of classes
     self.num_classes = num_classes
     self.parent_encoding = tf.squeeze(tf.one_hot([parent_index], self.num_classes))
     self.child_encoding = tf.squeeze(tf.one_hot([child_index], self.num_classes))
     self.parent_index = parent_index
     self.child_index = child_index

  @tf.function
  def rule_evaluation(self, p_student):
    return tf.math.minimum((1 - tf.tensordot(p_student, self.child_encoding, 1)) + tf.tensordot(p_student, self.parent_encoding, 1), 1)

  @tf.function
  def log_distribution(self, p_student, regularization_term, confidence_val):
    log_dist = tf.ones_like(p_student)
    rule_perf = tf.ones_like(p_student)
    rule_eval = self.rule_evaluation(p_student)
    child_indicies = tf.stack([tf.range(tf.shape(log_dist)[0]), tf.fill([tf.shape(log_dist)[0]], self.child_index)], axis=1)
    log_dist = tf.tensor_scatter_nd_update(log_dist, child_indicies, rule_eval)
    test_val = float(-1)*float(regularization_term)*float(confidence_val)
    sub_test_val = (rule_perf - log_dist)
    output = test_val*sub_test_val
    return output

In [14]:
class TeacherNetwork(tf.keras.layers.Layer):
  def __init__(self):
    super(TeacherNetwork, self).__init__()
    self.rules = [
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Logos"], node_list_task_two["Repetition"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Logos"], node_list_task_two["Obfuscation, Intentional vagueness, Confusion"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Logos"], node_list_task_two["Reasoning"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Logos"], node_list_task_two["Justification"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Justification"], node_list_task_two["Slogans"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Justification"], node_list_task_two["Bandwagon"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Justification"], node_list_task_two["Appeal to authority"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Justification"], node_list_task_two["Flag-waving"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Justification"], node_list_task_two["Appeal to fear/prejudice"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Reasoning"], node_list_task_two["Repetition"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Logos"], node_list_task_two["Simplification"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Simplification"], node_list_task_two["Causal Oversimplification"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Simplification"], node_list_task_two["Black-and-white Fallacy/Dictatorship"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Simplification"], node_list_task_two["Thought-terminating cliché"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Reasoning"], node_list_task_two["Distraction"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Distraction"], node_list_task_two["Misrepresentation of Someone's Position (Straw Man)"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Distraction"], node_list_task_two["Presenting Irrelevant Data (Red Herring)"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Distraction"], node_list_task_two["Whataboutism"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Ethos"], node_list_task_two["Appeal to authority"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Ethos"], node_list_task_two["Glittering generalities (Virtue)"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Ethos"], node_list_task_two["Bandwagon"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Ethos"], node_list_task_two["Ad Hominem"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Ethos"], node_list_task_two["Transfer"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Ad Hominem"], node_list_task_two["Doubt"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Ad Hominem"], node_list_task_two["Name calling/Labeling"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Ad Hominem"], node_list_task_two["Smears"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Ad Hominem"], node_list_task_two["Reductio ad hitlerum"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Ad Hominem"], node_list_task_two["Whataboutism"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Pathos"], node_list_task_two["Exaggeration/Minimisation"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Pathos"], node_list_task_two["Loaded Language"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Pathos"], node_list_task_two["Appeal to (Strong) Emotions"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Pathos"], node_list_task_two["Appeal to fear/prejudice"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Pathos"], node_list_task_two["Flag-waving"]),
      Hierarchy_Rule(len(node_list_task_two), node_list_task_two["Pathos"], node_list_task_two["Transfer"]),
    ]
    self.rule_lambdas = tf.fill([len(self.rules)], 100)
    self.regularization_term = 1

  def build(self, input_shape):
    self.batch_size = input_shape
    return

  def call(self, inputs):
    student_probs = inputs[0]
    rule_distr = self.calculate_rule_constraints(student_probs, self.rules, self.rule_lambdas, self.regularization_term)
    rule_adj_probs = tf.math.multiply(student_probs, rule_distr)
    return rule_adj_probs

  def calculate_rule_constraints(self, input, rules, rule_confidences, C):
    distr_total = tf.zeros_like(input)
    for i in range(len(rules)):
      distr = rules[i].log_distribution(input, C, rule_confidences[i])
      distr_total = distr_total + distr
    distr_total = tf.clip_by_value(distr_total, -60, 60)
    return tf.math.exp(distr_total)

In [15]:
def build_model():
  # student network
  text_input_length = 512
  label_count = 30
  input_ids = tf.keras.Input(shape=(text_input_length,),dtype='int32')
  attention_masks = tf.keras.Input(shape=(text_input_length,),dtype='int32')
  pixel_inputs = tf.keras.Input(shape=(3, 224, 224), dtype='float32')
  bert_outputs = bert_encoder([input_ids, attention_masks])['last_hidden_state']
  vision_outputs = image_model(pixel_inputs)['pooler_output']
  vision_pooling_layer = tf.keras.layers.GlobalMaxPool2D(data_format='channels_first', keepdims=False)
  bert_pooling_layer = tf.keras.layers.GlobalMaxPool1D()
  bert_pooled = bert_pooling_layer(bert_outputs)
  vision_pooled = vision_pooling_layer(vision_outputs)
  combined_bert_vision = tf.keras.layers.Concatenate()([bert_pooled, vision_pooled])
  dropout = tf.keras.layers.Dropout(0.1)
  hidden_dense_inputs = dropout(combined_bert_vision)
  hidden_dense = tf.keras.layers.Dense(int(combined_bert_vision.shape[1]/2), activation='relu')
  classifier_inputs = hidden_dense(hidden_dense_inputs)
  classifier = tf.keras.layers.Dense(label_count, activation='sigmoid', name='output')
  outputs = classifier(classifier_inputs)
  # teacher network
  teacher_network = TeacherNetwork()
  teacher_outputs = teacher_network([outputs])
  concatenate_outputs = tf.keras.layers.Concatenate(axis=0, trainable=False)
  return tf.keras.Model([input_ids, attention_masks, pixel_inputs], concatenate_outputs([outputs, teacher_outputs]))

In [16]:
# hierarchial f1 implementation is based on provided scorer by semeval task organizers, found by registering here: https://propaganda.math.unipd.it/semeval2024task4/
# their f1 implementation is in turn based on sklearn implementation found here: https://github.com/globality-corp/sklearn-hierarchical-classification/blob/1de19f782d992a82dace895f9c24a0fc074baeeb/sklearn_hierarchical_classification/metrics.py#L201
threshold = 0.5

@tf.function
def macro_f1_helper(y_true, y_pred):
  # small nonzero to avoid any divide by zero issues
  small_nonzero = 0.00000000001
  true_positives = tf.math.count_nonzero(y_true * y_pred, axis=0, dtype=tf.float64)
  all_positives = tf.math.count_nonzero(y_true, axis=0, dtype=tf.float64)
  predicted_positives = tf.math.count_nonzero(y_pred, axis=0, dtype=tf.float64)
  precision = true_positives / (predicted_positives + small_nonzero)
  recall = true_positives / (all_positives + small_nonzero)
  f1 = 2 * precision * recall / (precision + recall + small_nonzero)
  return tf.math.reduce_mean(f1)

@tf.function
def h_recall_score(y_true, y_pred):
    true_positives = tf.math.count_nonzero(y_true * y_pred, dtype=tf.float64)
    all_positives = tf.math.count_nonzero(y_true, dtype=tf.float64)
    return true_positives / all_positives

@tf.function
def h_precision_score(y_true, y_pred):
    true_positives = tf.math.count_nonzero(y_true * y_pred, dtype=tf.float64)
    all_results = tf.math.count_nonzero(y_pred, dtype=tf.float64)
    return true_positives / all_results

@tf.function
def hierarchial_f1_helper(y_true, y_pred, beta=1.):
    hP = h_precision_score(y_true, y_pred)
    hR = h_recall_score(y_true, y_pred)
    hF = (1. + beta ** 2.) * hP * hR / (beta ** 2. * hP + hR)
    return hP, hR, hF

# Metric Functions
@tf.function
def macro_f1_student(y_true, y_pred):
    output_divide = int(y_pred.shape[0]/2)
    p_student = y_pred[0:output_divide,:]
    p_student_rounded = tf.where(p_student > threshold, 1, 0)
    return macro_f1_helper(tf.cast(y_true, tf.int32), tf.cast(p_student_rounded, tf.int32))

@tf.function
def macro_f1_teacher(y_true, y_pred):
    output_divide = int(y_pred.shape[0]/2)
    p_teacher = y_pred[output_divide: y_pred.shape[0], :]
    p_teacher_rounded = tf.where(p_teacher > threshold, 1, 0)
    return macro_f1_helper(tf.cast(y_true, tf.int32), tf.cast(p_teacher_rounded, tf.int32))

@tf.function
def hierarchial_f1_student(y_true, y_pred):
    output_divide = int(y_pred.shape[0]/2)
    p_student = y_pred[0:output_divide,:]
    p_student_rounded = tf.where(p_student > threshold, 1, 0)
    return hierarchial_f1_helper(tf.cast(y_true, tf.int32), tf.cast(p_student_rounded, tf.int32))

@tf.function
def hierarchial_f1_teacher(y_true, y_pred):
    output_divide = int(y_pred.shape[0]/2)
    p_teacher = y_pred[output_divide: y_pred.shape[0], :]
    p_teacher_rounded = tf.where(p_teacher > threshold, 1, 0)
    return hierarchial_f1_helper(tf.cast(y_true, tf.int32), tf.cast(p_teacher_rounded, tf.int32))

In [17]:
from tqdm.autonotebook import tqdm
checkpoint_save_dir = "/tmp/model"

@tf.function
def train_step(x, y, model, loss, step):
  with tf.GradientTape() as tape:
    probs = model(x, training=True)
    student_loss, teacher_loss, final_loss = loss(y, probs, step)
  grads = tape.gradient(final_loss, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))
  return probs, y, student_loss, teacher_loss, final_loss

@tf.function
def test_step(x, y, model, loss, step):
  probs = model(x, training=False)
  student_loss, teacher_loss, final_loss = loss(y, probs, step)
  return probs, y, student_loss, teacher_loss, final_loss

def print_metrics(expected_vals, total_preds):
  print(f"macro f1 teacher: {macro_f1_teacher(expected_vals, total_preds)}")
  print(f"macro f1 student: {macro_f1_student(expected_vals, total_preds)}")
  student_hprec, student_hrec, student_hf1 =  hierarchial_f1_student(expected_vals, total_preds)
  teacher_hprec, teacher_hrec, teacher_hf1 =  hierarchial_f1_teacher(expected_vals, total_preds)
  print(f"Student Hierarchial; precision: {student_hprec}, recall: {student_hrec}, f1: {student_hf1}")
  print(f"Teacher Hierarchial; precision: {teacher_hprec}, recall: {teacher_hrec}, f1: {teacher_hf1}")

def training_loop(
  model,
  train_input_ids,
  train_attention_masks,
  train_images,
  train_labels,
  dev_input_ids,
  dev_attention_masks,
  dev_images,
  dev_labels,
  epochs,
  batch_size,
  optimizer,
  loss,
  checkpoint
):
  steps_per_epoch = int(train_input_ids.shape[0] / batch_size)
  dev_steps_per_epoch = int(dev_input_ids.shape[0] / batch_size)
  num_train_steps = int(steps_per_epoch * epochs)
  current_step = 0
  manager = tf.train.CheckpointManager(checkpoint, directory="/tmp/model", max_to_keep=epochs)
  for epoch in range(epochs):
      print("\nStart of epoch %d" % (epoch,))
      # Iterate over the batches of the dataset.
      train_teacher_loss_list = []
      train_student_loss_list = []
      train_final_loss_list = []
      train_preds_student_list = []
      train_preds_teacher_list = []
      train_expected_list = []
      for step in tqdm(range(steps_per_epoch)):
        # get current batch and run through model
        current_step += 1
        batch_input_ids = train_input_ids[step*batch_size:(step + 1)*batch_size, :]
        batch_attention_masks = train_attention_masks[step*batch_size:(step + 1)*batch_size, :]
        batch_images = train_images[step*batch_size:(step + 1)*batch_size, :, :, :]
        batch_labels = train_labels[step*batch_size:(step + 1)*batch_size, :]
        preds, expected, student_loss, teacher_loss, final_loss = train_step([batch_input_ids, batch_attention_masks, batch_images], batch_labels, model, loss, tf.convert_to_tensor(current_step, tf.float32))
        # update predictions for metric evaluation at the end
        output_divide = int(preds.shape[0]/2)
        p_student = preds[0:output_divide, :]
        p_teacher = preds[output_divide: preds.shape[0], :]
        train_preds_student_list.append(p_student)
        train_preds_teacher_list.append(p_teacher)
        train_student_loss_list.append(student_loss)
        train_teacher_loss_list.append(teacher_loss)
        train_final_loss_list.append(final_loss)
        train_expected_list.append(expected)
      # aggregate all predictions and print loss + metrics for training set
      train_preds_teacher = tf.concat(train_preds_teacher_list, 0)
      train_preds_student = tf.concat(train_preds_student_list, 0)
      total_preds = tf.concat([train_preds_student, train_preds_teacher], 0)
      expected_vals = tf.concat(train_expected_list, 0)
      print("Training Data Results:")
      print(f"student loss: {sum(train_student_loss_list)/len(train_student_loss_list)}")
      print(f"teacher loss: {sum(train_teacher_loss_list)/len(train_teacher_loss_list)}")
      print(f"final loss: {sum(train_final_loss_list)/len(train_final_loss_list)}")
      print_metrics(expected_vals, total_preds)
      manager.save(checkpoint_number=epoch)

      test_teacher_loss_list = []
      test_student_loss_list = []
      test_final_loss_list = []
      test_preds_student_list = []
      test_preds_teacher_list = []
      test_expected_list = []
      # Run a validation loop at the end of each epoch.
      for i in range(dev_steps_per_epoch):
        # get current batch and run through model
        batch_input_ids = dev_input_ids[i*batch_size:(i + 1)*batch_size, :]
        batch_attention_masks = dev_attention_masks[i*batch_size:(i + 1)*batch_size, :]
        batch_labels = dev_labels[i*batch_size:(i + 1)*batch_size, :]
        batch_images = dev_images[i*batch_size:(i + 1)*batch_size, :, :, :]
        batch_images = tf.convert_to_tensor(batch_images)
        preds, expected, student_loss, teacher_loss, final_loss = test_step([batch_input_ids, batch_attention_masks, batch_images], batch_labels, model, loss, tf.convert_to_tensor(current_step, tf.float32))
        output_divide = int(preds.shape[0]/2)
        # update predictions for metric evaluation at the end
        p_student = preds[0:output_divide, :]
        p_teacher = preds[output_divide: preds.shape[0], :]
        test_preds_student_list.append(p_student)
        test_preds_teacher_list.append(p_teacher)
        test_student_loss_list.append(student_loss)
        test_teacher_loss_list.append(teacher_loss)
        test_final_loss_list.append(final_loss)
        test_expected_list.append(expected)
      # aggregate all predictions and print loss + metrics for validation set
      test_student_preds = tf.concat(test_preds_student_list, 0)
      test_teacher_preds = tf.concat(test_preds_teacher_list, 0)
      test_preds = tf.concat([test_student_preds, test_teacher_preds], 0)
      test_expected_vals = tf.concat(test_expected_list, 0)
      print("\nTest Data Results:")
      print(f"student loss: {sum(test_student_loss_list)/len(test_student_loss_list)}")
      print(f"teacher loss: {sum(test_teacher_loss_list)/len(test_teacher_loss_list)}")
      print(f"final loss: {sum(test_final_loss_list)/len(test_final_loss_list)}")
      print_metrics(test_expected_vals, test_preds)


In [19]:
# loss function for equal weighted teacher and student losses
def custom_loss_wrapper():
  bce_loss= tf.keras.losses.BinaryCrossentropy(from_logits=False)
  kl_loss = tf.keras.losses.KLDivergence()
  def custom_loss(y_true, y_pred, step):
    output_divide = int(y_pred.shape[0]/2)
    p_student = y_pred[0:output_divide,:]
    p_teacher = y_pred[output_divide: y_pred.shape[0], :]
    student_loss = bce_loss(y_true, p_student)
    total_teacher_loss = 0
    # kl divergence in tensorflow requires all values to add up to 1, so done iteratively and averaged rather than applied to the entire output at once
    for i in range(p_teacher.shape[0]):
      for j in range(p_student.shape[1]):
        total_teacher_loss += kl_loss([p_teacher[i][j], 1 - p_teacher[i][j]], [p_student[i][j], 1 - p_student[i][j]])
    teacher_loss = total_teacher_loss / (p_teacher.shape[0] * p_student.shape[1])
    final_loss = student_loss + teacher_loss
    return  student_loss, teacher_loss, final_loss
  return custom_loss

In [None]:
# loss function to use if ignoring the effect of teacher network to evaluate base network only
def custom_loss_wrapper(initial_value, lower_bound, total_steps, start_step):
  bce_loss= tf.keras.losses.BinaryCrossentropy(from_logits=False)
  def custom_loss(y_true, y_pred, step):
    output_divide = int(y_pred.shape[0]/2)
    p_student = y_pred[0:output_divide,:]
    student_loss = bce_loss(y_true, p_student)
    return  student_loss, student_loss, student_loss
  return custom_loss

In [20]:
model = build_model()
checkpoint = tf.train.Checkpoint(model)

In [None]:
#optimizer
epochs = 2
batch_size = 4
steps_per_epoch = final_train_data["input_ids"].shape[0] / batch_size
num_train_steps = steps_per_epoch * epochs
num_warmup_steps = int(0.1*num_train_steps)

init_lr = 3e-5
optimizer = optimization.create_optimizer(init_lr=init_lr,
                                          num_train_steps=num_train_steps,
                                          num_warmup_steps=num_warmup_steps,
                                          optimizer_type='adamw')

loss= custom_loss_wrapper()
#metrics
training_loop(
    model,
    final_train_data["input_ids"],
    final_train_data["attention_mask"],
    final_train_data["images"],
    final_train_labels,
    final_eval_data["input_ids"],
    final_eval_data["attention_mask"],
    final_eval_data["images"],
    final_eval_labels,
    epochs,
    batch_size,
    optimizer,
    loss,
    checkpoint
)

In [None]:
# replace the number following ckpt with the checkpoint you would like to restore for final evaluation (choose checkpoint for which validation loss starts to increase)
checkpoint.restore('/tmp/model/ckpt-1')

teacher_preds_list = []
student_preds_list = []
for i in range(int(final_eval_labels.shape[0]/ batch_size)):
  preds = model([final_eval_data["input_ids"][i*batch_size:(i+1)*batch_size, :], final_eval_data["attention_mask"][i*batch_size:(i+1)*batch_size, :], final_eval_data["images"][i*batch_size:(i+1)*batch_size, :]])
  output_divide = int(preds.shape[0]/2)
  p_student = preds[0:output_divide, :]
  p_teacher = preds[output_divide: preds.shape[0], :]
  teacher_preds_list.append(p_teacher)
  student_preds_list.append(p_student)

teacher_preds = tf.concat(teacher_preds_list, 0)
student_preds = tf.concat(student_preds_list, 0)
total_preds = tf.concat([student_preds, teacher_preds], 0)

print_metrics(final_eval_labels[0:teacher_preds.shape[0]], total_preds)

In [None]:
def hierarchial_violations_helper(y_pred):
  y_pred = tf.where(y_pred > threshold, 1, 0)
  y_pred = tf.cast(y_pred, tf.int32)
  violation_matrix = np.zeros((y_pred.shape[1], y_pred.shape[1]))
  # row child, column parent
  violation_count = 0
  rows_with_violations = set()
  for i in range(y_pred.shape[0]):
    parent_node_indicies = set()
    for j in range(y_pred.shape[1]):
      if y_pred[i,j] == 1:
        parent_nodes = parent_child_dict[id_to_node[j]]
        for parent_node in parent_nodes:
          parent_node_indicies.add(node_list_task_two[parent_node])
          if y_pred[i, node_list_task_two[parent_node]] == 0:
            violation_matrix[j, node_list_task_two[parent_node]] += 1
    for parent_node_index in parent_node_indicies:
      if y_pred[i, parent_node_index] == 0:
        rows_with_violations.add(i)
        violation_count += 1
  return violation_count, violation_matrix
violations, violation_matrix = hierarchial_violations_helper(student_preds)
print(violations)