In [1]:
%matplotlib inline
import sys
sys.path.append("../")
sys.path.append('/home/hao/Research/probtorch/')
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from plots import *
from utils import *
from objectives import *
from torch.distributions.normal import Normal
from torch.distributions.one_hot_categorical import OneHotCategorical as cat
import time
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [10]:
N = 300
K = 3
D = 2

## Model Parameters
SAMPLE_SIZE = 10
NUM_HIDDEN_GLOBAL = 32
NUM_HIDDEN_LOCAL = 64
STAT_SIZE = 8
NUM_LATENTS =  D
## Training Parameters
SAMPLE_DIM = 0
BATCH_DIM = 1
BATCH_SIZE = 20
NUM_EPOCHS = 1000
LEARNING_RATE = 1e-3
CUDA = torch.cuda.is_available()
PATH = 'oneshot-nss'

gpu = torch.device('cuda:1')

In [11]:
Xs = torch.from_numpy(np.load('rings_varying_radius/obs.npy')).float()
NUM_SEQS = Xs.shape[0]
NUM_BATCHES = int((Xs.shape[0] / BATCH_SIZE))

In [None]:
class Enc_mu_rad(nn.Module):
    def __init__(self, K, D, num_hidden, num_stats, CUDA, device):
        super(self.__class__, self).__init__()
        
        self.neural_stats = nn.Sequential(
            nn.Linear(D, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, num_stats))

        self.mean_mu = nn.Sequential(
            nn.Linear(num_stats+2*D, NUM_HIDDEN_GLOBAL),
            nn.Tanh(),
            nn.Linear(NUM_HIDDEN_GLOBAL, K*D))

        self.mean_log_sigma = nn.Sequential(
            nn.Linear(num_stats+2*K*D, NUM_HIDDEN_GLOBAL),
            nn.Tanh(),
            nn.Linear(NUM_HIDDEN_GLOBAL, K*D))

        self.radius_mu = nn.Sequential(
            nn.Linear(num_stats+2, NUM_HIDDEN_GLOBAL),
            nn.Tanh(),
            nn.Linear(NUM_HIDDEN_GLOBAL, K))

        self.radius_log_sigma = nn.Sequential(
            nn.Linear(num_stats+2, NUM_HIDDEN_GLOBAL),
            nn.Tanh(),
            nn.Linear(NUM_HIDDEN_GLOBAL, K))

        self.prior_mean_mu = torch.zeros(K*D)
        self.prior_mean_sigma = torch.ones(K*D) * 4.0

        self.prior_radius_mu = torch.ones(K) * 2.0
        self.prior_radius_sigma = torch.ones(K)
        if CUDA:
            self.prior_mean_mu = self.prior_mean_mu.cuda().to(device)
            self.prior_mean_sigma = self.prior_mean_sigma.cuda().to(device)

            self.prior_radius_mu = self.prior_radius_mu.cuda().to(device)
            self.prior_radius_sigma = self.prior_radius_sigma.cuda().to(device)

    def forward(self, obs, K, D, sample_size, batch_size):
        q = probtorch.Trace()
        p = probtorch.Trace()

        neural_stats = self.neural_stats(obs)
        mean_stats = neural_stats.mean(-2) # S * B * STAT_DIM
    
        stat_mu = torch.cat((self.prior_mean_mu.repeat(sample_size, batch_size, 1), self.prior_mean_sigma.repeat(sample_size, batch_size, 1), mean_stats), -1)
        stat_radius = torch.cat((self.prior_radius_mu.repeat(sample_size, batch_size, 1), self.prior_radius_sigma.repeat(sample_size, batch_size, 1), mean_stats), -1)
        #
        q_mean_mu = self.mean_mu(stat_mu).view(sample_size, batch_size, K, D)
        q_mean_sigma = self.mean_log_sigma(stat_mu).exp().view(sample_size, batch_size, K, D)

        q_radius_mu = self.radius_mu(stat_radius)
        q_radius_sigma = self.radius_log_sigma(stat_radius).exp()

        means = Normal(q_mean_mu, q_mean_sigma).sample()
        q.normal(q_mean_mu,
                 q_mean_sigma,
                 value=means,
                 name='means')
        p.normal(self.prior_mean_mu,
                 self.prior_mean_sigma,
                 value=q['means'],
                 name='means')

        rads = Normal(q_radius_mu, q_radius_sigma).sample()
        q.normal(q_radius_mu,
                 q_radius_sigma,
                 value=rads,
                 name='radius')
        p.normal(self.prior_radius_mu,
                 self.prior_radius_sigma,
                 value=q['radius'],
                 name='radius')

        return q, p

