# VI without plates

i.e. no repeating bits to abstract over

Optimise the params of an approx posterior over extended Z-space, but not K space

$$Q (Z|X) = \prod_k  Q(Z^k|X) = \prod_k \prod_i Q(Z^k_i \mid Z^k_{qa(i)})$$

and

$$\prod_j f_j^{\kappa_j} = \frac{P(x_, Z)}{\prod Q(z_i^{k_i})}$$

Writing out the target (log marginal likelihood) fully makes the computation clear:

$$ \mathcal{L}= E_{Q(Z|X)} \left[ \log \frac{∑_K  P(Z,K,X)}{Q (Z|X)} \right]$$
$$= E_{Q} \left[ \log \frac{∑_K  P(Z,K,X)}{Q (Z|X)} \right]$$

with

$$
  \frac{P({Z, K, X})}{Q({Z|X})} = 
  P({K}) 
  P \left({X| Z_{\mathrm{pa}{X}}^{K_{\mathrm{pa}{X}}}} \right)  
  \prod_i 
  \frac{P\left({Z_i^{K_i}| Z^{K_{\mathrm{pa}(i)}}_{\mathrm{pa}(i)}} \right)}
  {Q \left(  
    Z_i^{K_i}| Z^{K_i}_{\mathrm{qa}(i)}
  \right)}$$

## The computation

1. Form joint P/Q: index prior * lik * product of latent P/Qs
2. $\mathcal{L}$: sum out K, then log P - log Q, then average

In [1]:
import math
import numpy as np
import torch as t
import torch.nn as nn
from torch.distributions import Normal, Categorical
from torch.distributions import MultivariateNormal as MVN

import sys; sys.path.append("..")
from tpp_trace import *
import utils as u
import tensor_ops as tpp
from tvi import *




## First with no plates

In [2]:
# a factorised approx posterior. generate 3 simple variables
# sample along the chain

# a ~ N([1],[3])
# b ~ N(a,[3])
# c ~ N(b,[3])

n = 3
scale = n
Norm = lambda mu, var : WrappedDist(Normal, mu, var)

TRUE_MEAN_A = 10

# Prior
# a -> b -> c observed
def chain_dist(trace, n=3):
    a = trace["a"].Normal(t.ones(n) * TRUE_MEAN_A, scale)
    b = trace["b"].Normal(a, scale)
    c = trace["c"].Normal(b, scale)
    
    return c


# a param placeholder
# Hardcoding 2 params for each var, for now
# factorised Gaussian with learned means and covs
class ChainQ(nn.Module):
    def __init__(self, n=3):
        super().__init__()
        self.mean_a = nn.Parameter(t.ones(n))
        self.mean_b = nn.Parameter(t.ones(n))
        self.logscale_a = nn.Parameter(t.ones(n)) # t.log(t.ones(n))
        self.logscale_b = nn.Parameter(t.ones(n))
    
    # TODO: make this actually depend on the params
    def sample(self, trace) :
        a = trace["a"].Normal(self.mean_a, t.exp(self.logscale_a))
        b = trace["b"].Normal(self.mean_b, t.exp(self.logscale_b))


In [3]:
class TVI(nn.Module) :
    def __init__(self, p, q, k, x, nProtectedDims):
        super().__init__()
        self.p = p
        self.q = q
        self.k = k
        self.nProtected = nProtectedDims
        
        self.data_dict = {}
        self.data_dict["__c"] = []
        self.data = nn.Parameter(x, requires_grad=False) 
        
    def forward(self):
        """
            1. s = sample Q
            2. lp_Q = eval Q.logprob(s)
            3. lp_P = eval P.logprob(s)
            4. f = lp_P - lp_Q
            5. loss = combine fs
        """
        self.data_dict["__c"].append(self.data)
        
        # init traces at each step
        sample_trace = sampler(self.k, self.nProtected) #, data={"__c": self.data}
        # sample recognition model Q -> Q-sample and Q-logprobs
        self.q.sample(sample_trace)
        
        # Pass Q samples to new trace
        eval_trace = evaluator(sample_trace, self.nProtected, data={"__c": self.data})
        # compute P logprobs 
        self.p(eval_trace)
        
        sum_out_pos(eval_trace)
        sum_out_pos(sample_trace)
        # align dims in Q
        sample_trace.trace.out_dicts = rename_placeholders(eval_trace, sample_trace)
        
        # to ratio land: P.log_probs - Q.log_probs (just the latents)
        tensors = subtract_latent_log_probs(eval_trace, sample_trace)
        
        # reduce gives loss
        loss_dict = tpp.combine_tensors(tensors)

        return tpp.undict(loss_dict)


