In [1]:
%load_ext autoreload
%autoreload 2

# Test `TorchEqualityExperiment()`

In [2]:
from torch_equality import *

params = dict(
    embed_dims=[10],
    hidden_dims=[100],
    alphas=[0.001],
    learning_rates=[0.01],
    n_trials=2,
    train_sizes=list(range(1004, 2005, 1000))
)

experiment = TorchEqualityExperiment(dataset_class=TorchEqualityDataset, **params)

In [3]:
experiment.run()

Grid size: 1 * 2; 2 experiments
Running trials for embed_dim=10 hidden_dim=100 alpha=0.001 lr=0.01 ... mean: 0.83; max: 0.998; took 6.0 secs


Unnamed: 0,trial,train_size,embed_dim,hidden_dim,alpha,learning_rate,accuracy,batch_pos,batch_neg
0,1,0,10,100,0.001,0.01,0.476,0,0
1,1,1004,10,100,0.001,0.01,0.99,503,501
2,1,2004,10,100,0.001,0.01,0.996,1002,1002
3,2,0,10,100,0.001,0.01,0.514,0,0
4,2,1004,10,100,0.001,0.01,0.99,498,506
5,2,2004,10,100,0.001,0.01,0.998,1002,1002


# Train model for intervention

In [4]:
embed_dim = 10
max_epochs = 100
hidden_dim = 100
train_size = 2000
test_size = 500
alpha = 0.001
lr = 0.01

