In [1]:
import torch
import random
import numpy as np
from sklearn.metrics import classification_report
from equality_datasets import get_equality_dataset, get_IIT_equality_dataset
from IIT_torch_shallow_neural_classifier import TorchShallowNeuralClassifierIIT
from torch_shallow_neural_classifier import TorchShallowNeuralClassifier
from torch_interventionable_model import InterventionableLayeredModel
from torch_rnn_classifier import TorchRNNClassifier
random.seed(42)

# Table of Contents
1. [Hierarchical Equality Dataset](#Hierarchical-Equality-Dataset)
2. [High-Level Tree-Structured Algorithm](#The-High-Level-Tree-Structured-Algorithm)
3. [A Fully-Connected Feed-Forward Neural Network](#A-Fully-Connected-Feed-Forward-Neural-Network)
4. [Causal Abstraction](#Causal-Abstraction)
5. [Interchange Intervention Training (IIT)](#Interchange-Intervention-Training-(IIT))

## Hierarchical Equality Dataset  
[Geiger, Carstensen, Frank, and Potts (2020)](https://arxiv.org/abs/2006.07968)

We will use a hierarchical equality task to present IIT. We define the hierarchical equality task as follows: The input is two pairs of objects and the output is **true** if both pairs contain the same object or if both pairs contain different objects and **false** otherwise. For example, AABB and ABCD are both labeled **true** while ABCC and BBCD are both labeled **false**. 

## The High-Level Tree-Structured Algorithm

Let $\mathcal{A}$ be the simple tree structured algorithm that solves this task by applying a simple equality relation three times: Compute whether the first two inputs are equal, compute whether the second two inputs are equal, then compute whether
the truth-valued outputs of these first two computations are equal. We visually define $\mathcal{A}$ below and then define a python function that computes $\mathcal{A}$, possibly under an intervention that sets $V_1$ and/or $V_2$ to fixed values.

<img src="fig/IIT/PremackFunctions.png" width="500"/>
<img src="fig/IIT/PremackGraph.png" width="500"/>

In [2]:
def compute_A(input, intervention):
    graph = dict()
    for i, object in enumerate(input):
        graph["input" + str(i+1)] = object
    if "V1" in intervention:
        graph["V1"] = intervention["V1"]
    else:
        graph["V1"] = graph["input1"] == graph["input2"]
    if "V2" in intervention:
        graph["V2"] = intervention["V2"]
    else:
        graph["V2"] = graph["input3"] == graph["input4"]
    graph["output"] = graph["V1"] == graph["V2"]
    return graph

### The algorithm with no intervention

First, observe the behavior of the algorithm whhen we provide the input **(pentagon,pentagon, triangle, square)** with no intervention. We show this visually and by using our **compute_A** function.

<img src="fig/IIT/PremackNoIntervention.png" width="500"/>

In [3]:
compute_A(("pentagon", "pentagon", "triangle", "square"), {})

{'input1': 'pentagon',
 'input2': 'pentagon',
 'input3': 'triangle',
 'input4': 'square',
 'V1': True,
 'V2': False,
 'output': False}

### The algorithm with an intervention

Observe the behavior of the algorithm whhen we provide the input **(square,pentagon, triangle, triangle)** with an intervention setting **V1** to **False**. We show this visually and by using our **compute_A** function.

<img src="fig/IIT/PremackIntervention.png" width="500"/>

In [4]:
compute_A(("square", "pentagon", "triangle", "triangle"), {"V1":True})

{'input1': 'square',
 'input2': 'pentagon',
 'input3': 'triangle',
 'input4': 'triangle',
 'V1': True,
 'V2': True,
 'output': True}

### The algorithm with an interchange intervention

Finaally, observe the behavior of the algorithm when we provide the base input **(square,pentagon, triangle, triangle)** with an intervention setting **V1** to be the value it would be for the source input **(pentagon,pentagon, triangle, square)**. We show this visually and by using our **compute_A** function.

<img src="fig/IIT/PremackIntervention.png" width="500"/>

In [5]:
def compute_interchange_A(base,source, variable):
    return compute_A(base, {variable:compute_A(source, {})[variable]})
    
compute_interchange_A(("pentagon", "pentagon", "triangle", "square"), ("square", "pentagon", "triangle", "triangle"), "V1")

{'input1': 'pentagon',
 'input2': 'pentagon',
 'input3': 'triangle',
 'input4': 'square',
 'V1': False,
 'V2': False,
 'output': True}

# A Fully-Connected Feed-Forward Neural Network

We will train a two layer feed-forward neural network on this task where each object has a random vector assigned to it and the objects in training are disjoint from the objects seen in testing

In [6]:

TRUE_LABEL = 1
FALSE_LABEL = 0
    
embedding_dim = 5
X_train, X_test, y_train, y_test, test_dataset = get_equality_dataset(embedding_dim,10000)

model = TorchShallowNeuralClassifier(hidden_dim=4*embedding_dim, hidden_activation=torch.nn.ReLU(), num_layers=3)
_ = model.fit(X_train,y_train)


Stopping after epoch 553. Training loss did not improve more than tol=1e-05. Final error is 0.0003255244682804914.

Observe that this neural network achieves near perfect performance on its test set.

In [7]:
print(X_train[0], y_train[0])
preds = model.predict(X_train)
print("Train Results")
print(classification_report(y_train, preds))
preds = model.predict(X_test)
print("\n\n\nTest Results")
print(classification_report(y_test, preds))

tensor([ 0.0364,  0.0224, -0.3713, -0.0588,  0.2444,  0.0364,  0.0224, -0.3713,
        -0.0588,  0.2444, -0.4442,  0.1645, -0.2804, -0.3022,  0.3193, -0.3240,
         0.1784,  0.3509,  0.2537, -0.3696], dtype=torch.float64) 0
Train Results
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5000
           1       1.00      1.00      1.00      5000

    accuracy                           1.00     10000
   macro avg       1.00      1.00      1.00     10000
weighted avg       1.00      1.00      1.00     10000




Test Results
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5000
           1       1.00      1.00      1.00      5000

    accuracy                           1.00     10000
   macro avg       1.00      1.00      1.00     10000
weighted avg       1.00      1.00      1.00     10000



### The network with no intervention

First, observe the behavior of the network when we provide the input **(pentagon,pentagon, triangle, square)** with no intervention. We assign each shape a random vector.

In [8]:
pentagon = [random.uniform(-0.5,0.5) for _ in range(embedding_dim)]
triangle = [random.uniform(-0.5,0.5) for _ in range(embedding_dim)]
square = [random.uniform(-0.5,0.5) for _ in range(embedding_dim)]



print("Input:",[[*pentagon,*pentagon,*triangle,*square]])
for k in range(len(model.layers)):
    get_coord = {"layer":k, "start":0, "end":embedding_dim*4}
    print(f"\nLayer {k}:", model.layers[k])
    print("\nNeural Activations:", model.retrieve_activations(torch.tensor([[*pentagon,*pentagon,*triangle,*square]]), get_coord, None))


Input: [[-0.1468924697219567, 0.08058055715288948, -0.00780593997868273, -0.09137492307583517, 0.37462958731792295, -0.1468924697219567, 0.08058055715288948, -0.00780593997868273, -0.09137492307583517, 0.37462958731792295, 0.15957464170810198, 0.23421546883109567, 0.2770979368015196, 0.33923462016878914, -0.40979122498799136, -0.3828546875372202, 0.22214910510358454, 0.47936303417244075, -0.4888289894644897, 0.15564735672704422]]

Layer 0: ActivationLayer(
  (linear): Linear(in_features=20, out_features=20, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[0.0000, 0.6438, 0.0000, 0.4686, 0.0000, 0.2442, 0.1763, 0.0000, 0.0000,
         0.0055, 0.4739, 0.0000, 0.0000, 0.0092, 0.0000, 0.1003, 0.0000, 0.0000,
         0.1359, 0.0000]], device='cuda:0', grad_fn=<SliceBackward>)

Layer 1: ActivationLayer(
  (linear): Linear(in_features=20, out_features=20, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[0.9234, 0.0949, 0.4198, 0.0000, 0.0256, 0.0000, 0.0000

### The network with an intervention

Now, observe the behavior of the network when we provide the input **(pentagon,pentagon, triangle, square)** with an intervention that zeros out five neurons after the first hidden layer.

In [9]:
set_coord = {"layer":1, "start":0, "end":embedding_dim, "intervention": torch.tensor([[0 for _ in range(embedding_dim)]])}

print("Input:",[[*pentagon,*pentagon,*triangle,*square]])
for k in range(len(model.layers)):
    get_coord = {"layer":k, "start":0, "end":embedding_dim*4}
    print(f"\nLayer {k}:", model.layers[k])
    print("\nNeural Activations:", model.retrieve_activations(torch.tensor([[*pentagon,*pentagon,*triangle,*square]]), get_coord, set_coord))


    

Input: [[-0.1468924697219567, 0.08058055715288948, -0.00780593997868273, -0.09137492307583517, 0.37462958731792295, -0.1468924697219567, 0.08058055715288948, -0.00780593997868273, -0.09137492307583517, 0.37462958731792295, 0.15957464170810198, 0.23421546883109567, 0.2770979368015196, 0.33923462016878914, -0.40979122498799136, -0.3828546875372202, 0.22214910510358454, 0.47936303417244075, -0.4888289894644897, 0.15564735672704422]]

Layer 0: ActivationLayer(
  (linear): Linear(in_features=20, out_features=20, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[0.0000, 0.6438, 0.0000, 0.4686, 0.0000, 0.2442, 0.1763, 0.0000, 0.0000,
         0.0055, 0.4739, 0.0000, 0.0000, 0.0092, 0.0000, 0.1003, 0.0000, 0.0000,
         0.1359, 0.0000]], device='cuda:0', grad_fn=<SliceBackward>)

Layer 1: ActivationLayer(
  (linear): Linear(in_features=20, out_features=20, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000

### The network with an interchange intervention

Finally, observe the behavior of the network when we provide the input **(pentagon,pentagon, triangle, square)** with an intervention that sets five neurons after the first hidden layer to the values they achieve for the source input  **(square, pentagon, triangle, triangle)**.

In [10]:
intervention = model.retrieve_activations(torch.tensor([[*square,*pentagon,*triangle,*triangle]]), {"layer":1, "start":0, "end":embedding_dim},None)

set_coord = {"layer":1, "start":0, "end":embedding_dim, "intervention": intervention}

print("Input:",[[*pentagon,*pentagon,*triangle,*square]])
for k in range(len(model.layers)):
    get_coord = {"layer":k, "start":0, "end":embedding_dim*4}
    print(f"\nLayer {k}:", model.layers[k])
    print("\nNeural Activations:", model.retrieve_activations(torch.tensor([[*pentagon,*pentagon,*triangle,*square]]), get_coord, set_coord))


def interchange_intervention(model, base, source, int_coord, get_coord):
    intervention = model.retrieve_activations(source, int_coord,None)
    int_coord["intervention"] = intervention
    return model.retrieve_activations(base, get_coord, int_coord)

Input: [[-0.1468924697219567, 0.08058055715288948, -0.00780593997868273, -0.09137492307583517, 0.37462958731792295, -0.1468924697219567, 0.08058055715288948, -0.00780593997868273, -0.09137492307583517, 0.37462958731792295, 0.15957464170810198, 0.23421546883109567, 0.2770979368015196, 0.33923462016878914, -0.40979122498799136, -0.3828546875372202, 0.22214910510358454, 0.47936303417244075, -0.4888289894644897, 0.15564735672704422]]

Layer 0: ActivationLayer(
  (linear): Linear(in_features=20, out_features=20, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[0.0000, 0.6438, 0.0000, 0.4686, 0.0000, 0.2442, 0.1763, 0.0000, 0.0000,
         0.0055, 0.4739, 0.0000, 0.0000, 0.0092, 0.0000, 0.1003, 0.0000, 0.0000,
         0.1359, 0.0000]], device='cuda:0', grad_fn=<SliceBackward>)

Layer 1: ActivationLayer(
  (linear): Linear(in_features=20, out_features=20, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[0.0000, 0.8636, 0.4841, 0.2193, 0.5431, 0.0000, 0.0000

# Causal Abstraction

We defined a **high-level tree structured agorithm** that solves the hierarchical equality task.

We trained a **low-level fully connected neural network** that solves the hierarchical equality task.

A formal theory of **causal abstraction** describes the conditions that must hold for the high-level tree structured algorithm to be a **simplified and faithful description** of the neural network: 

**An algorithm is a causal abstraction of a neural network if and only if for all base and source inputs, the algorithm and network provides the same output under an aligned interchange intervention.**

Below, we define an alignment between the neural network and the algorithm and a function to compute the **interchange intervention training accuracy** for a high-level variable, which is the percentage of aligned interchange interventions that the network and algorithm produce the same output on. When the IIT accuracy is 100%, the causal abstraction relation holds between the network and a simplified version of the algorithm where only one high-level variable exists.

We compute the IIT accuracy on our toy domain where each entity is either a pentagon, square, or triangle.

In [11]:
import itertools

alignment = {"V1": {"layer":1, "start":0, "end":embedding_dim}, "V2": {"layer":1, "start":embedding_dim+1, "end":embedding_dim*2}}

def convert_input(tensor, embedding_dim):
    return [tuple(tensor[0,embedding_dim*k:embedding_dim*(k+1)].flatten().tolist()) for k in range(4)]

def compute_IIT_accuracy(variable, model):
    labels = []
    predictions = []
    for base in itertools.product([pentagon, triangle, square], repeat=4):
        for source in itertools.product([pentagon, triangle, square], repeat=4):
            basetensor = torch.cat([torch.tensor([base[k]]) for k in range(4)], 1)
            sourcetensor = torch.cat([torch.tensor([source[k]]) for k in range(4)],1)
            algorithm_output = compute_interchange_A(convert_input(basetensor, embedding_dim), convert_input(sourcetensor, embedding_dim), variable)
            if algorithm_output["output"]:   
                labels.append(TRUE_LABEL)
            else:
                labels.append(FALSE_LABEL)
            get_coord = {"layer":3, "start":0, "end":2}
            network_output = interchange_intervention(model, basetensor, sourcetensor,alignment[variable], get_coord).argmax(axis=1)
            predictions.append(int(network_output))
    return labels, predictions

Observe that we have low IIT accuracy for both **V1** and **V2**, meaning that under this alignment the neural network does not compute either variable. We have no evidence that this network computes simple equality relations to solve this hierarchical equality task.

In [12]:

print(classification_report(*compute_IIT_accuracy("V1", model)))
print(classification_report(*compute_IIT_accuracy("V2", model)))

              precision    recall  f1-score   support

           0       0.62      0.49      0.55      2916
           1       0.65      0.75      0.70      3645

    accuracy                           0.64      6561
   macro avg       0.63      0.62      0.62      6561
weighted avg       0.64      0.64      0.63      6561

              precision    recall  f1-score   support

           0       0.53      0.57      0.55      2916
           1       0.63      0.59      0.61      3645

    accuracy                           0.58      6561
   macro avg       0.58      0.58      0.58      6561
weighted avg       0.59      0.58      0.58      6561



# Interchange Intervention Training (IIT)

Original IIT [Geiger\*, Wu\*, Lu\*, Rozner, Kreiss, Icard, Goodman, and Potts (2021)](https://arxiv.org/abs/2112.00826)

IIT for model distillation [ Wu\*,Geiger\*, Rozner, Kreiss, Lu, Icard, Goodman, and Potts (2021)](https://arxiv.org/abs/2112.02505)

Interchange intervention training is a method for training a neural network to conform to the causal structure of a high-level algorithm. Conceptually, it is a direct extension of the causal abstraction analysis we just performed, except instead of **evaluating** whether the neural network and algorithm produce the same outputs under aligned interchange interventions, we are now **training** the neural network to produce the output of the algorithm under aligned interchange interventions.


In [13]:
V1 = 0
V2 = 1
id_to_coords = {V1: alignment["V1"], V2: alignment["V2"]}

X_base_train, X_source_train, y_base_train, y_IIT_train, interventions = get_IIT_equality_dataset("V1", embedding_dim ,10000)

model = TorchShallowNeuralClassifierIIT(hidden_dim=embedding_dim*4, hidden_activation=torch.nn.ReLU(), num_layers=3, id_to_coords=id_to_coords)
_ = model.fit(X_base_train, X_source_train, y_base_train, y_IIT_train,interventions)


Stopping after epoch 590. Training loss did not improve more than tol=1e-05. Final error is 0.001175935372884851.

In [14]:
X_base_test, X_source_test, y_base_test, y_IIT_test, interventions = get_IIT_equality_dataset("V1", embedding_dim,10000)

IIT_preds, base_preds = model.model(model.prep_input(X_base_test, X_source_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_base_test, base_preds))
print(classification_report(y_IIT_test, IIT_preds))


X_base_test, X_source_test, y_base_test, y_IIT_test, interventions = get_IIT_equality_dataset("V2", embedding_dim,10000)
IIT_preds, base_preds = model.model(model.prep_input(X_base_test, X_source_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_IIT_test, IIT_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5000
           1       1.00      1.00      1.00      5000

    accuracy                           1.00     10000
   macro avg       1.00      1.00      1.00     10000
weighted avg       1.00      1.00      1.00     10000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5000
           1       1.00      1.00      1.00      5000

    accuracy                           1.00     10000
   macro avg       1.00      1.00      1.00     10000
weighted avg       1.00      1.00      1.00     10000

              precision    recall  f1-score   support

           0       0.65      0.65      0.65      5000
           1       0.65      0.65      0.65      5000

    accuracy                           0.65     10000
   macro avg       0.65      0.65      0.65     10000
weighted avg       0.65      0.65      0.65     10000



Observe that we now have perfect IIT accuracy **V1** meaning that under this alignment the neural network computes whether the first pair of inputs are equal. However, we still have low IIT accuracy for **V2**, meaning that under this alignment the neural network doesn't compute whether the second pair of inputs are equal.

This is expected, because we only trained the network to compute **V1**.

We can train the network to compute both **V1** and **V2**.

In [15]:
model = TorchShallowNeuralClassifierIIT(hidden_dim=embedding_dim*4, hidden_activation=torch.nn.ReLU(), num_layers=3, id_to_coords=id_to_coords)


v1data = get_IIT_equality_dataset("V1", embedding_dim ,10000)
v2data = get_IIT_equality_dataset("V2", embedding_dim ,10000)
X_base_train = torch.cat([v1data[0],v2data[0]], dim=0)
X_source_train = torch.cat([v1data[1],v2data[1]], dim=0)
y_base_train = torch.cat([v1data[2],v2data[2]])
y_IIT_train = torch.cat([v1data[3],v2data[3]])
interventions = torch.cat([v1data[4],v2data[4]])

_ = model.fit(X_base_train, X_source_train, y_base_train, y_IIT_train, interventions)




Stopping after epoch 772. Training loss did not improve more than tol=1e-05. Final error is 0.3725438618566841.

In [16]:
X_base_test, X_source_test, y_base_test, y_IIT_test, interventions = get_IIT_equality_dataset("V1", embedding_dim,10000)

IIT_preds, base_preds = model.model(model.prep_input(X_base_test, X_source_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_base_test, base_preds))
print(classification_report(y_IIT_test, IIT_preds))


X_base_test, X_source_test, y_base_test, y_IIT_test, interventions = get_IIT_equality_dataset("V2", embedding_dim,10000)
IIT_preds, base_preds = model.model(model.prep_input(X_base_test, X_source_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_IIT_test, IIT_preds))


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5000
           1       1.00      1.00      1.00      5000

    accuracy                           1.00     10000
   macro avg       1.00      1.00      1.00     10000
weighted avg       1.00      1.00      1.00     10000

              precision    recall  f1-score   support

           0       0.99      0.99      0.99      5000
           1       0.99      0.99      0.99      5000

    accuracy                           0.99     10000
   macro avg       0.99      0.99      0.99     10000
weighted avg       0.99      0.99      0.99     10000

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5000
           1       1.00      1.00      1.00      5000

    accuracy                           1.00     10000
   macro avg       1.00      1.00      1.00     10000
weighted avg       1.00      1.00      1.00     10000

