# Counterfactual training

In [1]:
%load_ext autoreload
%autoreload 2

from itertools import product
from pprint import pprint
from tqdm import tqdm
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from antra import *
import antra.location as location

from antra.interchange.mapping import mapping_to_string

import logging
logging.getLogger("antra.interchange.batched").setLevel(logging.INFO)

## Define compgraph structures for counterfactual training

We use a simple boolean logic program where the input is x, y, z and the output is x & y & z.

The high-level algorithm computes x & y & z as (x & y) & z, where the value of x & y is stored in an intermediate variable.

The neural network model takes in three real values corresponding to x, y and z: with `-1.` representing `False` and `1.` representing `True`. It is a simple MLP, where the root outputs logits with positive values as predicting True.



In [2]:
class BooleanLogicProgram(ComputationGraph):
    def __init__(self):
        leaf1 = GraphNode.leaf("leaf1")
        leaf2 = GraphNode.leaf("leaf2")
        leaf3 = GraphNode.leaf("leaf3")

        @GraphNode(leaf1,leaf2)
        def intermediate(x,y):
            return x & y

        @GraphNode(intermediate, leaf3)
        def root(w,v ):
            return w & v

        super().__init__(root)

class NeuralNetwork(torch.nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.lin1 = torch.nn.Linear(3,3)
        self.lin2 = torch.nn.Linear(3,1)

    def forward(self, x,y,z):
        x1 = torch.cat([x,y,z], dim=-1)
        h1 = self.lin1(x1)
        h1 = F.relu(h1)
        h2 = self.lin2(h1)
        return h2


class NeuralNetworkCompGraph(ComputationGraph):
    def __init__(self, model):
        leaf1 = GraphNode.leaf("leaf1")
        leaf2 = GraphNode.leaf("leaf2")
        leaf3 = GraphNode.leaf("leaf3")

        self.model = model

        @GraphNode(leaf1,leaf2, leaf3)
        def hidden1(x,y,z):
            # print(f"{x.shape=} {y.shape=}")
            x1 = torch.cat([x, y, z], dim=-1)
            # print(f"{a.shape=}")
            a1 = self.model.lin1(x1)
            # a1.retain_grad()
            # h = torch.matmul(a, self.model.lin1.T) + self.model.bias1
            h1 = F.relu(a1)
            return h1

        @GraphNode(hidden1)
        def root(h1):
            # print(f"{h.shape=} {z.shape=}")
            h2 = self.model.lin2(h1)
            return h2

        super().__init__(root)
#

## Util functions

In [3]:

def get_inputs(cache_results=False):
    low_inputs = [
        GraphInput({
            "leaf1": torch.tensor([a]),
            "leaf2": torch.tensor([b]),
            "leaf3": torch.tensor([c])
        }, cache_results=cache_results) for (a, b, c) in product((-1., 1.), repeat=3)
    ]

    high_inputs = [
        GraphInput({
            "leaf1": torch.tensor([a]),
            "leaf2": torch.tensor([b]),
            "leaf3": torch.tensor([c])
        }) for (a, b, c) in product((False, True), repeat=3)
    ]

    return low_inputs, high_inputs


@torch.no_grad()
def eval_acc(low_inputs, high_inputs, low_model, high_model, threshold=0.5):
    correct_cnt = 0
    total_cnt = 0
    for li, hi in zip(low_inputs, high_inputs):
        low_output = low_model.compute(li)
        hi_output = high_model.compute(hi)
        total_cnt += 1
        low_pred = torch.sigmoid(low_output) > threshold
        if low_pred == hi_output:
            correct_cnt += 1
#         else:
#             print(hi, low_output)
    return correct_cnt / total_cnt

#         print(f"Epoch {epoch} Loss {total_loss:.2f}")

## Functions for running counterfactual training

In [4]:
def counterfactual_training(low_inputs, high_inputs, low_model, high_model,
                            lr=0.001, num_epochs=20):
    optimizer = torch.optim.Adam(low_model.model.parameters(), lr=lr)
    loss_fn = torch.nn.BCEWithLogitsLoss()

    for epoch in range(num_epochs):
        total_loss = 0.
        low_model.model.zero_grad()
        
        for i, j in product(range(len(low_inputs)), repeat=2):
            li1, hi1 = low_inputs[i], high_inputs[i]
            li2, hi2 = low_inputs[j], high_inputs[j]
            
            li2_hidden = low_model.compute_node("hidden1", li2)
            low_interv = Intervention(li1, {"hidden1[:2]": li2_hidden[:2]}, cache_results=False)
            _, logits = low_model.intervene(low_interv)

            hi2_mid = high_model.compute_node("intermediate", hi2)
            hi_interv = Intervention(hi1, {"intermediate": hi2_mid}, cache_results=False)
            _, label = high_model.intervene(hi_interv)
            label = label.to(torch.float)
            
            loss = loss_fn(logits, label)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        total_loss /= len(low_inputs) ** 2
    return total_loss
#         print(f"Epoch {epoch} Loss {total_loss:.2f}")

def train_cf_models(seeds=100, lr=0.01, num_epochs=20):
    best_cf_models = []

    for seed in tqdm(range(100)):
        torch.manual_seed(seed)
        network = NeuralNetwork()
        low_model = NeuralNetworkCompGraph(network)
        high_model = BooleanLogicProgram()
        low_inputs, high_inputs = get_inputs()
        final_loss = counterfactual_training(low_inputs, high_inputs, low_model, high_model, lr=lr, num_epochs=num_epochs)
        acc = eval_acc(low_inputs, high_inputs, low_model, high_model)
        if acc == 1.0:
            best_cf_models.append(low_model)
#         if final_loss < best_cf_loss:
#             best_cf_model = low_model
#             best_cf_acc = final_loss
    return best_cf_models

In [5]:
cf_low_models = train_cf_models()

100%|██████████| 100/100 [02:41<00:00,  1.62s/it]


In [6]:
print(len(cf_low_models))

52


In [7]:
low_inputs, high_inputs = get_inputs()
high_model = BooleanLogicProgram()
eval_acc(low_inputs, high_inputs, cf_low_model, high_model)

1.0

## Train a baseline nn model only on low inputs

In [9]:
def train_baseline(low_inputs, high_inputs, low_model, high_model,
                   lr=0.001, num_epochs=30):
    optimizer = torch.optim.Adam(low_model.model.parameters(), lr=lr)
    loss_fn = torch.nn.BCEWithLogitsLoss()
    
    for epoch in range(num_epochs):
        total_loss = 0.
        low_model.model.zero_grad()
        for li, hi in zip(low_inputs, high_inputs):
            logits = low_model.compute(li)
            label = high_model.compute(hi)
            label = label.to(torch.float)
            
            loss = loss_fn(logits, label)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        total_loss /= len(low_inputs)

#     print(f"Final loss {total_loss:.4f}")
    return total_loss
    
        
def train_baseline_models(seeds=100, lr=0.01, num_epochs=30):
    baseline_models = []

    for seed in tqdm(range(seeds)):
        torch.manual_seed(seed)
        network = NeuralNetwork()
        low_model = NeuralNetworkCompGraph(network)
        high_model = BooleanLogicProgram()
        low_inputs, high_inputs = get_inputs()
        final_loss = train_baseline(low_inputs, high_inputs, low_model, high_model, lr=lr, num_epochs=num_epochs)
        acc = eval_acc(low_inputs, high_inputs, low_model, high_model)
        if acc == 1.0:
            baseline_models.append(low_model)
    return baseline_models


In [10]:
baseline_low_models = train_baseline_models()
print(len(baseline_low_models))

100%|██████████| 100/100 [00:17<00:00,  5.61it/s]

20





In [10]:
low_inputs, high_inputs = get_inputs()
high_model = BooleanLogicProgram()
eval_acc(low_inputs, high_inputs, baseline_low_model, high_model)

1.0

## Run interchange experiments

In [11]:
from antra.interchange import BatchedInterchange
import numpy as np

all_low_hidden1_locs =  [LOC[:, 0], LOC[:, 1], LOC[:, 2], LOC[:, 1:], LOC[:, :2], LOC[:, :]]

def result_format_fxn(high_base_res, high_ivn_res, low_base_res, low_ivn_res, threshold=0.5):
    lo_base_res = (torch.sigmoid(low_base_res) > threshold).item()
    lo_ivn_res = (torch.sigmoid(low_ivn_res) > threshold).item()
    hi_base_res = high_base_res.item()
    hi_ivn_res = high_ivn_res.item()
    return {
        "base_eq": hi_base_res == lo_base_res,
        "ivn_eq": hi_ivn_res == lo_ivn_res,
        "low_base_eq_ivn": lo_base_res == lo_ivn_res,
        "high_base_eq_ivn": hi_base_res == hi_ivn_res
    }

@torch.no_grad()
def interchange_expt(low_inputs, high_inputs, low_model, high_model):
    high_ivns = [
        Intervention({
            "leaf1": torch.tensor([a]),
            "leaf2": torch.tensor([b]),
            "leaf3": torch.tensor([c]),
        }, {
            "intermediate": torch.tensor([y])
        })
        for (a, b, c, y) in product((False, True), repeat=4)
    ]
    fixed_nodes = ["root", "leaf1", "leaf2", "leaf3"]
    fixed_node_mapping =  {x: {x: None} for x in fixed_nodes}
    low_nodes_to_indices = {
        "hidden1": all_low_hidden1_locs
    }
    interx = BatchedInterchange(
        low_model=low_model,
        high_model=high_model,
        low_inputs=low_inputs,
        high_inputs=high_inputs,
        high_interventions=high_ivns,
        low_nodes_to_indices=low_nodes_to_indices,
        fixed_node_mapping=fixed_node_mapping,
        store_low_interventions=True,
        result_format=result_format_fxn,
        batch_size=12,
    )
    
    find_abstr_res = interx.find_abstractions()
    
    base_eq_count = 0
    total_count = 0
    denominator = 0
    numerator = 0
    failed_keys = []
    
    low_loc_to_success_rate = {}
    for result, mapping in find_abstr_res:
        for keys, d in result.items():
            total_count += 1
            if d["base_eq"]:
                base_eq_count += 1
            else:
                failed_keys.append(keys)
            
            if d["base_eq"] and not d["high_base_eq_ivn"]:
                denominator += 1
            
            if d["base_eq"] and d["ivn_eq"] and not d["high_base_eq_ivn"]:
                numerator += 1
                
        ser_low_loc = location.serialize_location(mapping["intermediate"]["hidden1"])
        low_loc_to_success_rate[ser_low_loc] = numerator / denominator
        # print("\nMapping:")
        # print(mapping_to_string(mapping, ignore_nodes=fixed_nodes))
        # print(f"Base accuracy of low model {base_eq_count}/{total_count}={base_eq_count / total_count : 2%}")
        # print(f"Ivn success rate of low model {numerator}/{denominator}={numerator / denominator : 2%}")
    
    return low_loc_to_success_rate

def run_interchange_expts(low_models):
    low_inputs, high_inputs = get_inputs()
    high_model = BooleanLogicProgram()
    stats = {
        location.serialize_location(l): [] for l in all_low_hidden1_locs
    }
    
    for low_model in low_models:
        low_loc_to_success_rate = interchange_expt(low_inputs, high_inputs, low_model, high_model)
        for ser_loc, success_rate in low_loc_to_success_rate.items():
            stats[ser_loc].append(success_rate)
    
    stats = {
        ser_loc: (np.mean(np.array(rates)), np.std(np.array(rates)))
        for ser_loc, rates in stats.items()
    }
    return stats
        

In [12]:
cf_low_models_to_test = cf_low_models[:20]
baseline_low_models_to_test = baseline_low_models[:20]

In [13]:
cf_stats = run_interchange_expts(cf_low_models_to_test)

No device given, using CPU
  0%|          | 0/1 [00:00<?, ?it/s]Saw 1 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 175.19it/s]
  0%|          | 0/5 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 8 duplicates (maybe dupe examples?). Please check.
