In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.utils import Progbar

import tensorflow_addons as tfa

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

config = tf.compat.v1.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.95 # Change this value as per requirement
tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))

In [None]:
import os, nltk, re, sys
import pandas as pd
import pickle as pkl
import numpy as np
import matplotlib.pyplot as plt
import jieba
from einops import rearrange

from collections import defaultdict
from PIL import Image

In [None]:
test_id = pkl.load(open("./weibo/test_id.pickle", 'rb'))
train_id = pkl.load(open("./weibo/train_id.pickle", 'rb'))
validate_id = pkl.load(open("./weibo/validate_id.pickle", 'rb'))

train_id = pd.DataFrame({"values": map(int, train_id.values()), "tweet id": map(int, train_id.keys())})
test_id = pd.DataFrame({"values": map(int, test_id.values()), "tweet id": map(int, test_id.keys())})
validation_id = pd.DataFrame({"values": map(int, validate_id.values()), "tweet id": map(int, validate_id.keys())})

ids = {
    "train": train_id.set_index("tweet id")['values'],
    "test": test_id.set_index("tweet id")['values'],
    "validation": validation_id.set_index("tweet id")['values']
}

In [None]:
columns="tweet id|user name|tweet url|user url|publish time| original?|retweet count|comment count|praise count|user id|user authentication type|user fans count|user follow count|user tweet count|publish platform".split("|")

In [None]:
def clean_str_sst(string):
    """
    Tokenization/string cleaning for the SST dataset
    """
    string = re.sub("[，。 :,.；|-“”——_/nbsp+&;@、《》～（）())#O！：【】]", "", string)
    return string.strip().lower()

def stopwordslist(filepath = './weibo/stop_words.txt'):
    stopwords = {}
    for line in open(filepath, 'r', encoding='utf-8').readlines():
        line = line.strip()
        stopwords[line] = 1
    #stopwords = [line.strip() for line in open(filepath, 'r', encoding='utf-8').readlines()]
    return stopwords


In [None]:
# ids = pd.concat([train_id, validation_id, test_id]).set_index("tweet id")['values']
stop_words = stopwordslist()
image_paths = {
    0: "./weibo/nonrumor_images/",
    1: "./weibo/rumor_images/"
}
image_filelist = {
    0: os.listdir(image_paths[0]),
    1: os.listdir(image_paths[1])
}

def load_tweets(split):
    map_id = {}
    tweet_data = []
    pre_path = "./weibo/tweets/"
    id = ids[split]
    file_list = [(0, pre_path + "test_nonrumor.txt"), (1, pre_path + "test_rumor.txt"), \
                        (0, pre_path + "train_nonrumor.txt"), (1, pre_path + "train_rumor.txt")]
    
    for label, path in file_list:
        with open(path, 'r', encoding='utf-8') as input_file:
            while True:
                try:
                    lines = ['', '', '']
                    data = {}

                    for i in range(len(lines)):
                        lines[i]=next(input_file).replace("\n", "")

                    l1, l2, l3 = lines
                    tweet_id = int(l1.split('|')[0])
                    # get tweet details

                    data.update(dict([(col, item) for col, item in zip(columns, l1.split("|"))]))

                    found = False
                    for item in l2.split("|")[:-1]:
                        item = item.split("/")[-1]
                        if item in image_filelist[label]:
                            found = True
                            break

                    data['image'] = image_paths[label] + item

                    l3 = clean_str_sst(l3)
                    seg_list = jieba.cut_for_search(l3)
                    new_seg_list = []
                    for word in seg_list:
                        if word not in stop_words:
                            new_seg_list.append(word)

                    l3 = " ".join(new_seg_list)

                    data['tweet_content'] = l3
                    
                    # there are more than 10 tokens in the text
                    if len(l3) > 10 and tweet_id in id.index:
                        event = id[tweet_id]
                        if event not in map_id:
                            map_id[event] = len(map_id)
                            event = map_id[event]
                        else:
                            event = map_id[event]

                        data['event'] = event
                        data['label'] = label
                        tweet_data.append(data)
                        
                except StopIteration:
                    print("End of file reached")
                    break

                # except Exception as e:
                #     print(e)
                #     # break

    return pd.DataFrame.from_records(tweet_data)

