In [18]:
import torch
import torch.nn.functional as F
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
from pyro.infer import Predictive, SVI, Trace_ELBO
from pyro.optim import Adam

from gpytorch.kernels import RBFKernel, ScaleKernel
import gpytorch.distributions as gdist

import pdb
from pmextract import extract

In [19]:
G = 5 # max number of functions
A = 7 # max number of arguments per function
func_alpha = 1.0 # concentration parameter for function DP
markov_alpha = 1.0 # concentration parameter for how old your args should be

In [20]:
data = torch.ones(3, 9)

In [21]:
class GPKernel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.kernel = ScaleKernel(RBFKernel())
    def forward(self, input_locs):
        return self.kernel(input_locs)

In [22]:
kernel = GPKernel()

In [23]:
def mix_weights(beta):
    "Turn iid Beta samples into the weights of a categorical (Stick Breaking DP)"
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)

In [62]:
def partition(num_funcs, func_ids, outputs, inputs):
    return [(inputs[func_ids == i], outputs[func_ids == i]) for i in range(num_funcs)]

In [72]:
def model(data):
    T, M = data.shape # timesteps by variables
    N = T * M # total number of variables
    pyro.module("kernel", kernel)
    mu = pyro.param("mu", torch.randn(G))
    with pyro.plate("func_betas", G -1):
        func_beta = pyro.sample("func_beta", dist.Beta(1, func_alpha))
    with pyro.plate("function", N):
        from_function = pyro.sample("from_function", dist.Categorical(mix_weights(func_beta)))
        num_funcs = from_function.max() + 1
    with pyro.plate("arg_betas", T - 1):
        arg_beta = pyro.sample("arg_beta", dist.Beta(1, markov_alpha))
    with pyro.plate("func_arg_dists", num_funcs):
        func_arg_params = pyro.sample("func_arg_weights", dist.Dirichlet(torch.ones(M)))
    time_weights = torch.stack([F.pad(mix_weights(arg_beta[:i]), (0, T - i - 1)) for i in range(T)])
    with pyro.plate("arguments", N) as varindex:
        t = varindex.div(M, rounding_mode='trunc')
        with pyro.plate("nth_arg", A):
            arg_times = pyro.sample("arg_times", dist.Categorical(time_weights[t]))
            arg_vars = pyro.sample("arg_vars", dist.Categorical(func_arg_params[from_function]))
    pairs = partition(num_funcs, from_function, data.view(-1), data[arg_times, arg_vars].T)
    for i, pair in enumerate(pairs):
        gp_in, gp_out = pair
        k_xx = kernel(gp_in)
        pyro.sample(f"data_{i}", gdist.MultivariateNormal(mu[i:i+1].expand(k_xx.shape[0]), k_xx), obs=gp_out)

In [73]:
model(data)



TODO: pick up from here. Debug this!

Shouldn't there be 18 variables? Why are there 23?

In [48]:
k_xx.shape

torch.Size([18, 18])

In [49]:
gp_out.shape

torch.Size([18])

In [31]:
pdb.pm()

> [0;32m/tmp/ipykernel_641901/2010763094.py[0m(24)[0;36mmodel[0;34m()[0m
[0;32m     20 [0;31m            [0marg_vars[0m [0;34m=[0m [0mpyro[0m[0;34m.[0m[0msample[0m[0;34m([0m[0;34m"arg_vars"[0m[0;34m,[0m [0mdist[0m[0;34m.[0m[0mCategorical[0m[0;34m([0m[0mfunc_arg_params[0m[0;34m[[0m[0mfrom_function[0m[0;34m][0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m        [0mgp_ins[0m[0;34m,[0m [0mgp_outs[0m [0;34m=[0m [0mpartition[0m[0;34m([0m[0mnum_funcs[0m[0;34m,[0m [0mfrom_function[0m[0;34m,[0m [0mdata[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m,[0m [0mdata[0m[0;34m[[0m[0marg_times[0m[0;34m,[0m [0marg_vars[0m[0;34m][0m[0;34m.[0m[0mT[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     22 [0;31m        [0;32mfor[0m [0mgp_in[0m[0;34m,[0m [0mgp_out[0m [0;32min[0m [0mzip[0m[0;34m([0m[0mgp_ins[0m[0;34m,[0m [0mgp_outs[0m[0;34m)

In [27]:
pdb.pm()

> [0;32m/tmp/ipykernel_641901/947917112.py[0m(4)[0;36mpartition[0;34m()[0m
[0;32m      2 [0;31m    [0;31m# gp_out is a list (size n-funcs) of vectors (size 'examples')[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      3 [0;31m    [0;31m# gp_in is a list (size n-funcs) of (examples x n-args) matrices[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 4 [0;31m    [0mgp_out[0m [0;34m=[0m [0;34m[[0m[0minputs[0m[0;34m[[0m[0mfunc_ids[0m [0;34m==[0m [0mi[0m[0;34m][0m [0;32mfor[0m [0mi[0m [0;32min[0m [0mrange[0m[0;34m([0m[0mfunc_ids[0m[0;34m)[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m    [0mgp_in[0m [0;34m=[0m [0;34m[[0m[0moutputs[0m[0;34m[[0m[0mfunc_ids[0m [0;34m==[0m [0mi[0m[0;34m][0m [0;32mfor[0m [0mi[0m [0;32min[0m [0mrange[0m[0;34m([0m[0mfunc_ids[0m[0;34m)[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      6 [0;31m    [0;32mreturn[0m [0mgp_in[0m[0;34m,[0m [0mgp_out[0m[0;34m[0m[0

In [102]:
data[arg_times, arg_vars].shape # which argument x which variable. 

torch.Size([7, 27])

## GP Stuff

 Start with vanilla GPs. Eventually add fancy stuff like Kroncker structure or SKIP or variational versions.

Do we do this variationally, or exactly? We could do it exactly. Finding the log determinant of the covariance matrix will be the part that could take a long time. We could try both. 

To do it exactly, just make the covariance matrix, pass it to gpytorch MultivariateNormal, and you're good to go

## Other Thoughts

Idea: Bayesian programming learning in general gets you bounds on the causal effect size at the same time as it produces plausable programs. Because we know the true program is in the posterior somewhere, so if we just sample effects from the posterior at the same time as we sample parameters, we get a bound. 

Specifically, if we want to know a bound on causal effect size, we can check the differences between observations of $B_t$ in the original and observations of $B_t$ under intervention. That is, to check the causal effect in a given model, sample $\frac{1}{N}\sum_{i=1}^N B_t - B_t'$ from it.

What are the interventions here? With binary variables, that's easy. Just set or not set the variable of a variable. 

Alternately: there's also the perspective that causal effect MEANS that there's a directed path between the variables. We could just record, for each model sampled, whether such a path exists. This will tell you the posterior probability that A causes B. 