# ETEL

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.special import loggamma

from etel import log_pX𝜃, grad_log_pX𝜃, grad_InfoNCE

dtype = torch.float64
dtype_np = np.float64
torch.set_default_dtype(dtype)

def simulator(T, 𝜇, 𝜎=1.0, dist='gauss'):
    assert dist in ['gauss', 'gamma']
    N = np.size(𝜇)
    if dist == 'gauss':
        x = np.atleast_2d(𝜎).T * np.random.normal(size=(N,T)) + np.atleast_2d(𝜇).T
    elif dist == 'gamma':
        assert np.all(𝜇 >= 0.0)
        N = 𝜇.size
        s = 𝜎**2 / 𝜇
        k = 𝜇 / s
        x = np.zeros((N,T))
        for i in range(N):
            x[i] = np.random.gamma(size=T, shape=k[i], scale=s[i])
    return x 

class EE(torch.nn.Module):
    def __init__(self, D, M, K):
        super().__init__()
        self.D = D # dimensionality of each data point x_i, i=1,...,T
        self.M = M # dimensionality of model parameter 𝜃
        self.K = K # number of estimating equations
        #self.layer1 = torch.nn.Linear(D+M, K)
        self.A = torch.nn.Parameter(torch.randn((K,D+M))/np.sqrt(D+M))
        self.B = torch.nn.Parameter(torch.randn((K,D+M,D+M))/np.sqrt(D+M))
        self.c = torch.nn.Parameter(torch.randn((K)))
    def forward(self, x, 𝜃):
        # x.shape = N x T x D
        # 𝜃.shape = N x     M
        x_full = torch.cat((x,𝜃.unsqueeze(-2).expand(-1,x.shape[-2],-1)),dim=-1)
        # Ax.shape = N x T x K
        Ax = torch.matmul(self.A.unsqueeze(0).unsqueeze(0), x_full.unsqueeze(-1))
        # Bx.shape = N x T x K x D
        Bx = torch.matmul(self.B.unsqueeze(0).unsqueeze(0), x_full.unsqueeze(-2).unsqueeze(-1))
        out = (Ax + torch.matmul(x_full.unsqueeze(-2).unsqueeze(-2), Bx).squeeze(-2)).squeeze(-1) + self.c.unsqueeze(0).unsqueeze(0)
        return out
        #return self.layer1(x_full)
    def jacobian_pars(self, X, 𝜃):
        N,T,D = X.shape
        M = 𝜃.shape[-1]
        # X.shape     = N x T x D
        # 𝜃.shape     = N x     M
        X_full = torch.cat((X,𝜃.unsqueeze(-2).expand(-1,X.shape[-2],-1)),dim=-1)         # N x T x (D+M)
        eye = torch.eye(self.K).reshape(1,1,self.K,self.K).repeat(N,T,1,1).unsqueeze(-1) # N x T x K x K x 1
        dGdA = eye.repeat(1,1,1,1,D+M) * X_full.reshape(N,T,1,1,D+M)                     # N x T x K x K x D+M
        XXT = torch.matmul(X_full.unsqueeze(-1), X_full.unsqueeze(-2))
        dGdB = eye.unsqueeze(-1).repeat(1,1,1,1,D+M,D+M) * XXT.unsqueeze(2).unsqueeze(2) # N x T x K x K x D+M x D+M
        dGdc = eye.squeeze(-1)                                                           # N x T x K x K
        return {'A': dGdA, 'B' : dGdB, 'c' : dGdc}


In [None]:
D, M = 1, 2
K = M # use as many estimating equations as there are model parameters

g = EE(D,M,K)

# manually set estimating equations to g(x,𝜃) = x - 𝜃 such that the model fits the mean 𝜃 = E[X]
d = g.state_dict()
d['A'] = torch.tensor([[1.0,-1.0, 0.0],
                       [0.0, 0.0,-1.0]])
d['B'] = torch.tensor([[[0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0]],
                       [[1.0, 0.0, 0.0],
                        [0.0,-1.0, 0.0],
                        [0.0, 0.0, 0.0]]])