In [None]:
train_dataset = load_tweets("train")[['tweet id', 'tweet_content', 'image', 'event', 'label']]
test_dataset = load_tweets('test')[['tweet id', 'tweet_content', 'image', 'event', 'label']]

# max number of events are 10, but test set contain 14 unique events, we will the everything that is > 9
test_dataset = test_dataset[test_dataset['event'] <= 9]

validation_dataset = load_tweets('validation')[['tweet id', 'tweet_content', 'image', 'event', 'label']]
all_text = pd.concat([train_dataset['tweet_content'] + test_dataset['tweet_content'] + validation_dataset['tweet_content']]).dropna()


In [None]:
image_roots = {
    1: "./weibo/rumor_images",
    0: "./weibo/nonrumor_images",
}
def load_image(path):
    def center_crop(image, dim):
        width, height = image.size
        new_width, new_height = dim, dim

        left = (width - new_width)/2
        top = (height - new_height)/2
        right = (width + new_width)/2
        bottom = (height + new_height)/2

        # Crop the center of the image
        image = image.crop((left, top, right, bottom))
        return image

    image = Image.open(path)
    if len(np.array(image).shape) != 3:
        new_image = Image.new('RGB', image.size)
        new_image.paste(image)
        image = new_image

    image = image.resize((256, 256))
    image = center_crop(image, 224)
    image = np.array(image, dtype=np.float32)/255

    

    return image

In [None]:
from transformers import AutoTokenizer


In [None]:
embedding_path = "./weibo/w2v.pickle"
w2v = pkl.load(open(embedding_path, 'rb'), encoding='latin1')
vocab = list(w2v.keys())

tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-chinese")
vocab_size = tokenizer.vocab_size

In [None]:
BATCH_SIZE = 32
SEQ_LENGTH = 28
VECTOR_DIM = 32

def tokenize(sentence):
    return tokenizer(str(sentence), max_length=SEQ_LENGTH, padding='max_length', truncation=True)['input_ids']

def get_matrix(sentence):
    vectors = np.zeros((SEQ_LENGTH, VECTOR_DIM), dtype=np.float32)
    for i, word in enumerate(sentence[:SEQ_LENGTH]):
        vectors[i, :] = w2v[word]

    return vectors

def preprocess_image(text, image, event, label):
    image = load_image(image.numpy().decode('utf-8'))
    
    return text, image, event, label

def dict_map(text, image, event, label):
    return {
        "text": text,
        "image": image
    }, event, label

In [None]:
# pos_train_dataset = train_dataset[train_dataset['label'] == 1]
# neg_train_dataset = train_dataset[train_dataset['label'] == 0]

# pos_train_texts = np.array(pos_train_dataset['tweet_content'].map(tokenize).to_list(), dtype=np.float32)
# pos_train_images = pos_train_dataset['image'].to_list()
# pos_train_events = pos_train_dataset['event'].to_list()
# pos_train_labels = pos_train_dataset['label'].to_list()
# pos_train_ds = (tf.data.Dataset.from_tensor_slices((pos_train_texts, pos_train_images, pos_train_events, pos_train_labels))
#             .map(lambda text, image, event, label: tf.py_function(preprocess_image, [text, image, event, label], [tf.float32, tf.float32, tf.int32, tf.int32]))
#             .map(dict_map)
#             .shuffle(1000)
#             .batch(BATCH_SIZE)
#             .prefetch(tf.data.AUTOTUNE)
#             )

# neg_train_texts = np.array(neg_train_dataset['tweet_content'].map(tokenize).to_list(), dtype=np.float32)
# neg_train_images = neg_train_dataset['image'].to_list()
# neg_train_events = neg_train_dataset['event'].to_list()
# neg_train_labels = neg_train_dataset['label'].to_list()
# neg_train_ds = (tf.data.Dataset.from_tensor_slices((neg_train_texts, neg_train_images, neg_train_events, neg_train_labels))
#             .map(lambda text, image, event, label: tf.py_function(preprocess_image, [text, image, event, label], [tf.float32, tf.float32, tf.int32, tf.int32]))
#             .map(dict_map)
#             .shuffle(1000)
#             .batch(BATCH_SIZE)
#             .prefetch(tf.data.AUTOTUNE)
#             # .repeat()
#             )