class Enc_z(nn.Module):
    def __init__(self, K, D, num_hidden, CUDA, device):
        super(self.__class__, self).__init__()
        self.log_prob = nn.Sequential(
            nn.Linear(2*D+2, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, 1))

        self.prior_pi = torch.ones(K) * (1./ K)
        if CUDA:
            self.prior_pi = self.prior_pi.cuda().to(device)

    def forward(self, obs, obs_mu, obs_rad, N, sample_size, batch_size, noise_sigma, device):
        q = probtorch.Trace()
        p = probtorch.Trace()
        noise_sigmas = torch.ones((sample_size, batch_size, N, 1)).cuda().to(device) * noise_sigma

        prob1 = self.log_prob(torch.cat((obs, obs_mu[:, :, 0, :].unsqueeze(-2).repeat(1,1,N,1), obs_rad[:, :, 0, :].unsqueeze(-2).repeat(1,1,N,1), noise_sigmas), -1))
        prob2 = self.log_prob(torch.cat((obs, obs_mu[:, :, 1, :].unsqueeze(-2).repeat(1,1,N,1), obs_rad[:, :, 1, :].unsqueeze(-2).repeat(1,1,N,1), noise_sigmas), -1))
        prob3 = self.log_prob(torch.cat((obs, obs_mu[:, :, 2, :].unsqueeze(-2).repeat(1,1,N,1), obs_rad[:, :, 2, :].unsqueeze(-2).repeat(1,1,N,1), noise_sigmas), -1))

        probs = torch.cat((prob1, prob2, prob3), -1) # S * B * N * K
        q_pi = F.softmax(probs, -1)
        z = cat(q_pi).sample()

        _ = q.variable(cat, probs=q_pi, value=z, name='zs')
        _ = p.variable(cat, probs=self.prior_pi, value=z, name='zs')
        return q, p

def initialize(NUM_HIDDEN_GLOBAL, STAT_SIZE, NUM_HIDDEN_LOCAL, K, D, CUDA, DEVICE, LR):
    enc_mu_rad = Enc_mu_rad(K, D, num_hidden=NUM_HIDDEN_GLOBAL, num_stats=STAT_SIZE, CUDA=CUDA, device=DEVICE)
    enc_z = Enc_z(K, D, num_hidden=NUM_HIDDEN_LOCAL, CUDA=CUDA, device=DEVICE)
    if CUDA:
        enc_mu_rad.cuda().to(DEVICE)
        enc_z.cuda().to(DEVICE)
    optimizer =  torch.optim.Adam(list(enc_z.parameters())+list(enc_mu_rad.parameters()),lr=LR, betas=(0.9, 0.99))
    return enc_mu_rad, enc_z, optimizer


In [14]:
# prior_pi = torch.ones(K) * (1./ K)
# if CUDA:
#     prior_pi = prior_pi.cuda().to(gpu)
    
# def enc_z(obs, obs_mu, radius, noise_sigma, sample_size, batch_size):
#     obs_mu_expand = obs_mu.unsqueeze(-2).repeat(1, 1, 1, N, 1) # S * B * K * N * D
#     obs_expand = obs.unsqueeze(2).repeat(1, 1, K, 1, 1) #  S * B * K * N * D
#     distance = ((obs_expand - obs_mu_expand)**2).sum(-1).sqrt()
#     perihelion = distance - radius #  S * B * K * N 
#     obs_dist = Normal(torch.zeros(1).cuda().to(gpu), torch.ones(1).cuda().to(gpu) * noise_sigma)
#     log_perihelion = obs_dist.log_prob(perihelion).transpose(-1, -2) # S * B * N * K   

