# Investigating how Transformers learn propositional logic

Author: Anna Langedijk

With credits to: Jaap Jumelet, Jelle Zuidema

Part of the Logic & Deep Learning workshop for the course Interpretability & Explainability in AI 2023.

## Background

There are several ways researchers have combined the ultimately symbolic domain of Logic with that of neural networks. Neurosymbolic models are one research direction. For instance, in the Logic Tensor Networks paper you have seen one way of incorporating logic.  Another example is [NeuroSAT](https://arxiv.org/abs/1802.03685): an end-to-end SAT solver that, instead of adding logic explicitly, uses a graph neural network to represent its logical inputs. Input formulas are always in conjunctive normal form (CNF): relevant clauses and literals are connected by edges in the input graph to the network.

However, many models used today are trained end-to-end on flat input strings. They must rely on their training data to learn any type of logic and reasoning. These models have no explicit reasoning abilities, and yet they display behaviour that requires them to perform some form of reasoning, such as Question Answering and Natural Language Understanding. The question arises whether it is even possible to learn the rules of logic from data alone.

To test the reasoning abilities of certain Transformer models, or Transformers/neural models in general, there exist several benchmarks. In the domain of natural language, there exist several Natural Language Inference ([NLI](https://aclanthology.org/S14-2001/)) datasets. These datasets usually focus on classification tasks, for instance: Given two sentences, do these sentences entail one another or not? In the formal domain, similar datasets exist (e.g. [this paper](https://www.deepmind.com/publications/can-neural-networks-understand-logical-entailment)), where the task is still to predict (non)entailment, but inputs and outputs are instead given in a formal logical form.

A recent paper [Teaching Temporal Logic to Neural Networks](https://arxiv.org/abs/2003.04218) by Hahn et al. instead explores a different objective: that of generating correct solutions given an input formula in some logic. They do this using a generic encoder-decoder setup. Instead of a specific binary output, the model is trained to generate "explanations" (in some sense) for the formula, namely a possible world/trace for the input. Their main experiments focus on generating traces for Temporal Logic formulas, but in a second experiment the authors claim that these encoder-decoder models can also learn the semantics of Propositional Logic. 


In this notebook we will focus on the models trained on propositional logic trained on input-output pairs generated by a symbolic solver. 
An example of such an input-output pair is:


In: `(a xor b) & c`

Out: `a=False, b=True, c=True`


A model trained on 800000 of such pairs can achieve a high (93%) accuracy in predicting logically valid assignments, not only emulating the outputs provided by the symbolic solver, but also generating alternative valid possible worlds/assignments. 

But, while the task itself is interpretable (it is unambiguous, and there are no hidden assumptions as is the case with many tasks based on natural data), it is still a challenge to understand _how_ the model solves this task, and if it truly has knowledge about the semantics of logic or if it is exploiting irrelevant patterns in the data.

In this notebook, you will probe the hidden states of the encoder of this model for propositional logical 'truth', to see what kind of information is encoded when the encoder processes a sentence, and whether that information is encoded the same way for different types of tokens through its layers.


#### 🧠 ToThink: dataset creation
- What are other ways this dataset differs from classification tasks? Can you see downsides to this dataset?

## Setup
First, import some prerequisites.

In [None]:
import os
import random
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.utils.data import DataLoader
import sklearn
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

GPU shouldn't be necessary, although it will be faster if you want to (re)calculate the hidden states.

In [None]:
COLAB = False
INSTALL = False
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
if COLAB:
    # Mount to drive so we can access our own files
    from google.colab import drive
    drive.mount('/content/drive')

In [None]:
if INSTALL:
    !pip install --user torch sklearn matplotlib seaborn pandas

Download and import the code that can load the dataset and model architecture. 
The model implementation is available here: [github.com/annaproxy/transformer_logic_compact](https://github.com/annaproxy/transformer_logic_compact).

In [None]:
if COLAB:
    # Point this to the correct location
    %cd /content/drive/My Drive/IEAI_Notebooks
    if os.path.exists("transformer_logic_compact"):
        %cd transformer_logic_compact
        !git pull https://github.com/annaproxy/transformer_logic_compact
        %cd ..
    else:
        !git clone https://github.com/annaproxy/transformer_logic_compact

In [None]:
import sys
sys.path.append("transformer_logic_compact")

In [None]:
from interfaces import ModelInterface, TransformerEncDecInterface
from models import TransformerEncDecModel
from layers.transformer import Transformer
from helpers.collater import VarLengthCollate
from logicdatasets import LogicDataSet
from helpers import move_to_device

Download the pre-provided data (validation set, information about its subtrees, pretrained models) and put it in the data directory:
https://drive.google.com/drive/folders/1TE0loQ76bcBoYKxXDjUsyfCWiPsBHSv9 

In [None]:
DATA_FOLDER = 'transformer_logic_compact/data'

In [None]:
if COLAB:
    !gdown --folder 1TE0loQ76bcBoYKxXDjUsyfCWiPsBHSv9 -O $DATA_FOLDER

### Load dataset

In [None]:
# The files val.src and val.tgt will be loaded
dataset = LogicDataSet(data_path=f'{DATA_FOLDER}/val', vocab_path='transformer_logic_compact/vocabulary')

### First inspection of the data
The input data is provided to the model in [Polish prefix notation](https://en.wikipedia.org/wiki/Polish_notation). Instead of writing operators with infix, e.g. `a | b`, the operator is provided as a prefix: `| a b`. This avoids the use of parentheses, which would only add to the length of the input of the model. 

The following input means ``(b | d) xor (c)``:

In [None]:
print("String input:", dataset.input_ids_to_text(dataset[22623]['in']))

The outputs are provided as a simple array of maximum length 10, for example: `a 0 b 1 c 1 d 1 e 0` corresponds to the model {a=False, b=True, c=True,d=True, e=False}. The outputs are always in alphabetical order.

#### 📝 Pen and Paper 1
Think of a possible valid assignment to this sentence. Then, use the 'out' key of the dataset object to print the gold label. Did you generate the same output with your internal logic model, or are there multiple possible outputs?

In [None]:
print("String output:", "Todo: dataset output goes here.")

### Load the model
The Transformers that we use are relatively small: the encoder and decoder have 6 layers. The state sizes of the encoder and decoder are 128 and 64, respectively.

In [None]:
# For now, we work with one model seed
model_name = 'model_seed2'

In [None]:
def load_model_from_name(model_name):
    # Initialize model architecture
    model = TransformerEncDecModel(
                len(dataset.in_vocabulary),
                len(dataset.out_vocabulary),
                state_size=128,
                state_size_decoder=64,
                nhead=4,
                num_encoder_layers=6,
                num_decoder_layers=6,
                ff_multiplier=4,
                ff_multiplier_decoder=8,
                transformer=Transformer,
                tied_embedding=True,
                scale_mode="opennmt"
            )
    # Load from state dict
    model.load_state_dict(torch.load(f"{DATA_FOLDER}/{model_name}.pth", map_location=torch.device(DEVICE)))
    model.eval()
    model = model.to(DEVICE)

    # Model interface helps with some overhead such as adding <BOS> and <EOS>.
    # You should call this if you want to do a forward pass.
    model_interface = TransformerEncDecInterface(model)
    return model_interface
model_interface = load_model_from_name(model_name)

In [None]:
# Uncomment to inspect the model architecture in more detail
#print(model)

### Iterate over data to get hidden states (or load them from a file)
I've provided a portion (~26%) of the hidden states in a separate drive: https://drive.google.com/drive/folders/1oSmpcYy_zi-bNA81ODf001-M80B44FMc 
This should be enough to run your probing experiment. 



In [None]:
dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=256,
            collate_fn=VarLengthCollate(batch_dim=1),
        )

In [None]:
def forward_all(model_interface, dataloader, max_it=100, calculate_outputs = True,
               calculate_hidden=True):
    """Forwards max_it batches through the model_interface, collecting its hidden states."""
    stacked_hidden = {}
    outputs = []
    # The dataset has size 100_000 entries, when not setting MAX_IT, these tensors will take up about 7G (!)
    # It may take a while to run (3-6 minutes)
    with torch.no_grad():
        for it, d in enumerate(tqdm(dataloader, total=max_it)):
            d = move_to_device(d, DEVICE)
            result, hidden = model_interface(d, collect_hidden=True)

            current_batch_size = d['in'].shape[-1]
            if calculate_outputs:
                digits = model_interface.decode_outputs(result)
                output_strs = [dataset.sample_to_text(digits, i) for i in range(current_batch_size)]
                outputs.extend(output_strs)
            if calculate_hidden:
                for layer in hidden:
                    if layer not in stacked_hidden:
                        stacked_hidden[layer] = hidden[layer]
                    else:
                        # This is a bit slow, .cat creates a new tensor
                        # Also fun: batch size must be large enough to include a sentence of the max length every time
                        stacked_hidden[layer] = torch.cat([stacked_hidden[layer], hidden[layer]], dim =0)
            if it > max_it:
                break
    return stacked_hidden, outputs

In [None]:
def load_hidden_and_outputs(model_name, model_interface, load_from_file=True, calculate_outputs = True, MAX_IT=100,
                           calculate_hidden=True):
    stacked_hidden = None
    if not load_from_file:
        stacked_hidden, outputs = forward_all(model_interface, dataloader, max_it=MAX_IT, calculate_outputs=calculate_outputs)
        torch.save(stacked_hidden, f'{DATA_FOLDER}/{model_name}_hiddenstates.pth')
        if calculate_outputs:
            with open(f'{DATA_FOLDER}/outputs_{model_name}.txt', 'w') as f:
                for line in outputs:
                    f.write(line+'\n')

    else:
        if calculate_hidden:
            file_path = f'{DATA_FOLDER}/{model_name}_hiddenstates.pth'
            if not os.path.exists(file_path):
                # Ugly, but don't want to download the whole folder and blast you with 10GB of hidden states
                # gdown doesnt provide an option to pass file_path within a folder
                if model_name == 'model_seed2':
                    !gdown 1Ln-mIQ1YVQpHhePqRBXdgyv9w-tzldy7 -O $file_path
                elif model_name == 'model_without_not_xor_seed2':
                    !gdown 1l-pmncpeJ0jNdNhJHLoMR1wj_IHJXbGH -O $file_path
                elif model_name == 'model_without_and_xor_seed1':
                    !gdown 1dQjHw6eZEAgsIV2wB1-BQHdLeMOIwAoj -O $file_path
                else:
                    print(f"No file found for {model_name}")
            stacked_hidden = torch.load(file_path)
        with open(f'{DATA_FOLDER}/outputs_{model_name}.txt') as f:
            outputs = [l.strip() for l in f.readlines()]
    return stacked_hidden, outputs

# Set load_from_file to false if you want to calculate new states
# Downloading may take a while (~3gb)
stacked_hidden, outputs = load_hidden_and_outputs('model_seed2', model_interface, load_from_file=True)

In [None]:
print("Stacked hidden keys",stacked_hidden.keys())
print("Stacked hidden shape for layer 3",stacked_hidden[3].shape) # N x max_sentence_length x hidden_state_size

### Collect information about the dataset and model outputs

For the 100000 sentences in this validation set, I have calculated some additional information:
- `possible_worlds.txt` contains all the (partial) possible worlds in which a sentence can be true. This allows us to calculate whether any model prediction was correct, not just when the model output exactly matches the gold label. 
- `subformulas_valuation.txt` contains information about whether different tokens in the logical formula can be true/can be false. For instance, if the variable `a` is set to `1` in every possible world, it must always be true, and thus has the label `1`. If it is always false, it will have the label `0`. If it depends on the truth value of the rest of sentence (as is usually the case), it is set to `2`. You can find a more concrete example below.

Let's load this information into a pandas dataframe.

In [None]:
def calculate_df(model_outputs, cutoff):
    with open(f'{DATA_FOLDER}/possible_worlds.txt') as f:
        possible_worlds = [l.strip() for l in f.readlines()]

    with open(f'{DATA_FOLDER}/subformulas_valuation.txt') as f:
        valuations = [l.strip() for l in f.readlines()]

    df = pd.DataFrame({
        'sentence_idxs':range(CUTOFF),
        'inputs':[s.strip() for s in dataset.in_lines[:cutoff]],
        'gold_outputs':[s.strip() for s in dataset.out_lines[:cutoff]],
        'model_outputs':model_outputs[:cutoff],
        'possible_worlds':possible_worlds[:cutoff],
        'subformula_valuations':valuations[:cutoff]  
    })

    df.possible_worlds = df.possible_worlds.apply(lambda x:x.split(','))
    # An exact match is when the model output is exactly equal to the gold outputs
    df['exact_match'] = df['model_outputs'] == df['gold_outputs']
    # The model can be correct without exactly matching the gold outputs
    df['semantically_correct'] = df.apply(lambda row:row['model_outputs'].replace(' ','')in row['possible_worlds'] ,axis=1)

    # Handy statistic so we can look at short formulas as examples
    df['input_size'] = df['inputs'].apply(lambda x:len(x.split()))
    return df

# How many hidden states we have saved
CUTOFF = len(stacked_hidden[0])
df = calculate_df(outputs, CUTOFF)

We can verify that the model has a high accuracy on the task, even when it does not exactly match the generated 'gold' output.

In [None]:
print(df['exact_match'].mean())
print(df['semantically_correct'].mean())

Here are some examples of correct outputs that are not an exact match of the gold label, but where it still outputs a valid possible world.

In [None]:
df[~df.exact_match & df.semantically_correct].sort_values(by='input_size')[:5]

The two columns that we added contain information about all possible outputs for the input sentence. 
The `possible_worlds` column contains all possible (possibly partial) worlds for this sentence.
The `subformula_valuations` column contains information about which subtrees can or cannot be true in the sentence. The string has the same length as the input sentence. 

When a subformula must always be true, it is mapped to `1`. When a subformula must always be false, it is mapped to `0`. Otherwise, it is mapped to `2`. Most subformulas have the value `2`, their truth values are contingent on the truth values of other subformulas.


Take for instance the sentence:
`| d & a d` (standard notation: `d | (a & d)`). 

In [None]:
df.iloc[15113]

The partial output here is `d 1`, but the model would have also been correct if it had outputted `a 0 d 1`, since that is also a possible world.

The first subtree of the sentence is the root, so this will always be `1`
For instance, `d` must always be true, which is why the second and last character in `subformula_valuations` are also `1`.
Since the value of `a` can be either true or false (verify this by looking at the possible worlds), its `subformula_valuation` is `2`.
This is also the case for `&`, since we can choose to make either `d` true, or `& a d`.


#### 📝 Pen and Paper 2
Take the following sentence. It has two possible worlds: calculate the possible worlds.
Then, calculate what the `subformula_valuations` column looks like for this sentence. Which nodes in the tree should always be true, and which should always be false? Check your answers by printing the entire row. 

In [None]:
df.iloc[15612][['inputs']]

Since the model is forced to output just one world, it would be interesting to see if there is some information in the model about all these subtrees: has the model internalized when a subtree must be true or false, helping it along to a correct answer?

## Probing for Truth

### Probe task construction

The `ProbeTask` class will help you with extracting the correct hidden states and labels. You don't have to fully understand the code. There is an example of how to use this code below.

In [None]:
def set_seed(seed):
    if seed == -1:
        seed = random.randint(0, 1000)
    # Pandas also uses np random state by default
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # if you are using GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


class ProbeDataSet(torch.utils.data.Dataset):
    def __init__(self, hidden, ys):
        self.hidden = hidden
        self.ys = ys

    def __len__(self):
        return len(self.ys)

    def __getitem__(self, idx):
        return self.hidden[idx], self.ys[idx]

    
class ProbeTask:
    def __init__(self, hidden, df):
        """
        Initialize a 'ProbeTask' object.
        hidden: a dictionary of tensors of shape N x max_input_sentence_length x hidden_size.
                The dictionary key is the layer in the encoder.
        df: A pandas dataframe of N (or less) rows with relevant information about. 
            This will be used for sampling and for looking up indices of states in `hidden`.
            Thus, the values in column `sentence_idxs` should correspond to the indices in the `hidden` obejct.
        """
        self.hidden = hidden
        self.df = df.copy()
        
    def create_datapoints_from_sentence(self, layer_no, sentence_idx, word_idxs):
        """
        Transform a sentence into 1 or more datapoints. 
        layer_no: int
        sentence_idx: int
        word_idxs: [int] 0-based array of word indices. The BOS tag will be taken care of by self.get_hidden.
        """
        state, flat_idx = self.get_hidden(layer_no, [(sentence_idx, np.array(word_idxs))])
        return state
        
    def create_dataset(
        self,
        layer_no,
        node_types=["a","b","c","d","e"],
        output_types=[0,1,2],
        sents=2000,
        max_tokens_per_sent=3,
    ):
        """
        Creates a dataset given the subformula_valuations column.
        The dataset is balanced w.r.t. both the node types and the possible 'truth values' 0, 1 and 2.
        
        layer_no: Which layer to sample the hidden states from
        node_types: Which node types to sample. For instance: all variables. It can also be a singleton list, eg ["xor"].
        sents: amount of sentences to sample per node_type per valuation.
        max_tokens_per_sent: how many nodes we are allowed to take from the same sentence. Allowing this makes your dataset bigger.
        
        Returns: a probedataset of size max (sents * max_tokens_per_sent * len(node_types) * 3)
        In practice, it is shorter, since not all sentences contain max_tokens_per_sent relevant tokens.
        """
        ys = []
        xs = []
        idxs_ = []
        for node_type in node_types:
            # Iterate over all possible 
            for val_id in output_types:
                # Be sure to balance node types when a list is passed
                idxs = self.get_all_idx(
                    sents, max_tokens_per_sent, [node_type], str(val_id)
                )
                states, flat_idxs = self.get_hidden(layer_no, idxs)
                xs.append(states)
                ys.extend([val_id] * len(states))
                idxs_.extend(flat_idxs)
        xs = torch.cat(xs)
        ys = torch.tensor(ys, dtype=torch.int64)
        return ProbeDataSet(xs, ys), idxs_
    
    def get_hidden(self, layer_no, idx_list):
        """
        idx_list is a list of (int, [int]) indicating the sentence position and the word position array
        This function takes care of the BOS token by adding + 1 to the provided idxs.
        """
        states = self.hidden[layer_no]
        result = []
        idxs_ = []
        for idx, word_idxs in idx_list:
            # +1 for BOS tag
            result.append(states[idx][word_idxs + 1])
            idxs_.extend([(idx, word_idx) for word_idx in word_idxs])
        return torch.cat(result, dim=0), idxs_
    
    
    def get_all_idx(
        self,
        sent_amt, # How many sentences to sample
        per_sent_max_amt, # How many nodes are allowed in the same sentence
        token_identities,  # abcde xor ! & , etc...
        value: str,  # '0', '1' or '2'
    ):
        """This method uses the subformula_truths column to fetch relevant pairs of (sentence_idx, word_idx).
        """
        # Filter rows where a token of one of the provided token_identities has the desired value
        # E.g. get only those indices where there exists some `a` with value `1`.
        # Since the root node token *always* has value 1, we start at the second token.
        def get_idxs(row, start=1):
            idxs = []
            for i, (token, truth) in enumerate(
                zip(row["inputs"].split()[start:], row["subformula_valuations"][start:])
            ):
                if truth == value and token in token_identities:
                    idxs.append(i + start)
            return idxs

        self.df["_condition"] = self.df.apply(get_idxs, axis=1)
        sents = self.df[self.df["_condition"].apply(len) > 0]
        
        # Sample from all possible sentences
        sents = sents.sample(n=min(sent_amt, len(sents)), replace=False)

        result = []
        # For each sentence, sample from all possible positions in the sentence
        for i, row in sents.iterrows():
            sent_idx = row["sentence_idxs"]
            word_idxs = row["_condition"]
            word_idxs = np.random.choice(
                word_idxs, size=min(per_sent_max_amt, len(word_idxs)), replace=False
            )
            result.append((sent_idx, word_idxs))
        # Remove column just in case
        self.df = self.df.drop("_condition", axis=1)
        return result
def get_stats(array):
    """Tool for printing the label distribution"""
    return np.unique(array, return_counts=True)[1] / len(array)

#### Example of a probetask object

The probetask `create_dataset` method samples from the subset of hidden states that are above tokens of a certain type. You can specify which types of nodes to collect with the `node_types` parameter.

In [None]:
# Create a probetask object to sample from
probetask_test = ProbeTask(stacked_hidden, df)

In [None]:
# Sample hidden states from layer 3 for 100 sentences per node_type per output_type
# This may take a few seconds
probe_dataset_test, idxs = probetask_test.create_dataset(layer_no=3, sents = 500, node_types=["a","b","c","d","e"],
                                                        output_types=[0,1,2], max_tokens_per_sent=1)

In [None]:
# Check the output and verify it using the dataframe
hidden_state, prediction_label = probe_dataset_test.hidden[0], probe_dataset_test.ys[0]
sentence_idx, word_idx = idxs[0]
print("dataset size", len(probe_dataset_test))
print("Hidden state size", probe_dataset_test.hidden.shape)
print("sentence idx", sentence_idx, "word idx", word_idx)
print("Token at position?", df.iloc[sentence_idx].inputs.split()[word_idx])
print("Valuation at position?", df.iloc[sentence_idx].subformula_valuations[word_idx])
print("Prediction label?", prediction_label.item())

In [None]:
print("Label distribution")
print(get_stats(probe_dataset_test.ys))

### Probe training
Now that we know how to collect data for the probe, let's set up code to construct, train and evaluate a probing model.

Instead of using a LogisticRegression from `sklearn`, I have provided a simple pytorch module called `Probe`, which is essentially just a wrapper around `nn.Linear`.


#### ☑️ ToDo 1: Implement the training loop of the probe

In [None]:
class Probe(torch.nn.Module):
    """
    A very simple linear model.
    """
    def __init__(self, hidden_dim=128, classes=3):
        super().__init__()
        self.linear = torch.nn.Linear(hidden_dim, classes)

    def forward(self, hidden_state):
        return self.linear(hidden_state)
    
def train_probe(probe, train_dataloader, epochs=10, lr=0.005, debug=False):
    probe.train()
    optim = torch.optim.Adam(probe.parameters(), lr=lr)
    loss_function = torch.nn.CrossEntropyLoss()
    for epoch in range(epochs):
        # TODO: implement a simple training loop given the dataloader, optimizer, probe and loss function
        raise NotImplementedError()
    probe.eval()
    return probe
    
def eval_probe(probe, val_dataloader):
    probe.eval()
    ys_pred = np.zeros(len(val_dataloader.dataset))
    ys_true = np.zeros(len(val_dataloader.dataset))
    
    i = 0
    for inputs, labels in val_dataloader:
        output = probe(inputs)
        preds = torch.argmax(output, dim=-1)
        ys_pred[i : i + len(labels)] = preds.detach().int().numpy()
        ys_true[i : i + len(labels)] = labels.int().numpy()

        i += len(labels)
    return {
            "pred": ys_pred,
            "true": ys_true,
            "accuracy": accuracy_score(ys_true, ys_pred),
            "f1_macro": f1_score(ys_true, ys_pred, average="macro"),
            "confusion_matrix": confusion_matrix(ys_true, ys_pred, normalize='true')
    }

def random_split(dataset, train_size=.8):
    """Splits a ProbeDataSet into two datasets"""
    train_size = int(len(dataset) * train_size)
    test_size = len(dataset) - train_size
    return torch.utils.data.random_split(dataset, [train_size, test_size])

def split_dataset_into_dataloaders(probe_dataset, train_size=.8):
    """Splits a ProbeDataSet into two dataloaders"""
    train_dataset, val_dataset = random_split(probe_dataset, train_size=train_size)
    train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
    val_dataloader = DataLoader(val_dataset , batch_size=256, shuffle=False)
    return train_dataloader, val_dataloader

### Training variable probes per layer

We must first make some decisions about what data to use.:
1. To remove noise, it is probably a good idea to only include datapoints where the model was correct.

    * To think: what would happen if you included incorrect datapoints to the probe trainer? Could this also have benefits?

2. As a first experiment, we can try to probe nodes that are variables. If a variable **must** be true, the model **must also output** that it is true: this information should therefore be kept track of for all sentences where the model is correct.

3. Finally, we can choose which layer to probe. 

For evaluation of the probes: note that this is a 3-way classification task, so random accuracy is 33%. Some classes may have a higher accuracy.

In [None]:
# Only include that part of the df where the model was correct
p = ProbeTask(stacked_hidden, df[df.semantically_correct])

In [None]:
# Example code for training a probe:
# Let's start with at least 1000 examples per node_type per truth type.
# Print the distribution of labels to make sure the data is balanced
layer_no = 4
# Set seed for sampling data / training probe
set_seed(1)
probe_dataset, idxs = p.create_dataset(layer_no=layer_no, sents = 1000, node_types=["a","b","c","d","e"])
print("Label distribution", get_stats(probe_dataset.ys))
print("Total dataset size", len(probe_dataset))
train_dataloader, val_dataloader = split_dataset_into_dataloaders(probe_dataset)
probe = train_probe(Probe(), train_dataloader)
result = eval_probe(probe, val_dataloader)
print(result)

#### ☑️ ToDo 2 : probe variable nodes for all layers. 

Which layer gives the best performance and why do you think this is?

In [None]:
# ToDo: train and evaluate variable-node-probes for each of the six layers

#### ✍️ ToSubmit 1: Visualize your probing results per layer.

Give a brief description of the pattern you see, and why you might be seeing this.

(It is enough to do this in the caption of the figure)

#### 🧠 ToThink: 
- Given your results per layer, what are the limitations of probing for these `subformula_valuations` respective to the **entire** sentence?

### Probing other types of nodes

We have now only extracted hidden states of nodes at the positions of variables.

However, other positions have the same property of always true/always false/contingent, as we have seen in exercise Pen & Paper 2 above.
These nodes of other types can thus also be probed in the same way.


#### ☑️  ToDo 3: Run probes for non-variable/operator nodes.

(I recommend training on **exactly one type of node** at the time, see next section. Be sure to increase the `sents` parameter to get a bigger dataset for just a single node type, but also pay attention to the label distribution and keep it balanced!)

- Is the performance better or worse in general? 

- Is the pattern the same throughout the layers as with the variable nodes?

#### ✍️ ToSubmit 2: Add probing for your operator-node probe to your visualization for ToSubmit 1.

Describe the pattern for these operator probes, and compare these results to your results in ToSubmit 1. Is there a difference, and what might be the reason for this difference?

In [None]:
layer_no = 2
set_seed(1)
# 3000 sents gives a mostly balanced label distribution
probe_dataset, idxs = p.create_dataset(layer_no=layer_no, sents = 3000, node_types=["&"])
print("Label distribution",get_stats(probe_dataset.ys))
print("Total dataset size", len(probe_dataset))

In [None]:
# Your code here

### Do probes generalize?

During our previous experiment we have assumed that data is encoded similarly for every node type in our dataset (every variable). 
However, it could be that our model has learned solutions that do not treat different variable types the same.


#### ☑️ ToDo 4: train probes on one subset of node_types, then evaluate them on another subset as well as the original test set. 
E.g. train only on variable `a` (or `a` and `b`) and evaluate on all other variables. How does performance compare when testing on `a` versus other variables? 
Try multiple different subsets. Which nodes generalize to which other nodes?

For this experiment, you can stick with a single layer (the best one from the previous exercise), or try multiple layers.

#### ✍️ ToSubmit 3: Include your findings in the report.
Since you need to try multiple different subsets of train/test nodes, it is recommended to visualize this in a heatmap (similar to the cross-layer probing experiments from week 1).

Briefly reflect on these results. Which probes generalize, which ones do not? Might this tell us anything about the way this Transformer is representing propositional logic?

In [None]:
# Example code to generate a dataset that is of another type of node:
# probe_dataset_b_only, idxs = p.create_dataset(layer_no=5, sents = 5000, node_types=["b"])
# dataloader_b , _ = split_dataset_into_dataloaders(probe_dataset_b_only, 1.0) # Do not split
# Now, you can pass dataloader_b to eval_probe with a probe that was trained on different data

In [None]:
# Your code here

## Logical equivalences through the layers

If the model truly understood the semantics of propositional logic, it should "know about" logical equivalences. For instance, '! xor' is equal in meaning to '<->'.

We may be able to see this in the behaviour of the model, by comparing its outputs for sentences that are logically equivalent. Two sentences are logically equivalent if they have exactly the same possible worlds.

In [None]:
# Convert the full (non partial) possible worlds to a string so pandas can easily find all unique values
df['fullworlds'] = df['possible_worlds'].apply(lambda x:' '.join(z for z in x if len(z)==max(map(len,x))))

In [None]:
# What are the maximum occuring equivalent sentences?
df['fullworlds'].value_counts()[:10]

In [None]:
# The model outputs (but also the gold outputs) are not always the same for equivalent sentences
df[df['fullworlds']=='b0c0 b0c1 b1c1']

We see that the model does not always predict the same output for a logically equivalent sentence. Similarly, since there are only a relatively small number of possible outputs, the model will predict the same output for logically inequivalent sentences, as well. 

Maybe we can gain more insight into whether the model knows about logical equivalence by looking at its hidden representations.

### Creating our own small dataset

The validation set contains mostly very long sentences, and not many sentences that are logically equivalent. 

Instead of sampling from the validation set, we can create our own small dataset to further inspect its behaviour on hand-constructed (non)equivalent sentences.

In [None]:
# Some pairs of logical equivalences for the input data
# Student ToDo: add more inputs
mini_data_inputs = ["! <-> a b", "xor a b"]
# Corresponding correct outputs for completeness. Will not affect outputs
mini_data_outputs = ["a 1 b 0", "a 1 b 0"]
mini_dataset = LogicDataSet(
    vocab_path=f"transformer_logic_compact/vocabulary",
    data_inputs=(mini_data_inputs, mini_data_outputs)
)

In [None]:
# Wrap the mini dataset in a dataloader and perform a forward pass
mini_dataloader = DataLoader(mini_dataset, collate_fn=VarLengthCollate(batch_dim=1), batch_size=256)
mini_states, mini_outputs = forward_all(model_interface, mini_dataloader, max_it=1)

In [None]:
mini_outputs

In [None]:
# The root node is at position 1.
position_to_compare = 1

fig, axs = plt.subplots(ncols=3,nrows=2,figsize=(16,6))
axs = [ax for ax_array in axs for ax in ax_array]
for ax, layer_no in zip(axs,range(6)):
    cosines = sklearn.metrics.pairwise.cosine_similarity(mini_states[layer_no][:,position_to_compare,:])
    sns.heatmap(cosines, annot=True, ax=ax)
    ax.set_xticks(np.arange(len(mini_data_inputs))+0.5 ,mini_data_inputs,rotation=90)
    ax.set_yticks(np.arange(len(mini_data_inputs))+0.5 ,mini_data_inputs,rotation=0)
    ax.set_title(f'Layer {layer_no}')
plt.tight_layout()
plt.show()

#### ☑️ ToDo 5: Add at least 4 more inputs to the mini-dataset above and compare the similarity matrix through the layers

Be sure to include equivalent inputs as well as non-equivalent inputs. 

It might be interesting to also include inputs that are 'almost equivalent'!


#### 🧠 ToThink:
I set `position_to_compare` to 1, meaning we compare the root node. Which side effects could this have? Can we see these effects through the layers?
Could we also compare different positions?

#### ✍️ ToSubmit 4: Include the plot above with (a subset of your) (non)-equivalent sentences.

Reflect on at least these two questions:
- What does similarity mean through the layers?
- Which (non)equivalent sentences are more similar? Briefly describe at least two patterns you see.

Note: Feel free to change the plotting code, or include multiple similarity plots, if this makes things clearer.

## ✅ <font color=green> End of notebook! </font>

You have reached the end of the notebook - there are no more official ToDo's, but I have provided one more section below for interested students.

I have spent a long time with these models during my thesis, so if you have (practical, existential, philosophical) questions about any of these experiments, you can always ask me (Anna) questions about it (for instance by [emailing me](mailto:annalangedijk@gmail.com)).



## Optional: Models trained on systematically different data

I have also provided a model that is trained without the pattern `! xor`. Any `! xor` is replaced with `<->` during training, so the amount of training data stays (almost) equal.

The model performs with equal performance on most sentences, including sentences with `xor` and ``<->`, but when `! xor` is present, the performance drops. The model seems to not be able to combine `!` with `xor`.

Additionally, I have provided another model. This one is training without another pattern, namely `& xor`. Not only does `& xor` never appear (`xor` being the left subtree of the binary `&` node), but I have also filtered any occurences where `xor` is the right subtree of the `&` node, meaning the model has never seen a node `&` with a child `xor`, similar to the unary operator `!` in the previous model not having seen `xor` as a child.

The model performs with almost equal performance to the original model, including on sentences containing the left-out pattern.

How can we understand what this model is missing?


In [None]:
# Apply the models to exactly the same validation set 
_, model_outputs_without_not_xor =  \
    load_hidden_and_outputs('model_without_not_xor_seed2', None, load_from_file=True, calculate_outputs = True, MAX_IT=100, 
                           calculate_hidden=False)

_, model_outputs_without_and_xor =  \
    load_hidden_and_outputs('model_without_and_xor_seed1', None, load_from_file=True, calculate_outputs = True, MAX_IT=100,
                           calculate_hidden=False)


In [None]:
df_without_not_xor = calculate_df(model_outputs_without_not_xor, CUTOFF)
df_without_and_xor = calculate_df(model_outputs_without_and_xor, CUTOFF)

In [None]:
df_without_and_xor['semantically_correct'].mean()

In [None]:
df_without_not_xor['semantically_correct'].mean()

In [None]:
df_without_not_xor[~df_without_not_xor.semantically_correct].sort_values(by='input_size')[:5]

In [None]:
df_without_and_xor[df_without_and_xor['inputs'].str.contains('& xor')].sort_values(by='input_size',ascending=True)[:5]

## Future questions / Ideas for projects

- Are there differences in outputs/probing results/hidden state results for the models trained on systematically different data?
- Can we probe for something else?
- Can we compare this model to a (neuro)symbolic solver, or find an algorithm that the model is using internally?