# pos_validation_dataset = validation_dataset[validation_dataset['label'] == 1]
# neg_validation_dataset = validation_dataset[validation_dataset['label'] == 0]

# pos_validation_texts = np.array(pos_validation_dataset['tweet_content'].map(tokenize).to_list(), dtype=np.float32)
# pos_validation_images = pos_validation_dataset['image'].to_list()
# pos_validation_events = pos_validation_dataset['event'].to_list()
# pos_validation_labels = pos_validation_dataset['label'].to_list()
# pos_validation_ds = (tf.data.Dataset.from_tensor_slices((pos_validation_texts, pos_validation_images, pos_validation_events, pos_validation_labels))
#             .map(lambda text, image, event, label: tf.py_function(preprocess_image, [text, image, event, label], [tf.float32, tf.float32, tf.int32, tf.int32]))
#             .map(dict_map)
#             .shuffle(1000)
#             .batch(BATCH_SIZE)
#             .prefetch(tf.data.AUTOTUNE)
#             )

# neg_validation_texts = np.array(neg_validation_dataset['tweet_content'].map(tokenize).to_list(), dtype=np.float32)
# neg_validation_images = neg_validation_dataset['image'].to_list()
# neg_validation_events = neg_validation_dataset['event'].to_list()
# neg_validation_labels = neg_validation_dataset['label'].to_list()
# neg_validation_ds = (tf.data.Dataset.from_tensor_slices((neg_validation_texts, neg_validation_images, neg_validation_events, neg_validation_labels))
#             .map(lambda text, image, event, label: tf.py_function(preprocess_image, [text, image, event, label], [tf.float32, tf.float32, tf.int32, tf.int32]))
#             .map(dict_map)
#             .shuffle(1000)
#             .batch(BATCH_SIZE)
#             .prefetch(tf.data.AUTOTUNE)
#             # .repeat()
#             )


# pos_test_dataset = test_dataset[test_dataset['label'] == 1]
# neg_test_dataset = test_dataset[test_dataset['label'] == 0]

# pos_test_texts = np.array(pos_test_dataset['tweet_content'].map(tokenize).to_list(), dtype=np.float32)
# pos_test_images = pos_test_dataset['image'].to_list()
# pos_test_events = pos_test_dataset['event'].to_list()
# pos_test_labels = pos_test_dataset['label'].to_list()
# pos_test_ds = (tf.data.Dataset.from_tensor_slices((pos_test_texts, pos_test_images, pos_test_events, pos_test_labels))
#             .map(lambda text, image, event, label: tf.py_function(preprocess_image, [text, image, event, label], [tf.float32, tf.float32, tf.int32, tf.int32]))
#             .map(dict_map)
#             .shuffle(1000)
#             .batch(BATCH_SIZE)
#             .prefetch(tf.data.AUTOTUNE)
#             )

# neg_test_texts = np.array(neg_test_dataset['tweet_content'].map(tokenize).to_list(), dtype=np.float32)
# neg_test_images = neg_test_dataset['image'].to_list()
# neg_test_events = neg_test_dataset['event'].to_list()
# neg_test_labels = neg_test_dataset['label'].to_list()
# neg_test_ds = (tf.data.Dataset.from_tensor_slices((neg_test_texts, neg_test_images, neg_test_events, neg_test_labels))
#             .map(lambda text, image, event, label: tf.py_function(preprocess_image, [text, image, event, label], [tf.float32, tf.float32, tf.int32, tf.int32]))
#             .map(dict_map)
#             .shuffle(1000)
#             .batch(BATCH_SIZE)
#             .prefetch(tf.data.AUTOTUNE)
#             # .repeat()
#             )

In [None]:
# train_texts = np.array((train_dataset['tweet_content'].map(get_matrix).to_list()))
train_texts = np.array(train_dataset['tweet_content'].map(tokenize).to_list(), dtype=np.float32)

train_images = train_dataset['image'].to_list()
train_events = train_dataset['event'].to_list()
train_labels = train_dataset['label'].to_list()
train_ds = (tf.data.Dataset.from_tensor_slices((train_texts, train_images, train_events, train_labels))
            .map(lambda text, image, event, label: tf.py_function(preprocess_image, [text, image, event, label], [tf.float32, tf.float32, tf.int32, tf.int32]))
            .map(dict_map)
            .shuffle(1000)
            .batch(BATCH_SIZE)
            .prefetch(tf.data.AUTOTUNE)
            )