#     q_pi = F.softmax(log_perihelion, -1)
#     q = probtorch.Trace()
#     p = probtorch.Trace()
#     z = cat(q_pi).sample()
#     _ = q.variable(cat, probs=q_pi, value=z, name='zs')
#     p = probtorch.Trace()
#     _ = p.variable(cat, probs=prior_pi, value=z, name='zs')
#     return q, p
def Eubo_oneshot_nss(enc_mu, enc_z, obs, N, K, D, sample_size, batch_size, gpu):
    """
    objective for oneshot encoder with joint importance weight, might need to figure out local importance weight
    
    """
    q_eta, p_eta = enc_mu_rad(obs, K, D, sample_size, batch_size)
    log_q_eta = q_eta['means'].log_prob.sum(-1) + q_eta['radius'].log_prob.sum(-1)
    log_p_eta = p_eta['means'].log_prob.sum(-1) + p_eta['radius'].log_prob.sum(-1)# S * B * K
    obs_mu = q_eta['means'].value
    obs_rad = q_eta['radius'].value
    ## update z -- cluster assignments
    q_z, p_z = enc_z(obs, obs_mu, obs_rad, N, sample_size, batch_size, noise_sigma, DEVICE)
    log_p_z = p_z['zs'].log_prob
    log_q_z = q_z['zs'].log_prob
    states = q_z['zs'].value
    log_obs_n = True_Log_likelihood(obs, states, obs_mu, obs_rad, K, D, noise_sigma, DEVICE, cluster_flag=False)
    log_weights = log_obs_n.sum(-1) + log_p_z.sum(-1) - log_q_z.sum(-1) + log_p_eta.sum(-1) - log_q_eta.sum(-1)
    weights = F.softmax(log_weights, 0).detach()
    eubo =(weights * log_weights).sum(0).mean()
    elbo = log_weights.mean()
    ess = (1. / (weights**2).sum(0)).mean() 
    return eubo, elbo, ess

In [15]:
EUBOs = []
ELBOs = []
ESSs = []

