# training frameworks for the RPM on continous latents

Consider a recognition-parametrized model
$$
p_\theta(\mathcal{X},\mathcal{Z}) = p_\theta(\mathcal{Z}) \prod_j \frac{f_{\theta_j}(\mathcal{Z}|{x}_j) p_j(x_j)}{F_{\theta_j}(\mathcal{Z})}
$$
with exponential family $f_{\theta_j}(\mathcal{Z}|x_j), p(\theta(\mathcal{Z})$ trained via (a lower bound to) variational free energy
\begin{align}
    \log p_\theta(\mathcal{X}) &\geq \mathbb{E}_{q(\mathcal{Z}|\mathcal{X})}\left[ \log \frac{p_\theta(\mathcal{X}, \mathcal{Z})}{q_\psi(\mathcal{Z}|\mathcal{X})}\right] \nonumber \\
 &= \mathbb{E}_{q}[\log p_\theta(\mathcal{Z})] + \sum_j \mathbb{E}_{q}[\log f_{\theta_j}(\mathcal{Z}|{x}_j) ] + H[q] - \sum_j \mathbb{E}_q[\log F_{\theta_j}(\mathcal{Z})] + const. \nonumber \\    
 &\geq \mathbb{E}_{q}[\log p_\theta(\mathcal{Z})] + \sum_j \mathbb{E}_{q}[\log f_{\theta_j}(\mathcal{Z}|{x}_j) ] - (1-J) H[q] - \sum_j \log \int F_{\theta_j}(\mathcal{Z})/h_j(\mathcal{Z}) d\mathcal{Z} - \mathbb{E}_q[\log h_j(\mathcal{Z})] + const. \nonumber 
\end{align}
where $h_j(\mathcal{Z}) = \exp(\tilde{\eta}_j(\mathcal{X})^\top{}t(\mathcal{Z}))$ with inner variational parameters $\tilde{\eta}_j$ (one per data-point $\mathcal{X}$).

There are several ways we can handle the recognition model $q(\mathcal{Z}|\mathcal{X})$, even if for all of them we assume $q(\mathcal{Z}|\mathcal{X}) = q(\mathcal{Z} | \eta_q(\mathcal{X}))$ to lie in the same exponential family as $f_{\theta_j}(\mathcal{Z}|x_j), p_\theta(\mathcal{Z})$:
- $\eta_q(\mathcal{X}^{(n)})=\eta_q^{(n)}$: nonparametric natural parameter model (VI) with own natural parameter per datum $n$.
- $\eta_q(\mathcal{X}^{(n)})= NN_\psi(\mathcal{X}^{(n)})$ : parametric natural parameter model (VAE)  with recognition parameters $\psi$.
- $\eta_q(\mathcal{X}^{(n)}) = \sum_j \eta_{\theta_j}(x_j^{(n)}) + (1-J) \eta_0$ : analytic approach, holds if we assume $\forall j: F_{\theta_j}(\mathcal{Z}) = p_\theta(\mathcal{Z})$.

Similarly, we can handle inner variational parameters $\tilde{\eta}_j(\mathcal{X})$ as 
- $\forall j: \tilde{\eta}_j(\mathcal{X}^{(n)})=\tilde{\eta}_j^{(n)}$ : nonparametric.  
- $\forall j:\tilde{\eta}_{j}(\mathcal{X}^{(n)})= NN_{\psi_j}(\mathcal{X}^{(n)})$ : parametric with parameters $\{\psi_j\}_j$.
- $\forall j:\tilde{\eta}_{j}(\mathcal{X}^{(n)}) = \eta_0 - \eta_q(\mathcal{X}^{(n)})$ : Hugo's Ansatz with whatever choice for $\eta_q(\mathcal{X})$ we took above.

This gives us in total some $3 \times 3 = 9$ combinations on how to train the very same RPM $p_\theta(\mathcal{X},\mathcal{Z})$, which we will explore in the remainder. 

For experiments, we'll here focus on some with temporal structure in the latents. This motivates a time-series version of the RPM above, 
$$
p_\theta(\mathcal{X},\mathcal{Z}) = p_\theta(\mathcal{Z}) \prod_t \prod_j \frac{f_{\theta_j}(\mathcal{Z}_t|{x}_{jt}) p_{jt}(x_{jt})}{F_{\theta_j}(\mathcal{Z}_t)},
$$
which we will also test.

# comparing number of model parameters

In [None]:
%load_ext autoreload
%autoreload 2

from rpm import RPMEmpiricalMarginals
from utils_setup import init_gaussian_rpm
import torch
import numpy as np

In [None]:
colors = [ [[102,194,164],
            [44,162,95],     # VI: greens
            [0,109,44]
           ],
           [[116,169,207],
            [43,140,190],    # VAE: blues
            [4,90,141]
           ],
          [[254,153,41],
           [217,95,14],      # Amortized: oranges
           [153,52,4]              
          ]
         ]

colors = np.array(colors)/256


In [None]:
J, K, T = 10, 1, 50

init_rb_bandwidth = 1000.
obs_locs = torch.linspace(0, 1, T)

rpm_variants = ['VI', 'VAE', 'amortized']
amortize_ivis = ['none', 'full', 'use_q']

temporal = True

Ns = [10, 100, 1000]

numels = np.zeros((len(Ns), len(rpm_variants), len(amortize_ivis)))


for n,N in enumerate(Ns):
    
    # placeholder data
    xjs = [torch.zeros((N,T)) for j in range(J)]
    pxjs = RPMEmpiricalMarginals(xjs)
    
    for i,rpm_variant in enumerate(rpm_variants):
        for j,amortize_ivi in enumerate(amortize_ivis):

            # init model
            model = init_gaussian_rpm(
                N, J, K, T, pxjs,
                init_rb_bandwidth, obs_locs,
                rpm_variant, temporal, amortize_ivi,
                epochs=0, batch_size=N,
                dim_T=2, n_hidden=20, 
                optim_init_q=False, optim_init_ivi=False, optim_vae_params=False)[0]

            # count parameters
            numels[n,i,j] = sum([p.numel() for p in model.parameters()])

for n in range(len(Ns)):
    print(numels[n])

In [None]:
import matplotlib.pyplot as plt

for n in range(len(Ns)):
    plt.figure(figsize=(8,6))
    for i in range(len(amortize_ivis)):
        ax = plt.subplot(3,1,i+1)
        max_y_lim = numels[n].max()
        plt.ylim(0, max_y_lim)
        plt.bar(x=rpm_variants, height=numels[n,:,i], color=[colors[j,i] for j in range(len(rpm_variants))])
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.xaxis.tick_top()
        plt.yticks([0, 500*Ns[n], 1000*Ns[n]])
        valignm = 'top' if i==0 else 'bottom'
        for j in range(len(rpm_variants)):
            plt.text(j, numels[n,j,i], str(int(numels[n,j,i])),
                     horizontalalignment='center', verticalalignment=valignm)
        if not i == 0:
            plt.xticks([])
        plt.ylabel(amortize_ivis[i])
    plt.suptitle(r'total parameter counts ($\theta, \psi, \{\tilde{\eta}_{j}^{(n)}\}$), for N = ' +str(Ns[n]))
    plt.savefig('figs/paramCounts_N' + str(Ns[n]) + '.pdf', bbox_inches ='tight')


# comparing initializations by loss curves

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from utils_data_external import linear_regression_1D_latent as regLatent
from utils_data_external import plot_poisson_balls
import matplotlib.pyplot as plt
from utils_setup import init_gaussian_rpm

from rpm import RPMEmpiricalMarginals

N = 100
temporal = True

num_steps = 5000 if N==10 else 12000

rpm_variants = ['VAE', 'VI', 'amortized']
amortize_ivis = ['none', 'full', 'use_q']
model_seeds = [0]

inits = ['fits_reparam', 'fits_pretrainAsQ'] #['fits_noPretrain', 'fits_pretrainElbo', 'fits_pretrainAsQ'] #, 'fits_pretrainElbo', 'fits_pretrainAsQ']
obs_locs = torch.linspace(0,1,T).reshape(-1,1)

losses_train = np.zeros((len(inits), len(model_seeds), len(rpm_variants), len(amortize_ivis), num_steps))
losses_test = np.zeros((len(inits), len(model_seeds), len(rpm_variants), len(amortize_ivis), num_steps))

latents_train = np.zeros((len(inits), len(model_seeds), len(rpm_variants), len(amortize_ivis), N, 50))
true_latents_train = np.zeros((len(model_seeds), N, 50))
data_train_show  = np.zeros((len(model_seeds), N, J, T))

for s,model_seed in enumerate(model_seeds):
    for r,res_dir in enumerate(inits):
        for i,rpm_variant in enumerate(rpm_variants):
            for j,amortize_ivi in enumerate(amortize_ivis):

                if temporal:
                    identifier = rpm_variant + '_temp_' + amortize_ivi
                else:
                    identifier = rpm_variant + '_' + amortize_ivi        
                identifier = identifier + '_N_' + str(N) + '_seed_' + str(model_seed)

                root = os.curdir
                fn_base = os.path.join(res_dir, identifier, identifier)

                data = torch.tensor(np.load(fn_base + '_train_data.npy'))
                true_latent_ext = torch.tensor(np.load(fn_base + '_train_latents.npy'))
                #data = torch.tensor(np.load(fn_base + '_test_data.npy'))
                #true_latent_ext = torch.tensor(np.load(fn_base + '_test_latents.npy'))

                exp_dict = np.load(fn_base + '_exp_dict.npz', allow_pickle=True)['arr_0'].tolist()
                N,J,K,T = exp_dict['N'],exp_dict['J'],exp_dict['K'],exp_dict['T']
                #init_diag_val = exp_dict['init_diag_val']
                #init_off_val = exp_dict['init_off_val']
                init_rb_bandwidth = exp_dict['init_rb_bandwidth']

                ls_train = np.load(fn_base + '_loss_train.npy')
                ls_test = np.load(fn_base + '_loss_test.npy')

                xjs = [data[:,j] for j in range(J)]
                pxjs = RPMEmpiricalMarginals(xjs)
                observations = (torch.stack(xjs, dim=-1),)


                model = init_gaussian_rpm(N, J, K, T, pxjs,
                                        init_rb_bandwidth, obs_locs,
                                        rpm_variant, temporal, amortize_ivi,
                                        epochs=0, batch_size=N
                                       )[0]
                try:
                    model.load_state_dict(torch.load(fn_base + '_rpm_state_dict.zip'))
                except:
                    model.load_state_dict(torch.load(fn_base + '_rpm_state_dict'))


                prior = model.joint_model[1]
                eta_0 = prior.nat_param
                if rpm_variant == 'amortized':    
                    eta_q, _ = model.comp_eta_q(xjs, eta_0)
                else: 
                    eta_q = model.comp_eta_q(xjs, eta_0=eta_0, idx_data=np.arange(N))
                EqtZ = prior.log_partition.nat2meanparam(eta_q)

                mu = EqtZ[:,:T]
                sig2 = torch.diagonal(EqtZ[:,T:].reshape(-1,T,T),dim1=-2,dim2=-1) - mu**2

                latent_true, latent_mean_fit, latent_variance_fit, R2 = regLatent(
                    latent_true = true_latent_ext,
                    latent_mean_fit = mu.unsqueeze(-1), 
                    latent_variance_fit = sig2)

                #plot_poisson_balls(observations, 
                #                   obs_locs=obs_locs.squeeze(-1), 
                #                   latent_mean_fit=latent_mean_fit.squeeze(-1), 
                #                   latent_variance_fit=latent_variance_fit)

                losses_train[r,s,i,j,:len(ls_train)] = ls_train
                losses_test[r,s,i,j,:len(ls_test)] = ls_test
                n = 0
                latents_train[r,s,i,j] = latent_mean_fit.detach().numpy()[n,:,0]
                data_train_show[s] = data[n]
                true_latents_train[s] = true_latent_ext[n,:,0]

In [None]:
N = 10
temporal = True

model_seeds = [0,1,2]
rpm_variant = 'amortized'
amortize_ivi = 'use_q'
inits = ['fits_noPretrain', 'fits_pretrainElbo', 'fits_pretrainAsQ']

n33_losses_train = np.zeros((len(inits), len(model_seeds), 5000))
n33_losses_test = np.zeros((len(inits), len(model_seeds), 5000))
n33_latents_train = np.zeros((len(inits), len(model_seeds), len(rpm_variants), len(amortize_ivis), N, 50))
n33_true_latents_train = np.zeros((len(model_seeds), N, 50))

for s,model_seed in enumerate(model_seeds):
    for r,res_dir in enumerate(inits):

        if temporal:
            identifier_folder = rpm_variant + '_temp_33Hidden_' + amortize_ivi + '_N'
        else:
            identifier_folder = rpm_variant + '_33Hidden_' + amortize_ivi + '_N'
        identifier_folder = identifier_folder + '_' + str(N) + '_seed_' + str(model_seed)
        if temporal:
            identifier = rpm_variant + '_temp_' + amortize_ivi
        else:
            identifier = rpm_variant + '_' + amortize_ivi        
        identifier = identifier + '_N_' + str(N) + '_seed_' + str(model_seed)

        root = os.curdir
        fn_base = os.path.join(res_dir, identifier_folder, identifier)

        data = torch.tensor(np.load(fn_base + '_train_data.npy'))
        true_latent_ext = torch.tensor(np.load(fn_base + '_train_latents.npy'))
        #data = torch.tensor(np.load(fn_base + '_test_data.npy'))
        #true_latent_ext = torch.tensor(np.load(fn_base + '_test_latents.npy'))

        exp_dict = np.load(fn_base + '_exp_dict.npz', allow_pickle=True)['arr_0'].tolist()
        N,J,K,T = exp_dict['N'],exp_dict['J'],exp_dict['K'],exp_dict['T']
        #init_diag_val = exp_dict['init_diag_val']
        #init_off_val = exp_dict['init_off_val']
        init_rb_bandwidth = exp_dict['init_rb_bandwidth']

        ls_train = np.load(fn_base + '_loss_train.npy')
        ls_test = np.load(fn_base + '_loss_test.npy')

        xjs = [data[:,j] for j in range(J)]
        pxjs = RPMEmpiricalMarginals(xjs)
        observations = (torch.stack(xjs, dim=-1),)


        model = init_gaussian_rpm(N, J, K, T, pxjs,
                                init_rb_bandwidth, obs_locs,
                                rpm_variant, temporal, amortize_ivi,
                                epochs=0, batch_size=N,
                                n_hidden=33
                               )[0]
        model.load_state_dict(torch.load(fn_base + '_rpm_state_dict'))


        prior = model.joint_model[1]
        eta_0 = prior.nat_param
        if rpm_variant == 'amortized':    
            eta_q, _ = model.comp_eta_q(xjs, eta_0)
        else: 
            eta_q = model.comp_eta_q(xjs, eta_0=eta_0, idx_data=np.arange(N))
        EqtZ = prior.log_partition.nat2meanparam(eta_q)

        mu = EqtZ[:,:T]
        sig2 = torch.diagonal(EqtZ[:,T:].reshape(-1,T,T),dim1=-2,dim2=-1) - mu**2

        latent_true, latent_mean_fit, latent_variance_fit, R2 = regLatent(
            latent_true = true_latent_ext,
            latent_mean_fit = mu.unsqueeze(-1), 
            latent_variance_fit = sig2)

        #plot_poisson_balls(observations, 
        #                   obs_locs=obs_locs.squeeze(-1), 
        #                   latent_mean_fit=latent_mean_fit.squeeze(-1), 
        #                   latent_variance_fit=latent_variance_fit)

        n=0
        n33_losses_train[r,s] = ls_train
        n33_losses_test[r,s] = ls_test
        n33_latents_train[r,s,i,j] = latent_mean_fit.detach().numpy()[n,:,0]
        n33_true_latents_train[s] = true_latent_ext[n,:,0]


In [None]:
losses = losses_train
filter_len = 1

plt.figure(figsize=(16,12))

l_min = np.nanmin(losses[...,:100])
l_max = np.nanmax(losses[...,:100])

l_min = np.minimum(np.nanmin(n33_losses_train[...,:100]), l_min)
l_max = np.maximum(np.nanmax(n33_losses_train[...,:100]), l_max)


for r in range(len(inits)):
    for s in range(len(model_seeds)):
        ax = plt.subplot(len(model_seeds), len(inits), s*len(inits)+r+1)
        for i in range(len(rpm_variants)):
            for j in range(len(amortize_ivis)):
                loss_smoothed = np.convolve(losses[r,s,i,j], np.ones(filter_len)/filter_len, 'valid')
                plt.plot(loss_smoothed, color=colors[i,j], label=rpm_variants[i]+'_'+amortize_ivis[j])
                #plt.plot(losses[s, i, j], color=colors[i,j], label=amortize_ivis[j]+'_'+rpm_variants[i])
                if r == 0:
                    plt.ylabel('seed #' + str(s+1))
        if s == 0:
            plt.title(inits[r])                    
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        if s == len(model_seeds)-1:
            plt.xlabel('epochs')
        plt.ylim(l_min, l_max)
        plt.xlim(0, 100)
        loss_smoothed = np.convolve(n33_losses_train[r,s], np.ones(filter_len)/filter_len, 'valid')
        plt.plot(loss_smoothed, ':', color=colors[i,j], label=rpm_variants[i]+'_'+amortize_ivis[j] + '_33hidden')
        
plt.legend()
plt.suptitle('initial training losses, time-series model variant, Poisson balls, N='+str(N))
#plt.savefig('figs/poisson_balls_temporal_initial_training_losses_N='+str(N) + '_added33Hidden.pdf')

In [None]:
losses = losses_train
filter_len = 50

plt.figure(figsize=(16,12))

l_min = np.nanmin(losses)
l_max = np.nanmax(losses)
#l_min = np.minimum(np.nanmin(n33_losses_train), l_min)
#l_max = np.maximum(np.nanmax(n33_losses_train), l_max)


for r in range(len(inits)):
    for s in range(len(model_seeds)):
        ax = plt.subplot(len(model_seeds), len(inits), s*len(inits)+r+1)
        for i in range(len(rpm_variants)):
            for j in range(len(amortize_ivis)):
                loss_smoothed = np.convolve(losses[r,s,i,j], np.ones(filter_len)/filter_len, 'valid')
                plt.plot(loss_smoothed, color=colors[i,j], label=rpm_variants[i]+'_'+amortize_ivis[j])
                #plt.plot(losses[s, i, j], color=colors[i,j], label=amortize_ivis[j]+'_'+rpm_variants[i])
                if r == 0:
                    plt.ylabel('seed #' + str(s+1))
        if s == 0:
            plt.title(inits[r])
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        if s == len(model_seeds)-1:
            plt.xlabel('epochs')
        plt.ylim(l_min, l_max)
        #loss_smoothed = np.convolve(n33_losses_train[r,s], np.ones(filter_len)/filter_len, 'valid')
        #plt.plot(loss_smoothed, ':', color=colors[i,j], label=rpm_variants[i]+'_'+amortize_ivis[j] + '_33hidden')
        
plt.legend()
plt.suptitle('training losses (time-smoothed), time-series model variant, Poisson balls, N='+str(N))
#plt.savefig('figs/poisson_balls_temporal_full_training_losses_N='+str(N) + '_added33Hidden.pdf')

In [None]:
losses = losses_test
filter_len = 1

plt.figure(figsize=(16,12))

l_min = np.nanmin(losses)
l_max = np.nanmax(losses)
l_min = np.minimum(np.nanmin(n33_losses_test), l_min)
l_max = np.maximum(np.nanmax(n33_losses_test), l_max)


for r in range(len(inits)):
    for s in range(len(model_seeds)):
        ax = plt.subplot(len(model_seeds), len(inits), s*len(inits)+r+1)
        for i in range(len(rpm_variants)):
            for j in range(len(amortize_ivis)):
                loss_smoothed = np.convolve(losses[r,s,i,j], np.ones(filter_len)/filter_len, 'valid')
                plt.plot(loss_smoothed, color=colors[i,j], label=rpm_variants[i]+'_'+amortize_ivis[j])
                #plt.plot(losses[s, i, j], color=colors[i,j], label=amortize_ivis[j]+'_'+rpm_variants[i])
                if r == 0:
                    plt.ylabel('seed #' + str(s+1))
        if s == 0:
            plt.title(inits[r])
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        if s == len(model_seeds)-1:
            plt.xlabel('epochs')
        plt.ylim(l_min, l_max)
        loss_smoothed = np.convolve(n33_losses_test[r,s], np.ones(filter_len)/filter_len, 'valid')
        plt.plot(loss_smoothed, ':', color=colors[i,j], label=rpm_variants[i]+'_'+amortize_ivis[j] + '_33hidden')
        
plt.legend()
plt.suptitle('test losses, Poisson balls, time-series model variant, N='+str(N))
plt.savefig('figs/poisson_balls_temporal_full_test_losses_N='+str(N) + '_added33Hidden.pdf')

In [None]:
plt.figure(figsize=(16,12))

n = 0
for r in range(len(inits)):
    for s in range(len(model_seeds)):
        ax = plt.subplot(2*len(model_seeds), len(inits), 2*s*len(inits)+r+1)
        for i in range(len(rpm_variants)):
            for j in range(len(amortize_ivis)):
                plt.imshow(data_train_show[s,n], aspect='auto', origin='lower', cmap='gray',extent=[0, 1, -1, 1])
                plt.plot(obs_locs, latents_train[r,s,i,j,n], 
                         color=colors[i,j], label=rpm_variants[i]+'_'+amortize_ivis[j],
                         linewidth=2.5)
                if r == 0:
                    plt.ylabel('seed #' + str(s+1))
        #plt.plot(obs_locs, n33_latents_train[r,s,i,j,n], 
        #         color=colors[i,j], linestyle=':', label=rpm_variants[i]+'_'+amortize_ivis[j],
        #         linewidth=2.5)
        if s == 0:
            plt.title(inits[r])
        plt.ylim(-1, 1)
        ax = plt.subplot(2*len(model_seeds), len(inits), (2*s+1)*len(inits)+r+1)
        MSEs = np.zeros((len(rpm_variants)*len(amortize_ivis)+1))
        norm_ = (true_latents_train[s]**2).mean() # this is using that we see exactly 2pi of the true latents!
        for i in range(len(rpm_variants)):
            for j in range(len(amortize_ivis)):
                MSEs[i*len(amortize_ivis)+j] = ((latents_train[r,s,i,j]-true_latents_train[s])**2).mean()
        #MSEs[-1] = ((n33_latents_train[r,s,i,j]-n33_true_latents_train[s])**2).mean()
        plt.bar(np.arange(len(rpm_variants)*len(amortize_ivis)+1)+1, 
                MSEs/norm_,
                color=np.concatenate(list([colors[i] for i in range(colors.shape[0])]) + [colors[-1,-1].reshape(1,-1)], axis=0)
               )
        plt.bar(np.arange(len(rpm_variants)*len(amortize_ivis))+1, 
                MSEs[:-1]/norm_,
                color=np.concatenate([colors[i] for i in range(colors.shape[0])], axis=0)
               )        
        #plt.bar(len(rpm_variants)*len(amortize_ivis)+1, 
        #        MSEs[-1]/norm_,
        #        color='w', 
        #        edgecolor=colors[-1,-1],
        #        linestyle=':'
        #       )        
        plt.xticks([])
        plt.ylabel('nMSEs')

plt.suptitle('Latent means, time-series model variant, Poisson balls, N='+str(N))
#plt.savefig('figs/poisson_balls_temporal_training_latents_='+str(N) + '_added33Hidden.pdf')

# textured ball experiments

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from utils_data_external import linear_regression_1D_latent as regLatent
from utils_data_external import plot_poisson_balls
import matplotlib.pyplot as plt
from utils_setup import init_gaussian_rpm

from rpm import RPMEmpiricalMarginals

N = 10
temporal = True

rpm_variants = ['VI', 'VAE', 'amortized']
amortize_ivis = ['none', 'full', 'use_q']
model_seeds = [0,1,2]

#inits = ['fits_noPretrain', 'fits_pretrainElbo', 'fits_pretrainAsQ', 'fits_reparam'] #
inits = ['fits'] #['fits_reparam']
obs_locs = torch.linspace(0,1,T).reshape(-1,1)

losses_train = np.zeros((len(inits), len(model_seeds), len(rpm_variants), len(amortize_ivis), 5000))
losses_test = np.zeros((len(inits), len(model_seeds), len(rpm_variants), len(amortize_ivis), 5000))

latents_train = np.zeros((len(inits), len(model_seeds), len(rpm_variants), len(amortize_ivis), N, 50))
true_latents_train = np.zeros((len(model_seeds), N, 50))
data_train_show  = np.zeros((len(model_seeds), N, J, T))

for s,model_seed in enumerate(model_seeds):
    for r,res_dir in enumerate(inits):
        for i,rpm_variant in enumerate(rpm_variants):
            for j,amortize_ivi in enumerate(amortize_ivis):

                if temporal:
                    identifier = rpm_variant + '_temp_' + amortize_ivi
                else:
                    identifier = rpm_variant + '_' + amortize_ivi        
                identifier = identifier + '_textured_N_' + str(N) + '_seed_' + str(model_seed)

                root = os.curdir
                fn_base = os.path.join(res_dir, identifier, identifier)

                data = torch.tensor(np.load(fn_base + '_train_data.npy'))
                true_latent_ext = torch.tensor(np.load(fn_base + '_train_latents.npy'))
                #data = torch.tensor(np.load(fn_base + '_test_data.npy'))
                #true_latent_ext = torch.tensor(np.load(fn_base + '_test_latents.npy'))

                exp_dict = np.load(fn_base + '_exp_dict.npz', allow_pickle=True)['arr_0'].tolist()
                N,J,K,T = exp_dict['N'],exp_dict['J'],exp_dict['K'],exp_dict['T']
                #init_diag_val = exp_dict['init_diag_val']
                #init_off_val = exp_dict['init_off_val']
                init_rb_bandwidth = exp_dict['init_rb_bandwidth']

                ls_train = np.load(fn_base + '_loss_train.npy')
                ls_test = np.load(fn_base + '_loss_test.npy')

                xjs = [data[:,j] for j in range(J)]
                pxjs = RPMEmpiricalMarginals(xjs)
                observations = (torch.stack(xjs, dim=-1),)


                model = init_gaussian_rpm(N, J, K, T, pxjs,
                                        init_rb_bandwidth, obs_locs,
                                        rpm_variant, temporal, amortize_ivi,
                                        epochs=0, batch_size=N
                                       )[0]
                model.load_state_dict(torch.load(fn_base + '_rpm_state_dict'))


                prior = model.joint_model[1]
                eta_0 = prior.nat_param
                if rpm_variant == 'amortized':    
                    eta_q, _ = model.comp_eta_q(xjs, eta_0)
                else: 
                    eta_q = model.comp_eta_q(xjs, eta_0=eta_0, idx_data=np.arange(N))
                EqtZ = prior.log_partition.nat2meanparam(eta_q)

                mu = EqtZ[:,:T]
                sig2 = torch.diagonal(EqtZ[:,T:].reshape(-1,T,T),dim1=-2,dim2=-1) - mu**2

                latent_true, latent_mean_fit, latent_variance_fit, R2 = regLatent(
                    latent_true = true_latent_ext,
                    latent_mean_fit = mu.unsqueeze(-1), 
                    latent_variance_fit = sig2)

                #plot_poisson_balls(observations, 
                #                   obs_locs=obs_locs.squeeze(-1), 
                #                   latent_mean_fit=latent_mean_fit.squeeze(-1), 
                #                   latent_variance_fit=latent_variance_fit)

                losses_train[r,s,i,j,:len(ls_train)] = ls_train
                losses_test[r,s,i,j,:len(ls_test)] = ls_test
                n = 0
                latents_train[r,s,i,j] = latent_mean_fit.detach().numpy()[n,:,0]
                data_train_show[s] = data[n]
                true_latents_train[s] = true_latent_ext[n,:,0]


In [None]:

plt.figure(figsize=(12,8))
plt.subplot(1,2,1)
plt.imshow(poisson_ball, cmap='gray')
plt.xticks([0, 24, 49], [1, 25, 50])
plt.yticks([0,9], [10,1])
plt.xlabel('time t')
plt.ylabel('space j')
plt.title('Poisson bouncing ball')

plt.subplot(1,2,2)
plt.imshow(data_train_show[1,1], cmap='gray')
plt.xticks([0, 24, 49], [1, 25, 50])
plt.yticks([0,9], [10,1])
plt.xlabel('time t')
plt.title('Textured bouncing ball')

#plt.savefig('figs/bouncing_balls_examples.pdf', bbox_inches ='tight')

In [None]:

N = 10
temporal = True

model_seeds = [0,1,2]
rpm_variant = 'amortized'
amortize_ivi = 'use_q'
inits = ['fits_noPretrain', 'fits_pretrainElbo', 'fits_pretrainAsQ']

n33_losses_train = np.zeros((len(inits), len(model_seeds), 5000))
n33_losses_test = np.zeros((len(inits), len(model_seeds), 5000))
n33_latents_train = np.zeros((len(inits), len(model_seeds), len(rpm_variants), len(amortize_ivis), N, 50))
n33_true_latents_train = np.zeros((len(model_seeds), N, 50))

for s,model_seed in enumerate(model_seeds):
    for r,res_dir in enumerate(inits):

        if temporal:
            identifier_folder = rpm_variant + '_temp_33Hidden_' + amortize_ivi + '_N'
        else:
            identifier_folder = rpm_variant + '_33Hidden_' + amortize_ivi + '_N'
        identifier_folder = identifier_folder + '_' + str(N) + '_seed_' + str(model_seed)
        if temporal:
            identifier = rpm_variant + '_temp_' + amortize_ivi
        else:
            identifier = rpm_variant + '_' + amortize_ivi        
        identifier = identifier + '_N_' + str(N) + '_seed_' + str(model_seed)

        root = os.curdir
        fn_base = os.path.join(res_dir, identifier_folder, identifier)

        data = torch.tensor(np.load(fn_base + '_train_data.npy'))
        true_latent_ext = torch.tensor(np.load(fn_base + '_train_latents.npy'))
        #data = torch.tensor(np.load(fn_base + '_test_data.npy'))
        #true_latent_ext = torch.tensor(np.load(fn_base + '_test_latents.npy'))

        exp_dict = np.load(fn_base + '_exp_dict.npz', allow_pickle=True)['arr_0'].tolist()
        N,J,K,T = exp_dict['N'],exp_dict['J'],exp_dict['K'],exp_dict['T']
        #init_diag_val = exp_dict['init_diag_val']
        #init_off_val = exp_dict['init_off_val']
        init_rb_bandwidth = exp_dict['init_rb_bandwidth']

        ls_train = np.load(fn_base + '_loss_train.npy')
        ls_test = np.load(fn_base + '_loss_test.npy')

        xjs = [data[:,j] for j in range(J)]
        pxjs = RPMEmpiricalMarginals(xjs)
        observations = (torch.stack(xjs, dim=-1),)


        model = init_gaussian_rpm(N, J, K, T, pxjs,
                                init_rb_bandwidth, obs_locs,
                                rpm_variant, temporal, amortize_ivi,
                                epochs=0, batch_size=N,
                                n_hidden=33
                               )[0]
        model.load_state_dict(torch.load(fn_base + '_rpm_state_dict'))


        prior = model.joint_model[1]
        eta_0 = prior.nat_param
        if rpm_variant == 'amortized':    
            eta_q, _ = model.comp_eta_q(xjs, eta_0)
        else: 
            eta_q = model.comp_eta_q(xjs, eta_0=eta_0, idx_data=np.arange(N))
        EqtZ = prior.log_partition.nat2meanparam(eta_q)

        mu = EqtZ[:,:T]
        sig2 = torch.diagonal(EqtZ[:,T:].reshape(-1,T,T),dim1=-2,dim2=-1) - mu**2

        latent_true, latent_mean_fit, latent_variance_fit, R2 = regLatent(
            latent_true = true_latent_ext,
            latent_mean_fit = mu.unsqueeze(-1), 
            latent_variance_fit = sig2)

        #plot_poisson_balls(observations, 
        #                   obs_locs=obs_locs.squeeze(-1), 
        #                   latent_mean_fit=latent_mean_fit.squeeze(-1), 
        #                   latent_variance_fit=latent_variance_fit)

        n=0
        n33_losses_train[r,s] = ls_train
        n33_losses_test[r,s] = ls_test
        n33_latents_train[r,s,i,j] = latent_mean_fit.detach().numpy()[n,:,0]
        n33_true_latents_train[s] = true_latent_ext[n,:,0]


In [None]:
losses = losses_train
filter_len = 1

plt.figure(figsize=(16,12))

l_min = np.nanmin(losses[...,:100])
l_max = np.nanmax(losses[...,:100])
l_min = np.minimum(np.nanmin(n33_losses_train[...,:100]), l_min)
l_max = np.maximum(np.nanmax(n33_losses_train[...,:100]), l_max)

for r in range(len(inits)):
    for s in range(len(model_seeds)):
        ax = plt.subplot(len(model_seeds), len(inits), s*len(inits)+r+1)
        for i in range(len(rpm_variants)):
            for j in range(len(amortize_ivis)):
                loss_smoothed = np.convolve(losses[r,s,i,j], np.ones(filter_len)/filter_len, 'valid')
                plt.plot(loss_smoothed, color=colors[i,j], label=rpm_variants[i]+'_'+amortize_ivis[j])
                #plt.plot(losses[s, i, j], color=colors[i,j], label=amortize_ivis[j]+'_'+rpm_variants[i])
                if r == 0:
                    plt.ylabel('seed #' + str(s+1))
        if s == 0:
            plt.title(inits[r])                    
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        if s == len(model_seeds)-1:
            plt.xlabel('epochs')
        plt.ylim(l_min, l_max)
        plt.xlim(0, 100)
        loss_smoothed = np.convolve(n33_losses_train[r,s], np.ones(filter_len)/filter_len, 'valid')
        plt.plot(loss_smoothed, ':', color=colors[i,j], label=rpm_variants[i]+'_'+amortize_ivis[j] + '_33hidden')        
plt.legend()
plt.suptitle('initial training losses, time-series model variant, textured balls, N='+str(N))
#plt.savefig('figs/textured_balls_temporal_initial_training_losses_N='+str(N) + '_added33Hidden.pdf')

In [None]:
losses = losses_train
filter_len = 50

plt.figure(figsize=(16,12))

l_min = np.nanmin(losses)
l_max = np.nanmax(losses)
#l_min = np.minimum(np.nanmin(n33_losses_train), l_min)
#l_max = np.maximum(np.nanmax(n33_losses_train), l_max)


for r in range(len(inits)):
    for s in range(len(model_seeds)):
        ax = plt.subplot(len(model_seeds), len(inits), s*len(inits)+r+1)
        for i in range(len(rpm_variants)):
            for j in range(len(amortize_ivis)):
                loss_smoothed = np.convolve(losses[r,s,i,j], np.ones(filter_len)/filter_len, 'valid')
                plt.plot(loss_smoothed, color=colors[i,j], label=rpm_variants[i]+'_'+amortize_ivis[j])
                #plt.plot(losses[s, i, j], color=colors[i,j], label=amortize_ivis[j]+'_'+rpm_variants[i])
                if r == 0:
                    plt.ylabel('seed #' + str(s+1))
        if s == 0:
            plt.title(inits[r])
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        if s == len(model_seeds)-1:
            plt.xlabel('epochs')
        try:
            plt.ylim(l_min, l_max)
        except:
            pass
        #loss_smoothed = np.convolve(n33_losses_train[r,s], np.ones(filter_len)/filter_len, 'valid')
        plt.plot(loss_smoothed, ':', color=colors[i,j], label=rpm_variants[i]+'_'+amortize_ivis[j] + '_33hidden')        
plt.legend()
plt.suptitle('training losses (time-smoothed), time-series model variant, textured balls, N='+str(N))
#plt.savefig('figs/textured_balls_temporal_full_training_losses_N='+str(N) + '_added33Hidden.pdf')

In [None]:
losses = losses_test
filter_len = 1

plt.figure(figsize=(16,12))

l_min = np.nanmin(losses)
l_max = np.nanmax(losses)
l_min = np.minimum(np.nanmin(n33_losses_test), l_min)
l_max = np.maximum(np.nanmax(n33_losses_test), l_max)

for r in range(len(inits)):
    for s in range(len(model_seeds)):
        ax = plt.subplot(len(model_seeds), len(inits), s*len(inits)+r+1)
        for i in range(len(rpm_variants)):
            for j in range(len(amortize_ivis)):
                loss_smoothed = np.convolve(losses[r,s,i,j], np.ones(filter_len)/filter_len, 'valid')
                plt.plot(loss_smoothed, color=colors[i,j], label=rpm_variants[i]+'_'+amortize_ivis[j])
                #plt.plot(losses[s, i, j], color=colors[i,j], label=amortize_ivis[j]+'_'+rpm_variants[i])
                if r == 0:
                    plt.ylabel('seed #' + str(s+1))
        if s == 0:
            plt.title(inits[r])                    
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        if s == len(model_seeds)-1:
            plt.xlabel('epochs')
        plt.ylim(l_min, l_max)
        loss_smoothed = np.convolve(n33_losses_test[r,s], np.ones(filter_len)/filter_len, 'valid')
        plt.plot(loss_smoothed, ':', color=colors[i,j], label=rpm_variants[i]+'_'+amortize_ivis[j] + '_33hidden')        
plt.legend()
plt.suptitle('test losses, textured balls, time-series model variant, N='+str(N))
plt.savefig('figs/textured_balls_temporal_full_test_losses_N='+str(N) + '_added33Hidden.pdf')

In [None]:
plt.figure(figsize=(16,12))

n = 0
for r in range(len(inits)):
    for s in range(len(model_seeds)):
        ax = plt.subplot(2*len(model_seeds), len(inits), 2*s*len(inits)+r+1)
        for i in range(len(rpm_variants)):
            for j in range(len(amortize_ivis)):
                plt.imshow(data_train_show[s,n], aspect='auto', origin='lower', cmap='gray',extent=[0, 1, -1, 1])
                plt.plot(obs_locs, latents_train[r,s,i,j,n], 
                         color=colors[i,j], label=rpm_variants[i]+'_'+amortize_ivis[j],
                         linewidth=2.5)
                if r == 0:
                    plt.ylabel('seed #' + str(s+1))
        plt.plot(obs_locs, n33_latents_train[r,s,i,j,n], 
                 color=colors[i,j], linestyle=':', label=rpm_variants[i]+'_'+amortize_ivis[j],
                 linewidth=2.5)
        if s == 0:
            plt.title(inits[r])
        plt.ylim(-1, 1)
        ax = plt.subplot(2*len(model_seeds), len(inits), (2*s+1)*len(inits)+r+1)
        MSEs = np.zeros((len(rpm_variants)*len(amortize_ivis)+1))
        norm_ = (true_latents_train[s]**2).mean() # this is using that we see exactly 2pi of the true latents!
        for i in range(len(rpm_variants)):
            for j in range(len(amortize_ivis)):
                MSEs[i*len(amortize_ivis)+j] = ((latents_train[r,s,i,j]-true_latents_train[s])**2).mean()
        MSEs[-1] = ((n33_latents_train[r,s,i,j]-n33_true_latents_train[s])**2).mean()
        #plt.bar(np.arange(len(rpm_variants)*len(amortize_ivis)+1)+1, 
        #        MSEs/norm_,
        #        color=np.concatenate(list([colors[i] for i in range(colors.shape[0])]) + [colors[-1,-1].reshape(1,-1)], axis=0)
        #       )
        plt.bar(np.arange(len(rpm_variants)*len(amortize_ivis))+1, 
                MSEs[:-1]/norm_,
                color=np.concatenate([colors[i] for i in range(colors.shape[0])], axis=0)
               )        
        plt.bar(len(rpm_variants)*len(amortize_ivis)+1, 
                MSEs[-1]/norm_,
                color='w', 
                edgecolor=colors[-1,-1],
                linestyle=':'
               )        
        plt.xticks([])
        plt.ylabel('nMSEs')
        
plt.suptitle('Latent means, time-series model variant, textured balls, N='+str(N))
plt.savefig('figs/textured_balls_temporal_training_latents_='+str(N) + '_added33Hidden.pdf')

In [None]:
i,j = 2,2

plt.figure(figsize=(16,12))

prior = model.joint_model[1]
full2diag_gaussian = prior.log_partition.full2diag_gaussian
eta_0 = prior.nat_param
eta_0_diag, eta_0_uncorr = full2diag_gaussian(eta_0)
etajs_all = prior.log_partition.extract_diagonal(model.factorNatParams(eta_off=eta_0_uncorr)) # N-by-J-by-D
etajs_all = etajs_all.reshape(*etajs_all.shape[:-1],2,T).transpose(-1,-2)                    # N-by-J-by-T-by-2
marginal_log_partition = prior.log_partition.marginal_log_partition 
phijs_all = marginal_log_partition(etajs_all)                                                # N-by-J-by-T

phi0 = marginal_log_partition(eta_0_diag)
Z = torch.linspace(-5, 5, 200)
tZ = torch.stack([Z, Z**2], axis=1)
pZ = torch.exp((eta_0_diag[0].unsqueeze(1) * tZ.unsqueeze(0)).sum(axis=-1) - phi0.T)/ np.sqrt(2*np.pi)

for jj in range(J):
    etaj_all = etajs_all[:,jj]
    phij_all = phijs_all[:,jj]
    fj = torch.exp((etaj_all.unsqueeze(0) * tZ.unsqueeze(1).unsqueeze(1)).sum(axis=-1) - phij_all.unsqueeze(0))
    Fj = fj.mean(axis=1) / np.sqrt(2*np.pi)
    plt.plot(Z.detach().numpy(), Fj.detach().numpy(), label='Fj(Z), j='+str(j+1), color=colors[i,j])
plt.plot(Z.detach().numpy(), pZ.detach().numpy().T, label='prior p(Z)', color='cyan')
plt.xlabel('Z')
plt.ylabel('density')
plt.title('Z-marginals')

plt.show()