100%|██████████| 5/5 [00:00<00:00, 179.69it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 1 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 367.66it/s]
  0%|          | 0/5 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 8 duplicates (maybe dupe examples?). Please check.
100%|██████████|

100%|██████████| 1/1 [00:00<00:00, 376.51it/s]
  0%|          | 0/5 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 8 duplicates (maybe dupe examples?). Please check.
100%|██████████| 5/5 [00:00<00:00, 177.54it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 0 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 383.57it/s]
  0%|          | 0/6 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 4 duplicates (maybe dupe examples?). Please check.
100%|██████████| 6/6 [00:00<00:00, 214.73it/s]
No device given, using CPU
  0%|

100%|██████████| 1/1 [00:00<00:00, 423.03it/s]
  0%|          | 0/4 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
100%|██████████| 4/4 [00:00<00:00, 192.28it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 0 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 338.93it/s]
  0%|          | 0/6 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 4 duplicates (maybe dupe examples?). Please check.
100%|██████████| 6/6 [00:00<00:00, 193.04it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 1 duplicates (maybe dupe examples?). Please ch

100%|██████████| 1/1 [00:00<00:00, 390.35it/s]
  0%|          | 0/5 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 8 duplicates (maybe dupe examples?). Please check.
100%|██████████| 5/5 [00:00<00:00, 200.66it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 2 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 393.68it/s]
  0%|          | 0/4 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
100%|██████████| 4/4 [00:00<00:00, 193.41it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 2 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 433.12it/s]
  0%

Saw 8 duplicates (maybe dupe examples?). Please check.
100%|██████████| 5/5 [00:00<00:00, 199.66it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 0 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 402.37it/s]
  0%|          | 0/6 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 4 duplicates (maybe dupe examples?). Please check.
100%|██████████| 6/6 [00:00<00:00, 203.25it/s]
No device given, using CPU
  0%|          | 0/1 [00:00<?, ?it/s]Saw 1 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 403.92it/s]
  0%|          | 0/5 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe 

