# Load libraries and dataset


In [None]:
# Import relevant packages
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from torch import nn
import torch
from tqdm import *
from sklearn.model_selection import train_test_split
from statsmodels.stats.proportion import proportion_confint
import pickle

!pip install transformers
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig

from sklearn.linear_model import LogisticRegression
import sklearn.metrics as met

In [None]:
# Save or load data
data = pd.read_csv('scriptie_data.csv')
data['underspecified'] = [1] * len(data)
control_data = pd.read_csv('scriptie_control_data.csv')
control_data['underspecified'] = [0] * len(control_data)
all_data = pd.concat([data, control_data], ignore_index=True)

# Extract hidden states and probe

The majority of the code in this section is adapted from code by Jaap Jumelet and Jelle Zuidema used for the course Interpretability & Explanability in AI

Code for loading models

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def load_model(path):
    tokenizer = AutoTokenizer.from_pretrained(path)
    model = AutoModelForSequenceClassification.from_pretrained(path)
    config = AutoConfig.from_pretrained(path)
    model.eval()
    model.to(DEVICE)

    return model, tokenizer, config

Code for creating sentence masks (not really used for our purposes)

In [None]:
def create_sen_masks(
    input_ids, tokenizer
):
    """Returns a mask for the sentence positions that have a negative/positive sentiment"""
    all_sen_masks = []
    all_labels = []

    for tokenised_sen in tqdm_notebook(input_ids.tolist()):
        sen_mask = []

        for w_idx, (token_id) in enumerate(tokenised_sen):
            if token_id == 0:
                sen_mask.append(False)
                continue

            word = tokenizer.convert_ids_to_tokens(token_id).replace("Ġ", "")
            sen_mask.append(True)

        all_sen_masks.append(torch.tensor(sen_mask))

    padded_masks = torch.nn.utils.rnn.pad_sequence(all_sen_masks, batch_first=True)

    return padded_masks, torch.tensor(all_labels)

Code for extracting hidden states

In [None]:
def extract_hidden_states(
    input_ids, attention_mask, all_sen_masks, num_layers, batch_size=128
):
    token_states = {
        layer_idx: torch.zeros(all_sen_masks.sum(), 768)
        for layer_idx in range(num_layers)
    }

    cls_states = {
        layer_idx: torch.zeros(all_sen_masks.size(0), 768)
        for layer_idx in range(num_layers)
    }

    dataset = torch.utils.data.TensorDataset(input_ids, attention_mask, all_sen_masks)
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False
    )

    num_extracted = 0
    sens_extracted = 0
    layer_idx = 0

    for batch_input_ids, batch_attention_mask, batch_sen_mask in tqdm_notebook(
        data_loader, unit="batches"
    ):
        with torch.no_grad():
            all_hidden_states = model(
                batch_input_ids,
                attention_mask=batch_attention_mask,
                output_hidden_states=True,
            ).hidden_states
        for layer_idx, hidden_states in enumerate(all_hidden_states):
            hidden_states_subset = hidden_states[batch_sen_mask]
            subset_size = hidden_states_subset.shape[0]

            token_states[layer_idx][
                num_extracted : num_extracted + subset_size
            ] = hidden_states_subset.cpu()


            cls_states[layer_idx][
                sens_extracted : sens_extracted + batch_size
            ] = hidden_states[:,-1].cpu() # Chance -1 to 0 for XLNet

        num_extracted += subset_size
        sens_extracted += batch_size

    return token_states, cls_states

In [None]:
#@title Key functions for creating the probing datasets
def create_probe_data(dataset, model, tokenizer, num_layers, type_tested):
    all_sentences, sentence_labels = dataset['text'].tolist(), dataset[type_tested].tolist()


    encoded_corpus = tokenizer(
        all_sentences,
        padding=True,
        truncation=True,
        return_tensors="pt",
    ).to(DEVICE)

    input_ids = encoded_corpus.input_ids
    attention_mask = encoded_corpus.attention_mask

    all_sen_masks, token_labels = create_sen_masks(
        input_ids,
        tokenizer
    )

    token_states, cls_states = extract_hidden_states(
        input_ids, attention_mask, all_sen_masks, num_layers
    )

    return token_states, cls_states, sentence_labels

Create data used for probing

In [None]:
models = [
    ("OPT", "facebook/opt-125m")
]

data_dict = {}

type_tested = 'underspecified'
data_used = all_data.sample(frac=1).reset_index(drop=True)
train, test = train_test_split(data_used, test_size=0.2)
train = train.reset_index(drop=True); test = test.reset_index(drop=True)

for name, path in models:
    model, tokenizer, config = load_model(path)
    tokenizer.pad_token = tokenizer.eos_token

    num_layers = config.num_hidden_layers + 1

    train_X, cls_train_X, cls_train_y = create_probe_data(
        train, model, tokenizer, num_layers, type_tested
    )
    test_X, cls_test_X, cls_test_y = create_probe_data(
        test, model, tokenizer, num_layers, type_tested
    )

    data_dict[name] = {
        "train_X": train_X,  # Dictionary mapping layer_idx -> tensor
        "test_X": test_X,
        "cls_train_X": cls_train_X,
        "cls_train_y": cls_train_y,
        "cls_test_X": cls_test_X,
        "cls_test_y": cls_test_y,
    }

