# ETEL

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

import scipy.optimize as spo
from scipy.special import logsumexp, loggamma

dtype = torch.float32
dtype_np = np.float32
torch.set_default_dtype(dtype)

def simulator(T, 洧랞, 洧랥=1.0, dist='gauss'):
    assert dist in ['gauss', 'gamma']
    N = np.size(洧랞)
    if dist == 'gauss':
        x = 洧랥 * np.random.normal(size=(N,T)) + np.atleast_2d(洧랞).T
    elif dist == 'gamma':
        assert np.all(洧랞 >= 0.0)
        N = 洧랞.size
        s = 洧랥**2 / 洧랞
        k = 洧랞 / s
        x = np.zeros((N,T))
        for i in range(N):
            x[i] = np.random.gamma(size=T, shape=k[i], scale=s[i])
    return x 

class EE(torch.nn.Module):
    def __init__(self, D, M, K):
        super().__init__()
        self.D = D # dimensionality of each data point x_i, i=1,...,T
        self.M = M # dimensionality of model parameter 洧랚
        self.K = K # number of estimating equations
        self.layer1 = torch.nn.Linear(D+M, K)
    def forward(self, x, 洧랚):
        # x.shape = N x T x D
        # 洧랚.shape = N x     M
        x_full = torch.cat((x,洧랚.unsqueeze(-2).expand(-1,x.shape[-2],-1)),dim=-1)
        return self.layer1(x_full)
    def jacobian_pars(self, X, 洧랚):
        # X.shape     = N x T x D
        # 洧랚.shape     = N x     M
        # X_full.shape= N x T x (D+M)
        X_full = torch.cat((X,洧랚.unsqueeze(-2).expand(-1,X.shape[-2],-1)),dim=-1)
        # dGd洧랯.shape = N x T x K x dim(洧랯)
        dGdA = torch.matmul(torch.ones((*X.shape[:-1], self.K, 1)), X_full.unsqueeze(-2))
        dGdb = torch.ones((*X.shape[:-1], self.K, 1))
        return torch.cat((dGdA, dGdb), dim=-1)

def 풙(洧랝, G):
    洧랙 = torch.matmul(G, 洧랝.unsqueeze(-1)).squeeze(-1)
    return torch.logsumexp(洧랙, axis=-1)

def 풙_np(洧랝, G):
    洧랙 = G.dot(洧랝)    
    return logsumexp(洧랙, axis=-1)

def grad풙(洧랝, G):
    w = torch.exp(log_w_opt(洧랝, G))
    return torch.bmm(w.unsqueeze(-2), G).squeeze(-2)

def solve_洧랝(G):
    with torch.no_grad():
        N, K = G.shape[0], G.shape[-1]
        洧랝0 = np.zeros((N, K))
        洧랝 = np.zeros_like(洧랝0)
        for i in range(N):
            def 풙_G(洧랝):
                return 풙_np(洧랝, G[i].numpy())
            洧랝[i] = spo.minimize(풙_G, 洧랝0[i])['x']
    return torch.tensor(洧랝,dtype=dtype)

def log_w_opt(洧랝, G):
    洧랙 = torch.bmm(G, 洧랝.unsqueeze(-1)).squeeze(-1)
    return 洧랙 - torch.logsumexp(洧랙,axis=-1).unsqueeze(-1)

def log_pX洧랚(g, X, 洧랚):
    G = g(X, 洧랚)
    洧랝 = solve_洧랝(G)
    log_w = log_w_opt(洧랝, G)
    return log_w.sum(axis=-1)

def comp_dFd洧랝(洧랝, G, GTdiagw):

    F = grad풙(洧랝, G) # F(洧랝) = d풙d洧랝
    dFd洧랝 = torch.bmm(GTdiagw, G) - torch.bmm(F.unsqueeze(-1), F.unsqueeze(-2))

    return dFd洧랝