def setup_and_run(tvi, ep=2000, eta=1) :
    optimiser = t.optim.Adam(tvi.q.parameters(), lr=eta) # optimising q only    
    optimise(tvi, optimiser, ep)
    
    return tvi


def optimise(tvi, optimiser, eps) :
    for i in range(eps):
        optimiser.zero_grad()
        loss = - tvi() 
        loss.backward()
        optimiser.step()
        

def sample_generator(nProtected, P, dataName="__c") :
    k = 1
    trp = sampler(k, nProtected)
    P(trp)
    return trp.trace.out_dicts["sample"][dataName] \
            .squeeze(0)
        

def get_error_on_a(a_mean, n, tvi) :
    a_mean = t.ones(n) * a_mean
    return a_mean - tvi.q.mean_a


# Recovering mean of first var
def main(nvars=3, nProtected=2, k=2, epochs=2000, true_mean=10, lr=0.2) :
    Q = ChainQ()
    P = chain_dist
    
    # Get _c data by sampling generator
    x = sample_generator(nProtected, P, dataName="__c")
    tvi = setup_and_run(TVI(P, Q, k, x, nProtected), epochs, eta=lr)
    
    return get_error_on_a(true_mean, nvars, tvi), tvi

In [20]:
# One-pass example
k = 2
nProtected = 2
p = chain_dist
q = ChainQ()

tr1 = trace({}, SampleLogProbK(K=4, protected_dims=nProtected))
val = chain_dist(tr1)
tr2 = trace({"data": {}, "sample": tr1.trace.out_dicts["sample"]}, LogProbK(tr1.trace.fn.dim_names))
val = chain_dist(tr2)

tr2.trace.out_dicts
# sample_trace = trace({}, SampleLogProbK(K=4, protected_dims=2))
# _ = chain_dist(sample_trace)
# eval_trace = trace({"data": {}, "sample": sample_trace.trace.out_dicts["sample"]}, \
#                    LogProbK(sample_trace.trace.fn.dim_names))
# _ = chain_dist(eval_trace)
# p(eval_trace)


# x = sample_generator(nProtected, p, dataName="__c")
# sample_trace = sampler(k, nProtected)
# q.sample(sample_trace)
# eval_trace = evaluator(sample_trace, nProtected, data={"__c": x})
# compute P logprobs 


#sum_out_pos(tr2)
lps = tr2.trace.out_dicts["log_prob"]
 
#tpp.combine_tensors(lps)

tr2.trace.out_dicts