100%|██████████| 1/1 [00:00<00:00, 264.56it/s]
  0%|          | 0/6 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 4 duplicates (maybe dupe examples?). Please check.
100%|██████████| 6/6 [00:00<00:00, 146.62it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 1 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 428.43it/s]
  0%|          | 0/5 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 8 duplicates (maybe dupe examples?). Please check.
100%|██████████| 5/5 [00:00<00:00, 201.62it/s]
  0%|          | 0/1 [00:00<?, ?

100%|██████████| 5/5 [00:00<00:00, 196.52it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 2 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 306.42it/s]
  0%|          | 0/4 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
100%|██████████| 4/4 [00:00<00:00, 163.29it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 0 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 407.02it/s]
  0%|          | 0/6 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 4 duplicates (maybe dupe examples?). Please ch

100%|██████████| 2/2 [00:00<00:00, 238.56it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 1 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 416.18it/s]
  0%|          | 0/5 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 8 duplicates (maybe dupe examples?). Please check.
100%|██████████| 5/5 [00:00<00:00, 186.36it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 2 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 430.10it/s]
  0%|          | 0/4 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
100%|██████████| 4/4 [00:00<00:00, 186.65it/s]
  0%

In [14]:
def print_interx_success_rate_stats(stats):
    for ser_loc, (mean, std) in stats.items():
        loc = location.deserialize_location(ser_loc)
        str_loc = location.location_to_str(loc, add_brackets=True)
        print(f"hidden1{str_loc}, mean={mean:.4f}, std={std:.4f}")


In [15]:
print_interx_success_rate_stats(cf_stats)

hidden1[::,0], mean=0.5479, std=0.3019
hidden1[::,1], mean=0.6256, std=0.0495
hidden1[::,2], mean=0.4922, std=0.0377
hidden1[::,1::], mean=0.4989, std=0.0539
hidden1[::,:2:], mean=0.5949, std=0.0440
hidden1[::,::], mean=0.6263, std=0.0353


In [16]:
baseline_stats = run_interchange_expts(baseline_low_models_to_test)

No device given, using CPU
  0%|          | 0/1 [00:00<?, ?it/s]Saw 1 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 281.55it/s]
  0%|          | 0/5 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 8 duplicates (maybe dupe examples?). Please check.