# test_texts = np.array((test_dataset['tweet_content'].map(get_matrix).to_list()))
test_texts = np.array(test_dataset['tweet_content'].map(tokenize).to_list(), dtype=np.float32)
test_images = test_dataset['image'].to_list()
test_events = test_dataset['event'].to_list()
test_labels = test_dataset['label'].to_list()
test_ds = (tf.data.Dataset.from_tensor_slices((test_texts, test_images, test_events, test_labels))          
            .map(lambda text, image, event, label: tf.py_function(preprocess_image, [text, image, event, label], [tf.float32, tf.float32, tf.int32, tf.int32]))
            .map(dict_map)
            .shuffle(1000)
            .batch(BATCH_SIZE)
            .prefetch(tf.data.AUTOTUNE)
            )

# validation_texts = np.array((validation_dataset['tweet_content'].map(get_matrix).to_list()))
validation_texts = np.array(validation_dataset['tweet_content'].map(tokenize).to_list(), dtype=np.float32)
validation_images = validation_dataset['image'].to_list()
validation_events = validation_dataset['event'].to_list()
validation_labels = validation_dataset['label'].to_list()
validation_ds = (tf.data.Dataset.from_tensor_slices((validation_texts, validation_images, validation_events, validation_labels))
            .map(lambda text, image, event, label: tf.py_function(preprocess_image, [text, image, event, label], [tf.float32, tf.float32, tf.int32, tf.int32]))
            .map(dict_map)
            .shuffle(1000)
            .batch(BATCH_SIZE)
            .prefetch(tf.data.AUTOTUNE)
            )

In [None]:
HIDDEN_DIMS = 32
NUM_FILTERS = 20
WINDOW_SIZE = [1, 2, 3, 4]
EPOCHS = 10

p = np.linspace(0, 1, 10)
alpha = 10
beta = 0.75
lmd = 1

In [None]:
def load_feature_extractor():
    vgg19 = keras.applications.VGG19(
        include_top=False,
        input_shape=(224,224,3)
    )
    vgg19.layers[0]._name = "image"
    vgg19.trainable = False

    text_input = layers.Input((SEQ_LENGTH,), name='text')
    text_embeddings = layers.Embedding(vocab_size, HIDDEN_DIMS)(text_input)

    # image feature extractor
    image_features = layers.Flatten()(vgg19.output)
    image_features = layers.Dense(HIDDEN_DIMS, activation='leaky_relu')(image_features)

    # text feature extractor
    convs = [layers.Conv1D(NUM_FILTERS, k)(text_embeddings) for k in WINDOW_SIZE]
    pools = [layers.MaxPooling1D(C.shape[1])(C) for C in convs]

    text_cnn = layers.Concatenate()(pools)
    text_cnn = layers.Dense(HIDDEN_DIMS, activation='leaky_relu')(text_cnn)

    feature_extractor = Model(inputs=[vgg19.input, text_input], outputs=[text_cnn[:, 0, :], image_features])
    return feature_extractor

In [None]:
# Image augumentation
augumentor = keras.Sequential([
    layers.RandomFlip("horizontal_and_vertical"),
    layers.RandomRotation(0.2),
    layers.RandomTranslation(height_factor=[-.2, .2], width_factor=[-.2, .2])
])

def augument(image, n_aug=4):
    return [augumentor(image) for _ in range(n_aug)]

### Label aware contrastive loss for pretraining

#### Approach 1: Training similar to CLiP
* Calculate vector similarities between text and images, 
* if the corresponding label is 0, then similarity is expected to be 1
* if the cooresponding label is 1, then similarity is expected to be -1
* based on the instructions provided in the assignment

In [None]:
feature_extractor = load_feature_extractor()

In [None]:
def cosine_similarity_batched(A, B):
    # Normalize vectors
    A_normalized = tf.nn.l2_normalize(A, axis=1)
    B_normalized = tf.nn.l2_normalize(B, axis=1)

    # Compute dot product
    dot_product = tf.matmul(A_normalized, B_normalized, transpose_b=True)

    return dot_product

