In [None]:
%matplotlib inline
import sys
sys.path.append("../")
sys.path.append('/home/hao/Research/probtorch/')
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from plots import *
from utils import *
from objectives import *
from torch.distributions.normal import Normal
from torch.distributions.one_hot_categorical import OneHotCategorical as cat
import time
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

In [None]:
N = 600
K = 3
D = 2

## Model Parameters
MCMC_SIZE = 10
SAMPLE_SIZE = 10
NUM_HIDDEN1 = 8
NUM_HIDDEN2 = 16
STAT_SIZE = 8
NUM_LATENTS =  D
## Training Parameters
SAMPLE_DIM = 0
BATCH_DIM = 1
BATCH_SIZE = 20
NUM_EPOCHS = 10000
LEARNING_RATE = 1e-2
CUDA = torch.cuda.is_available()
PATH = 'ag-sis-mu-adapt'

gpu2 = torch.device('cuda:1')

In [None]:
Xs = torch.from_numpy(np.load('rings_dataset/obs.npy')).float()
STATES = torch.from_numpy(np.load('rings_dataset/states.npy')).float()
OBS_MU = torch.from_numpy(np.load('rings_dataset/obs_mu.npy')).float()
NUM_SEQS = Xs.shape[0]
NUM_BATCHES = int((Xs.shape[0] / BATCH_SIZE))

In [None]:
class Enc_mu(nn.Module):
    def __init__(self, num_obs=D,
                       num_hidden1=NUM_HIDDEN1,
                       num_stats=STAT_SIZE,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        
        self.neural_stats = nn.Sequential(
            nn.Linear(K+D, num_hidden1),
            nn.Tanh(),
            nn.Linear(num_hidden1, num_stats),
            nn.Tanh())
        self.mean_mu = nn.Sequential(
            nn.Linear(num_stats+D, int(0.5 * num_stats+ 0.5 * D)),
            nn.Tanh(),
            nn.Linear(int(0.5 * num_stats + 0.5 * D), D))
        
        self.mean_log_sigma = nn.Sequential(
            nn.Linear(num_stats+D, int(0.5 * num_stats+ 0.5 * D)),
            nn.Tanh(),
            nn.Linear(int(0.5 * num_stats + 0.5 * D), D))

        self.prior_mu = torch.zeros((K, D))
        self.prior_sigma = torch.ones((K, D)) * 4.0
        if CUDA:
            self.prior_mu = self.prior_mu.cuda().to(gpu2)
            self.prior_sigma = self.prior_sigma.cuda().to(gpu2)
        
    def forward(self, obs, states, sample_size, batch_size):
        neural_stats = self.neural_stats(torch.cat((obs, states), -1))
        _, _, _, stat_size = neural_stats.shape
        cluster_size = states.sum(-2)  
        cluster_size[cluster_size == 0.0] = 1.0 # S * B * K
        neural_stats_expand = neural_stats.unsqueeze(-1).repeat(1, 1, 1, 1, K).transpose(-1, -2) ## S * B * N * K * STAT_SIZE
        states_expand = states.unsqueeze(-1).repeat(1, 1, 1, 1, stat_size) ## S * B * N * K * STAT_SIZE
        sum_stats = (states_expand * neural_stats_expand).sum(2) ## S * B * K * STAT_SIZE
        mean_stats = sum_stats / cluster_size.unsqueeze(-1)
        
        c1 = torch.cat((self.prior_mu[0].repeat(batch_size, 1).repeat(sample_size, 1, 1), mean_stats[:,:,0,:]), -1)
        c2 = torch.cat((self.prior_mu[1].repeat(batch_size, 1).repeat(sample_size, 1, 1), mean_stats[:,:,1,:]), -1)
        c3 = torch.cat((self.prior_mu[2].repeat(batch_size, 1).repeat(sample_size, 1, 1), mean_stats[:,:,2,:]), -1)
        #
        q_mu_c1 = self.mean_mu(c1)
        q_mu_c2 = self.mean_mu(c2)
        q_mu_c3 = self.mean_mu(c3)
        q_sigma_c1 = self.mean_log_sigma(c1).exp()
        q_sigma_c2 = self.mean_log_sigma(c2).exp()
        q_sigma_c3 = self.mean_log_sigma(c3).exp() 
        q_mu = torch.cat((q_mu_c1.unsqueeze(-2), q_mu_c2.unsqueeze(-2), q_mu_c3.unsqueeze(-2)), -2)
        q_sigma = torch.cat((q_sigma_c1.unsqueeze(-2), q_sigma_c2.unsqueeze(-2), q_sigma_c3.unsqueeze(-2)), -2)
        
        q = probtorch.Trace()   
        p = probtorch.Trace()
        means = Normal(q_mu, q_sigma).sample()
        q.normal(q_mu, 
                 q_sigma, 
                 value=means, 
                 name='means')
        p.normal(self.prior_mu, 
                 self.prior_sigma, 
                 value=q['means'], 
                 name='means')    
        return q, p

class Enc_z(nn.Module):
    def __init__(self, num_obs=2*D+1+D,
                       num_hidden=NUM_HIDDEN2,
                       num_latents=K):
        super(self.__class__, self).__init__()
        self.pi_prob = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, int(0.5*num_hidden)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden), int(0.25*num_hidden)),
            nn.Tanh(),
            nn.Linear(int(0.25*num_hidden), 1))
        
        self.prior_pi = torch.ones(K) * (1./ K)
        self.radius = torch.ones((SAMPLE_SIZE, BATCH_SIZE, N, 1)) *  1.5
        self.noise_sigma = torch.ones((SAMPLE_SIZE, BATCH_SIZE, N, D)) *  0.05
        if CUDA:
            self.prior_pi = self.prior_pi.cuda().to(gpu2)
            self.radius = self.radius.cuda().to(gpu2)
            self.noise_sigma = self.noise_sigma.cuda().to(gpu2)
  
    def forward(self, obs, obs_mu, sample_size, batch_size):
        q = probtorch.Trace()
        obs_mu_c1 = obs_mu[:, :, 0, :].unsqueeze(-2).repeat(1,1,N,1)
        obs_mu_c2 = obs_mu[:, :, 1, :].unsqueeze(-2).repeat(1,1,N,1)
        obs_mu_c3 = obs_mu[:, :, 2, :].unsqueeze(-2).repeat(1,1,N,1)
        
        data_c1 = torch.cat((obs, obs_mu_c1, self.radius, self.noise_sigma), -1) ## S * B * N * (2D+1+D)
        data_c2 = torch.cat((obs, obs_mu_c2, self.radius, self.noise_sigma), -1) ## S * B * N * (2D+1+D)
        data_c3 = torch.cat((obs, obs_mu_c3, self.radius, self.noise_sigma), -1) ## S * B * N * (2D+1+D)
        
        z_pi_c1 = self.pi_prob(data_c1)
        z_pi_c2 = self.pi_prob(data_c2)
        z_pi_c3 = self.pi_prob(data_c3)
        
        z_pi = F.softmax(torch.cat((z_pi_c1, z_pi_c2, z_pi_c3), -1), -1)
        z = cat(z_pi).sample()
        _ = q.variable(cat, probs=z_pi, value=z, name='zs')
        p = probtorch.Trace()
        _ = p.variable(cat, probs=self.prior_pi, value=z, name='zs')
        return q, p
    
