In [1]:
import torch
import random
import copy
import itertools
import numpy as np
from sklearn.metrics import classification_report
from iit import get_equality_dataset, get_IIT_equality_dataset, get_IIT_equality_dataset_both
from torch_deep_neural_classifier_iit import TorchDeepNeuralClassifierIIT, TorchDeepNeuralClassifier
from torch_rnn_classifier import TorchRNNClassifier
random.seed(42)

# Table of Contents
1. [Hierarchical Equality Dataset](#Hierarchical-Equality-Dataset)
2. [High-Level Tree-Structured Algorithm](#The-High-Level-Tree-Structured-Algorithm)
3. [A Fully-Connected Feed-Forward Neural Network](#A-Fully-Connected-Feed-Forward-Neural-Network)
4. [Causal Abstraction](#Causal-Abstraction)
5. [Interchange Intervention Training (IIT)](#Interchange-Intervention-Training-(IIT))

## Hierarchical Equality Dataset  
[Geiger, Carstensen, Frank, and Potts (2020)](https://arxiv.org/abs/2006.07968)

We will use a hierarchical equality task to present IIT. We define the hierarchical equality task as follows: The input is two pairs of objects and the output is **true** if both pairs contain the same object or if both pairs contain different objects and **false** otherwise. For example, AABB and ABCD are both labeled **true** while ABCC and BBCD are both labeled **false**. 

## The High-Level Tree-Structured Algorithm

Let $\mathcal{A}$ be the simple tree structured algorithm that solves this task by applying a simple equality relation three times: Compute whether the first two inputs are equal, compute whether the second two inputs are equal, then compute whether
the truth-valued outputs of these first two computations are equal. We visually define $\mathcal{A}$ below and then define a python function that computes $\mathcal{A}$, possibly under an intervention that sets $V_1$ and/or $V_2$ to fixed values.

<img src="fig/IIT/PremackFunctions.png" width="500"/>
<img src="fig/IIT/PremackGraph.png" width="500"/>

In [2]:
def compute_A(input, intervention):
    graph = dict()
    for i, object in enumerate(input):
        graph["input" + str(i+1)] = object
    if "V1" in intervention:
        graph["V1"] = intervention["V1"]
    else:
        graph["V1"] = graph["input1"] == graph["input2"]
    if "V2" in intervention:
        graph["V2"] = intervention["V2"]
    else:
        graph["V2"] = graph["input3"] == graph["input4"]
    graph["output"] = graph["V1"] == graph["V2"]
    return graph

### The algorithm with no intervention

First, observe the behavior of the algorithm whhen we provide the input **(pentagon,pentagon, triangle, square)** with no intervention. We show this visually and by using our **compute_A** function.

<img src="fig/IIT/PremackNoIntervention.png" width="500"/>

In [3]:
compute_A(("pentagon", "pentagon", "triangle", "square"), {})

{'input1': 'pentagon',
 'input2': 'pentagon',
 'input3': 'triangle',
 'input4': 'square',
 'V1': True,
 'V2': False,
 'output': False}

### The algorithm with an intervention

Observe the behavior of the algorithm whhen we provide the input **(square,pentagon, triangle, triangle)** with an intervention setting **V1** to **False**. We show this visually and by using our **compute_A** function.

<img src="fig/IIT/PremackIntervention.png" width="500"/>

In [4]:
compute_A(("square", "pentagon", "triangle", "triangle"), {"V1":True})

{'input1': 'square',
 'input2': 'pentagon',
 'input3': 'triangle',
 'input4': 'triangle',
 'V1': True,
 'V2': True,
 'output': True}

### The algorithm with an interchange intervention

Finaally, observe the behavior of the algorithm when we provide the base input **(square,pentagon, triangle, triangle)** with an intervention setting **V1** to be the value it would be for the source input **(pentagon,pentagon, triangle, square)**. We show this visually and by using our **compute_A** function.

<img src="fig/IIT/algorithmII.png" width="800"/>

In [5]:
def compute_interchange_A(base,source, variable):
    return compute_A(base, {variable:compute_A(source, {})[variable]})
    
compute_interchange_A(("pentagon", "pentagon", "triangle", "square"), ("square", "pentagon", "triangle", "triangle"), "V1")

{'input1': 'pentagon',
 'input2': 'pentagon',
 'input3': 'triangle',
 'input4': 'square',
 'V1': False,
 'V2': False,
 'output': True}

# A Fully-Connected Feed-Forward Neural Network

We will train a three layer feed-forward neural network on this task where each object has a random vector assigned to it and the objects in training are disjoint from the objects seen in testing.

<img src="fig/IIT/Network.png" width="800"/>

In [6]:
class InterventionableTorchDeepNeuralClassifier(TorchDeepNeuralClassifier):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def make_hook(self, gets, sets, layer):
        def hook(model, input, output):
            layer_gets, layer_sets = [], []
            if gets is not None and layer in gets:
                layer_gets = gets[layer]
            if sets is not None and layer in sets:
                layer_sets = sets[layer]
            for get in layer_gets:
                self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}'] = output[:,get["start"]: get["end"] ]
            for set in layer_sets:
                output[:,set["start"]: set["end"]] = set["intervention"]
        return hook

    def _gets_sets(self,gets=None, sets = None):
        handlers = []
        for layer in range(len(self.layers)):
            hook = self.make_hook(gets,sets, layer)
            both_handler = self.layers[layer].register_forward_hook(hook)
            handlers.append(both_handler)
        return handlers

    def retrieve_activations(self, input, get, sets):
        input = input.type(torch.FloatTensor).to(self.device)
        self.activation = dict()
        handlers = self._gets_sets({get["layer"]:[get]}, sets)
        logits = self.model(input)
        for handler in handlers:
            handler.remove()
        return self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}']

In [7]:

TRUE_LABEL = 1
FALSE_LABEL = 0

data_size = 1024 * 10
embedding_dim = 4
X_train, X_test, y_train, y_test, test_dataset = get_equality_dataset(embedding_dim,100)

model = InterventionableTorchDeepNeuralClassifier(hidden_dim=4*embedding_dim, hidden_activation=torch.nn.ReLU(), num_layers=3)
_ = model.fit(X_train,y_train)


Stopping after epoch 484. Training loss did not improve more than tol=1e-05. Final error is 0.0014873843174427748.

Observe that this neural network achieves near perfect performance on its test set.

In [8]:
print(X_train[0], y_train[0])
preds = model.predict(X_train)
print("Train Results")
print(classification_report(y_train, preds))
preds = model.predict(X_test)
print("\n\n\nTest Results")
print(classification_report(y_test, preds))

tensor([-0.0306, -0.0595, -0.3156, -0.4486, -0.0306, -0.0595, -0.3156, -0.4486,
         0.4411, -0.0223,  0.3221, -0.0993, -0.4259,  0.1294, -0.4464, -0.3508],
       dtype=torch.float64) 0
Train Results
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        50
           1       1.00      1.00      1.00        50

    accuracy                           1.00       100
   macro avg       1.00      1.00      1.00       100
weighted avg       1.00      1.00      1.00       100




Test Results
              precision    recall  f1-score   support

           0       0.55      0.54      0.55        50
           1       0.55      0.56      0.55        50

    accuracy                           0.55       100
   macro avg       0.55      0.55      0.55       100
weighted avg       0.55      0.55      0.55       100



### The network with no intervention

First, observe the behavior of the network when we provide the input **(pentagon,pentagon, triangle, square)** with no intervention. We assign each shape a random vector.

In [9]:
pentagon = [random.uniform(-0.5,0.5) for _ in range(embedding_dim)]
triangle = [random.uniform(-0.5,0.5) for _ in range(embedding_dim)]
square = [random.uniform(-0.5,0.5) for _ in range(embedding_dim)]



print("Input:",[[*pentagon,*pentagon,*triangle,*square]])
for k in range(len(model.layers)):
    get_coord = {"layer":k, "start":0, "end":embedding_dim*4}
    print(f"\nLayer {k}:", model.layers[k])
    print("\nNeural Activations:", model.retrieve_activations(torch.tensor([[*pentagon,*pentagon,*triangle,*square]]), get_coord, None))


Input: [[0.22598836347044526, 0.2552321570675361, -0.07076905924465338, -0.40880081278310687, 0.22598836347044526, 0.2552321570675361, -0.07076905924465338, -0.40880081278310687, -0.4048163220205434, -0.2865788230537869, -0.4220612784273938, 0.09367949684106447, 0.019444254636635794, -0.05493312322134569, -0.4919374309808232, -0.1564258499658926]]

Layer 0: ActivationLayer(
  (linear): Linear(in_features=16, out_features=16, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[0.1312, 0.4781, 0.0000, 0.4358, 0.0000, 0.1913, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.4050, 0.5134, 1.0772, 0.7328, 0.1617]],
       device='cuda:0', grad_fn=<SliceBackward>)

Layer 1: ActivationLayer(
  (linear): Linear(in_features=16, out_features=16, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[0.7966, 0.0000, 0.4038, 1.3004, 0.6069, 0.5085, 0.8640, 1.0712, 0.3403,
         0.4802, 0.0998, 0.0000, 0.3822, 0.2907, 0.8604, 0.3601]],
       device='cuda:0', grad_fn=<

### The network with an intervention

Now, observe the behavior of the network when we provide the input **(pentagon,pentagon, triangle, square)** with an intervention that zeros out five neurons after the first hidden layer.

In [10]:
set_coord = {"layer":1, "start":0, "end":embedding_dim, "intervention": torch.tensor([[0 for _ in range(embedding_dim)]])}

print("Input:",[[*pentagon,*pentagon,*triangle,*square]])
for k in range(len(model.layers)):
    get_coord = {"layer":k, "start":0, "end":embedding_dim*4}
    print(f"\nLayer {k}:", model.layers[k])
    print("\nNeural Activations:", model.retrieve_activations(torch.tensor([[*pentagon,*pentagon,*triangle,*square]]), get_coord, set_coord))


    

Input: [[0.22598836347044526, 0.2552321570675361, -0.07076905924465338, -0.40880081278310687, 0.22598836347044526, 0.2552321570675361, -0.07076905924465338, -0.40880081278310687, -0.4048163220205434, -0.2865788230537869, -0.4220612784273938, 0.09367949684106447, 0.019444254636635794, -0.05493312322134569, -0.4919374309808232, -0.1564258499658926]]

Layer 0: ActivationLayer(
  (linear): Linear(in_features=16, out_features=16, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[0.1312, 0.4781, 0.0000, 0.4358, 0.0000, 0.1913, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.4050, 0.5134, 1.0772, 0.7328, 0.1617]],
       device='cuda:0', grad_fn=<SliceBackward>)

Layer 1: ActivationLayer(
  (linear): Linear(in_features=16, out_features=16, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[0.7966, 0.0000, 0.4038, 1.3004, 0.6069, 0.5085, 0.8640, 1.0712, 0.3403,
         0.4802, 0.0998, 0.0000, 0.3822, 0.2907, 0.8604, 0.3601]],
       device='cuda:0', grad_fn=<

### The network with an interchange intervention

Finally, observe the behavior of the network when we provide the input **(square, pentagon, triangle, triangle)** with an intervention that sets five neurons after the first hidden layer to the values they achieve for the source input **(pentagon,pentagon, triangle, square)**.
<img src="fig/IIT/networkII.png" width="500"/>

In [11]:
intervention = model.retrieve_activations(torch.tensor([[*pentagon,*pentagon,*triangle,*square]]), {"layer":1, "start":0, "end":embedding_dim},None)

set_coord = {"layer":1, "start":0, "end":embedding_dim, "intervention": intervention}

print("Input:",[[*pentagon,*pentagon,*triangle,*square]])
for k in range(len(model.layers)):
    get_coord = {"layer":k, "start":0, "end":embedding_dim*4}
    print(f"\nLayer {k}:", model.layers[k])
    print("\nNeural Activations:", model.retrieve_activations(torch.tensor([[*square,*pentagon,*triangle,*triangle]]), get_coord, set_coord))




Input: [[0.22598836347044526, 0.2552321570675361, -0.07076905924465338, -0.40880081278310687, 0.22598836347044526, 0.2552321570675361, -0.07076905924465338, -0.40880081278310687, -0.4048163220205434, -0.2865788230537869, -0.4220612784273938, 0.09367949684106447, 0.019444254636635794, -0.05493312322134569, -0.4919374309808232, -0.1564258499658926]]

Layer 0: ActivationLayer(
  (linear): Linear(in_features=16, out_features=16, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[0.0000, 0.5111, 0.0000, 0.4507, 0.0000, 0.0280, 0.0000, 0.0000, 0.0000,
         0.0000, 0.1421, 0.4130, 0.0450, 0.6898, 0.6195, 0.0000]],
       device='cuda:0', grad_fn=<SliceBackward>)

Layer 1: ActivationLayer(
  (linear): Linear(in_features=16, out_features=16, bias=True)
  (activation): ReLU()
)

Neural Activations: tensor([[0.4637, 0.1112, 0.7164, 0.9098, 0.2960, 0.6310, 0.4019, 0.6862, 0.5158,
         0.3605, 0.3475, 0.0000, 0.2031, 0.6713, 0.4579, 0.2025]],
       device='cuda:0', grad_fn=<

# Causal Abstraction

We defined a **high-level tree structured agorithm** that solves the hierarchical equality task.

We trained a **low-level fully connected neural network** that solves the hierarchical equality task.

A formal theory of **causal abstraction** describes the conditions that must hold for the high-level tree structured algorithm to be a **simplified and faithful description** of the neural network: 

**An algorithm is a causal abstraction of a neural network if and only if for all base and source inputs, the algorithm and network provides the same output under an aligned interchange intervention.**

Below, we define an alignment between the neural network and the algorithm and a function to compute the **interchange intervention training accuracy** for a high-level variable, which is the percentage of aligned interchange interventions that the network and algorithm produce the same output on. When the IIT accuracy is 100%, the causal abstraction relation holds between the network and a simplified version of the algorithm where only one high-level variable exists.

<img src="fig/IIT/alignment.png" width="500"/>

We compute the IIT accuracy on our toy domain where each entity is either a pentagon, square, or triangle.

In [12]:


alignment = {"V1": {"layer":1, "start":0, "end":embedding_dim}, "V2": {"layer":1, "start":embedding_dim, "end":embedding_dim*2}}

def interchange_intervention(model, base, source, int_coord, output_coord):
    intervention = model.retrieve_activations(source, int_coord[1][0],None)
    int_coord[1][0]["intervention"] = intervention
    return model.retrieve_activations(base, output_coord, int_coord)

def convert_input(tensor, embedding_dim):
    return [tuple(tensor[0,embedding_dim*k:embedding_dim*(k+1)].flatten().tolist()) for k in range(4)]

def compute_IIT_accuracy(variable, model):
    labels = []
    predictions = []
    for base in itertools.product([pentagon, triangle, square], repeat=4):
        for source in itertools.product([pentagon, triangle, square], repeat=4):
            basetensor = torch.cat([torch.tensor([base[k]]) for k in range(4)], 1)
            sourcetensor = torch.cat([torch.tensor([source[k]]) for k in range(4)],1)
            algorithm_output = compute_interchange_A(convert_input(basetensor, embedding_dim), convert_input(sourcetensor, embedding_dim), variable)
            if algorithm_output["output"]:   
                labels.append(TRUE_LABEL)
            else:
                labels.append(FALSE_LABEL)
            output_coord = {"layer":3, "start":0, "end":2}
            network_output = interchange_intervention(model, basetensor, sourcetensor,{1:[copy.deepcopy(alignment[variable])]}, output_coord).argmax(axis=1)
            predictions.append(int(network_output))
    return labels, predictions

Observe that we have low IIT accuracy for both **V1** and **V2**, meaning that under this alignment the neural network does not compute either variable. We have no evidence that this network computes simple equality relations to solve this hierarchical equality task.

In [13]:

print(classification_report(*compute_IIT_accuracy("V1", model)))
print(classification_report(*compute_IIT_accuracy("V2", model)))

              precision    recall  f1-score   support

           0       0.45      0.31      0.36      2916
           1       0.56      0.70      0.62      3645

    accuracy                           0.52      6561
   macro avg       0.50      0.50      0.49      6561
weighted avg       0.51      0.52      0.50      6561

              precision    recall  f1-score   support

           0       0.45      0.30      0.36      2916
           1       0.56      0.70      0.62      3645

    accuracy                           0.52      6561
   macro avg       0.50      0.50      0.49      6561
weighted avg       0.51      0.52      0.50      6561



# Interchange Intervention Training (IIT)

Original IIT [Geiger\*, Wu\*, Lu\*, Rozner, Kreiss, Icard, Goodman, and Potts (2021)](https://arxiv.org/abs/2112.00826)

IIT for model distillation [ Wu\*,Geiger\*, Rozner, Kreiss, Lu, Icard, Goodman, and Potts (2021)](https://arxiv.org/abs/2112.02505)

Interchange intervention training is a method for training a neural network to conform to the causal structure of a high-level algorithm. Conceptually, it is a direct extension of the causal abstraction analysis we just performed, except instead of **evaluating** whether the neural network and algorithm produce the same outputs under aligned interchange interventions, we are now **training** the neural network to produce the output of the algorithm under aligned interchange interventions.


In [14]:
V1 = 0
V2 = 1
both = 2
id_to_coords = {V1:{1: [{"layer":1, "start":0, "end":embedding_dim}]}, \
    V2: {1: [{"layer":1, "start":embedding_dim, "end":embedding_dim*2}]}, \
    both: {1: [{"layer":1, "start":0, "end":embedding_dim},{"layer":1, "start":embedding_dim, "end":embedding_dim*2}]}}

X_base_train, X_sources_train, y_base_train, y_IIT_train, interventions = get_IIT_equality_dataset("V1", embedding_dim ,data_size)

model = TorchDeepNeuralClassifierIIT(hidden_dim=embedding_dim*4, hidden_activation=torch.nn.ReLU(), num_layers=3, id_to_coords=id_to_coords)
_ = model.fit(X_base_train, X_sources_train, y_base_train, y_IIT_train,interventions)


Stopping after epoch 715. Training loss did not improve more than tol=1e-05. Final error is 0.0007661401759833097.

In [15]:
X_base_test, X_sources_test, y_base_test, y_IIT_test, interventions = get_IIT_equality_dataset("V1", embedding_dim,data_size)

IIT_preds, base_preds = model.model(model.prep_input(X_base_test, X_sources_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_base_test, base_preds))
print(classification_report(y_IIT_test, IIT_preds))


X_base_test, X_sources_test, y_base_test, y_IIT_test, interventions = get_IIT_equality_dataset("V2", embedding_dim,data_size)
IIT_preds, base_preds = model.model(model.prep_input(X_base_test, X_sources_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_IIT_test, IIT_preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5120
           1       1.00      1.00      1.00      5120

    accuracy                           1.00     10240
   macro avg       1.00      1.00      1.00     10240
weighted avg       1.00      1.00      1.00     10240

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5120
           1       1.00      1.00      1.00      5120

    accuracy                           1.00     10240
   macro avg       1.00      1.00      1.00     10240
weighted avg       1.00      1.00      1.00     10240

              precision    recall  f1-score   support

           0       0.74      0.74      0.74      5120
           1       0.74      0.73      0.74      5120

    accuracy                           0.74     10240
   macro avg       0.74      0.74      0.74     10240
weighted avg       0.74      0.74      0.74     10240



Observe that we now have perfect IIT accuracy **V1** meaning that under this alignment the neural network computes whether the first pair of inputs are equal. However, we still have low IIT accuracy for **V2**, meaning that under this alignment the neural network doesn't compute whether the second pair of inputs are equal.

This is expected, because we only trained the network to compute **V1**.

We can train the network to compute both **V1** and **V2**.

In [16]:
model = TorchDeepNeuralClassifierIIT(hidden_dim=embedding_dim*4, hidden_activation=torch.nn.ReLU(), num_layers=3, id_to_coords=id_to_coords)


v1data = get_IIT_equality_dataset("V1", embedding_dim, data_size)
v2data = get_IIT_equality_dataset("V2", embedding_dim, data_size)
X_base_train = torch.cat([v1data[0],v2data[0]], dim=0)
X_sources_train = [ torch.cat([v1data[1][i],v2data[1][i]], dim=0) for i in range(len(v1data[1]))] 
y_base_train = torch.cat([v1data[2],v2data[2]])
y_IIT_train = torch.cat([v1data[3],v2data[3]])
interventions = torch.cat([v1data[4],v2data[4]])

_ = model.fit(X_base_train, X_sources_train, y_base_train, y_IIT_train, interventions)




Stopping after epoch 584. Training loss did not improve more than tol=1e-05. Final error is 0.03112224835786037.

In [17]:
X_base_test, X_sources_test, y_base_test, y_IIT_test, interventions = get_IIT_equality_dataset("V1", embedding_dim,data_size)

IIT_preds, base_preds = model.model(model.prep_input(X_base_test, X_sources_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_base_test, base_preds))
print(classification_report(y_IIT_test, IIT_preds))


X_base_test, X_sources_test, y_base_test, y_IIT_test, interventions = get_IIT_equality_dataset("V2", embedding_dim,data_size)
IIT_preds, base_preds = model.model(model.prep_input(X_base_test, X_sources_test, interventions))
IIT_preds = np.array(IIT_preds.argmax(axis=1).cpu())
base_preds = np.array(base_preds.argmax(axis=1).cpu())
print(classification_report(y_IIT_test, IIT_preds))


              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5120
           1       1.00      1.00      1.00      5120

    accuracy                           1.00     10240
   macro avg       1.00      1.00      1.00     10240
weighted avg       1.00      1.00      1.00     10240

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5120
           1       1.00      1.00      1.00      5120

    accuracy                           1.00     10240
   macro avg       1.00      1.00      1.00     10240
weighted avg       1.00      1.00      1.00     10240

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      5120
           1       1.00      1.00      1.00      5120

    accuracy                           1.00     10240
   macro avg       1.00      1.00      1.00     10240
weighted avg       1.00      1.00      1.00     10240



In [18]:
print(classification_report(*compute_IIT_accuracy("V1", model.model)))
print(classification_report(*compute_IIT_accuracy("V2", model.model)))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2916
           1       1.00      1.00      1.00      3645

    accuracy                           1.00      6561
   macro avg       1.00      1.00      1.00      6561
weighted avg       1.00      1.00      1.00      6561

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2916
           1       1.00      1.00      1.00      3645

    accuracy                           1.00      6561
   macro avg       1.00      1.00      1.00      6561
weighted avg       1.00      1.00      1.00      6561



In [19]:

def compute_multisource_interchange_A(base,source,source2):
    return compute_A(base, {"V1":compute_A(source, {})["V1"], "V2":compute_A(source2, {})["V2"]})

def multisource_interchange_intervention(model, base, sources, coords, output_coord):
    source_activations = model.retrieve_activations(sources[0], coords[1][0],None)
    source_activations2 = model.retrieve_activations(sources[1], coords[1][1],None)
    coords = copy.deepcopy(coords)
    coords[1][0]["intervention"] = source_activations
    coords[1][1]["intervention"] = source_activations2
    return model.retrieve_activations(base, output_coord, coords)

def compute_multisource_IIT_accuracy(model, coords):
    labels = []
    predictions = []
    for base in itertools.product([pentagon, triangle, square], repeat=4):
        for source in itertools.product([pentagon, triangle, square], repeat=4):
            for source2 in itertools.product([pentagon, triangle, square], repeat=4):
                basetensor = torch.cat([torch.tensor([base[k]]) for k in range(4)], 1)
                sourcetensor = torch.cat([torch.tensor([source[k]]) for k in range(4)],1)
                sourcetensor2 = torch.cat([torch.tensor([source2[k]]) for k in range(4)],1)
                algorithm_output = compute_multisource_interchange_A(convert_input(basetensor, embedding_dim), convert_input(sourcetensor, embedding_dim),convert_input(sourcetensor2, embedding_dim))
                if algorithm_output["output"]:   
                    labels.append(TRUE_LABEL)
                else:
                    labels.append(FALSE_LABEL)
                get_coord = {"layer":3, "start":0, "end":2}
                network_output = multisource_interchange_intervention(model, basetensor, [sourcetensor,sourcetensor2], coords, get_coord).argmax(axis=1)
                predictions.append(int(network_output))
    return labels, predictions

In [20]:
sets = {1: [{"layer":1, "start":0, "end":embedding_dim},{"layer":1, "start":embedding_dim, "end":embedding_dim*2}]}
print(classification_report(*compute_multisource_IIT_accuracy(model.model, sets)))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    236196
           1       1.00      1.00      1.00    295245

    accuracy                           1.00    531441
   macro avg       1.00      1.00      1.00    531441
weighted avg       1.00      1.00      1.00    531441



In [21]:
v1data = get_IIT_equality_dataset("V1", embedding_dim ,data_size)
v2data = get_IIT_equality_dataset("V2", embedding_dim ,data_size)
bothdata = get_IIT_equality_dataset_both(embedding_dim ,data_size)
X_base_train = torch.cat([v1data[0],v2data[0], bothdata[0]], dim=0)
X_sources_train = [ torch.cat([v1data[1][0],v2data[1][0], bothdata[1][i]], dim=0) for i in range(len(bothdata[1]))] 
y_base_train = torch.cat([v1data[2],v2data[2],bothdata[2]])
y_IIT_train = torch.cat([v1data[3],v2data[3], bothdata[3]])
interventions = torch.cat([v1data[4],v2data[4], bothdata[4]])

model = TorchDeepNeuralClassifierIIT(hidden_dim=embedding_dim*4, hidden_activation=torch.nn.ReLU(), num_layers=3, id_to_coords=id_to_coords)

_ = model.fit(X_base_train, X_sources_train, y_base_train, y_IIT_train, interventions)


Stopping after epoch 440. Training loss did not improve more than tol=1e-05. Final error is 0.24750305037014186.

In [22]:
print(classification_report(*compute_multisource_IIT_accuracy(model.model, sets)))
print(classification_report(*compute_IIT_accuracy("V1", model.model)))
print(classification_report(*compute_IIT_accuracy("V2", model.model)))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00    236196
           1       1.00      1.00      1.00    295245

    accuracy                           1.00    531441
   macro avg       1.00      1.00      1.00    531441
weighted avg       1.00      1.00      1.00    531441

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2916
           1       1.00      1.00      1.00      3645

    accuracy                           1.00      6561
   macro avg       1.00      1.00      1.00      6561
weighted avg       1.00      1.00      1.00      6561

              precision    recall  f1-score   support

           0       1.00      1.00      1.00      2916
           1       1.00      1.00      1.00      3645

    accuracy                           1.00      6561
   macro avg       1.00      1.00      1.00      6561
weighted avg       1.00      1.00      1.00      6561



In [23]:
for layer in model.model.layers:
    if isinstance(layer, torch.nn.Linear):
        print(layer.state_dict()['weight'])
        print(layer.state_dict()['bias'])
    else:
        print(layer.linear.state_dict()['weight'])
        print(layer.linear.state_dict()['bias'])

tensor([[ 1.9822e-02,  1.8774e-02,  1.5825e-05,  1.0034e-02, -3.4869e-02,
          1.8639e-02,  1.1978e-02,  5.8974e-03, -1.0267e+00, -7.2273e-02,
          8.2198e-01, -6.8227e-01,  1.0258e+00,  6.5848e-02, -8.2571e-01,
          6.7370e-01],
        [ 6.6098e-01, -5.8664e-01, -2.0231e-01,  1.1299e+00, -6.7887e-01,
          5.6991e-01,  1.9586e-01, -1.1392e+00,  3.7295e-03,  2.1067e-02,
          4.8160e-03, -9.9840e-03,  4.0230e-03, -1.3166e-02, -1.8038e-02,
          2.0497e-02],
        [ 4.8361e-03, -7.8354e-03,  7.4648e-03,  1.3511e-02,  1.1528e-02,
          5.3422e-03,  2.0913e-02,  4.8559e-03,  8.6704e-01, -2.9921e-01,
         -3.9933e-02, -9.7669e-01, -8.5531e-01,  2.8625e-01,  4.3487e-02,
          9.7053e-01],
        [ 1.2590e-01, -1.0640e-01,  2.2621e-01, -2.0776e-01, -1.9550e-01,
          9.2362e-02, -2.2186e-01,  2.1222e-01, -5.3099e-01,  1.2107e-01,
         -5.3057e-01,  6.3170e-02,  5.1720e-01, -1.4435e-01,  5.4054e-01,
         -5.9443e-02],
        [ 9.1320e-02