100%|██████████| 5/5 [00:00<00:00, 174.32it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 2 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 201.24it/s]
  0%|          | 0/4 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
100%|██████████| 4/4 [00:00<00:00, 170.17it/s]
  0%|          | 0/1 [00

  0%|          | 0/1 [00:00<?, ?it/s]Saw 0 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 338.58it/s]
  0%|          | 0/6 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 4 duplicates (maybe dupe examples?). Please check.
100%|██████████| 6/6 [00:00<00:00, 168.02it/s]
No device given, using CPU
  0%|          | 0/1 [00:00<?, ?it/s]Saw 2 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 321.95it/s]
  0%|          | 0/4 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
100%|██████████

Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 4 duplicates (maybe dupe examples?). Please check.
100%|██████████| 4/4 [00:00<00:00, 202.39it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 0 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 390.93it/s]
  0%|          | 0/6 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 4 duplicates (maybe dupe examples?). Please check.
100%|██████████| 6/6 [00:00<00:00, 180.37it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 0 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 401.41it/s]
  0%|          | 0/6 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please che

  0%|          | 0/1 [00:00<?, ?it/s]Saw 3 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 302.73it/s]
  0%|          | 0/4 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 4 duplicates (maybe dupe examples?). Please check.
100%|██████████| 4/4 [00:00<00:00, 213.09it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 2 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 393.17it/s]
  0%|          | 0/4 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