d['c'] = torch.tensor([0.0, 0.0])
g.load_state_dict(d)

In [None]:
import matplotlib.pyplot as plt
N = 10
𝜇 = 0.5 # np.random.normal(size=N)
𝜎= 1.0 
#𝜃 = torch.linspace(0.2, 1., N+2)[1:-1].unsqueeze(-1) # range of test values for 𝜃
grids = np.meshgrid(np.linspace(0.1, 1., N+2)[1:-1], np.linspace(0.5, 2.0, N))
𝜃 = torch.stack([torch.tensor(xx.flatten()) for xx in grids],axis=-1)
dist = 'gauss'

Ts = [10, 50, 100, 1000, 10000]
plt.figure(figsize=(16,5))
for i,T in enumerate(Ts):

    X = torch.tensor(simulator(T, 𝜇=𝜇, 𝜎=𝜎, dist=dist), dtype=dtype).unsqueeze(-1)
    X = X[0].unsqueeze(0).repeat(𝜃.shape[0],1,1) # fix one dataset
    ll = log_pX𝜃(g, X, 𝜃).reshape(N,N)
    
    # compare to true likelihood
    if dist == 'gauss':
        ll_true = (- 0.5 * T * torch.log(𝜃[:,1]) - (0.5/(𝜃[:,1:]) * (𝜃[:,:1] - X.squeeze(-1))**2).sum(axis=-1)).detach().numpy().reshape(N,N)

    # exponentially tilted empirical likelihood 
    plt.subplot(2,len(Ts),i+1)
    plt.imshow(ll.detach().numpy() - ll.max().detach().numpy(), label='Empirical log-likelihood')
    plt.title("T = "+str(T))
        
    plt.subplot(2,len(Ts),len(Ts)+i+1)
    plt.imshow(ll_true - ll_true.max(), label='Gaussian log-likelihood')
    print(𝜃[np.argmax(ll_true)], 𝜃[np.argmax(ll.detach().numpy())])
    if i == 0:
        plt.legend()


In [None]:
N_train, T_train = 100, 1000

𝜇s_train = np.random.randn(N_train) 
𝜎s_train = np.random.uniform(size=N_train) + 0.5
if dist == 'gamma':
    𝜇s_train = np.abs(𝜇s_train )

X_train = torch.tensor(simulator(T_train, 𝜇=𝜇s_train, 𝜎=𝜎s_train, dist=dist), dtype=dtype).unsqueeze(-1)
𝜇s_train = torch.tensor(𝜇s_train, dtype=dtype).unsqueeze(-1)
𝜎s_train = torch.tensor(𝜎s_train, dtype=dtype).unsqueeze(-1)
𝜃s_train = torch.cat((𝜇s_train, 𝜎s_train),axis=-1)


In [None]:
import tqdm
n_steps = 100
losses = np.zeros(n_steps)

lr = 5e-8
batch_size = 100

alg = 'MEL'
assert alg in ['NCE', 'MEL']

for i in tqdm.tqdm(range(n_steps)):
    idx = torch.tensor(np.random.choice(N_train, batch_size, replace=False))

    #try:
    if alg == 'NCE':
        grad, losses[i] = grad_InfoNCE(g, X_train[idx], 𝜃s_train[idx])
    elif alg == 'MEL':
        losses[i] = log_pX𝜃(g, X_train[idx], 𝜃s_train[idx]).sum()
        grad = grad_log_pX𝜃(g, X_train[idx], 𝜃s_train[idx])
    #except:
    #    print('gradient/loss computation broke !')
    #    break

    d = g.state_dict()
    d['A'] += lr * grad['A'].sum(axis=0).reshape(*d['A'].shape)
    d['B'] += lr * grad['B'].sum(axis=0).reshape(*d['B'].shape)
    d['c'] += lr * grad['c'].sum(axis=0).reshape(*d['c'].shape)
    g.load_state_dict(d)

import matplotlib.pyplot as plt

plt.plot(losses[:i])
plt.xlabel('# iter')