Actual probing (one class)

In [None]:
from sklearn.linear_model import LogisticRegression
import sklearn.metrics as met

cls_probes = {
    model_name: {
        layer_idx: LogisticRegression(solver="liblinear", penalty="l2", max_iter=10)
        for layer_idx in range(num_layers)
    }
    for model_name in data_dict.keys()
}

cls_probe_results = {
    model_name: {
      'accuracy' : [],
      'recall'  : [],
      'precision' : [],
      'f1' : []
    } for model_name in cls_probes.keys()
}

for model_name in cls_probes.keys():
  train_mask = (train['class'] == klasje)
  train_index_list = train.index[train_mask]
  test_mask = (test['class'] == klasje)
  test_index_list = test.index[test_mask]
  for layer_idx in range(num_layers):
      # Skip neutral sentence and binarise sentiment classification
      train_X, train_y = (
          data_dict[model_name]["cls_train_X"][layer_idx][train_index_list],
          np.array(data_dict[model_name]["cls_train_y"])[train_index_list]
      )

      cls_probes[model_name][layer_idx].fit(train_X, train_y)

      test_X, test_y = (
          data_dict[model_name]["cls_test_X"][layer_idx][test_index_list],
          np.array(data_dict[model_name]["cls_test_y"])[test_index_list]
      )

      test_pred = cls_probes[model_name][layer_idx].predict(test_X)

      test_acc = met.accuracy_score(test_y, test_pred)
      test_precision = met.precision_score(test_y, test_pred, average='macro')
      test_recall = met.recall_score(test_y, test_pred, average='macro')
      test_f1 = met.f1_score(test_y, test_pred, average='macro')

      print(layer_idx)
      disp = met.ConfusionMatrixDisplay(met.confusion_matrix(test_y, test_pred),
                                        display_labels = list(set(test_y)))
      disp.plot()
      plt.show()

      cls_probe_results[model_name][klasje]['accuracy'].append(test_acc)
      cls_probe_results[model_name][klasje]['precision'].append(test_precision)
      cls_probe_results[model_name][klasje]['recall'].append(test_recall)
      cls_probe_results[model_name][klasje]['f1'].append(test_f1)

Actual probing (multiple classes)

In [None]:
cls_probes = {
    model_name: {
        layer_idx: LogisticRegression(solver="liblinear", penalty="l2", max_iter=10)
        for layer_idx in range(num_layers)
    }
    for model_name in data_dict.keys()
}

cls_probe_results = {
    model_name: {
      klasje: {
        'accuracy' : [],
        'recall'  : [],
        'precision' : [],
        'f1' : []
      }
      for klasje in [1,2,3,4]
    }
    for model_name in data_dict.keys()
}

for model_name in cls_probes.keys():
  for klasje in [1,2,3,4]:
    train_mask = (train['class'] == klasje)
    train_index_list = train.index[train_mask]
    test_mask = (test['class'] == klasje)
    test_index_list = test.index[test_mask]
    for layer_idx in range(num_layers):
        train_X, train_y = (
            data_dict[model_name]["cls_train_X"][layer_idx][train_index_list],
            np.array(data_dict[model_name]["cls_train_y"])[train_index_list]
        )

        cls_probes[model_name][layer_idx].fit(train_X, train_y)

        test_X, test_y = (
            data_dict[model_name]["cls_test_X"][layer_idx][test_index_list],
            np.array(data_dict[model_name]["cls_test_y"])[test_index_list]
        )

        test_pred = cls_probes[model_name][layer_idx].predict(test_X)

        test_acc = met.accuracy_score(test_y, test_pred)
        test_precision = met.precision_score(test_y, test_pred, average='macro')
        test_recall = met.recall_score(test_y, test_pred, average='macro')
        test_f1 = met.f1_score(test_y, test_pred, average='macro')

        print(layer_idx)
        disp = met.ConfusionMatrixDisplay(met.confusion_matrix(test_y, test_pred),
                                          display_labels = list(set(test_y)))
        disp.plot()
        plt.show()

        cls_probe_results[model_name][klasje]['accuracy'].append(test_acc)
        cls_probe_results[model_name][klasje]['precision'].append(test_precision)
        cls_probe_results[model_name][klasje]['recall'].append(test_recall)
        cls_probe_results[model_name][klasje]['f1'].append(test_f1)

# Visualization

Get 95% confidence interval (for one class)

In [None]:
def std_lower(accuracy):
  return proportion_confint(count=accuracy *len(test), nobs = len(test))[0]