def initialize():
    enc_mu = Enc_mu()
    enc_z = Enc_z()
    if CUDA:
        enc_mu.cuda().to(gpu2)
        enc_z.cuda().to(gpu2)
    optimizer =  torch.optim.Adam(list(enc_mu.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))    
    return enc_mu, enc_z, optimizer

In [None]:
enc_mu, enc_z, optimizer = initialize()

In [None]:
def Log_likelihood(obs, states, obs_mu, K, D, radius, noise_sigma, gpu, cluster_flag=False):
    """
    cluster_flag = False : return S * B * N
    cluster_flag = True, return S * B * K
    """
    
    labels = states.argmax(-1)
    labels_flat = labels.unsqueeze(-1).repeat(1, 1, 1, D)
    obs_mu_expand = torch.gather(obs_mu, 2, labels_flat)
    distance = ((obs - obs_mu_expand)**2).sum(-1).sqrt()
    perihelion = distance - radius
    aphelion = distance + radius
    obs_dist = Normal(torch.zeros(1).cuda().to(gpu), torch.ones(1).cuda().to(gpu) * noise_sigma)
    log_perihelion = obs_dist.log_prob(perihelion) # S * B * N
    log_aphelion = obs_dist.log_prob(aphelion) # S * B * N
    if cluster_flag:
        log_perihelion = torch.cat([((labels==k).float() * log_perihelion).sum(-1).unsqueeze(-1) for k in range(K)], -1) # S * B * K
    return log_perihelion

from torch.distributions.categorical import Categorical

def resample_mu(obs_mu, weights):
    """
    weights is S * B
    """
    S, B, K, D = obs_mu.shape
    ancesters = Categorical(weights.transpose(0,1)).sample((S, )).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, K, D) ## S * B * K * D
    obs_mu_r = torch.gather(obs_mu, 0, ancesters)
    return obs_mu_r

def Eubo(enc_mu, enc_z, obs, N, K, D, mcmc_size, sample_size, batch_size):
    """
    initialize z
    """
    eubos = torch.zeros(mcmc_size).cuda()
    elbos = torch.zeros(mcmc_size).cuda()
    esss = torch.zeros(mcmc_size).cuda()
    
    for m in range(mcmc_size):
        if m == 0:
            p_init_z = cat(enc_z.prior_pi)
            states = p_init_z.sample((sample_size, batch_size, N,))
            log_p_z = p_init_z.log_prob(states)## S * B * N
            log_q_z = p_init_z.log_prob(states)
        else:
            obs_mu = resample_mu(obs_mu, weights)
            ## update z -- cluster assignments
            q_z, p_z = enc_z(obs, obs_mu, sample_size, batch_size)
            log_p_z = p_z['zs'].log_prob
            log_q_z = q_z['zs'].log_prob ## S * B * N
            states = q_z['zs'].value 
        q, p = enc_mu(obs, states, sample_size, batch_size)
        log_q_mu = q['means'].log_prob.sum(-1)
        log_p_mu = p['means'].log_prob.sum(-1) # S * B * K
        obs_mu = q['means'].value
        log_obs = Log_likelihood(obs, states, obs_mu, K, D, radius=1.5, noise_sigma = 0.05, gpu=gpu2, cluster_flag=False)
        log_weights = log_p_mu.sum(-1) - log_q_mu.sum(-1) +  log_obs.sum(-1) + log_p_z.sum(-1) - log_q_z.sum(-1)
        weights = F.softmax(log_weights, 0).detach()
        eubos[m] = (weights * log_weights).sum(0).mean()
        elbos[m] = log_weights.mean()
        esss[m] = (1. / (weights**2).sum(0)).mean()    
    return eubos.mean(), elbos.mean(), esss.mean()

