# Introduction

**The SN² Solver** finds the optimal parameters of a Gaussian pmDAG.
As we will see, there are more than one method to compute the optimal parameters of the pmDAG structure.

We start by loading the libraries. The SN² Solver uses PyTorch&reg; and depends on CUDA&reg; as the backend.
We will need to load pytorch and make sure that the backend used is CUDA. We do this by
introducing a `device` variable that is always set to cuda and pass it to any tensor we make.
We also import our library `sn2_cuda`.

In [1]:
import torch
import sn2_cuda as sn2

device = torch.device("cuda")

Consider the following pmDAG.

![alt text](img/health-graph.svg "Health Graph")

It is composed of four visible variables and six latent variables.
We will compile this graph as an adgacency matrix. The encoding is simple: the latent variables go on top
and the remaining variables (the visible ones) form an upper-triangular matrix at the bottom.

In [2]:
struct = torch.tensor([
        [1, 1, 1, 0],
        [0, 1, 0, 1],
        [1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1],
        [0, 1, 1, 1],  # V_X
        [0, 0, 1, 0],  # V_BP
        [0, 0, 0, 1],  # V_BMI
        [0, 0, 0, 0],  # V_Y
    ], dtype=torch.bool, device=device)

In order to obtain the fittest parameters, we fit this structure to a `SN2Solver`.


In [3]:
pmDAG = sn2.SN2Solver(struct)

We will obtain the fittest parameters (that is the optimal weights between the variables)
with respect to the sample covariance matrix. It is the observed covariance matrix that has been induced by the causal system.

In [4]:
sample_covariance = torch.tensor([
        [2,  3,  6,  8],
        [3,  7, 12, 16],
        [6, 12, 23, 30],
        [8, 16, 30, 41],
    ], device=device)

pmDAG.sample_covariance = sample_covariance

The newly created `SN2Solver` object will use the gradient descent method to compute the optimal weights.
However, it only computes the partial derivatives of the objective function with respect to the weights.
Therefore, we need to use an arbitrary optimizer to update the weights in each step.

In [5]:
optim = torch.optim.Adamax([pmDAG.weights], lr=0.001)

That's it! We only need to start training the SN2Solver. Training the SN2Solver is pretty much
like training a neural network: we take a `forward()` and `backward()` step and then call the `step()` method
of the optimizer to do the rest. For this, we first set the stopping conditions:

In [6]:
max_iterations = 10000
min_error = 1.0e-7

The following trains the pmDAG. We would also like to print valuable information at each step of the optimizer.


In [7]:
for i in range(max_iterations):
    pmDAG.forward()
    error = pmDAG.loss().item()

    if error < min_error:
        break

    pmDAG.backward()
    optim.step()

    if i % (max_iterations / 10) == 0:
        print(f"iteration={i:<10} loss={error:<15.5}")
else:
    print("Did not converge in the maximum number of iterations!")

iteration=0          loss=23.049         
iteration=1000       loss=3.3946         
iteration=2000       loss=0.7656         
iteration=3000       loss=0.36466        
iteration=4000       loss=0.02155        
iteration=5000       loss=5.6028e-06     


Now that the pmDAG has been parametrized using the `SN2Solver`, we can print out the induced visible covariance matrix...

In [9]:
print(pmDAG.visible_covariance_)

tensor([[ 1.9988,  2.9999,  5.9986,  7.9981],
        [ 2.9999,  7.0011, 12.0013, 16.0017],
        [ 5.9986, 12.0013, 23.0003, 30.0002],
        [ 7.9981, 16.0017, 30.0002, 41.0001]], device='cuda:0')


... and the weights matrix of the pmDAG.

In [10]:
print(pmDAG.weights)

tensor([[-6.7116e-01,  1.4078e+00,  2.1650e+00, -0.0000e+00],
        [-0.0000e+00,  9.4347e-01, -0.0000e+00,  5.0055e-01],
        [-1.2443e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],
        [-0.0000e+00, -2.7091e-01, -0.0000e+00, -0.0000e+00],
        [-0.0000e+00, -0.0000e+00,  2.4145e-26,  0.0000e+00],
        [-0.0000e+00, -0.0000e+00,  0.0000e+00,  1.2459e+00],
        [-0.0000e+00,  1.9735e+00,  3.3452e+00,  4.7257e-01],
        [ 0.0000e+00,  0.0000e+00,  2.5508e-01,  0.0000e+00],
        [-0.0000e+00,  0.0000e+00, -0.0000e+00,  1.1759e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]], device='cuda:0')