In [None]:
X, E, Y = train_ds.as_numpy_iterator().next()
images = X['image']
texts = X['text']

In [None]:
augumented_images = augument(images)
augumented_texts = [texts] * len(augumented_images)

In [None]:
image_feature_vector, text_feature_vector, labels = [], [], []

for i, t in zip(augumented_images, augumented_texts):
    encoded_i, encoded_t = feature_extractor((i, t))
    image_feature_vector.append(encoded_i)
    text_feature_vector.append(encoded_t)
    labels.append(Y)

image_feature_vector = tf.concat(image_feature_vector, axis=0)
text_feature_vector = tf.concat(text_feature_vector, axis=0)
labels = tf.concat(labels, axis=0)

In [None]:
image_feature_vector.shape, text_feature_vector.shape, labels.shape

In [None]:
similarities = cosine_similarity_batched(image_feature_vector, text_feature_vector)
similarities_logits = tf.linalg.tensor_diag_part(similarities)
loss = tf.losses.mean_squared_error(labels, similarities_logits)

In [None]:
optimizer = keras.optimizers.AdamW(learning_rate=0.001)
n_augumentations = 4

for i in range(EPOCHS):
    print("EPOCH ", i+1)

    train_ds_iter = train_ds.as_numpy_iterator()

    train_progbar = Progbar(len(train_ds))

    loss_metric = keras.metrics.Mean(name="loss")

    for step in range(len(train_ds)):
        X, E, Y = train_ds_iter.next()
        images = X['image']
        texts = X['text']

        # Get image augumentations
        augumented_images = augument(images, n_aug=n_augumentations)
        augumented_texts = [texts] * n_augumentations
        

        with tf.GradientTape() as tape:
            image_feature_vector, text_feature_vector = [], []

            for i, t in zip(augumented_images, augumented_texts):
                encoded_i, encoded_t = feature_extractor((i, t))
                image_feature_vector.append(encoded_i)
                text_feature_vector.append(encoded_t)

            image_feature_vector = tf.concat(image_feature_vector, axis=0)
            text_feature_vector = tf.concat(text_feature_vector, axis=0)

            similarities = cosine_similarity_batched(image_feature_vector, text_feature_vector)
            similarities_logits = tf.linalg.tensor_diag_part(similarities)
            loss = tf.losses.mean_squared_error(labels, similarities_logits)
            # labels = (Y.astype(bool) == False).astype(int)
            # labels = tf.concat([labels] * n_augumentations, axis=0)
            # labels = tf.where(tf.equal(labels, 0), -1 * tf.ones_like(labels), labels)
            # targets = tf.linalg.diag(labels)

            # loss = tf.reduce_mean(tf.losses.mean_squared_error(targets, similarities))

        grads = tape.gradient(loss, feature_extractor.trainable_variables)
        optimizer.apply_gradients(zip(grads, feature_extractor.trainable_variables))

        loss_value = loss_metric(loss)

        train_progbar.update(step+1, [
            ('loss', loss_value),
        ])
    
    print()

#### Approach 2: Label Aware Supervised Contrastive Loss
* based on the code that was provided in the assignment document and the Supervised Contrastive Loss paper

In [None]:
feature_extractor = load_feature_extractor()

In [None]:
class SupervisedContrastiveLoss(keras.losses.Loss):
    def __init__(self, temperature=1, name=None):
        super().__init__(name=name)
        self.temperature = temperature

    def __call__(self, labels, feature_vectors, sample_weight=None):
        # Normalize feature vectors
        feature_vectors_normalized = tf.math.l2_normalize(feature_vectors, axis=1)
        # Compute logits
        logits = tf.divide(
            tf.matmul(
                feature_vectors_normalized, tf.transpose(feature_vectors_normalized)
            ),
            self.temperature,
        )
        return tfa.losses.npairs_loss(tf.squeeze(labels), logits)

In [None]:
temperature = 0.07

optimizer = keras.optimizers.AdamW(learning_rate=0.001)
scl = SupervisedContrastiveLoss(temperature=temperature)