cls_probe_results_lower = {
    model_name: {
      'accuracy' : list(map(std_lower, cls_probe_results[model_name]['accuracy'])),
      'recall'  : list(map(std_lower, cls_probe_results[model_name]['recall'])),
      'precision' : list(map(std_lower, cls_probe_results[model_name]['precision'])),
      'f1' : list(map(std_lower, cls_probe_results[model_name]['f1']))
    }
    for model_name in cls_probe_results.keys()
}

def std_upper(accuracy):
  return proportion_confint(count=accuracy *len(test), nobs = len(test))[1]

cls_probe_results_upper = {
    model_name: {
      'accuracy' : list(map(std_upper, cls_probe_results[model_name]['accuracy'])),
      'recall'  : list(map(std_upper, cls_probe_results[model_name]['recall'])),
      'precision' : list(map(std_upper, cls_probe_results[model_name]['precision'])),
      'f1' : list(map(std_upper, cls_probe_results[model_name]['f1']))
    }
    for model_name in cls_probe_results.keys()
}

Visualize (for one class)

In [None]:
fig, axes = plt.subplots(1, 2, sharey=True, figsize=(10,4))

for i, model_name, results in enumerate(cls_probe_results.items()):
  x = range(len(results['accuracy']))
  axes[i].plot(results['accuracy'], 'o-', lw=2, markersize=3, label=model_name + " accuracy")
  axes[i].fill_between(x, cls_probe_results_upper[model_name]['accuracy'], cls_probe_results_lower[model_name]['accuracy'], interpolate=True, alpha=0.5)
  axes[i].plot(results['recall'], 'o-', lw=2, markersize=3, label=model_name + " recall")
  axes[i].plot(results['precision'], 'o-', lw=2, markersize=3, label=model_name + " precision")
  axes[i].plot(results['f1'], 'o-', lw=2, markersize=3, label=model_name + " F1")
  axes[i].set_title(model_name + " probe accuracy")
  axes[i].legend()
  axes[i].set_ylim(0,1.0)
  axes[i].axhline(y = 0.5, color = 'grey', linestyle = '--')
  axes[i].set_xlabel(model_name + " layer nr.")


fig.tight_layout()
fig.savefig("probing_class_distinction.png")
plt.show()

Get 95% confidence interval (for multiple classes)

In [None]:
def std_lower(accuracy):
  return proportion_confint(count=accuracy * len(test), nobs = len(test))[0]

cls_probe_results_lower = {
    model_name: {
        klasje: {
          'accuracy' : list(map(std_lower, cls_probe_results[model_name][klasje]['accuracy'])),
          'recall'  : list(map(std_lower, cls_probe_results[model_name][klasje]['recall'])),
          'precision' : list(map(std_lower, cls_probe_results[model_name][klasje]['precision'])),
          'f1' : list(map(std_lower, cls_probe_results[model_name][klasje]['f1']))
        } for klasje in [1,2,3,4]]
    }
    for model_name in cls_probe_results.keys()
}

def std_upper(accuracy):
  return proportion_confint(count=accuracy *len(XLNet_test), nobs = len(XLNet_test))[1]

cls_probe_results_upper = {
    model_name: {
        klasje: {
          'accuracy' : list(map(std_upper, cls_probe_results[model_name][klasje]['accuracy'])),
          'recall'  : list(map(std_upper, cls_probe_results[model_name][klasje]['recall'])),
          'precision' : list(map(std_upper, cls_probe_results[model_name][klasje]['precision'])),
          'f1' : list(map(std_upper, cls_probe_results[model_name][klasje]['f1']))
        } for klasje in [1,2,3,4]
    }
    for model_name in cls_probe_results.keys()
}

Visualize (for multiple classes)

In [None]:
for model_name, results in cls_probe_results.items():
  fig, axes = plt.subplots(2, 2, sharey=True, figsize=(10,4))
  for klasje, plotje in zip([1,2,3,4], [(0,0), (0,1), (1,0), (1,1)]):
    x = range(len(results[klasje]['accuracy']))
    axes[plotje].plot(results[klasje]['accuracy'], 'o-', lw=2, markersize=3, label=model_name + " accuracy")
    axes[plotje].fill_between(x, cls_probe_results_upper[model_name][klasje]['accuracy'], cls_probe_results_lower[model_name][klasje]['accuracy'], interpolate=True, alpha=0.5)
    axes[plotje].plot(results[klasje]['recall'], 'o-', lw=2, markersize=3, label=model_name + " recall")
    axes[plotje].plot(results[klasje]['precision'], 'o-', lw=2, markersize=3, label=model_name + " precision")
    axes[plotje].plot(results[klasje]['f1'], 'o-', lw=2, markersize=3, label=model_name + " F1")
    axes[plotje].set_title("Probe accuracy for " + klasje + " expressions")
    axes[plotje].axhline(y = 0.5, color = 'grey', linestyle = '--')
    axes[plotje].set_xlabel(model_name + " layer nr.")

  axes[(0,0)].legend()
  axes[(0,0)].set_ylim(0,1.0)

  fig.tight_layout()
  fig.savefig("probing_" + str(klasje) + "_" + model_name + ".png")

  plt.show()