In [None]:
plt.figure(figsize=(16,5))
for i,T in enumerate(Ts):

    X = torch.tensor(simulator(T, 𝜇=𝜇, 𝜎=𝜎, dist=dist), dtype=dtype).unsqueeze(-1)
    X = X[0].unsqueeze(0).repeat(𝜃.shape[0],1,1) # fix one dataset
    ll = log_pX𝜃(g, X, 𝜃).reshape(N,N)
    
    # compare to true likelihood
    if dist == 'gauss':
        ll_true = (- 0.5 * T * torch.log(𝜃[:,1]) - (0.5/(𝜃[:,1:]) * (𝜃[:,:1] - X.squeeze(-1))**2).sum(axis=-1)).detach().numpy().reshape(N,N)

    # exponentially tilted empirical likelihood 
    plt.subplot(2,len(Ts),i+1)
    plt.imshow(ll.detach().numpy() - ll.max().detach().numpy(), label='Empirical log-likelihood')
    plt.title("T = "+str(T))
        
    plt.subplot(2,len(Ts),len(Ts)+i+1)
    plt.imshow(ll_true - ll_true.max(), label='Gaussian log-likelihood')
    print(𝜃[np.argmax(ll_true)], 𝜃[np.argmax(ll.detach().numpy())])
    if i == 0:
        plt.legend()


In [None]:
def test_dFd𝜆():
    
    N, T, K, D, M = 10, 1000, 1, 1, 1
    𝜃 = np.random.randn(N)
    X = torch.tensor(simulator(T, 𝜇=𝜃, 𝜎=1.0), dtype=dtype).unsqueeze(-1)
    𝜃 = torch.tensor(𝜃, dtype=dtype).unsqueeze(-1)

    g = EE(D,M,K)
    d = g.layer1.state_dict()
    d['weight'] = torch.tensor([[-1.0,1.0]])
    d['bias'] = torch.tensor([0.0])
    g.layer1.load_state_dict(d)    

    G = g(X, 𝜃)
    def F(𝜆, G):
        w = torch.exp(log_w_opt(𝜆, G))
        return (G * w.unsqueeze(-1)).sum(axis=-2)

    # experiment 1: minimize mean parameter F(𝜆, G) wrt 𝜆
    n_steps = 50

    𝜆  = torch.ones(N,K)
    𝜆s = torch.zeros((n_steps, N))
    losses = np.zeros((N,n_steps))
    for i in range(n_steps):

        losses[:,i] = 0.5 * (F(𝜆, G).detach().numpy()[:,0])**2

        w = torch.exp(log_w_opt(𝜆, G))
        GTdiagw = (G  * w.unsqueeze(-1)).transpose(-1,-2)
        dFd𝜆 = comp_dFd𝜆(𝜆, G, GTdiagw)
        𝜆 -= (F(𝜆, G) * dFd𝜆[...,0])
        𝜆s[i,:] = 1. * 𝜆[:,0]

    plt.figure(figsize=(12,6))
    plt.subplot(2,2,1)
    plt.semilogy(losses.T)
    plt.subplot(2,2,2)
    plt.plot(𝜆s.detach().numpy())

    # experiment 2: maximize mean parameter F(𝜆, G) wrt 𝜆 ( should equal largest suff statistic g(xi,𝜃) )
    n_steps = 2000

    𝜆  = torch.ones(N,K)
    𝜆s = torch.zeros((n_steps, N))
    losses = np.zeros((N,n_steps))
    for i in range(n_steps):

        losses[:,i] = 0.5 * (F(𝜆, G).detach().numpy()[:,0])**2

        w = torch.exp(log_w_opt(𝜆, G))
        GTdiagw = (G  * w.unsqueeze(-1)).transpose(-1,-2)
        dFd𝜆 = comp_dFd𝜆(𝜆, G, GTdiagw)
        𝜆 += (F(𝜆, G) * dFd𝜆[...,0])
        𝜆s[i,:] = 1. * 𝜆[:,0]

    plt.subplot(2,3,4)
    plt.plot(losses.T)
    plt.subplot(2,3,5)
    plt.plot(𝜆s.detach().numpy())
    plt.subplot(2,3,6)
    max_vals = G[...,0].max(axis=1)[0].detach().numpy()**2/2
    plt.plot(max_vals, losses[:,-1], 'b*')
    plt.plot([max_vals.min(), max_vals.max()], [max_vals.min(), max_vals.max()], 'k')
    plt.show()