def comp_dFd洧랯(洧랝, G, w, GTdiagw, dGd洧랯):

    GTw = torch.bmm(w.unsqueeze(-2), G).transpose(-1,-2)
    M = torch.eye(洧랝.shape[-1]).unsqueeze(0) - torch.bmm(GTw, 洧랝.unsqueeze(-2))

    dFd洧랯 = torch.bmm(M, (dGd洧랯 * w.unsqueeze(-1).unsqueeze(-1)).sum(axis=1))
    dFd洧랯 = dFd洧랯 + torch.bmm(GTdiagw, (dGd洧랯 * 洧랝.unsqueeze(-2).unsqueeze(-1)).sum(axis=-2))

    return dFd洧랯

def grad_log_pX洧랚(g, X, 洧랚):

    T = X.shape[1] # X.shape is N x T x D
    G = g(X, 洧랚)
    洧랝 = solve_洧랝(G)
    w = torch.exp(log_w_opt(洧랝, G))
    GTdiagw = (G  * w.unsqueeze(-1)).transpose(-1,-2)
    dGd洧랯=g.jacobian_pars(X, 洧랚)

    # inverse function theorem 
    dFd洧랝 = comp_dFd洧랝(洧랝, G, GTdiagw)
    dFd洧랯 = comp_dFd洧랯(洧랝, G, w, GTdiagw, dGd洧랯)
    d洧랝d洧랯 = - torch.linalg.solve(dFd洧랝, dFd洧랯) # Inverse function theorem: d洧랝d洧랯 = inv(dFd洧랝) * dFd洧랯 

    # differentiating w* and g_洧랯(X,洧랚) wrt 洧랯
    diff  = (1. - T * w).unsqueeze(-1)
    grad = torch.bmm((diff * G).sum(axis=-2).unsqueeze(-2), d洧랝d洧랯) 
    grad = grad + torch.bmm(洧랝.unsqueeze(-2), (diff.unsqueeze(-1) * dGd洧랯).sum(axis=-3))

    return grad.squeeze(-2)

def loss_INFONCE(g, X, 洧랚):

    N = X.shape[0]
    idx = torch.arange(N)
    idx洧랚, idxX = torch.repeat_interleave(idx, N), idx.repeat(N)
    log_p_all = log_pX洧랚(g, X[idxX], 洧랚[idx洧랚]).reshape(N,N)
    log_normalizers = torch.logsumexp(log_p_all, axis=-1)
    log_p = torch.diag(log_p_all)
    return (log_p - log_normalizers).sum(axis=0)

def grad_INFONCE(g, X, 洧랚):

    N = X.shape[0]
    assert 洧랚.shape[0] == N # could untie this, but not necessary for basic usage

    idx = torch.arange(N)
    idx洧랚, idxX = torch.repeat_interleave(idx, N), idx.repeat(N)
    log_p_all = log_pX洧랚(g, X[idxX], 洧랚[idx洧랚]).reshape(N,N) # 洧랚s constant across rows, Xs constant across columns

    log_normalizers = torch.logsumexp(log_p_all, axis=-1)
    log_p = torch.diag(log_p_all)
    losses = log_p - log_normalizers

    v = (torch.eye(N) - torch.exp(log_p_all - log_normalizers.unsqueeze(-1))).unsqueeze(-1)  
    grad_log_p_all = grad_log_pX洧랚(g, X[idxX], 洧랚[idx洧랚]).reshape(N,N,-1)

    grads = (v * grad_log_p_all).sum(axis=-2)

    return grads.sum(axis=0), losses.sum(axis=0)


In [None]:
D, M = 1, 1
K = M # use as many estimating equations as there are model parameters

g = EE(D,M,K)

# manually set estimating equations to g(x,洧랚) = x - 洧랚 such that the model fits the mean 洧랚 = E[X]
d = g.layer1.state_dict()
d['weight'] = torch.tensor([[-0.7,0.3]])
d['bias'] = torch.tensor([0.1])
g.layer1.load_state_dict(d)

In [None]:
import matplotlib.pyplot as plt
N = 100
洧랞s = 0.5 * np.ones(N) # np.random.normal(size=N)
洧랥= 1.0
洧랚 = torch.linspace(0.2, 1., N+2)[1:-1].unsqueeze(-1) # range of test values for 洧랚
dist = 'gauss'