{'sample': {'__a': tensor([[[ 6.4433,  9.0297, 10.5075]],
  
          [[12.4035, 10.4682,  4.3390]],
  
          [[10.4899,  7.4070,  8.6449]],
  
          [[ 9.8458,  9.2511,  7.2392]]], names=('_k__a', None, None)),
  '__b': tensor([[[[ 5.0056,  4.1001,  9.1492]]],
  
  
          [[[ 8.7417, 19.9388,  1.9296]]],
  
  
          [[[ 7.5112,  3.7101,  9.2609]]],
  
  
          [[[ 5.4997,  3.9223,  5.0610]]]], names=('_k__b', '_k__a', None, None)),
  '__c': tensor([[[[[ 5.1926,  3.4912, 10.6923]]]],
  
  
  
          [[[[ 7.3232, 21.7457,  0.8533]]]],
  
  
  
          [[[[ 5.4600,  5.1678, 11.6522]]]],
  
  
  
          [[[[ 4.7056,  3.8290,  7.8642]]]]],
         names=('_k__c', '_k__b', '_k__a', None, None))},
 'log_prob': {'__a': tensor([[[-2.3506, -2.1443, -2.2651]],
  
          [[-2.6814, -2.2596, -2.9556]],
  
          [[-2.2626, -2.2107, -2.1413]],
  
          [[-2.1895, -2.1514, -2.2295]]], names=('_k__a', None, None)),
  '__b': tensor([[[[ -1.7383,  -2.7746,  -2.19

In [None]:
ep = 20 #5000 
error, tvi = main(nvars=3, k=2, epochs=ep, true_mean=TRUE_MEAN_A, lr=0.1)

mean_error = error.abs().sum() / 3 
error_percent = mean_error / TRUE_MEAN_A * 100
print(f"Mean error on the parameter's of A: {error_percent}%")

## VI, No plates but including deletes


In [None]:
# def chain_dist_del(trace):
#     a = trace["a"](Norm(t.ones(n)))
#     b = trace["b"](Norm(a))
#     c = trace["c"](Norm(b))
#     (c,) = trace.delete_names(("a", "b"), (c,))
#     d = trace["d"](Norm(c))
    
#     return c


In [None]:
# # call sampler on Q. 
# # gives you the samples and a log Q tensor `log_prob`
# tr1 = sampler(draws, nProtected, data=data)

# val = P(tr1)
# log_q = tr1.trace.out_dicts["log_prob"]

# # compute the log_probs

# # pass these to evaluator, which does a lookup for all the latents 
# # gives you log P for each latent
# tr2 = evaluator(tr1, nProtected, data=data)
# val = P(tr2)

# #tr2.trace.out_dicts["log_prob"]
# #log_p = 

# #Q = pytorch.module
# #    - `q.forward()` will look like chain_dist
    

# #- optimise it


## plate VI

- For plates, we just don't filter [@17](https://github.com/LaurenceA/tpp/blob/bd1fe20dcf86a1c02cc0424632571fba998d104f/utils.py#L17)
- Painful stuff: need to keep the generative order (e.g. a, b, c, d)
    - because we start by summing the lowest-level plates
        - solution: enforce that the last variable is a leaf e.g. `d`
- Careful when combining P & Q tensors: maintain the ordering!

- Plates: doing the summation backwards through the plates, yeah?
    - This implies tricky implementation blah
    - Py 3.6 dicts are ordered by insertion though, so use that


In [None]:
# a factorised approx posterior. generate 3 simple variables
# sample along the chain

# a ~ N([1],[3])
# b ~ N(a,[3])
# c ~ N(b,[3])

n = 3
scale = n
Norm = lambda mu, var : WrappedDist(Normal, mu, var)

TRUE_MEAN_A = 10


# a param placeholder
# Hardcoding 2 params for each var, for now
# factorised Gaussian with learned means and covs
class ChainQ(nn.Module):
    def __init__(self, n=3):
        super().__init__()
        self.mean_a = nn.Parameter(t.ones(n))
        self.mean_b = nn.Parameter(t.ones(n))
        self.logscale_a = nn.Parameter(t.ones(n)) # t.log(t.ones(n))
        self.logscale_b = nn.Parameter(t.ones(n))
    
    # TODO: make this actually depend on the params
    def sample(self, trace) :
        a = trace["a"](Norm(self.mean_a, t.exp(self.logscale_a)))
        b = trace["b"](Norm(self.mean_b, t.exp(self.logscale_b)))

        
# example directed graph with plate repeats
# 3(a) -> 4(b) -> c -> d
def plate_dist(trace, n=3):
    Na = Norm(t.ones(n) * TRUE_MEAN_A, 1)
    a = trace["a"](Na, plate_name="A", plate_shape=3)
    Nb = Norm(a, 1)
    b = trace["b"](Nb, plate_name="B", plate_shape=4)
    Nc = Norm(b, 1)
    c = trace["c"](Nc)
    
    #(c,) = trace.delete_names(("a", "b"), (c,))
    #Nd = Norm(c, 1)
    #d = trace["d"](Nd)
    
    return c



In [None]:
# One-pass example
k = 2
nProtected = 2
p = plate_dist
q = ChainQ()

x = sample_generator(nProtected, p, dataName="__c")
tr = sample_and_eval(plate_dist, draws=2, nProtected=2)

# sample_trace = sampler(k, nProtected, data={"__c": x})
# q.sample(sample_trace)
# eval_trace = evaluator(sample_trace, nProtected, data={"__c": x})
# # compute P logprobs 
# p(eval_trace)

sum_out_pos(tr)
lps = tr.trace.out_dicts["log_prob"]

tpp.combine_over_plates(lps)

In [None]:

sample_trace = sampler(k, nProtected, data={"__c": x})
# sample recognition model Q -> Q-sample and Q-logprobs
q.sample(sample_trace)

eval_trace = evaluator(sample_trace, nProtected, data={"__c": x})
# compute P logprobs 
p(eval_trace)

sum_out_pos(sample_trace)
sum_out_pos(eval_trace)
lps = tr.trace.out_dicts["log_prob"]

tpp.combine_over_plates(lps)

In [None]:
# Same as above but with new combine func
class TVI(nn.Module) :
    def __init__(self, p, q, k, x, nProtectedDims):
        super().__init__()
        self.p = p
        self.q = q
        self.k = k
        self.nProtected = nProtectedDims
        
        self.data_dict = {}
        self.data_dict["__c"] = []
        self.data = nn.Parameter(x, requires_grad=False) 
        
    def forward(self):
        """
            1. s = sample Q
            2. lp_Q = eval Q.logprob(s)
            3. lp_P = eval P.logprob(s)
            4. f = lp_P - lp_Q
            5. loss = combine fs
        """
        self.data_dict["__c"].append(self.data)
        
        # init traces at each step
        sample_trace = sampler(self.k, self.nProtected, data={"__c": self.data})
        # sample recognition model Q -> Q-sample and Q-logprobs
        self.q.sample(sample_trace)
        
        # Pass Q samples to new trace
        eval_trace = evaluator(sample_trace, self.nProtected, data={"__c": self.data})
        # compute P logprobs 
        self.p(eval_trace)
        
        sum_out_pos(eval_trace)
        sum_out_pos(sample_trace)
        # align dims in Q
        sample_trace.trace.out_dicts = rename_placeholders(eval_trace, sample_trace)
        
        # to ratio land: P.log_probs - Q.log_probs (just the latents)
        tensors = subtract_latent_log_probs(eval_trace, sample_trace)
        
        # reduce gives loss
        loss_dict = combine_over_plates(tensors)

        return tpp.undict(loss_dict)


# Recovering mean of first var
def main(nvars=3, nProtected=2, k=2, epochs=2000, true_mean=10, lr=0.2) :
    Q = ChainQ()
    P = plate_dist
    
    # Get _c data by sampling generator
    x = sample_generator(nProtected, P, dataName="__c")
    tvi = setup_and_run(TVI(P, Q, k, x, nProtected), epochs, eta=lr)
    
    return get_error_on_a(true_mean, nvars, tvi), tvi

In [None]:
ep = 20 #5000 
error, tvi = main(nvars=3, k=2, epochs=ep, true_mean=TRUE_MEAN_A, lr=0.1)

mean_error = error.abs().sum() / 3 
error_percent = mean_error / TRUE_MEAN_A * 100
print(f"Mean error on the parameter's of A: {error_percent}%")

# Alt frontend

- PLATES
    - LogProbK -> log_probs will have the right order but nothing else will
    - prior component K -> K_a
    - remove K from all -> replace with k_a, k_b
- simple_trace
    - convention: sample K first, 
    - plates go up-and-left as we go deeper
    - 
- get a few examples of models
    - "we only need enough variables for a neurips paper"
- hopefully able to handle the pos dims generically in the prior
- Also want some checking code: run through Ps & lps to see if its sane


In [None]:
def P(tr): 
    tr.set_names('a', 'b')
    tr['a'] = Normal(tr.zeros(()), 1)
    tr['b'] = Normal(tr['a'], 1)
    tr.add_remove_names(add_names=('c',), remove_names=('a',))
    tr['c'] = Normal(tr['b'], 1, sample_shape=3, sample_names='plate_a')
    print(tr['c'].names)
    print(tr['c'].shape)
    tr['obs'] = Normal(tr['c'], 1, sample_shape=5, sample_names='plate_b')


class Q(nn.Module):
    def __init__(self):
        super().__init__()
        self.m_a = nn.Parameter(t.zeros(()))
        self.m_b = nn.Parameter(t.zeros(()))
        self.m_c = nn.Parameter(t.zeros((3,), names=('plate_a',)))
    
    def forward(self, tr):
        tr['a'] = Normal(tr.pad(self.m_a), 1)
        tr['b'] = Normal(tr.pad(self.m_b), 1)
        tr['c'] = Normal(tr.pad(self.m_c), 1)
        

        