test_dFd𝜆()

In [None]:
def test_dFd𝜙():

    N, T, K, D, M = 10, 100, 1, 1, 1
    𝜃 = np.random.randn(N)
    X = torch.tensor(simulator(T, 𝜇=𝜃, 𝜎=1.0), dtype=dtype).unsqueeze(-1)
    𝜃 = torch.tensor(𝜃, dtype=dtype).unsqueeze(-1)
    𝜆  = torch.randn((N,K))

    def F(𝜆, G):
        w = torch.exp(log_w_opt(𝜆, G))
        return (G * w.unsqueeze(-1)).sum(axis=-2)

    # experiment 1: minimize mean parameter F(𝜆, G) wrt 𝜙
    n_steps = 1000

    g = EE(D,M,K)
    dim_𝜙 = 3
    𝜙  = torch.randn(dim_𝜙) #torch.tensor([1.0, -1.0, 0.0])
    d = g.layer1.state_dict()
    d['weight'] = 𝜙[:-1].unsqueeze(0)
    d['bias'] = 𝜙[-1:]
    g.layer1.load_state_dict(d)    

    𝜙s = torch.zeros((n_steps, dim_𝜙))
    losses = np.zeros((N,n_steps))
    for i in range(n_steps):

        G = g(X, 𝜃)
        w = torch.exp(log_w_opt(𝜆, G))

        GTdiagw = (G  * w.unsqueeze(-1)).transpose(-1,-2)

        dGd𝜙 = g.jacobian_pars(X, 𝜃)
        dFd𝜙 = comp_dFd𝜙(𝜆, G, w, GTdiagw, dGd𝜙)
        grad = torch.bmm(F(𝜆, G).unsqueeze(-2), dFd𝜙).mean(axis=0)

        losses[:,i] = 0.5 * ((F(𝜆, G).detach().numpy()[:,0])**2).mean(axis=0)

        lr = 1e-1
        d = g.layer1.state_dict()
        d['weight'] += - lr * grad[0,:-1]
        d['bias'] += - lr * grad[0,-1:]
        g.layer1.load_state_dict(d)    
        𝜙s[i, :-1] = d['weight']
        𝜙s[i, -1:] = d['bias']

    plt.figure(figsize=(12,6))
    plt.subplot(2,2,1)
    plt.semilogy(losses.T)
    plt.subplot(2,2,2)
    plt.plot(𝜙s.detach().numpy())
    plt.show()

    print('final 𝜙: ', 𝜙s[-1])

test_dFd𝜙()

In [None]:
from torch.autograd.functional import jacobian

