# Counterfactual training

In [55]:
%load_ext autoreload
%autoreload 2

from itertools import product
from pprint import pprint
from tqdm import tqdm
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from antra import *

import logging
logging.getLogger().setLevel(logging.DEBUG)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Do counterfactual training

In [2]:
class BooleanLogicProgram(ComputationGraph):
    def __init__(self):
        leaf1 = GraphNode.leaf("leaf1")
        leaf2 = GraphNode.leaf("leaf2")
        leaf3 = GraphNode.leaf("leaf3")

        @GraphNode(leaf1,leaf2)
        def intermediate(x,y):
            return x & y

        @GraphNode(intermediate, leaf3)
        def root(w,v ):
            return w & v

        super().__init__(root)

class NeuralNetwork(torch.nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.lin1 = torch.nn.Linear(3,3)
        self.lin2 = torch.nn.Linear(3,1)

    def forward(self, x,y,z):
        x1 = torch.cat([x,y,z], dim=-1)
        h1 = self.lin1(x1)
        h1 = F.relu(h1)
        h2 = self.lin2(h1)
        return h2


class NeuralNetworkCompGraph(ComputationGraph):
    def __init__(self, model):
        leaf1 = GraphNode.leaf("leaf1")
        leaf2 = GraphNode.leaf("leaf2")
        leaf3 = GraphNode.leaf("leaf3")

        self.model = model

        @GraphNode(leaf1,leaf2, leaf3)
        def hidden1(x,y,z):
            # print(f"{x.shape=} {y.shape=}")
            x1 = torch.cat([x, y, z], dim=-1)
            # print(f"{a.shape=}")
            a1 = self.model.lin1(x1)
            # a1.retain_grad()
            # h = torch.matmul(a, self.model.lin1.T) + self.model.bias1
            h1 = F.relu(a1)
            return h1

        @GraphNode(hidden1)
        def root(h1):
            # print(f"{h.shape=} {z.shape=}")
            h2 = self.model.lin2(h1)
            return h2

        super().__init__(root)
#

In [27]:

def get_inputs(cache_results=False):
    low_inputs = [
        GraphInput({
            "leaf1": torch.tensor([a]),
            "leaf2": torch.tensor([b]),
            "leaf3": torch.tensor([c])
        }, cache_results=cache_results) for (a, b, c) in product((-1., 1.), repeat=3)
    ]

    high_inputs = [
        GraphInput({
            "leaf1": torch.tensor([a]),
            "leaf2": torch.tensor([b]),
            "leaf3": torch.tensor([c])
        }) for (a, b, c) in product((False, True), repeat=3)
    ]

    return low_inputs, high_inputs


@torch.no_grad()
def eval_acc(low_inputs, high_inputs, low_model, high_model, threshold=0.5):
    correct_cnt = 0
    total_cnt = 0
    for li, hi in zip(low_inputs, high_inputs):
        low_output = low_model.compute(li)
        hi_output = high_model.compute(hi)
        total_cnt += 1
        low_pred = torch.sigmoid(low_output) > threshold
        if low_pred == hi_output:
            correct_cnt += 1
#         else:
#             print(hi, low_output)
    return correct_cnt / total_cnt

#         print(f"Epoch {epoch} Loss {total_loss:.2f}")

In [59]:
def counterfactual_training(low_inputs, high_inputs, low_model, high_model,
                            lr=0.001, num_epochs=20):
    optimizer = torch.optim.Adam(low_model.model.parameters(), lr=lr)
    loss_fn = torch.nn.BCEWithLogitsLoss()

    for epoch in range(num_epochs):
        total_loss = 0.
        low_model.model.zero_grad()
        
        for i, j in product(range(len(low_inputs)), repeat=2):
            li1, hi1 = low_inputs[i], high_inputs[i]
            li2, hi2 = low_inputs[j], high_inputs[j]
            
            li2_hidden = low_model.compute_node("hidden1", li2)
            low_interv = Intervention(li1, {"hidden1[:2]": li2_hidden[:2]}, cache_results=False)
            _, logits = low_model.intervene(low_interv)

            hi2_mid = high_model.compute_node("intermediate", hi2)
            hi_interv = Intervention(hi1, {"intermediate": hi2_mid}, cache_results=False)
            _, label = high_model.intervene(hi_interv)
            label = label.to(torch.float)
            
            loss = loss_fn(logits, label)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        total_loss /= len(low_inputs) ** 2

    print(f"Final loss {total_loss:.4f}")
    return total_loss
#         print(f"Epoch {epoch} Loss {total_loss:.2f}")

def find_best_cf_model(lr=0.01, num_epochs=20):
    best_cf_model = None
    best_cf_loss = math.inf
    for seed in range(20):
        torch.manual_seed(seed)
        network = NeuralNetwork()
        low_model = NeuralNetworkCompGraph(network)
        high_model = BooleanLogicProgram()
        low_inputs, high_inputs = get_inputs()
        final_loss = counterfactual_training(low_inputs, high_inputs, low_model, high_model, lr=lr, num_epochs=num_epochs)
#         acc = eval_acc(low_inputs, high_inputs, low_model, high_model)
        if final_loss < best_cf_loss:
            best_cf_model = low_model
            best_cf_acc = final_loss
    return best_cf_model

In [60]:
cf_low_model = find_best_cf_model()

Final loss 0.0102
Final loss 0.0130
Final loss 0.1794
Final loss 0.0093
Final loss 0.2314
Final loss 0.0168
Final loss 0.1737
Final loss 0.0087
Final loss 0.0114
Final loss 0.0105
Final loss 0.2852
Final loss 0.0210
Final loss 0.1759
Final loss 0.1741
Final loss 0.1745
Final loss 0.0146
Final loss 0.1760
Final loss 0.1738
Final loss 0.2865
Final loss 0.0067


In [61]:
high_model = BooleanLogicProgram()
eval_acc(low_inputs, high_inputs, cf_low_model, high_model)

1.0

In [62]:
def train_baseline(low_inputs, high_inputs, low_model, high_model,
                   lr=0.001, num_epochs=30):
    optimizer = torch.optim.Adam(low_model.model.parameters(), lr=lr)
    loss_fn = torch.nn.BCEWithLogitsLoss()
    
    for epoch in range(num_epochs):
        total_loss = 0.
        low_model.model.zero_grad()
        for li, hi in zip(low_inputs, high_inputs):
            logits = low_model.compute(li)
            label = high_model.compute(hi)
            label = label.to(torch.float)
            
            loss = loss_fn(logits, label)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        total_loss /= len(low_inputs)

    print(f"Final loss {total_loss:.4f}")
    return total_loss
    
        
def find_best_baseline_model(lr=0.01, num_epochs=30):
    best_model = None
    best_loss = math.inf
    for seed in range(100):
        torch.manual_seed(seed)
        network = NeuralNetwork()
        low_model = NeuralNetworkCompGraph(network)
        high_model = BooleanLogicProgram()
        low_inputs, high_inputs = get_inputs()
        final_loss = train_baseline(low_inputs, high_inputs, low_model, high_model, lr=lr, num_epochs=num_epochs)
        # acc = eval_acc(low_inputs, high_inputs, low_model, high_model)
        if final_loss < best_loss:
            best_model = low_model
            best_loss = final_loss
    return best_model


In [63]:
baseline_low_model = find_best_baseline_model()

Final loss 0.1248
Final loss 0.1503
Final loss 0.1588
Final loss 0.1366
Final loss 0.2090
Final loss 0.1468
Final loss 0.0996
Final loss 0.1353
Final loss 0.1049
Final loss 0.1409
Final loss 0.1573
Final loss 0.1976
Final loss 0.1899
Final loss 0.2146
Final loss 0.0457
Final loss 0.1607
Final loss 0.1850
Final loss 0.1685
Final loss 0.2850
Final loss 0.1559
Final loss 0.1921
Final loss 0.0595
Final loss 0.0972
Final loss 0.2084
Final loss 0.1008
Final loss 0.1382
Final loss 0.0719
Final loss 0.2644
Final loss 0.0747
Final loss 0.1616
Final loss 0.0679
Final loss 0.1560
Final loss 0.2612
Final loss 0.1561
Final loss 0.1597
Final loss 0.2893
Final loss 0.1561
Final loss 0.0896
Final loss 0.1732
Final loss 0.1902
Final loss 0.3662
Final loss 0.1207
Final loss 0.0657
Final loss 0.1281
Final loss 0.1318
Final loss 0.1504
Final loss 0.1588
Final loss 0.1337
Final loss 0.1805
Final loss 0.1325
Final loss 0.1585
Final loss 0.1395
Final loss 0.2849
Final loss 0.0839
Final loss 0.2082
Final loss

In [64]:
eval_acc(low_inputs, high_inputs, baseline_low_model, high_model)

1.0

## Run interchange experiments

In [68]:
from antra.interchange import BatchedInterchange

def result_format_fxn(high_base_res, high_ivn_res, low_base_res, low_ivn_res, threshold=0.5):
#     print(low_base_res)
    lo_base_res = (torch.sigmoid(low_base_res) > threshold).item()
    lo_ivn_res = (torch.sigmoid(low_ivn_res) > threshold).item()
    hi_base_res = high_base_res.item()
    hi_ivn_res = high_ivn_res.item()
    return {
        "base_eq": hi_base_res == lo_base_res,
        "ivn_eq": hi_ivn_res == lo_ivn_res,
        "low_base_eq_ivn": lo_base_res == lo_ivn_res,
        "high_base_eq_ivn": hi_base_res == hi_ivn_res
    }

@torch.no_grad()
def interchange_expt(low_inputs, high_inputs, low_model, high_model):
    high_ivns = [
        Intervention({
            "leaf1": torch.tensor([a]),
            "leaf2": torch.tensor([b]),
            "leaf3": torch.tensor([c]),
        }, {
            "intermediate": torch.tensor([y])
        })
        for (a, b, c, y) in product((False, True), repeat=4)
    ]
    fixed_node_mapping =  {x: {x: None} for x in ["root", "leaf1",  "leaf2", "leaf3"]}
    low_nodes_to_indices = {
        "hidden1": [LOC[:, 1:], LOC[:, :2], LOC[:, 0]]
    }
    interx = BatchedInterchange(
        low_model=low_model,
        high_model=high_model,
        low_inputs=low_inputs,
        high_inputs=high_inputs,
        high_interventions=high_ivns,
        low_nodes_to_indices=low_nodes_to_indices,
        fixed_node_mapping=fixed_node_mapping,
        store_low_interventions=True,
        result_format=result_format_fxn,
        batch_size=12,
    )
    
    find_abstr_res = interx.find_abstractions()
    
    base_eq_count = 0
    total_count = 0
    denominator = 0
    numerator = 0
    failed_keys = []
    for result, mapping in find_abstr_res:
        for keys, d in result.items():
            total_count += 1
            if d["base_eq"]:
                base_eq_count += 1
            else:
                failed_keys.append(keys)
            
            if d["base_eq"]:
                denominator += 1
            
            if d["base_eq"] and d["ivn_eq"]:
                numerator += 1
        print(f"Mapping: {mapping}")
        print(f"Base accuracy of low model {base_eq_count}/{total_count}={base_eq_count / total_count : 2%}")
        print(f"Ivn success rate of low model {numerator}/{denominator}={numerator / denominator : 2%}")
    
    return failed_keys

In [69]:
failed_keys = interchange_expt(low_inputs, high_inputs, cf_low_model, high_model)

No device given, using CPU
  0%|          | 0/1 [00:00<?, ?it/s]Saw 1 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 377.53it/s]
  0%|          | 0/5 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 8 duplicates (maybe dupe examples?). Please check.