100%|██████████| 4/4 [00:00<00:00, 171.89it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 0 duplicates (maybe dupe examples?). Please check.
100%|██████████

100%|██████████| 4/4 [00:00<00:00, 203.26it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 0 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 436.09it/s]
  0%|          | 0/6 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 4 duplicates (maybe dupe examples?). Please check.
100%|██████████| 6/6 [00:00<00:00, 189.45it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 0 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 428.47it/s]
  0%|          | 0/6 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please ch

100%|██████████| 1/1 [00:00<00:00, 406.98it/s]
  0%|          | 0/4 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
100%|██████████| 4/4 [00:00<00:00, 185.67it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 2 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 447.06it/s]
  0%|          | 0/4 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
100%|██████████| 4/4 [00:00<00:00, 188.16it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 4 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 305.04it/s]
  0%|          | 0/3 [00:00<?, ?it/s]Saw 12 duplicates (may

Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 4 duplicates (maybe dupe examples?). Please check.
100%|██████████| 4/4 [00:00<00:00, 194.07it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 4 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 386.46it/s]
  0%|          | 0/3 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 8 duplicates (maybe dupe examples?). Please check.
100%|██████████| 3/3 [00:00<00:00, 202.29it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 1 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 333.33it/s]
  0%|          | 0/5 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 8 duplicates (maybe dupe examples?). Please chec

  0%|          | 0/2 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
100%|██████████| 2/2 [00:00<00:00, 189.46it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 3 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 421.28it/s]
  0%|          | 0/4 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 4 duplicates (maybe dupe examples?). Please check.
100%|██████████| 4/4 [00:00<00:00, 205.68it/s]
  0%|          | 0/1 [00:00<?, ?it/s]Saw 3 duplicates (maybe dupe examples?). Please check.
100%|██████████| 1/1 [00:00<00:00, 351.61it/s]
  0%|          | 0/4 [00:00<?, ?it/s]Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 12 duplicates (maybe dupe examples?). Please check.
Saw 4 duplicat

In [17]:
print_interx_success_rate_stats(baseline_stats)

hidden1[::,0], mean=0.3400, std=0.1593
hidden1[::,1], mean=0.3394, std=0.1073
hidden1[::,2], mean=0.3431, std=0.0647
hidden1[::,1::], mean=0.3936, std=0.0711
hidden1[::,:2:], mean=0.4193, std=0.0662
hidden1[::,::], mean=0.4802, std=0.0587