Ts = [10, 50, 100, 1000]
plt.figure(figsize=(16,5))
for i,T in enumerate(Ts):
    plt.subplot(1,len(Ts),i+1)

    X = torch.tensor(simulator(T, 洧랞=洧랞s, 洧랥=洧랥, dist=dist), dtype=dtype).unsqueeze(-1)
    X = X[0].unsqueeze(0).repeat(N,1,1) # fix one dataset
    ll = log_pX洧랚(g, X, 洧랚)

    # exponentially tilted empirical likelihood 
    plt.plot(洧랚.detach().numpy(), ll.detach().numpy() - ll.max().detach().numpy(), '.', color='orange', label='Empirical log-likelihood')
    # compare to true likelihood
    if dist == 'gauss':
        ll_true = (- 0.5 * T/洧랥**2 * (洧랚.squeeze(1) - X.mean(axis=(1,2)))**2).detach().numpy()
    elif dist == 'gamma':
        s = 洧랥**2 / 洧랚
        k = 洧랚 / s
        ll_true = ((k-1) * np.log(X.squeeze(-1)) -X.squeeze(-1)/s - k * np.log(s) - loggamma(s)).sum(axis=1)
    plt.plot(洧랚.detach().numpy(),ll_true - ll_true.max(), '.', color='b', label='Gaussian log-likelihood') 
    if i == 0:
        plt.legend()
    plt.title("T = "+str(T))

In [None]:
N_train, T_train = 100, 100

洧랞s_train = np.random.randn(N_train) 
if dist == 'gamma':
    洧랞s_train = np.abs(洧랞s_train )
洧랥_train = 1.0

X_train = torch.tensor(simulator(T_train, 洧랞=洧랞s_train, 洧랥=洧랥_train, dist=dist), dtype=dtype).unsqueeze(-1)
洧랞s_train = torch.tensor(洧랞s_train, dtype=dtype).unsqueeze(-1)

In [None]:
n_steps = 1000
losses = np.zeros(n_steps)

lr = 1e-5
batch_size = 10
for i in range(n_steps):
    idx = torch.tensor(np.random.choice(N_train, batch_size, replace=False))

    #try:
    #grad, losses[i] = grad_INFONCE(g, X_train[idx], 洧랞s_train[idx])
    losses[i] = log_pX洧랚(g, X_train[idx], 洧랞s_train[idx]).sum()
    grad = grad_log_pX洧랚(g, X_train[idx], 洧랞s_train[idx]).sum(axis=0)
    #except:
    #    print('gradient/loss computation broke !')
    #    break
    print('step #' + str(i+1) + '/' + str(n_steps) + ', value=' + str(losses[i]))
    d = g.layer1.state_dict()
    d['weight'] += lr * grad[:2].reshape(1,2)
    d['bias'] += lr * grad[-1]
    g.layer1.load_state_dict(d)

import matplotlib.pyplot as plt

plt.plot(losses[:i])
plt.xlabel('# iter')

d['weight'], d['bias']

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(16,5))
for i,T in enumerate(Ts):
    plt.subplot(1,len(Ts),i+1)

    X = torch.tensor(simulator(T, 洧랞=洧랞s, 洧랥=洧랥, dist=dist), dtype=dtype).unsqueeze(-1)
    X = X[0].unsqueeze(0).repeat(N,1,1) # fix one dataset

    ll = log_pX洧랚(g, X, 洧랚)
    # exponentially tilted empirical likelihood 
    plt.plot(洧랚.detach().numpy(), ll.detach().numpy() - ll.max().detach().numpy(), '.', color='orange', label='Empirical log-likelihood')
    # compare to true likelihood
    if dist == 'gauss':
        ll_true = (- 0.5 * T/洧랥**2 * (洧랚.squeeze(1) - X.mean(axis=(1,2)))**2).detach().numpy()
    elif dist == 'gamma':
        s = 洧랥**2 / 洧랚
        k = 洧랚 / s
        ll_true = ((k-1) * np.log(X.squeeze(-1)) -X.squeeze(-1)/s - k * np.log(s) - loggamma(s)).sum(axis=1)
    plt.plot(洧랚.detach().numpy(),ll_true - ll_true.max(), '.', color='b', label='Gaussian log-likelihood') 
    if i == 0:
        plt.legend()
    plt.title("T = "+str(T))