100%|██████████| 5/5 [00:00<00:00, 180.76it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 1 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 400.22it/s]
  0%|          | 0/5 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 8 duplicates (maybe dupe examples?). Please check.
100%|██████████|

end of dataloader; 7 new realizations
end of dataloader; 0 new realizations
end of dataloader; 7 new realizations
end of dataloader; 0 new realizations
end of dataloader; 7 new realizations
end of dataloader; 0 new realizations
Mapping: {'root': {'root': None}, 'leaf1': {'leaf1': None}, 'leaf2': {'leaf2': None}, 'leaf3': {'leaf3': None}, 'intermediate': {'hidden1': (slice(None, None, None), slice(1, None, None))}}
Base accuracy of low model 56/56= 100.000000%
Ivn success rate of low model 46/56= 82.142857%
Mapping: {'root': {'root': None}, 'leaf1': {'leaf1': None}, 'leaf2': {'leaf2': None}, 'leaf3': {'leaf3': None}, 'intermediate': {'hidden1': (slice(None, None, None), slice(None, 2, None))}}
Base accuracy of low model 112/112= 100.000000%
Ivn success rate of low model 102/112= 91.071429%
Mapping: {'root': {'root': None}, 'leaf1': {'leaf1': None}, 'leaf2': {'leaf2': None}, 'leaf3': {'leaf3': None}, 'intermediate': {'hidden1': (slice(None, None, None), 0)}}
Base accuracy of low model 16