for i in range(5): # 5 epochs
    print("EPOCH ", i+1)

    train_ds_iter = train_ds.as_numpy_iterator()

    train_progbar = Progbar(len(train_ds))

    loss_metric = keras.metrics.Mean(name="loss")


    for step in range(len(train_ds)):
        X, E, Y = train_ds_iter.next()
        images = X['image']
        texts = X['text']

        positive_idx = tf.where(Y == 1)[:, 0]
        negative_idx = tf.where(Y == 0)[:, 0]
        augumented_pos_images = augument(images)

        with tf.GradientTape() as tape:
            processed = tf.convert_to_tensor([feature_extractor((aug_imgs, aug_texts)) for aug_imgs, aug_texts in zip(augumented_pos_images, [texts]*len(augumented_pos_images))])
            encoded_images = processed[:, 0]
            encoded_texts = processed[:, 1]

            positive_label_img_features = rearrange(tf.gather(encoded_images, positive_idx, axis=1), 'a b d -> (a b) d')
            positive_label_txt_features = rearrange(tf.gather(encoded_texts, positive_idx, axis=1), 'a b d -> (a b) d')
            positive_data_features = tf.concat([positive_label_img_features, positive_label_txt_features], axis=0)

            negative_label_img_features = rearrange(tf.gather(encoded_images, negative_idx, axis=1), 'a b d -> (a b) d')
            negative_label_txt_features = rearrange(tf.gather(encoded_texts, negative_idx, axis=1), 'a b d -> (a b) d')
            negative_data_features = tf.concat([negative_label_img_features, negative_label_txt_features], axis=0)

            all_features = tf.concat([positive_data_features, negative_data_features], axis=0)
            all_labels = tf.concat([tf.ones(positive_data_features.shape[0]), tf.zeros(negative_data_features.shape[0])], axis=0)
            loss = scl(all_labels, all_features)


        grads = tape.gradient(loss, feature_extractor.trainable_variables)
        optimizer.apply_gradients(zip(grads, feature_extractor.trainable_variables))

        loss_value = loss_metric(loss)

        train_progbar.update(step+1, [
            ('loss', loss_value),
        ])
    
    print()

### Training the full network with fake news detector and event classifier

In [None]:
class GradientReversal(keras.layers.Layer):
    def __init__(self, λ=1, **kwargs):
        super(GradientReversal, self).__init__(**kwargs)
        self.λ = λ

    @staticmethod
    @tf.custom_gradient
    def reverse_gradient(x, λ):
        return tf.identity(x), lambda dy: (-dy, None)

    def call(self, x):
        return self.reverse_gradient(x, self.λ)

    def compute_mask(self, inputs, mask=None):
        return mask

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        return super(GradientReversal, self).get_config() | {'λ': self.λ}

In [None]:
# This 'feature_extractor' comes from either approach 1 or approach 2
features = layers.Concatenate(name="multi_modal_feature")(feature_extractor.outputs)

# Fake news detector
features = layers.Dropout(0.2)(features)
predictor = layers.Dense(2, activation="softmax", name='prediction')(features)

# Event Discriminator
grd_r = GradientReversal(λ=lmd)(features)
event_discriminator = layers.Dense(HIDDEN_DIMS, activation='leaky_relu')(grd_r)
event_discriminator = layers.Dropout(0.2)(event_discriminator)
event_discriminator = layers.Dense(HIDDEN_DIMS, activation='leaky_relu')(event_discriminator)
event_discriminator = layers.Dense(10, activation='softmax', name='event_discriminator')(event_discriminator)

model = Model(inputs=feature_extractor.inputs, outputs=[predictor, event_discriminator])
model.compile()

In [None]:
margin=1

binary_ce = keras.losses.BinaryCrossentropy()
categorical_ce = keras.losses.CategoricalCrossentropy()
optimizer = keras.optimizers.AdamW(learning_rate=0.001)

