In [1]:
%load_ext autoreload
%autoreload 2

# A Demonstration of intervention experiments with `ComputationGraph`

## Setup

Python modules for constructing a `ComputationGraph` and doing intervention experiments are in Atticus' Github repo `https://github.com/atticusg/Interchange`. To include the intervention modules here, you can make a local install of the `intervension` package:

```
$ cd path/to/Interchange
$ pip install -e .
```

## A simple feed-forward network for determining equality

In `torch_equality.py` I wrote a `TorchEqualityModel` that is basically a pytorch replication of the `MLPClassifier` model in `equality_experiment.py` which is implemented in scikit-learn.

The `TorchEqualityModel` wraps around `TorchEqualityModule` which is a subclass of `torch.nn.Module`. The former contains training and prediction functionalities, similar to its scikit-learn counterpart.

I use `TorchEqualityModel` and `TorchEqualityModule` to demonstrate our intervention model.

**because of some wierd typing issue with jupyter notebook, it is suggested that if you make any changes to code in the `intervention` module, use *restart and run all* in the notebook.**

In [1]:
# Training the TorchEqualityModel

from torch_equality import *

embed_dim = 10
max_epochs = 100
hidden_dim = 100
train_size = 2000
test_size = 500
alpha = 0.001
lr = 0.01

