In [1]:
import json
import os
import random

import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

import tensorflow as tf

from src.data.robustness_study.bert_data_preprocessing import bert_preprocess
from src.models.bert_model import AleatoricMCDropoutBERT, create_bert_config
from src.utils.loss_functions import null_loss, bayesian_binary_crossentropy

In [31]:
# load teacher model and processed data
with open(os.path.join('../out/bert_teacher/final_e3_lr2_hd020_ad020_cd030/model/config.json'), 'r') as f:
    teacher_config = json.load(f)
    
config = create_bert_config(teacher_config['hidden_dropout_prob'],
                            teacher_config['attention_probs_dropout_prob'],
                            teacher_config['classifier_dropout'])

# initialize teacher model
teacher = AleatoricMCDropoutBERT(config=config, custom_loss_fn=bayesian_binary_crossentropy)
checkpoint_path = os.path.join('out/bert_teacher/final_e3_lr2_hd020_ad020_cd030/model', 'cp-{epoch:02d}.ckpt')
checkpoint_dir = os.path.dirname(checkpoint_path)

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
    print("Loading weights from", checkpoint_dir)
    teacher.load_weights(latest_checkpoint).expect_partial()

teacher.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=teacher_config['learning_rate']),
        loss={'classifier': bayesian_binary_crossentropy, 'log_variance': null_loss},
        metrics=[tf.keras.metrics.BinaryAccuracy(), tf.keras.metrics.Precision(), tf.keras.metrics.Recall()],
        run_eagerly=True
    )

# load data
df = pd.read_csv(os.path.join('../data/robustness_study/preprocessed/train.csv'), sep='\t', index_col=0)

# randomly choose one sequence
sample = df.sample(1)

# preprocess sampled sequence
input_ids, attention_masks, labels = bert_preprocess(sample)
sample_preprocessed = tf.data.Dataset.from_tensor_slices((
    sample['text'].values,
    {
        'input_ids': input_ids,
        'attention_mask': attention_masks
    },
    labels
)).batch(1)

# for this sequence, first compute epistemic uncertainty (simple MC dropout sampling)
# to illustrate, use 50 samples


# then compute aleatoric uncertainty (from mean prediction and log variance)

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions w

In [78]:
def illustrate_uncertainties(model, data: tf.data.Dataset, m: int = 50):
    """
    Perform Monte Carlo Dropout Transfer Sampling on a given model and dataset.

    This method generates an augmented dataset with additional uncertain labels
    created by MC Dropout and aleatoric uncertainty sampling.

    :param  model: The TensorFlow model to use for sampling.
    :param  data: Data to be used for sampling. Each element should be a tuple (features, labels).
    :param  m: Number of MC Dropout iterations.
    :param k: Number of aleatoric uncertainty samples per MC iteration.
    :return: df: Augmented dataset with original features, labels, and uncertain labels.
    """

    text, features, labels = next(iter(data))
    all_logits = []
    all_log_variances = []
    for i in range(m):
        print('sampling epistemic uncertainty: {}/{}'.format(i, m))
        rand_seed = random.randint(0, 2 ** 32 - 1)
        tf.random.set_seed(rand_seed)
        outputs = model(features, training=True)
        logits = outputs.logits
        log_variances = outputs.log_variances
        all_logits.append(logits)
        all_log_variances.append(log_variances)

    mu_t = tf.stack(all_logits, axis=0)  # shape is (m, batch_size, num_classes)
    all_log_variances = tf.stack(all_log_variances, axis=0)
    
    mean_logits = tf.reduce_mean(mu_t, axis=0)

    sigma_hat = tf.sqrt(tf.exp(all_log_variances))
    sigma_tilde = tf.reduce_mean(sigma_hat, axis=0)

    return mu_t, mean_logits, sigma_tilde

In [79]:
m = 25
k = 25
mcd_sample, mean_logit, mean_std_dev_logit = illustrate_uncertainties(teacher, sample_preprocessed, m)

sampling epistemic uncertainty: 0/25
sampling epistemic uncertainty: 1/25
sampling epistemic uncertainty: 2/25
sampling epistemic uncertainty: 3/25
sampling epistemic uncertainty: 4/25
sampling epistemic uncertainty: 5/25
sampling epistemic uncertainty: 6/25
sampling epistemic uncertainty: 7/25
sampling epistemic uncertainty: 8/25
sampling epistemic uncertainty: 9/25
sampling epistemic uncertainty: 10/25
sampling epistemic uncertainty: 11/25
sampling epistemic uncertainty: 12/25
sampling epistemic uncertainty: 13/25
sampling epistemic uncertainty: 14/25
sampling epistemic uncertainty: 15/25
sampling epistemic uncertainty: 16/25
sampling epistemic uncertainty: 17/25
sampling epistemic uncertainty: 18/25
sampling epistemic uncertainty: 19/25
sampling epistemic uncertainty: 20/25
sampling epistemic uncertainty: 21/25
sampling epistemic uncertainty: 22/25
sampling epistemic uncertainty: 23/25
sampling epistemic uncertainty: 24/25


In [80]:
mean_std_dev_logit = mean_std_dev

In [99]:
# sample k times from the aleatoric uncertainty
# k std normal samples
eps = tf.random.normal((k, 1, 1), mean=0, stddev=1)

y_aleatoric_logits = mean_logits + (mean_std_dev_logit * eps)  # y_t should be (m, batch_size, num_classes)

In [106]:
tf.sigmoid(mean_logits).numpy(), tf.sigmoid(mean_std_dev_logit).numpy()

(array([[0.567637]], dtype=float32), array([[0.74959713]], dtype=float32))

In [100]:
y_aleatoric_probs = tf.sigmoid(y_aleatoric_logits)

In [101]:
y_epistemic_probs = tf.sigmoid(mcd_samples)

In [113]:
SMALL_SIZE = 14
MEDIUM_SIZE = 14
BIGGER_SIZE = 16

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

plt.rcParams["font.family"] = "serif"

In [129]:
total_uncertainty = (y_epistemic_probs[:, 0, 0] + y_aleatoric_probs[:, 0, 0]).numpy()

# normalize into 0, 1 interval
total_uncertainty = (total_uncertainty - total_uncertainty.min()) / (total_uncertainty.max() - total_uncertainty.min())

In [136]:
plt.figure(figsize=(7, 5))
sns.histplot(y_epistemic_probs[:, 0, 0], label='Epistemic Uncertainty', alpha=0.5, binwidth=0.05)
sns.histplot(y_aleatoric_probs[:, 0, 0], label='Aleatoric Uncertainty', alpha=0.5, color='orange', binwidth=0.05)
plt.xlabel('Probability')
plt.ylabel('Density')
plt.legend()

plt.tight_layout()
plt.savefig('plots/plot_illustration_teacher_uncertainty_sample.pdf')
plt.close()