# Proxy Equalizer -- Tutorial

## TODOs

### General Todos

* Go through everything and make sure order of variables in graphs/samples/stacking/networks etc. is maintained/fixed somehow and not scrambled up by set/dictionary conversions at some point.
* Freezing the layers in `interventions._copy_and_freeze()` is the last hard coded part. Make it flexible.

## Imports

In [None]:
import torch

from mlp import MLP, train
from sem import SEM
from interventions import Interventions
import utils

## Input the graph for the SEM

First we have to set up a structural equation model.
It consists of a graph and the corresponding equations.
We initialize an `SEM` object by passing in a graph as a dictionary. (Details of the data structure are in the docstring of the `SEM` class.

We can then draw the graph with `sem.draw()` and print a lot of information about it with `sem.summary()`.

In [None]:
sem = SEM({"Np": None, "A": None, "Nx": None, "P": ["Np", "A"], "X": ["A", "P", "Nx"], "Y": ["P", "X"]})
sem.summary()
sem.draw()

## Specify the structural equations

Let us first check the status of the vertices to make sure we attach valid equations.

In [None]:
# All vertices
print("All vertices: ", sem.vertices())
# Root vertices => provide distributions
print("Roots: ", sem.roots())
# Non root vertices => provide equations making use of all parents
print("Non-roots: ", sem.non_roots())

Now we attach structural equations to the vertices with `sem.attach_equation(vertex, callable)`.
For the root vertices, we draw from a standard normal.

The only argument to the callable is an integer `n`, the number of samples to draw. Of course, we could also attach different distributions separately.

**Note**: The `callable` attached to a vertex needs to return a `torch.tensor`.

In [None]:
for v in sem.roots():
    sem.attach_equation(v, lambda n: torch.randn(n, 1))

For the non-root vertices we attach made up functions.

The only argument to the callable for non-roots is a dictionary `data` that must have the vertex names as keys. This example shows how the parent vertices are accessed. We just construct a fully linear model in which all coefficients are just 1.

In [None]:
sem.attach_equation("P", lambda data: 1 * data['Np'] + 1 * data['A'])
sem.attach_equation("X", lambda data: 1 * data['A'] + 1 * data['P'] + 1 * data['Nx'])
sem.attach_equation("Y", lambda data: 1 * data['P'] + 1 * data['X'])

## Sample from the SEM

Now the SEM is fully specified and we can draw samples from it.

In [None]:
orig_sample = sem.sample(8192)

The `utils` module contains functions for plotting whole samples, where each variable is plotted as a function of its parents.

In [None]:
utils.plot_samples(sem, orig_sample)

## Learn the structural equations from data

While in this example we provided analytical equations for the structural equation model, in reality we only get data. Our assumptions are that we guessed the causal graph correctly, but we do not know the structural equations. We assume that we have a observed samples from the graph. In this example, we will use the generated sample as our observed data.

Given the graph and the observed data, we can now try to learn the structural equations. **Note**: This can be done even if we had not attached structural equations to the `SEM` object.

**Arguments**: We pass in our "observed" sample, and can specify the number and sizes of hidden layers by `hidden_sizes` (default: `()` i.e. no hidden layers). Moreover, we can pass a list of vertices to the `binarize` keyword to add a `torch.nn.Sigmoid()` layer at the end when predicting those vertices (default: `[]`). Further, we can pass `epochs` (default: `50`) and `batchsize` (default: `32`) as named arguments.

In [None]:
sem.learn_from_sample(sample=orig_sample, hidden_sizes=(), binarize=[])

We can look at what networks have been learned.

In [None]:
sem.learned

For smaller networks (especially in the linear case with no hidden layers), it can be insightful to check whether the learned parameters match the actual coefficients in the analytical equations from which the sample was generated. In our simple case we get only ones, so we almost perfectly learned the linear equations (unsurprisingly).

In [None]:
sem.print_learned_parameters(weights=True, biases=False)

## Sample from the learned equations

Similarly to how we sampled from the analytical structural equations before, we can now sample from the learned equations.

Note, however, that we did not learn the distributions for the root vertices. Hence we have to provide values for the root vertices and can then pass those down to predict the other vertices with our learned functions with the `predict_from_sample()` function. Without further arguments, it does not mutate the input, but returns a new sample that has identical values for the root vertices and updates all non-root vertices with predictions from the learned functions.

**Note**: The `predict_from_sample()` function is more flexible. One can choose manually which vertices to update (`update` argument), whether to mutate the passed in sample instead of creating a new one with `mutate=True` (then the return value is `None`) and also to use a different predictor for specified vertices by `replace={vertex: predictor}`. 

In [None]:
learned_sample = sem.predict_from_sample(orig_sample)

We can now plot the original sample and the learned sample simultaneously by passing a list of samples to `utils.plot_samples()`.

In [None]:
utils.plot_samples(sem, [orig_sample, learned_sample], legend=['analytic', 'learned'], alpha=0.5)

In the fully linear case, we recover the original sample basically perfectly, i.e. we learned the structural equations exactly.

## Specify the interventions

This is our self made format to specify interventions. In a dict, for each proxy variable, we store another dict, which we call `functions`. In `functions`, keys are preset strings that correspond to the `known_functions` in the `Intervention` class. Current options: `'randn'`, `'rand'`, `'const'`, `'range'`. Every value of `functions` must be a list of tuples (!), where the tuples hold one or multiple scalar arguments (depending on the key).

**Example:**

This specifies five different intervened values for the proxy `'P1'` and four different intervened values for the proxy `'P2'`, a total of `5 * 4 = 20` different intervened samples.

```python
intervention_spec = {
    'P1': {
          'randn': [(0, 3), (0, 3), (0, 5)],
          'const':[(1,), (0,)],
          },
    'P2': {
          'range': [(-1, 1), (-5, 5)],
          'rand':[(-1, 1), (-5, 5)],
          },
    }
```

Note that `Interventions` also takes a sample as an argument. Currently, interventions are done on an existing sample, i.e. first, we compute the intervened graph, given the proxies specified in the `intervention_spec`. Then we copy the sample `n_interventions` times and fill the proxy values in each sample with one of the possible combination of specified interventions. In the intervened graph, we then update all descendents of the proxies (in topological order), where we might also need values from other root vertices. This is why we already provide a sample.

Strictly, this corresponds to neither counterfactuals nor interventions. As always there's no "right" way to this, but I'm happy for your opinions on the following options:

1. Always use one single sample for the other root vertices in the intervened graph:
    a. Use the same original sample that was used to learn the equations.
    b. Draw a new "base sample" for the retraining part.
2. For each intervened sample, draw the other root vertices in the intervened graph anew.

Consider also:

* In reality, we do not observe a full sample of the graph (root vertices are not observed).
* Can we make assumptions about distribution of root vertices in real life, e.g. Gaussian? If so, how do we find the corresponding root vertex values belonging to one specific observation. (If we see P, X, Y, how do we find the corresponding Nx, A, Np?) While the distributions are enough to sample new values, the specific corresponding values are needed to learn the equations in the first step.

For the linear example, we choose random normal distributions with different variances as interventions.

In [None]:
intervention_spec = {
    'P': {
         'randn': [(0, 3), (0, 3), (0, 5), (0, 5)],
         },
    }
interventions = Interventions(sem, orig_sample, intervention_spec)
interventions.summary()

## Train a corrected version

Eventually we can actually retrain part of the target network, in this case the network for `'Y'` to minimize the variance of predictions across all different intervened samples. Note that here it seems like it only makes sense to do this for the same values of root vertices (closer to counterfactual?), because why would I want similar `'Y'` values for completely different starting values? On the other hand, we want that to be true in distribution, hence for a large batch size, we could also try to enforce that criterion with different values for the root vertices in each intervened sample.

In [None]:
corrected = interventions.train_corrected(epochs=100, batchsize=64)

## Evaluate the corrected model

### Small linear models: check parameters directly

For this small linear network we can look directly at the parameters it has learned. We indeed see that it learns the ones everywhere originally and in the corrected version has a -1 for `'P'` instead, exactly what theory demands.

In [None]:
from pprint import pprint
print("Original weights:")
sem.print_learned_parameters(show=['Y'], weights=True, biases=False)

print("")
print("Fair parameters:")
for name, param in corrected.named_parameters():
    if 'bias' not in name:
        print(param.data.numpy())

### Comparison on a new sample

Let's look at the full sample plots we have already encountered above for a new sample, its learned reproduction and the corrected results.

In [None]:
base, orig, fair = utils.evaluate_on_new_sample(sem, 'Y', corrected, plot=True)

As we have already seen, the learned perfectly recovers the original sample from the analytical structural equation model. The fair results coincide up to the target value `'Y'` of course, because we did not touch any other part. The dependence of `'Y'` on both `'P'` and `'X'` has been decreased, but is **not** zero (see next section for an explanation).

### Evaluation tools for linear prediction

In the linear case, we can also look at (print and plot) all sorts of correlations, i.e. the slopes, r-values (Pearson Correlation Coefficient), p-values and standard errors of these tests.

We see that the correlation between `'Yfair'` and `'P'` goes down as compared to `'Y'` and `'P'`, but is **not** zero. There is still correlation bettwen `'Yfair'` and `'P'` left through the confounder `'A'`. This is the main difference to all "learning fair representation" approaches so far.

In [None]:
utils.print_correlations(orig, sem=sem, sources=['A', 'P', 'X'], targets=['Y', 'Yfair'])

In [None]:
all_vars = sem.vertices() + ['Yfair']
utils.plot_correlations(orig, sem=sem, sources=all_vars, targets=all_vars)

## Quick run through a binarized example

Now we go through the whole workflow from specifying a graph to the final evaluation (without unnecessary intermediate steps), where we binarize the value of `'P'`.

In [None]:
# Construct the graph
sem = SEM({"Np": None, "A": None, "Nx": None, "P": ["Np", "A"], "X": ["A", "P", "Nx"], "Y": ["P", "X"]})

# Attach equations
for v in sem.roots():
    sem.attach_equation(v, lambda n: torch.randn(n, 1))
sem.attach_equation("P", lambda data: (1 * data['Np'] + 1 * data['A'] > 0.0).float())
sem.attach_equation("X", lambda data: 1 * data['A'] + 5 * data['P'] + 1 * data['Nx'])
sem.attach_equation("Y", lambda data: 1 * data['P'] + 1 * data['X'])

# Learn the equations (internally computes sample), use hidden layer for demo purposes
orig = sem.learn_from_sample(hidden_sizes=(128,), epochs=50, binarize=['P'])
learned_sample = sem.predict_from_sample(orig_sample)

# Specify interventions, this time constants 0 and 1
# intervention_spec = {'P': {'const': [(0,), (1,)], 'range': [(0, 1)], 'rand': [(0, 1)]}}
intervention_spec = {'P': {'const': [(0,), (1,), (0,), (1,), (0,), (1,), (0,), (1,)]}}
interventions = Interventions(sem, orig_sample, intervention_spec)

# Remove proxy discrimination
corrected = interventions.train_corrected(epochs=100, batchsize=64)
                    
# Evaluate on new sample
base, orig, fair = utils.evaluate_on_new_sample(sem, 'Y', corrected, plot=True)

## MISC

### Evaluation tools for binary prediction

In [None]:
import copy
from sklearn.metrics import confusion_matrix

In [None]:
s1 = copy.deepcopy(orig_sample)
s2 = copy.deepcopy(fair_sample)
s1['Y'] = (s1['Y'] > 0.5).float()
s2['Y'] = (s2['Y'] > 0.5).float()

In [None]:
utils.plot_samples(sem, [test_sample, s1, s2], legend=['analytical', 'learned', 'fair'], alpha=0.3)

In [None]:
confusion_matrix(s1['Y'].int().numpy(), s2['Y'].int().numpy())

# DEVELOPMENTAL STAGE -- DEPRECATED BEYOND THIS POINT

## Imports

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import collections
import copy
import tqdm
from IPython.core.debugger import Tracer

import torch
import torch.nn as nn
from torch.autograd import Variable

from mlp import MLP, train
from graph import Graph

## Normal multilayer perceptron + training (now in `mlp.py`)

In [None]:
class MLP(nn.Module):
    """A simple fully connected feed forward network."""
    def __init__(self, sizes, final=None):
        """
        Initialize the network.
        
        A variable size network with only fully connected layers and ELU activations after all but the last layer.
        
        Args:
        
        sizes: A list of the numbers of neurons in the layers.
               len(sizes)-1 is the number of layers.
               First and last entries are input and output dimension.
        final: What to use as a final layer, e.g. torch.nn.Sigmoid()
               None (default) means no final layer (regression vs. classification).
               
        Example:
            A network with 2-dimensional input, one hidden layer with 128 neurons and 1-dimensional output for regression.
            >>> net = MLP([2, 128, 1])
            
            A network with 10-dimensional input, two hidden layers of 128 and 256 neurons and 1-dimensional output for classification.
            >>> net = MLP([10, 128, 256, 1], final=torch.nn.Sigmoid())            
        """
        super(MLP, self).__init__()
        
        self.layers = nn.ModuleList()
        # If there is only one input dimension, everything is fine
        if sizes[0] == 1:
            self.layers.append(nn.Linear(sizes[0], sizes[1]))
        # For multiple input dimensions, each one has a separate following hidden layer.
        # This is necessary for the partial training later on.
        else:
            self.layers.append(nn.ModuleList([nn.Linear(1, sizes[1]) for _ in range(sizes[0])]))
            
        # Add the remaining layers with elu activations
        for i in range(len(sizes) - 1)[1:]:
            if i != (len(sizes) - 1):
                self.layers.append(nn.ELU()) 
            self.layers.append(nn.Linear(sizes[i], sizes[i + 1]))
            
        if final is not None:
            self.layers.append(final)


    def forward(self, x):
        """The forward pass."""
        # If there are multiple inputs, add up their hidden layers
        if isinstance(self.layers[0], collections.Iterable):
            y = self.layers[0][0](x[:, 0, None])
            for i in range(1, len(self.layers[0])):
                y += self.layers[0][i](x[:, i, None])
            return nn.Sequential(*[self.layers[i] for i in range(1, len(self.layers))])(y)
        # Otherwise just build a simple sequential model
        else:
            return nn.Sequential(*self.layers)(x)

In [None]:
def train(net, x, y, loss_func=nn.MSELoss(), epochs=50, batchsize=32):
    """
    Train a network.
    
    Args:
        net:       A network module.
        x:         Training input data.
        y:         Training labels.
        loos_func: Loss function, default is nn.MSELoss(), i.e. mean squared error.
        n_epochs:  Number of training epochs.
    """
    opt = torch.optim.Adam(net.parameters())
    n_samples = x.size(0)
    for epoch in range(epochs):
        # Shuffle training data
        p = torch.randperm(n_samples).long()
        xp = x[p]
        yp = y[p]

        for i1 in range(0, n_samples, batchsize):
            # Extract a batch
            i2 = min(i1 + batchsize, n_samples)
            xi, yi = xp[i1:i2], yp[i1:i2]

            # Reset gradients
            opt.zero_grad()
            
            # Forward pass
            loss = loss_func(net(Variable(xi)), Variable(yi))
            
            # Backward pass
            loss.backward()
            
            # Parameter update
            opt.step()
    return net

## The graph representation (now in `graph.py`)

In [None]:
class Graph:
    """A light weight, self made graph representation."""
    def __init__(self, graph):
        """Initialize a Graph object."""
        if isinstance(graph, dict):
            self.graph = graph
        else:
            print("Could not process input {} as graph. Initialized empty graph.".format(graph))
            self.graph = None
    
    def __repr__(self):
        """Define representation."""
        import pprint
        return pprint.pformat(self.graph)
    
    def __str__(self):
        """Define string format."""
        import pprint
        return pprint.pformat(self.graph)

    def __iter__(self):
        return iter(self.graph)
    
    def __getitem__(self, item):
        return self.graph[item]
    
    def _try_add_vertex(self, vertex):
        if vertex in self.graph:
            print("Vertex already exists.")
        else:
            self.graph[vertex] = None
            print("Added vertex ", vertex)
    
    def _try_add_edge(self, source, target):
        if source in self.graph:
            if target not in self.graph[source]:
                self.graph[source].append(target)
            else:
                print("Edge already exists.")
        else:
            self.graph[source] = [target]
    
    def add_vertices(self, vertices):
        """Add one or multiple vertices to the graph."""
        if isinstance(vertices, collections.Iterable):
            for v in vertices:
                _try_add_vertex(v)
        else:
            _try_add_vertex(v)

    def add_edge(self, source, target):
        """Add a single edge from source to target."""
        self._try_add_edge(source, target)
        
    def vertices(self):
        """Find all vertices."""
        return list(self.graph.keys())
    
    def edges(self):
        """Find all edges."""
        edges = []
        for node, parents in self.graph.items():
            if parents is not None:
                for p in parents:
                    edges.append({p: node})
        return edges

    def roots(self):
        """Find all root vertices."""
        return [node for node in self.graph if self.graph[node] is None]

    def non_roots(self):
        return [node for node in self.graph if self.graph[node] is not None]
    
    def leafs(self):
        """Find all leaf vertices."""
        return list(set(self.vertices()).difference(self.non_leafs()))

    def non_leafs(self):
        """Find all non-leaf vertices."""
        return list(set(sum([p for p in self.graph.values() if p is not None], [])))
    
    def parents(self, vertex):
        """Find the parents of a vertex."""
        return self.graph[vertex]

    def children(self, vertex):
        """Find the children of a vertex."""
        children = []
        for node, parents in self.graph.items():
            if parents is not None and vertex in parents:
                children.append(node)
        return children
    
    def descendents(self, vertex):
        """Find all descendents of a vertex."""
        descendents = []
        current_children = self.children(vertex)
        if not current_children:
            return descendents
    
        descendents += current_children
    
        for child in current_children:
            new_descendents = self.descendents(child)
            descendents += new_descendents

        return list(set(descendents))
    
    def get_intervened_graph(self, interventions):
        """Return the intervened graph as a new graph."""
        intervened_graph = copy.deepcopy(self.graph)
        if isinstance(interventions, collections.Iterable):
            for i in interventions:
                intervened_graph[i] = None
        else:
            intervened_graph[interventions] = None
        return Graph(intervened_graph)
    
    def summary(self):
        """Print summary of the graph."""
        print("Vertices in graph", self.vertices())
        print("Roots in graph", self.roots())
        print("Non-roots in graph", self.non_roots())
        print("Leafs in graph", self.leafs())
        print("Non-leafs in graph", self.non_leafs())
        print("Edges in the graph", self.edges())

        for v in self.vertices():
            print("Children of {} are {}".format(v, self.children(v)))
            print("Parents of {} are {}".format(v, self.parents(v)))
            print("Descendents of {} are {}".format(v, self.descendents(v)))
        
    def _convert_to_nx(self):
        import networkx as nx
        G = nx.DiGraph()
        for edge in self.edges():
            edge = next(iter(edge.items()))
            G.add_edge(*edge)
        return G

    def topological_sort(self):
        import networkx as nx
        G = self._convert_to_nx()
        return list(nx.topological_sort(G))
    
    def draw(self):
        import networkx as nx
#         from nxpd import draw, nxpdParams
#         nxpdParams['show'] = 'ipynb'
        G = self._convert_to_nx()
        G.graph['dpi'] = 150
        draw(G)

## Sampling (now merged with graph in graph superclass `sem.py`)

In [None]:
# This is how we define a causal graph
# Nodes are the keys of the graph and the values are the parents(!) of the key.
# Setting the value to None means that the node is a root of the graph.
graph = Graph({"Np": None, "A": None, "Nx": None, "P": ["Np", "A"], "X": ["A", "P", "Nx"], "Y": ["P", "X"]})
graph.summary()
graph.draw()

In [None]:
# Default number of examples in a sample
n_sample = 8192

Extremely hard coded sampling from one given graph

In [None]:
def get_sample(n, eps=0.05):
    """Generate sample data specific to a graph (hand tuned)."""
    # Randomly sample the root nodes variables
    Np = torch.randn(n, 1)
    Nx = torch.randn(n, 1)
#     A = torch.zeros(n, 1)
#     A[torch.randperm(n).long()[:int(n/2)]] = 1.
    A = torch.randn(n, 1)
    
    P = 1 * Np + 3 * A + eps * torch.randn(n, 1)
#     P = (1 * Np + 3 * A + eps * torch.randn(n, 1) > 0.0).float()
    X = 2 * A + 1 * P + 3 * Nx + eps * torch.randn(n, 1)
#     Y = 1 * P + 3 * X + eps * torch.randn(n, 1)
    Y = (1 * P + 3 * X + eps * torch.randn(n, 1) > 0.0).float()
    return dict(Np=Np, Nx=Nx, A=A, P=P, X=X, Y=Y)

def plot_samples(graph, samples):
    """Plot all relevant dependencies in a graph from a/multiple sample(s)."""        
    # If we did not already receive a list of samples, make one element list
    if not isinstance(samples, list):
        samples = [samples]
    # Get non root variables
    non_roots = graph.non_roots()
    # Get maximum number of input variables
    max_deps = max([len(graph.parents(var)) for var in non_roots])
    
    fig, axs = plt.subplots(len(non_roots), max_deps, figsize=(5 * max_deps, 5 * len(non_roots)))

    # Go through all dependencies and plot them as 2D scatter plots
    for i, y_var in enumerate(non_roots):
        for j, x_var in enumerate(graph.parents(y_var)):
            for sample in samples:
                axs[i, j].plot(sample[x_var].numpy(), sample[y_var].numpy(), '.')
                axs[i, j].set_xlabel(x_var)
                axs[i, j].set_ylabel(y_var)
    plt.tight_layout()
    plt.show()
    
# Pure util function
def combine_variables(variables, sample):
    """Stack variables from sample along new axis."""
    data = torch.stack([sample[i] for i in variables], dim=1).squeeze()
    if len(data.size()) == 1:
        data.unsqueeze_(1)
    return data

In [None]:
# Sample should probably come as a pandas dataframe?
# But then the things are not torch arrays, so maybe keeping it as a dict is smarter?
sample = get_sample(n_sample, eps=0)

## Learn the _real_ SEM (now also part of `sem.py`)

In [None]:
def learn_sem(graph, sample, hidden_sizes=(), binarize=None):
    """Given a graph and a sample from it, learn the structural equations."""
    learned = {}
    for vertex in graph.non_roots():
        print("Training {} -> {}...".format(graph.parents(vertex), vertex), end=' ')
        data = combine_variables(graph.parents(vertex), sample)
        if vertex in binarize:
            final = nn.Sigmoid()
        else:
            final = None
        learned[vertex] = train(MLP([data.size(-1), *hidden_sizes, 1], final=final), data, sample[vertex])
        print("DONE")
    return learned

In [None]:
def predict_sample(graph, sample, learned):
    new_sample = copy.deepcopy(sample)
    need_update = [v for v in graph.topological_sort() if v not in graph.roots()]
    print("Updating the nodes {}...".format(need_update), end=' ')
    for update in need_update:
        argument = Variable(combine_variables(graph.parents(update), new_sample))
        new_sample[update] = learned[update](argument).data
    print("DONE")
    return new_sample

In [None]:
def learn_from_sample(self, sample, learned):
        from torch.autograd import Variable
        new_sample = copy.deepcopy(sample)
        need_update = [v for v in self.topological_sort()
                       if v not in self.roots()]
        print("Updating the nodes {}.".format(need_update))
        for update in need_update:
            print("Updating node {}...".format(update), end=' ')
            argument = Variable(utils.combine_variables(self.parents(update),
                                                        new_sample))
            new_sample[update] = learned[update](argument).data
        print("DONE")
        return new_sample

In [None]:
learned = learn_sem(graph, sample, binarize='Y')

In [None]:
pred_sample = predict_sample(graph, sample, learned)

In [None]:
plot_samples(graph, [sample, pred_sample])

## Interventions and intervened data sets (now part of `interventions.py`)

In [None]:
# Self made structure to specify interventions
class Interventions:
    """Manage and create training data sets for interventions."""
    
    # Methods for creating intervened samples
    known_functions = {
        'randn': (lambda self, mean, var: torch.randn(self.n_samples, 1) * var + mean),
        'const': (lambda self, const: torch.ones(self.n_samples, 1) * const),
        'rand': (lambda self, start, end: torch.rand(self.n_samples, 1) * (start - end) + end),
        'range': (lambda self, start, end: torch.linspace(start, end, steps=self.n_samples).unsqueeze_(1))
    }

    def __init__(self, graph, base_sample, intervention_spec, target='Y'):
        """Initialize with a base sample and intervention specification."""
        self.base_sample = base_sample
        self.n_samples = len(next(iter(base_sample.values())))
        self.interventions = intervention_spec
        self.proxies = list(intervention_spec.keys())
        self.graph = graph
        self.intervened_graph = self.graph.get_intervened_graph(self.proxies)
        self.target = target
        self._set_n_interventions()
        self.training_samples = []
        self._check_input()

    def _check_input(self):
        """Some basic checks of the input."""
        assert self.target in self.graph.leafs(), "Can't correct for non-leaf {}".format(self.target)

        for proxy in self.proxies:
            assert self.target in self.graph.descendents(proxy), "Can't correct for non-descendent {} of proxy {}.".format(self.target, proxy)

    def _set_n_interventions(self):
        """Compute and set the total number of interventions, i.e. training sets."""
        self.n_interventions = 1
        for proxy, funcs in self.interventions.items():
            for params in funcs.values():
                if not isinstance(params, list):
                    params = [params]
                self.n_interventions *= len(params)

    def get_training_samples(self):
        """Generate the training samples for the given interventions."""
        if not self.training_samples:
            self._create_intervened_samples()
            self._update()
        return self.training_samples
    
    def _create_intervened_samples(self):
        """Generate copies of base sample for each intervention and set proxies."""
        self.training_samples = []
        for proxy, functions in self.interventions.items():
            for func, parameters in functions.items():
                if not isinstance(parameters, list):
                    parameters = [parameters]
                for params in parameters:
                    sample = copy.deepcopy(self.base_sample)
                    sample[proxy] = self.known_functions[func](self, *params)
                    self.training_samples.append(sample)
    
    def _update(self):
        """Update the variables downstream of the proxies."""
        downstream = list(set(sum([self.intervened_graph.descendents(proxy) for proxy in self.proxies], [])))
        need_update = list(set(downstream).difference(set(self.target)))
        fixed = set(self.intervened_graph.vertices()).difference(downstream)

        while need_update:
            found_one = False
            for update in need_update:
                if set(self.intervened_graph.parents(update)) <= set(fixed):
                    # Found one that can be updated
                    found_one = True
                    # Update this variable in all samples
                    for sample in self.training_samples:
                        argument = Variable(combine_variables(self.intervened_graph.parents(update), sample))
                        sample[update] = learned[update](argument).data
                    # Remove the updated one from the list
                    need_update.remove(update)
            assert found_one, "Could not update any downstream variables {} from {}".format(need_update, fixed)

In [None]:
# This is our self made format to specify interventions.
# In a dict, for each proxy variable, we store another dict, which we call `functions`.
# In `functions`, keys are preset strings that correspond to the `known_functions` in the `Intervention` class.
# Current options: 'randn', 'rand', 'const', 'range'
# Every value of `functions` must be a list of tuples (!),
# where the tuples hold one or multiple scalar arguments (depending on the key).
# Example:
# intervention_spec = {
#     'P': {'randn': [(0, 3), (0, 3)],
#           'const': [(1,), (0,)],
#           'range': [(-1, 1)]
#          },
#     'X': {'randn': [(0, 1), (0, 1), (0, 1)]
#          },
#     }

intervention_spec = {
    'P': {
#          'randn': [(0, 3), (0, 3)],
         'const':[(1,), (0,)]
         },
    }
interventions = Interventions(graph, sample, intervention_spec)
print("Sample size: {}, Number of interventions {}".format(interventions.n_samples, interventions.n_interventions))

## Correct the _real_ SEM

**Need to create/manage intervened samples differently:**

* Can't have `intervention_values`, because just two different random samples should also be possible

In [None]:
def copy_and_freeze(model):
    """Copy a learned model and partially freeze parameters."""
    # Copy the original model for the target variable
    corrected = copy.deepcopy(model)

    # First freeze all parameters
    for param in corrected.parameters():
        param.requires_grad = False

    # Then only give gradients to the part that should be retrained for correction
    # FIXME: the layer indices are hard coded. I have to find those out
    # FIXME: not sure whether to finetune only weights or also biases?

    # fine tune weights and bias:
    for param in corrected.layers[0][0].parameters():
        param.requires_grad = True
    # fine tune only weights
#     corrected.layers[0][0].weight.requires_grad = True

    return corrected

In [None]:
def train_corrected(learned, interventions, batchsize=32, epochs=50):
    # Some basic input checks
    target = interventions.target
    proxies = interventions.proxies
    print("Correct for the effect of {} on {}.".format(proxies, target))

    print("Generate intervened samples...", end=' ')
    train_samples = interventions.get_training_samples()
    print("DONE")

    # Sanity check
    assert len(train_samples) == interventions.n_interventions, "Number of interventions {} does not match number of training samples {}".format(interventions.n_interventions, train_samples)
    print("There is a total of {} interventions.".format(len(train_samples)))    
    
    print("Freeze everything except first weights from {} to {}...".format(proxies, target), end=' ')
    corrected = copy_and_freeze(learned[target])
    print("DONE")

    print("Set up the optimizer...", end=' ')
    opt = torch.optim.Adam(filter(lambda p: p.requires_grad, corrected.parameters()))
    print("DONE")
    
    print("Partially retrain the target model for correction...", end=' ')
    n_samples = interventions.n_samples
    for epoch in tqdm.tqdm(range(epochs)):
        p = torch.randperm(n_samples).long()
                    
        for i1 in range(0, n_samples, batchsize):
            # sample data
            i2 = min(i1 + batchsize, n_samples)

            # reset gradients
            opt.zero_grad()

            # forward pass
            Ys = Variable(torch.zeros(batchsize, interventions.n_interventions))
            for i, sample in enumerate(train_samples):
                argument = Variable(combine_variables(interventions.intervened_graph.parents(target), sample)[i1:i2, :])
                Ys[:, i] = corrected(argument).squeeze()

            loss = torch.sum(torch.var(Ys, dim=1))
            
            # backward pass
            loss.backward()

            # parameter update
            opt.step()
    print("DONE")
    print("Finished correction.")
    return corrected

In [None]:
equalizer = train_corrected(learned, interventions)

## Sanity Checks

### $\mathbb{R} \to \mathbb{R}$

In [None]:
def example_linear(support=[0,1], slope=1, constant=0, n=1024, eps=0.1):
    """Simple linear data with noise."""
    x = torch.rand(n, 1) * (support[1] - support[0]) + support[0]
    y = slope * x + constant + eps * torch.rand(n, 1)
    return x, y

In [None]:
def example_quadratic(support=[0,1], a=1, b=0, c=0, n=1024, eps=0.1):
    """Simple linear data with noise."""
    x = torch.rand(n, 1) * (support[1] - support[0]) + support[0]
    y = a * x**2 + b * x + c + eps * torch.rand(n, 1)
    return x, y

In [None]:
x, y = example_quadratic()

In [None]:
plt.plot(x.numpy(), y.numpy(), '.')

In [None]:
pred = train(MLP([1, 128, 1]), x, y)

In [None]:
plt_x = torch.linspace(-10, 10, steps=1024)[:, None]

In [None]:
plt.plot(plt_x.numpy(), pred(Variable(plt_x)).data.numpy(), '.')

### $\mathbb{R}^2 \to \mathbb{R}$

In [None]:
x = torch.randn(1000, 2) * 3

In [None]:
y = x[:, 0] * 2 - 1.5 * x[:, 1]**2

In [None]:
plt.plot(x.numpy()[:, 0], y.numpy(), '.')
plt.plot(x.numpy()[:, 1], y.numpy(), '.')

In [None]:
test = train(MLP([2, 128, 1]), x, y)

In [None]:
plt_x = torch.randn(1000, 2) * 5

In [None]:
plt.plot(plt_x.numpy()[:, 0], test(Variable(plt_x)).data.numpy(), '.')
plt.plot(plt_x.numpy()[:, 1], test(Variable(plt_x)).data.numpy(), '.')