class Solver_𝜆(torch.autograd.Function):

    @staticmethod
    def forward(ctx, X, 𝜃):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        with torch.no_grad():
            G = g(X, 𝜃)
            dGd𝜙 = g.jacobian_pars(X, 𝜃)
            N, K = G.shape[0], G.shape[-1]
            𝜆0 = np.zeros((N, K))
            𝜆 = np.zeros_like(𝜆0)
            for i in range(N):
                def Φ_G(𝜆):
                    return Φ_np(𝜆, G[i].numpy())
                𝜆[i] = spo.minimize(Φ_G, 𝜆0[i])['x']
        ctx.save_for_backward(G, dGd𝜙, 𝜆)

        return torch.tensor(𝜆,dtype=dtype)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.

        g(X,𝜃)             # N x T x K
        𝜆                  # N x     K
        w                  # N x T
        gradΦ(𝜆, G)        # N x     K
        dGd𝜙               # N x T x K x dim(𝜙)

        """
        G,dGd𝜙,𝜆 = ctx.saved_tensors
        with torch.no_grad():
            w = torch.exp(log_w_opt(𝜆, G))
            dΦd𝜆 = gradΦ(𝜆, G)
            dFd𝜆 = torch.bmm(G.transpose(-1,-2), G  * w.unsqueeze(-1)) - torch.bmm(dΦd𝜆.unsqueeze(-1), dΦd𝜆.unsqueeze(-2))

            #dwd𝜙 = w * torch.matmul(dGd𝜙 - torch.matmul(dGd𝜙, w).unsqueeze(-2), 𝜆) 
            #dFd𝜙 = torch.bmm(dgd𝜙, w) + torch.matmul(dwd𝜙, G)
            dFd𝜙 = torch.bmm((G  * w.unsqueeze(-1)).transpose(-1,-2), (dGd𝜙 * 𝜆.unsqueeze(-2).unsqueeze(-1)).sum(axis=-2)) 

            d𝜆d𝜙 = torch.linalg.solve(dFd𝜆, dFd𝜙) # Inverse function theorem: d𝜆d𝜙 = inv(dFd𝜆) * dFd𝜙 
        return grad_output * d𝜆d𝜙


In [None]:
import matplotlib.pyplot as plt
N = 100
𝜇s = 0.5 * np.ones(N) # np.random.normal(size=N)
𝜎= 1.0
𝜃 = torch.linspace(0.2, 1., N+2)[1:-1].unsqueeze(-1) # range of test values for 𝜃
dist = 'gamma'

Ts = [10, 50, 100, 1000]
plt.figure(figsize=(16,5))
for i,T in enumerate(Ts):
    plt.subplot(1,len(Ts),i+1)

    X = torch.tensor(simulator(T, 𝜇=𝜇s, 𝜎=𝜎, dist=dist), dtype=dtype).unsqueeze(-1)
    X = X[0].unsqueeze(0).repeat(N,1,1) # fix one dataset
    ll = log_pX𝜃(g, X, 𝜃)

    # exponentially tilted empirical likelihood 
    plt.plot(𝜃.detach().numpy(), ll.detach().numpy() - ll.max().detach().numpy(), '.', color='orange', label='Empirical log-likelihood')
    # compare to true likelihood
    if dist == 'gauss':
        ll_true = (- 0.5 * T/𝜎**2 * (𝜃.squeeze(1) - X.mean(axis=(1,2)))**2).detach().numpy()
    elif dist == 'gamma':
        s = 𝜎**2 / 𝜃
        k = 𝜃 / s
        ll_true = ((k-1) * np.log(X.squeeze(-1)) -X.squeeze(-1)/s - k * np.log(s) - loggamma(s)).mean(axis=1)
    plt.plot(𝜃.detach().numpy(),ll_true - ll_true.max(), '.', color='b', label='Gaussian log-likelihood') 
    print(𝜃[np.argmax(ll_true)], 𝜃[np.argmax(ll.detach().numpy())])
    if i == 0:
        plt.legend()
    plt.title("T = "+str(T))

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(16,5))
for i,T in enumerate(Ts):
    plt.subplot(1,len(Ts),i+1)

    X = torch.tensor(simulator(T, 𝜇=𝜇s, 𝜎=𝜎, dist=dist), dtype=dtype).unsqueeze(-1)
    X = X[0].unsqueeze(0).repeat(N,1,1) # fix one dataset

    ll = log_pX𝜃(g, X, 𝜃)
    # exponentially tilted empirical likelihood 
    plt.plot(𝜃.detach().numpy(), ll.detach().numpy() - ll.max().detach().numpy(), '.', color='orange', label='Empirical log-likelihood')
    # compare to true likelihood
    if dist == 'gauss':
        ll_true = (- 0.5 * T/𝜎**2 * (𝜃.squeeze(1) - X.mean(axis=(1,2)))**2).detach().numpy()
    elif dist == 'gamma':
        s = 𝜎**2 / 𝜃
        k = 𝜃 / s
        ll_true = ((k-1) * np.log(X.squeeze(-1)) -X.squeeze(-1)/s - k * np.log(s) - loggamma(s)).sum(axis=1)
    plt.plot(𝜃.detach().numpy(),ll_true - ll_true.max(), '.', color='b', label='Gaussian log-likelihood') 
    if i == 0:
        plt.legend()
    plt.title("T = "+str(T))