In [114]:
pprint(failed_keys)

[(((('leaf1', (1.0,)), ('leaf2', (1.0,)), ('leaf3', (1.0,))),
   (('hidden1[::,1::]', (0.0017668381333351135, 0.0)),)),
  ((('leaf1', (True,)), ('leaf2', (True,)), ('leaf3', (True,))),
   (('intermediate', (False,)),))),
 (((('leaf1', (1.0,)), ('leaf2', (1.0,)), ('leaf3', (1.0,))),
   (('hidden1[::,1::]', (0.0, 0.0)),)),
  ((('leaf1', (True,)), ('leaf2', (True,)), ('leaf3', (True,))),
   (('intermediate', (False,)),))),
 (((('leaf1', (1.0,)), ('leaf2', (1.0,)), ('leaf3', (1.0,))),
   (('hidden1[::,1::]', (0.557224452495575, 0.0)),)),
  ((('leaf1', (True,)), ('leaf2', (True,)), ('leaf3', (True,))),
   (('intermediate', (False,)),))),
 (((('leaf1', (1.0,)), ('leaf2', (1.0,)), ('leaf3', (1.0,))),
   (('hidden1[::,1::]', (0.1927388310432434, 0.0)),)),
  ((('leaf1', (True,)), ('leaf2', (True,)), ('leaf3', (True,))),
   (('intermediate', (False,)),))),
 (((('leaf1', (1.0,)), ('leaf2', (1.0,)), ('leaf3', (1.0,))),
   (('hidden1[::,1::]', (0.7481964230537415, 0.0)),)),
  ((('leaf1', (True,)), 

In [128]:
interchange_expt(low_inputs, high_inputs, low_model2, high_model2)

No device given, using CPU
  0%|          | 0/1 [00:00<?, ?it/s]Saw 3 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 301.66it/s]
  0%|          | 0/4 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 4 duplicates (maybe dupe examples?). Please check.
100%|██████████| 4/4 [00:00<00:00, 235.59it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 3 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 430.32it/s]
  0%|          | 0/4 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 4 duplicates (maybe dupe examples?). Please check.
100%|██████████| 4/4 [00:00<00:00, 244.13it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 6 duplicates (maybe dupe examples?). Ple

Mapping: {'root': {'root': None}, 'leaf1': {'leaf1': None}, 'leaf2': {'leaf2': None}, 'leaf3': {'leaf3': None}, 'intermediate': {'hidden1': (slice(None, None, None), slice(1, None, None))}}
Base accuracy of low model 35/40= 87.500000%
Ivn success rate of low model 29/35= 82.857143%
Mapping: {'root': {'root': None}, 'leaf1': {'leaf1': None}, 'leaf2': {'leaf2': None}, 'leaf3': {'leaf3': None}, 'intermediate': {'hidden1': (slice(None, None, None), slice(None, 2, None))}}
Base accuracy of low model 70/80= 87.500000%
Ivn success rate of low model 58/70= 82.857143%
Mapping: {'root': {'root': None}, 'leaf1': {'leaf1': None}, 'leaf2': {'leaf2': None}, 'leaf3': {'leaf3': None}, 'intermediate': {'hidden1': (slice(None, None, None), 0)}}
Base accuracy of low model 84/96= 87.500000%
Ivn success rate of low model 69/84= 82.142857%





[(((('leaf1', (1.0,)), ('leaf2', (1.0,)), ('leaf3', (1.0,))),
   (('hidden1[::,1::]', (0.0, 0.0)),)),
  ((('leaf1', (True,)), ('leaf2', (True,)), ('leaf3', (True,))),
   (('intermediate', (False,)),))),
 (((('leaf1', (1.0,)), ('leaf2', (1.0,)), ('leaf3', (1.0,))),
   (('hidden1[::,1::]', (0.5659663081169128, 0.0)),)),
  ((('leaf1', (True,)), ('leaf2', (True,)), ('leaf3', (True,))),
   (('intermediate', (False,)),))),
 (((('leaf1', (1.0,)), ('leaf2', (1.0,)), ('leaf3', (1.0,))),
   (('hidden1[::,1::]', (0.2046559453010559, 0.0)),)),
  ((('leaf1', (True,)), ('leaf2', (True,)), ('leaf3', (True,))),
   (('intermediate', (False,)),))),
 (((('leaf1', (1.0,)), ('leaf2', (1.0,)), ('leaf3', (1.0,))),
   (('hidden1[::,1::]', (0.7891774773597717, 0.25335538387298584)),)),
  ((('leaf1', (True,)), ('leaf2', (True,)), ('leaf3', (True,))),
   (('intermediate', (True,)),))),
 (((('leaf1', (1.0,)), ('leaf2', (1.0,)), ('leaf3', (1.0,))),
   (('hidden1[::,1::]', (0.1846955418586731, 0.19087612628936768))