flog = open('results/log-' + PATH + '.txt', 'w+')
flog.write('EUBO\tELBO\tESS\n')
flog.close()
for epoch in range(NUM_EPOCHS):
    time_start = time.time()
    indices = torch.randperm(NUM_SEQS)
    EUBO = 0.0
    ELBO = 0.0
    ESS = 0.0
    for step in range(NUM_BATCHES):
        optimizer.zero_grad()
        batch_indices = indices[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        obs = Xs[batch_indices]
        obs = shuffler(obs).repeat(SAMPLE_SIZE, 1, 1, 1)
        if CUDA:
            obs = obs.cuda().to(gpu)
        eubo, elbo, ess = Eubo(enc_mu, enc_z, obs, N, K, D, SAMPLE_SIZE, BATCH_SIZE, gpu)
        ## gradient step
        eubo.backward()
        optimizer.step()
        EUBO += eubo.item()
        ELBO += elbo.item()
        ESS += ess.item()
    EUBOs.append(EUBO / NUM_BATCHES)
    ELBOs.append(ELBO / NUM_BATCHES)
    ESSs.append(ESS / NUM_BATCHES) 
    flog = open('results/log-' + PATH + '.txt', 'a+')
    print('%.3f\t%.3f\t%.3f'
            % (EUBO/NUM_BATCHES, ELBO/NUM_BATCHES, ESS/NUM_BATCHES), file=flog)
    flog.close()
    time_end = time.time()
    print('epoch=%d, EUBO=%.3f, ELBO=%.3f, ESS=%.3f (%ds)'
            % (epoch, EUBO/NUM_BATCHES, ELBO/NUM_BATCHES, ESS/NUM_BATCHES, 
               time_end - time_start))


epoch=0, EUBO=-54849.266, ELBO=-118258.983, ESS=1.000 (0s)
epoch=1, EUBO=-46707.918, ELBO=-109774.229, ESS=1.000 (1s)
epoch=2, EUBO=-42518.832, ELBO=-104844.678, ESS=1.000 (0s)
epoch=3, EUBO=-40042.711, ELBO=-99127.542, ESS=1.000 (0s)
epoch=4, EUBO=-38817.159, ELBO=-92934.404, ESS=1.000 (0s)
epoch=5, EUBO=-35854.855, ELBO=-85223.064, ESS=1.000 (0s)
epoch=6, EUBO=-34833.646, ELBO=-79987.647, ESS=1.000 (0s)
epoch=7, EUBO=-34162.907, ELBO=-75071.605, ESS=1.000 (0s)
epoch=8, EUBO=-33094.263, ELBO=-70988.469, ESS=1.000 (0s)
epoch=9, EUBO=-32928.943, ELBO=-68818.058, ESS=1.000 (0s)
epoch=10, EUBO=-31921.467, ELBO=-64841.663, ESS=1.000 (0s)
epoch=11, EUBO=-31421.557, ELBO=-61560.540, ESS=1.000 (0s)
epoch=12, EUBO=-30803.512, ELBO=-59034.454, ESS=1.000 (0s)
epoch=13, EUBO=-30782.096, ELBO=-58656.974, ESS=1.000 (0s)
epoch=14, EUBO=-30562.950, ELBO=-57662.190, ESS=1.000 (0s)
epoch=15, EUBO=-30849.315, ELBO=-57695.230, ESS=1.001 (0s)
epoch=16, EUBO=-29647.168, ELBO=-57772.958, ESS=1.001 (0s)
epoc

KeyboardInterrupt: 

In [None]:
BATCH_SIZE_TEST = 50

def sample_single_batch(num_seqs, N, K, D, sample_size, batch_size, gpu):
    indices = torch.randperm(num_seqs)
    batch_indices = indices[0*batch_size : (0+1)*batch_size]
    obs = Xs[batch_indices]
    obs = shuffler(obs).repeat(sample_size, 1, 1, 1)
    if CUDA:
        obs = obs.cuda().to(gpu)
    return obs

def test(enc_mu, enc_z, obs, N, K, D, mcmc_size, sample_size, batch_size, gpu):
    p_init_z = cat(prior_pi)
    states = p_init_z.sample((sample_size, batch_size, N,))
    log_p_z = p_init_z.log_prob(states)## S * B * N
    log_q_z = p_init_z.log_prob(states)
    for m in range(mcmc_size):
        q_mu, p_mu = enc_mu(obs, states, sample_size, batch_size)
        log_q_mu = q_mu['means'].log_prob.sum(-1)
        log_p_mu = p_mu['means'].log_prob.sum(-1) # S * B * K
        obs_mu = q_mu['means'].value
        log_obs_k = Log_likelihood(obs, states, obs_mu, K, D, radius=1.5, noise_sigma = 0.05, gpu=gpu, cluster_flag=True)
        log_weights_global = log_obs_k + log_p_mu - log_q_mu
        weights_global = F.softmax(log_weights_global, 0).detach()
        ## resample mu
        obs_mu = resample_mu(obs_mu, weights_global)
        ## update z -- cluster assignments
        q_z, p_z = enc_z(obs, obs_mu, 1.5, 0.05, sample_size, batch_size)
        log_p_z = p_z['zs'].log_prob
        log_q_z = q_z['zs'].log_prob ## S * B * N
        states = q_z['zs'].value
        log_obs_n = Log_likelihood(obs, states, obs_mu, K, D, radius=1.5, noise_sigma = 0.05, gpu=gpu, cluster_flag=False)
        log_weights_local = log_obs_n + log_p_z - log_q_z
        weights_local = F.softmax(log_weights_local, 0).detach()

    return q_mu, q_z

def plot_samples(obs, q_eta, q_z, K, batch_size, PATH):
    colors = ['r', 'b', 'g']
    fig = plt.figure(figsize=(25,50))
    xs = obs[0].cpu()
    mu_mu = q_eta['means'].dist.loc[0].cpu().data.numpy()
    mu_sigma = q_eta['means'].dist.scale[0].cpu().data.numpy()
    zs = q_z['zs'].dist.probs[0].cpu().data.numpy()
    for b in range(batch_size):
        ax = fig.add_subplot(int(batch_size / 5), 5, b+1)
        x = xs[b]
        z = zs[b]
        mu_mu_b = mu_mu[b]
        mu_sigma_b = mu_sigma[b]
        assignments = z.argmax(-1)
        for k in range(K):
            cov_k = np.diag(mu_sigma_b[k]**2)
            xk = x[np.where(assignments == k)]
            ax.scatter(xk[:, 0], xk[:, 1], c=colors[k], alpha=0.2)
            plot_cov_ellipse(cov=cov_k, pos=mu_mu_b[k], nstd=2, ax=ax, alpha=1.0, color=colors[k])
        ax.set_ylim([-5, 5])
        ax.set_xlim([-5, 5])
    plt.savefig('results/modes-' + PATH + '.svg')
    
obs = sample_single_batch(NUM_SEQS, N, K, D, SAMPLE_SIZE, BATCH_SIZE_TEST, gpu)
q_mu, q_z = test(enc_mu, enc_z, obs, N, K, D, MCMC_SIZE, SAMPLE_SIZE, BATCH_SIZE_TEST, gpu)
%time plot_samples(obs, q_mu, q_z, K, BATCH_SIZE_TEST, PATH)