for i in range(20): # 20 epochs
    print("EPOCH ", i+1)
    train_ds_iter = train_ds.as_numpy_iterator()
    validation_ds_iter = validation_ds.as_numpy_iterator()

    train_progbar = Progbar(len(train_ds))
    
    loss_D_metric = keras.metrics.Mean(name="detector_loss")
    loss_E_metric = keras.metrics.Mean(name="event_loss")
    loss_final_metric = keras.metrics.Mean(name="final_loss")

    fake_news_accuracy = keras.metrics.CategoricalAccuracy (name="fake_news_accuracy")
    event_accuracy = keras.metrics.CategoricalAccuracy (name="event_accuracy")

    for step in range(len(train_ds)):
    # for step in range(1):
        X, E, Y = train_ds_iter.next()
        Y = keras.utils.to_categorical(Y, num_classes=2)
        E = keras.utils.to_categorical(E, num_classes=10)

        with tf.GradientTape() as tape:
            pred, event = model(X)

            Ld = binary_ce(Y, pred)
            Le = categorical_ce(E, event)

            final_loss = Ld / (lmd*Le)

        grads = tape.gradient(final_loss, model.trainable_variables)

        # calculate Learning rate
        # optimizer.learning_rate = lr_schedule(i, optimizer.learning_rate)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        # metrics
        Acc_F = fake_news_accuracy(Y, pred)
        Acc_E = event_accuracy(E, event)
        m_Ld = loss_D_metric(Ld)
        m_Le = loss_E_metric(Le)
        m_L_final = loss_final_metric(final_loss)

        train_progbar.update(step+1, [
            ('lr', optimizer.learning_rate),
            ('detector loss', m_Ld),
            ('event loss', m_Le),
            ('final loss', m_L_final),
            ('fake news accuracy', Acc_F),
            ('event accuracy', Acc_E),
        ])
    print()
    
    
    loss_D_metric = keras.metrics.Mean(name="detector_loss")
    loss_E_metric = keras.metrics.Mean(name="event_loss")
    loss_final_metric = keras.metrics.Mean(name="final_loss")
    
    fake_news_accuracy = keras.metrics.CategoricalAccuracy (name="fake_news_accuracy")
    event_accuracy = keras.metrics.CategoricalAccuracy (name="event_accuracy")

    validation_progbar = Progbar(len(validation_ds))

    for step in range(len(validation_ds)):
        X, E, Y = validation_ds_iter.next()
        Y = keras.utils.to_categorical(Y, num_classes=2)
        E = keras.utils.to_categorical(E, num_classes=10)

        pred, event = model(X)

        Ld = binary_ce(Y, pred)
        Le = categorical_ce(E, event)

        final_loss = Ld / (lmd*Le)

        # metrics
        Acc_F = fake_news_accuracy(Y, pred)
        Acc_E = event_accuracy(E, event)
        m_Ld = loss_D_metric(Ld)
        m_Le = loss_E_metric(Le)
        m_L_final = loss_final_metric(final_loss)

        validation_progbar.update(step+1, [
            # ('lr', optimizer.learning_rate),
            ('val detector loss', m_Ld),
            ('val event loss', m_Le),
            ('val final loss', m_L_final),
            ('val fake news accuracy', Acc_F),
            ('val event accuracy', Acc_E),
        ])
    
    print()

In [None]:
model.save_weights("./models/task2-approach2.h5")

In [None]:
# Loss functions 
binary_ce = keras.losses.BinaryCrossentropy()
categorical_ce = keras.losses.CategoricalCrossentropy()

test_ds_iter = test_ds.as_numpy_iterator()

loss_D_metric = keras.metrics.Mean(name="detector_loss")
loss_E_metric = keras.metrics.Mean(name="event_loss")
loss_final_metric = keras.metrics.Mean(name="final_loss")

fake_news_accuracy = keras.metrics.CategoricalAccuracy(name="fake_news_accuracy")
event_accuracy = keras.metrics.CategoricalAccuracy(name="event_accuracy")

test_progbar = Progbar(len(test_ds))

for step in range(len(test_ds)):
    X, E, Y = test_ds_iter.next()
    E = keras.utils.to_categorical(E, num_classes=10)

    pred, event = model(X)

    Ld = binary_ce(Y, pred[:, 0])
    Le = categorical_ce(E, event)

    final_loss = (lmd * Le) - Ld

    # metrics
    Acc_F = fake_news_accuracy(Y, pred[:, 0])
    Acc_E = event_accuracy(E, event)
    m_Ld = loss_D_metric(Ld)
    m_Le = loss_E_metric(Le)
    m_L_final = loss_final_metric(final_loss)

    test_progbar.update(step+1, [
        ('test detector loss', m_Ld),
        ('test event loss', m_Le),
        ('test final loss', m_L_final),
        ('test fake news accuracy', Acc_F),
        ('test event accuracy', Acc_E),
    ])