In [None]:
def test_dFd洧랝():
    
    N, T, K, D, M = 10, 1000, 1, 1, 1
    洧랚 = np.random.randn(N)
    X = torch.tensor(simulator(T, 洧랞=洧랚, 洧랥=1.0), dtype=dtype).unsqueeze(-1)
    洧랚 = torch.tensor(洧랚, dtype=dtype).unsqueeze(-1)

    g = EE(D,M,K)
    d = g.layer1.state_dict()
    d['weight'] = torch.tensor([[-1.0,1.0]])
    d['bias'] = torch.tensor([0.0])
    g.layer1.load_state_dict(d)    

    G = g(X, 洧랚)
    def F(洧랝, G):
        w = torch.exp(log_w_opt(洧랝, G))
        return (G * w.unsqueeze(-1)).sum(axis=-2)

    # experiment 1: minimize mean parameter F(洧랝, G) wrt 洧랝
    n_steps = 50

    洧랝  = torch.ones(N,K)
    洧랝s = torch.zeros((n_steps, N))
    losses = np.zeros((N,n_steps))
    for i in range(n_steps):

        losses[:,i] = 0.5 * (F(洧랝, G).detach().numpy()[:,0])**2

        w = torch.exp(log_w_opt(洧랝, G))
        GTdiagw = (G  * w.unsqueeze(-1)).transpose(-1,-2)
        dFd洧랝 = comp_dFd洧랝(洧랝, G, GTdiagw)
        洧랝 -= (F(洧랝, G) * dFd洧랝[...,0])
        洧랝s[i,:] = 1. * 洧랝[:,0]

    plt.figure(figsize=(12,6))
    plt.subplot(2,2,1)
    plt.semilogy(losses.T)
    plt.subplot(2,2,2)
    plt.plot(洧랝s.detach().numpy())

    # experiment 2: maximize mean parameter F(洧랝, G) wrt 洧랝 ( should equal largest suff statistic g(xi,洧랚) )
    n_steps = 2000

    洧랝  = torch.ones(N,K)
    洧랝s = torch.zeros((n_steps, N))
    losses = np.zeros((N,n_steps))
    for i in range(n_steps):

        losses[:,i] = 0.5 * (F(洧랝, G).detach().numpy()[:,0])**2

        w = torch.exp(log_w_opt(洧랝, G))
        GTdiagw = (G  * w.unsqueeze(-1)).transpose(-1,-2)
        dFd洧랝 = comp_dFd洧랝(洧랝, G, GTdiagw)
        洧랝 += (F(洧랝, G) * dFd洧랝[...,0])
        洧랝s[i,:] = 1. * 洧랝[:,0]

    plt.subplot(2,3,4)
    plt.plot(losses.T)
    plt.subplot(2,3,5)
    plt.plot(洧랝s.detach().numpy())
    plt.subplot(2,3,6)
    max_vals = G[...,0].max(axis=1)[0].detach().numpy()**2/2
    plt.plot(max_vals, losses[:,-1], 'b*')
    plt.plot([max_vals.min(), max_vals.max()], [max_vals.min(), max_vals.max()], 'k')
    plt.show()

test_dFd洧랝()