train_dataset = TorchEqualityDataset(embed_dim=embed_dim, n_pos=train_size//2, n_neg=train_size//2)
test_dataset = TorchEqualityDataset(embed_dim=embed_dim, n_pos=test_size//2, n_neg=test_size//2)

train_dataset.test_disjoint(test_dataset)
model = TorchEqualityModel(max_epochs=max_epochs,
                           input_size=embed_dim*2,
                           batch_size=1000,
                           hidden_layer_size=hidden_dim,
                           alpha=alpha,
                           lr=lr,
                           gpu=True)

model.fit(train_dataset)


## Explicitly defining a computation graph

A computation graph can be defined manually, as shown below.

In [3]:
# from intervention import ComputationGraph, GraphNode

# class TorchEqualityCompGraph(ComputationGraph):
#     def __init__(self, model):
#         assert isinstance(model, TorchEqualityModel)
#         self.model = model
#         self.module = model.module

#         @GraphNode()
#         def linear(x):
#             # preprocess inputs here
#             x = x.float().to(self.model.device)
#             return self.module.linear(x)

#         @GraphNode(linear)
#         def activation(x):
#             return self.module.activation(x)

#         @GraphNode(activation)
#         def logits(x):
#             return self.module.output(x)

#         @GraphNode(logits)
#         def root(x):
#             scores = self.module.sigmoid(x)
#             return [1 if z >= 0.5 else 0 for z in scores]

#         super().__init__(root)
        
# g = TorchEqualityCompGraph(model)

## Automatically constructing a computation graph

It can also be automatically extracted given an instance of a `torch.nn.Module`, using the `CompGraphConstructor`.

The computation graph is constructed dynamically, and an input instance is required.

In [4]:
from intervention import CompGraphConstructor, GraphInput, Intervention

module = model.module
input_0 = torch.tensor(test_dataset.X[0])
input_0 = input_0.float()

print(input_0)
g, input_0 = CompGraphConstructor.construct(module, input_0, device=model.device)
print(type(input_0))


tensor([-0.2604, -0.1911,  0.2648,  0.4248, -0.0555, -0.3701,  0.0138, -0.3214,
        -0.2608, -0.3415, -0.2604, -0.1911,  0.2648,  0.4248, -0.0555, -0.3701,
         0.0138, -0.3214, -0.2608, -0.3415])
current_input in make_graph None
I am in module linear I have 1 inputs
I am in module activation I have 1 inputs
I am in module output I have 1 inputs
I am in module sigmoid I have 1 inputs
<class 'intervention.computation_graph.GraphInput'>


## Intervene on entire tensor

In [5]:

# input_0 = GraphInput({"linear": torch.tensor(test_dataset.X[0]).float()}, device=model.device)
input_3 = GraphInput({"linear": torch.tensor(test_dataset.X[3]).float()}, device=model.device)

res_0 = g.compute(input_0)
res_3 = g.compute(input_3)

print("res_0:", res_0, "res_3:", res_3)

# Use input_0 as input, but set the result of "activation" node to that of input_3
interv_3_0 = Intervention(base=input_0, intervention={"activation": g.get_result("activation", input_3)})

print("--- intervene, type of interve_3_0", type(interv_3_0))
before, after = g.intervene(interv_3_0)

print("input_0 before intervention %s, after intervention %s" % (before, after))

check type of inputs in compute <class 'intervention.computation_graph.GraphInput'>
check type of inputs in compute <class 'intervention.computation_graph.GraphInput'>
res_0: tensor([0.9661], device='cuda:0', grad_fn=<SigmoidBackward>) res_3: tensor([0.0117], device='cuda:0', grad_fn=<SigmoidBackward>)
--- intervene, type of interve_3_0 <class 'intervention.computation_graph.Intervention'>
input_0 before intervention tensor([0.9661], device='cuda:0', grad_fn=<SigmoidBackward>), after intervention tensor([0.0117], device='cuda:0', grad_fn=<SigmoidBackward>)


## Intervene on a part of a tensor

We can specify which part of a tensor we would like to intervene by adding indexing after a node's name.

In [6]:
in2 = Intervention(input_0)

replace_value = g.get_result("activation", input_3)[:10]
in2.intervention = {"activation[:10]": replace_value}

before, after = g.intervene(in2)
print("input_0 before intervention %s, after intervention %s" % (before, after))

input_0 before intervention tensor([0.9661], device='cuda:0', grad_fn=<SigmoidBackward>), after intervention tensor([0.9330], device='cuda:0', grad_fn=<SigmoidBackward>)


## Another way to intervene on part of a tensor

Sometimes we would like more flexibility when intervening on a part of a tensor, for instance if the intervention location in a tensor changes dynamically. 

The way to do this is to provide a `dict` to the `locs` parameter of the `Intervention()` constructor. The keys in the dict are `str`s for node names, and values being any of the following:

1. `LOC[...]`. A way to specify the indexing in the exact same way you would index into a tensor, but not actually getting the values at a tensor location.
2. Indexing in a string form
3. an `int` for single elements in a tensor.
4. a tuple with `int`s, `slice` objects, and/or `Ellipsis` objects.

For example the following are equivalent:

```
intervention = {"node_a[5]": value_a,
                "node_b[:10,:,10:]": value_b,
                "node_c[:5,...]": value_c,
                "node_d[5:10]": value_d}
interv = Intervention(base=some_base, intervention=intervention)
```
and 
```
intervention = {"node_a": value_a,
                "node_b": value_b,
                "node_c": value_c,
                "node_d": value_d}
locs = {"node_a": 5,
        "node_b": LOC[:10,:,10:],
        "node_c": ":5,...",
        "node_d": slice(5,10)}
interv = Intervention(base=some_base, intervention=intervention, locs=locs)
```

Another example is shown below.

In [7]:
from intervention import LOC

print(LOC[1:, 2, 3, ...])

act3 = g.get_result("activation", input_3)
step = 20
for i in range(0, 100, step):
    replace_value = act3[i:i+step]
    intervention = {"activation": replace_value}
    locs = {"activation": LOC[i:i+step]}
    interv = Intervention(base=input_0, intervention=intervention, locs=locs)
    
    before, after = g.intervene(in2)
    print("Replace indices %d to %d with values from `act3`. Before %s, after %s" % (i, i+step, before, after))
    

(slice(1, None, None), 2, 3, Ellipsis)
Replace indices 0 to 20 with values from `act3`. Before tensor([0.9661], device='cuda:0', grad_fn=<SigmoidBackward>), after tensor([0.9330], device='cuda:0', grad_fn=<SigmoidBackward>)
Replace indices 20 to 40 with values from `act3`. Before tensor([0.9661], device='cuda:0', grad_fn=<SigmoidBackward>), after tensor([0.9330], device='cuda:0', grad_fn=<SigmoidBackward>)
Replace indices 40 to 60 with values from `act3`. Before tensor([0.9661], device='cuda:0', grad_fn=<SigmoidBackward>), after tensor([0.9330], device='cuda:0', grad_fn=<SigmoidBackward>)
Replace indices 60 to 80 with values from `act3`. Before tensor([0.9661], device='cuda:0', grad_fn=<SigmoidBackward>), after tensor([0.9330], device='cuda:0', grad_fn=<SigmoidBackward>)
Replace indices 80 to 100 with values from `act3`. Before tensor([0.9661], device='cuda:0', grad_fn=<SigmoidBackward>), after tensor([0.9330], device='cuda:0', grad_fn=<SigmoidBackward>)