In [None]:
EUBOs = []
ELBOs = []
ESSs = []

flog = open('results/log-' + PATH + '.txt', 'w+')
flog.write('EUBO\tELBO\tESS\n')
flog.close()
for epoch in range(NUM_EPOCHS):
    time_start = time.time()
    indices = torch.randperm(NUM_SEQS)
    EUBO = 0.0
    ELBO = 0.0
    ESS = 0.0
    for step in range(NUM_BATCHES):
        optimizer.zero_grad()
        batch_indices = indices[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        obs = Xs[batch_indices]
        states = STATES[batch_indices]
        data = shuffler(torch.cat((obs, states), -1)).repeat(SAMPLE_SIZE, 1, 1, 1)
        if CUDA:
            data = data.cuda().to(gpu2)
        obs =data[:, :, :, :2]
        states = data[:, :, :, 2:]
        eubo, elbo, ess = Eubo(enc_mu, enc_z, obs, N, K, D, MCMC_SIZE, SAMPLE_SIZE, BATCH_SIZE)
        ## gradient step
        eubo.backward()
        optimizer.step()
        EUBO += eubo.item()
        ELBO += elbo.item()
        ESS += ess.item()
    EUBOs.append(EUBO / NUM_BATCHES)
    ELBOs.append(ELBO / NUM_BATCHES)
    ESSs.append(ESS / NUM_BATCHES) 
    flog = open('results/log-' + PATH + '.txt', 'a+')
    print('%.3f\t%.3f\t%.3f'
            % (EUBO/NUM_BATCHES, ELBO/NUM_BATCHES, ESS/NUM_BATCHES), file=flog)
    flog.close()
    time_end = time.time()
    print('epoch=%d, EUBO=%.3f, ELBO=%.3f, ESS=%.3f (%ds)'
            % (epoch, EUBO/NUM_BATCHES, ELBO/NUM_BATCHES, ESS/NUM_BATCHES, 
               time_end - time_start))


In [None]:
def sample_single_batch(num_seqs, N, K, D, sample_size, batch_size):
    indices = torch.randperm(num_seqs)
    batch_indices = indices[0*batch_size : (0+1)*batch_size]
    obs = Xs[batch_indices]
    states = STATES[batch_indices]
    obs_mu_t = OBS_MU[batch_indices]
    data = shuffler(torch.cat((obs, states), -1)).repeat(sample_size, 1, 1, 1)
    if CUDA:
        data = data.cuda().to(gpu2)
    obs = data[:,:,:,:2]
    states = data[:,:,:,2:]
    return obs, states, obs_mu_t

In [None]:
obs, states, obs_mu_t = sample_single_batch(NUM_SEQS, N, K, D, SAMPLE_SIZE, batch_size=25)
q, p = enc_mu(obs, states, SAMPLE_SIZE, batch_size=25)

In [None]:
def plot_samples(obs, states, obs_mu_t, q, batch_size):
    colors = ['r', 'b', 'gold']
    fig = plt.figure(figsize=(25,25))
    xs = obs[0].cpu()
    zs = states[0].cpu()
    mu_mu = q['means'].dist.loc[0].cpu().data.numpy()
    mu_sigma = q['means'].dist.scale[0].cpu().data.numpy()
    for b in range(batch_size):
        ax = fig.add_subplot(int(batch_size / 5), 5, b+1)
        x = xs[b].data.numpy()
        z = zs[b].data.numpy()
        mu_b_t = obs_mu_t[b].data.numpy()
        mu_mu_b = mu_mu[b]
        mu_sigma2_b = mu_sigma[b]**2
        assignments = z.argmax(-1)
        for k in range(K):
            cov_k = np.diag(mu_sigma2_b[k])
            xk = x[np.where(assignments == k)]
            ax.scatter(xk[:, 0], xk[:, 1], c=colors[k])
            ax.scatter(mu_b_t[k, 0], mu_b_t[k, 1], c=colors[k])
            plot_cov_ellipse(cov=cov_k, pos=mu_mu_b[k], nstd=2, ax=ax, alpha=0.2, color=colors[k])
        ax.set_ylim([-5, 5])
        ax.set_xlim([-5, 5])
#     plt.savefig('results/modes-' + PATH + '.svg')

In [None]:
plot_samples(obs, states, obs_mu_t, q, batch_size=25)