In [None]:
def test_dFd洧랯():

    N, T, K, D, M = 10, 100, 1, 1, 1
    洧랚 = np.random.randn(N)
    X = torch.tensor(simulator(T, 洧랞=洧랚, 洧랥=1.0), dtype=dtype).unsqueeze(-1)
    洧랚 = torch.tensor(洧랚, dtype=dtype).unsqueeze(-1)
    洧랝  = torch.randn((N,K))

    def F(洧랝, G):
        w = torch.exp(log_w_opt(洧랝, G))
        return (G * w.unsqueeze(-1)).sum(axis=-2)

    # experiment 1: minimize mean parameter F(洧랝, G) wrt 洧랯
    n_steps = 1000

    g = EE(D,M,K)
    dim_洧랯 = 3
    洧랯  = torch.randn(dim_洧랯) #torch.tensor([1.0, -1.0, 0.0])
    d = g.layer1.state_dict()
    d['weight'] = 洧랯[:-1].unsqueeze(0)
    d['bias'] = 洧랯[-1:]
    g.layer1.load_state_dict(d)    

    洧랯s = torch.zeros((n_steps, dim_洧랯))
    losses = np.zeros((N,n_steps))
    for i in range(n_steps):

        G = g(X, 洧랚)
        w = torch.exp(log_w_opt(洧랝, G))

        GTdiagw = (G  * w.unsqueeze(-1)).transpose(-1,-2)

        dGd洧랯 = g.jacobian_pars(X, 洧랚)
        dFd洧랯 = comp_dFd洧랯(洧랝, G, w, GTdiagw, dGd洧랯)
        grad = torch.bmm(F(洧랝, G).unsqueeze(-2), dFd洧랯).mean(axis=0)

        losses[:,i] = 0.5 * ((F(洧랝, G).detach().numpy()[:,0])**2).mean(axis=0)

        lr = 1e-1
        d = g.layer1.state_dict()
        d['weight'] += - lr * grad[0,:-1]
        d['bias'] += - lr * grad[0,-1:]
        g.layer1.load_state_dict(d)    
        洧랯s[i, :-1] = d['weight']
        洧랯s[i, -1:] = d['bias']

    plt.figure(figsize=(12,6))
    plt.subplot(2,2,1)
    plt.semilogy(losses.T)
    plt.subplot(2,2,2)
    plt.plot(洧랯s.detach().numpy())
    plt.show()

    print('final 洧랯: ', 洧랯s[-1])

test_dFd洧랯()

In [None]:
from torch.autograd.functional import jacobian

class Solver_洧랝(torch.autograd.Function):

    @staticmethod
    def forward(ctx, X, 洧랚):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        with torch.no_grad():
            G = g(X, 洧랚)
            dGd洧랯 = g.jacobian_pars(X, 洧랚)
            N, K = G.shape[0], G.shape[-1]
            洧랝0 = np.zeros((N, K))
            洧랝 = np.zeros_like(洧랝0)
            for i in range(N):
                def 풙_G(洧랝):
                    return 풙_np(洧랝, G[i].numpy())
                洧랝[i] = spo.minimize(풙_G, 洧랝0[i])['x']
        ctx.save_for_backward(G, dGd洧랯, 洧랝)

        return torch.tensor(洧랝,dtype=dtype)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.

        g(X,洧랚)             # N x T x K
        洧랝                  # N x     K
        w                  # N x T
        grad풙(洧랝, G)        # N x     K
        dGd洧랯               # N x T x K x dim(洧랯)

        """
        G,dGd洧랯,洧랝 = ctx.saved_tensors
        with torch.no_grad():
            w = torch.exp(log_w_opt(洧랝, G))
            d풙d洧랝 = grad풙(洧랝, G)
            dFd洧랝 = torch.bmm(G.transpose(-1,-2), G  * w.unsqueeze(-1)) - torch.bmm(d풙d洧랝.unsqueeze(-1), d풙d洧랝.unsqueeze(-2))

            #dwd洧랯 = w * torch.matmul(dGd洧랯 - torch.matmul(dGd洧랯, w).unsqueeze(-2), 洧랝) 
            #dFd洧랯 = torch.bmm(dgd洧랯, w) + torch.matmul(dwd洧랯, G)
            dFd洧랯 = torch.bmm((G  * w.unsqueeze(-1)).transpose(-1,-2), (dGd洧랯 * 洧랝.unsqueeze(-2).unsqueeze(-1)).sum(axis=-2)) 

            d洧랝d洧랯 = torch.linalg.solve(dFd洧랝, dFd洧랯) # Inverse function theorem: d洧랝d洧랯 = inv(dFd洧랝) * dFd洧랯 
        return grad_output * d洧랝d洧랯