train_dataset = TorchEqualityDataset(embed_dim=embed_dim, n_pos=train_size//2, n_neg=train_size//2)
test_dataset = TorchEqualityDataset(embed_dim=embed_dim, n_pos=test_size//2, n_neg=test_size//2)

train_dataset.test_disjoint(test_dataset)
model = TorchEqualityModel(max_epochs=max_epochs,
                           input_size=embed_dim*2,
                           batch_size=1000,
                           hidden_layer_size=hidden_dim,
                           alpha=alpha,
                           lr=lr,
                           gpu=True)

model.fit(train_dataset)


# Test `TorchEqualityIntervention`

In [5]:
print(test_dataset.X[0], test_dataset.y[0])
print(test_dataset.X[3], test_dataset.y[3])

[ 0.43706961 -0.4827849   0.37416355 -0.28851521  0.24289165  0.19498498
 -0.26960545 -0.21563691  0.25949174 -0.01998772 -0.26968687 -0.46189744
  0.32713884 -0.13627835 -0.39339898  0.29566419  0.22301776 -0.27912615
  0.18120715 -0.36316504] 0
[ 0.34292536 -0.30845335 -0.39574798  0.09547243  0.45492058  0.1343681
  0.20372263 -0.14990982 -0.4816382   0.05901701  0.17036996 -0.31355334
 -0.40446764 -0.2006948   0.15480174  0.26021808 -0.36356341 -0.09771308
  0.20747919  0.25207756] 0


In [6]:
from intervention_interface import *

intervention = TorchEqualityIntervention(model)

In [7]:
res_3 = intervention.run(torch.tensor(test_dataset.X[3]))
print(res_3)
res_0 = intervention.run(torch.tensor(test_dataset.X[0]))
print(res_0)

[0]
[0]


In [8]:
res_3 = intervention.run(torch.tensor(test_dataset.X[3]))
print(intervention.get_from_cache("hidden_vec"))
res_intervention = intervention.fix_and_run("hidden_vec", torch.tensor(test_dataset.X[0]))
print(res_3, res_intervention)

tensor([1.7307e-02, 0.0000e+00, 0.0000e+00, 1.4745e-02, 1.1599e-01, 0.0000e+00,
        0.0000e+00, 4.7905e-01, 2.8397e-01, 2.0213e-01, 3.2581e-01, 8.1380e-02,
        4.0340e-03, 4.0865e-01, 3.2057e-01, 0.0000e+00, 2.0551e-01, 0.0000e+00,
        8.4007e-01, 4.3933e-01, 5.4185e-02, 0.0000e+00, 1.0326e-01, 1.1600e-01,
        0.0000e+00, 0.0000e+00, 1.0331e-01, 0.0000e+00, 1.1047e-01, 0.0000e+00,
        1.7635e-01, 5.3663e-01, 1.0845e-01, 2.5078e-01, 9.1275e-02, 1.1915e-01,
        6.5898e-01, 0.0000e+00, 0.0000e+00, 2.1667e-01, 0.0000e+00, 0.0000e+00,
        9.7042e-02, 0.0000e+00, 4.2061e-02, 2.9561e-01, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 2.1148e-01, 3.2209e-01, 2.5021e-01, 0.0000e+00, 4.2517e-01,
        3.0799e-01, 0.0000e+00, 0.0000e+00, 1.8936e-02, 3.5749e-01, 2.0232e-01,
        0.0000e+00, 0.0000e+00, 1.7303e-01, 3.6376e-02, 7.9375e-02, 3.7546e-01,
        9.3856e-02, 1.3494e-01, 0.0000e+00, 0.0000e+00, 3.3927e-01, 7.0746e-02,
        0.0000e+00, 6.4800e-02, 2.5557e-

# Test newest `ComputationGraph`

In [9]:
# I've added a setup.py in the Interchange repo, so we can install the `intervention` package locally using
#   $ pip install -e path/to/Interchange

from intervention import ComputationGraph, GraphNode

class TorchEqualityCompGraph(ComputationGraph):
    def __init__(self, model):
        assert isinstance(model, TorchEqualityModel)
        self.model = model
        self.module = model.module

        @GraphNode()
        def linear(x):
            # preprocess inputs here
            x = x.float().to(self.model.device)
            return self.module.linear(x)

        @GraphNode(linear)
        def activation(x):
            return self.module.activation(x)

        @GraphNode(activation)
        def logits(x):
            return self.module.output(x)

        @GraphNode(logits)
        def root(x):
            scores = self.module.sigmoid(x)
            return [1 if z >= 0.5 else 0 for z in scores]

        super().__init__(root)

## Intervene on entire tensor

In [10]:
from intervention import GraphInput, Intervention

g = TorchEqualityCompGraph(model)

input_0 = GraphInput({"linear": torch.tensor(test_dataset.X[0])})
input_3 = GraphInput({"linear": torch.tensor(test_dataset.X[3])})

res_0 = g.compute(input_0)
res_3 = g.compute(input_3)

print("res_0:", res_0, "res_3:", res_3)

# Use input_0 as input, but set the result of "activation" node to that of input_3
interv_3_0 = Intervention(input_0, {"activation": g.get_result("activation", input_3)})

before, after = g.intervene(interv_3_0)

print("input_0 before intervention %s, after intervention %s" % (before, after))

res_0: [0] res_3: [0]
input_0 before intervention [0], after intervention [0]


## Intervene on a part of a tensor

We can specify which part of a tensor we would like to intervene by adding indexing after a node's name.

In [11]:
in2 = Intervention(input_0)

replace_value = g.get_result("activation", input_3)[:10]
inputs = {"activation[:10]": replace_value}
in2.inputs = inputs

before, after = g.intervene(in2)
print("input_0 before intervention %s, after intervention %s" % (before, after))

input_0 before intervention [0], after intervention [0]


## Using the `LOC` object to represent the indexing for an intervention location

We can also specify where to intervene by passing an additional `locs` dictionary, with keys being node names, and  values being any of the following:

1. `LOC[...]`. `LOC` is a special object that when applied the bracket indexing method `[]`, simply returns the underlying  object denoted by whatever is in the brackets.
2. Indexing in a string form
3. an `int` for single elements
4. a tuple with `int`s, `slice` objects, or 

In [13]:
from intervention import LOC

act3 = g.get_result("activation", input_3)
for i in range(0, 100, 20):
    replace_value = act3[i:i+20]
    inputs = {"activation": replace_value}
    locs = {"activation": LOC[i:i+20]}
    interv = Intervention(base=input_0, inputs=inputs, locs=locs)
    
    before, after = g.intervene(in2)
    print("input_0 before intervention %s, after intervention %s" % (before, after))
    

input_0 before intervention [0], after intervention [0]
input_0 before intervention [0], after intervention [0]
input_0 before intervention [0], after intervention [0]
input_0 before intervention [0], after intervention [0]
input_0 before intervention [0], after intervention [0]
