In [None]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from plots import *
from kls import *
from torch.distributions.normal import Normal
from torch.distributions.one_hot_categorical import OneHotCategorical as cat
from torch.distributions.beta import Beta
from torch.distributions.uniform import Uniform
from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical as rcat

from torch import logsumexp
import sys
import time
import datetime
import math

In [None]:
Xs = torch.from_numpy(np.load('multishapes/Shapes.npy')).float()
Mus = torch.from_numpy(np.load('multishapes/Mus.npy')).float()
Zs = torch.from_numpy(np.load('multishapes/Zs.npy')).float()

num_seqs, N, D = Xs.shape
K = 4
## Model Parameters
STEPS = 5
NUM_SAMPLES = 10
NUM_HIDDEN = 32
NUM_LATENTS = 2
NUM_OBS = D
NUM_EPOCHS = 1000
BATCH_SIZE = 50
LEARNING_RATE = 1e-3
CUDA = True
PATH = 'ag-classtype-fixed-%dsteps' % STEPS

In [None]:
class Encoder(nn.Module):
    def __init__(self, num_obs= NUM_OBS,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_h = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.Tanh())
        self.enc_mu = nn.Sequential(
            nn.Linear(num_hidden, int(0.5*num_hidden)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden), num_latents))
        self.enc_log_sigma = nn.Sequential(
            nn.Linear(num_hidden, int(0.5*num_hidden)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden), num_latents))

    def forward(self, obs):
        h = self.enc_h(obs) # (B, H)
        mu = self.enc_mu(h) # (B, 2)
        sigma = torch.exp(self.enc_log_sigma(h)) # (B, 2)
        q_u = Normal(mu, sigma)
        u = q_u.sample() # (B, 2)
        log_q_u = q_u.log_prob(u).sum(-1)
        return u, log_q_u
    
class Decoder(nn.Module):
    def __init__(self, num_obs=D,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()

        self.x_sigma = 0.1 * torch.ones(num_obs).cuda()
        self.dec_mu = nn.Sequential(
            nn.Linear(num_latents, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, num_obs))

    def forward(self, u, obs):
        mu = self.dec_mu(u) # (B, K) -> (B, O)
        Nk = mu.shape[1]
        p_x = Normal(mu, self.x_sigma) # (B, O)
        log_p_x = p_x.log_prob(obs).sum(-1) # (B)
        return mu, log_p_x

In [None]:
def initialize():
    enc_circle = Encoder().cuda()
    dec_circle = Decoder().cuda()
    
    enc_square = Encoder().cuda()
    dec_square = Decoder().cuda()
    
    enc_triangle = Encoder().cuda()
    dec_triangle = Decoder().cuda()
    
    enc_cross = Encoder().cuda()
    dec_cross = Decoder().cuda()
    return enc_circle, dec_circle, enc_square, dec_square, enc_triangle, dec_triangle, enc_cross, dec_cross
enc_circle, dec_circle, enc_square, dec_square, enc_triangle, dec_triangle, enc_cross, dec_cross = initialize()

In [None]:
enc_circle.load_state_dict(torch.load('VAE/enc-circle'))
enc_square.load_state_dict(torch.load('VAE/enc-square'))
enc_triangle.load_state_dict(torch.load('VAE/enc-triangle'))
enc_cross.load_state_dict(torch.load('VAE/enc-cross'))

dec_circle.load_state_dict(torch.load('VAE/dec-circle'))
dec_square.load_state_dict(torch.load('VAE/dec-square'))
dec_triangle.load_state_dict(torch.load('VAE/dec-triangle'))
dec_cross.load_state_dict(torch.load('VAE/dec-cross'))

In [None]:
class Encoder_global(nn.Module):
    def __init__(self, num_obs=D+K,
                       num_stats=D+D*K,
                       num_hidden=64,
                       num_latents=D*K):
        super(self.__class__, self).__init__()
        self.enc_stats = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, int(0.5*num_hidden)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden), num_stats))

        self.mus_mean = nn.Sequential(
            nn.Linear(num_stats+K*K, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, int(0.5*num_hidden)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden), num_latents))
        self.mus_log_sigma = nn.Sequential(
            nn.Linear(num_stats+K*K, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, int(0.5*num_hidden)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden), num_latents))
        
    def forward(self, obs, cs, N, K, D, num_samples, batch_size):
        stats = self.enc_stats(obs).sum(-2) ## S * B * STATS_DIM
        stats_cs = torch.cat((stats, cs.view(num_samples, batch_size, K*K)), -1)
        q_mean = self.mus_mean(stats_cs).view(num_samples, batch_size, K, D)
        q_sigma = torch.exp(self.mus_log_sigma(stats_cs).view(num_samples, batch_size, K, D))
        q = Normal(q_mean, q_sigma)
        mus = q.sample()  # S * B * K * D
        log_q = q.log_prob(mus).sum(-1).sum(-1) # S * B
        return q_mean, q_sigma, mus, log_q ## mus_mean and mus_sigma are B * K * D

class Encoder_type(nn.Module):
    def __init__(self, num_obs=D+K,
                       num_stats=D+D*K,
                       num_hidden=64,
                       num_latents=K*K):
        super(self.__class__, self).__init__()
        self.enc_stats = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, int(0.5*num_hidden)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden), num_stats))

        self.enc_pis = nn.Sequential(
            nn.Linear(num_stats+K, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, int(0.5*num_hidden)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden), num_latents))

        
    def forward(self, obs, Nks, K, D, num_samples, batch_size):
        stats = self.enc_stats(obs).sum(-2) ## S * B * STATS_DIM
        stats_nks = torch.cat((stats, Nks), -1) ## S * B * STATS_DIM+K
        cs_pis = F.softmax(self.enc_pis(stats_nks).view(num_samples, batch_size, K, K), -1) ## S * B * K*K
        q_c = rcat(temperature=torch.Tensor([0.66]).cuda(), probs=cs_pis)
        cs = q_c.sample() # S * B * K*K
        log_q_c = q_c.log_prob(cs).sum(-1)# S * B
        return cs_pis, cs, log_q_c
    
class Encoder_local(nn.Module):
    def __init__(self, num_obs=D+K*D+K*K,
                       num_hidden=64,
                       num_latents=K):
        super(self.__class__, self).__init__()
        self.enc_onehot = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, int(0.5*num_hidden)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden), num_latents),
            nn.Softmax(-1))
        
    def forward(self, obs, N, K, batch_size):
        zs_pi = self.enc_onehot(obs)
        q_z = rcat(temperature=torch.Tensor([0.66]).cuda(), probs=zs_pi)
        zs = q_z.sample() # B * N * K
        log_qz = q_z.log_prob(zs).sum(-1) ## B
        return zs_pi, zs, log_qz

In [None]:
prior_mean = torch.zeros((BATCH_SIZE, K, D)).cuda()
prior_sigma = torch.ones((BATCH_SIZE, K, D)).cuda()
prior_pi = 0.25*torch.ones((BATCH_SIZE, K, K)).cuda()


def shuffler(batch_Xs, N, D, batch_size):
    indices = torch.cat([torch.randperm(N).unsqueeze(0) for b in range(batch_size)])
    indices_Xs = indices.unsqueeze(-1).repeat(1, 1, D)
    return torch.gather(batch_Xs, 1, indices_Xs)

def log_likelihood(x, z, mus, cs, N, D, K, num_samples, batch_size):
    log_p_xs = torch.zeros((num_samples, batch_size)).cuda()
    for s in range(num_samples):
        for b in range(batch_size):
            xb = x[b]
            labels = z[s, b].argmax(-1)
            for k in range(K):
                ind = (labels == k).nonzero()[:, 0]
                xbk = xb[ind]
                xbk = xbk - mus[s, b, k]
                cbk = cs[s, b, k].argmax().item()
                if cbk == 0:
                    u, log_q = enc_circle(xbk)
                    mu, log_p_x = dec_circle(u, xbk)
                    log_p_xs[s, b] = log_p_xs[s, b] + log_p_x.sum(-1)
                elif cbk == 1:
                    u, log_q = enc_square(xbk)
                    mu, log_p_x = dec_square(u, xbk)
                    log_p_xs[s, b] = log_p_xs[s, b] + log_p_x.sum(-1)
                elif cbk == 2:
                    u, log_q = enc_cross(xbk)
                    mu, log_p_x = dec_cross(u, xbk)
                    log_p_xs[s, b] = log_p_xs[s, b] + log_p_x.sum(-1)
                else:
                    u, log_q = enc_triangle(xbk)
                    mu, log_p_x = dec_triangle(u, xbk)    
                    log_p_xs[s, b] = log_p_xs[s, b] + log_p_x.sum(-1)
    return log_p_xs

def inti_globals(prior_mean, prior_sigma, prior_pi, num_samples):
    p_mu = Normal(prior_mean, prior_sigma)
    mus = p_mu.sample((num_samples,))
    p_c = rcat(temperature=torch.Tensor([0.66]).cuda(), probs=prior_pi)
    cs = p_c.sample((num_samples,))  
    ## log prior size B
    log_p_mu =  p_mu.log_prob(mus).sum(-1).sum(-1) ## S * B
    log_p_c = p_c.log_prob(cs).sum(-1)
    return mus, cs, log_p_mu + log_p_c

def ag(x, N, K, D, num_samples, steps, batch_size):
    data_flat = x.repeat(num_samples, 1, 1, 1) ## S * B * N * D
    log_increment_weights = torch.zeros((steps, num_samples, batch_size))
    p_z = rcat(temperature=torch.Tensor([0.66]).cuda(), probs=0.25*torch.ones(4).cuda())
    p_c = rcat(temperature=torch.Tensor([0.66]).cuda(), probs=0.25*torch.ones((4, 4)).cuda())
    for m in range(steps):
        if m == 0:
            mus, cs, log_p = inti_globals(prior_mean, prior_sigma, prior_pi, num_samples)
            mus_flat = mus.view(num_samples, batch_size, K*D).unsqueeze(2).repeat(1, 1, N, 1)
            cs_flat = cs.view(num_samples, batch_size, K*K).unsqueeze(2).repeat(1, 1, N, 1)
            x_mu_flat = torch.cat((data_flat, mus_flat, cs_flat), -1)

            zs_pi, zs, log_q_z = enc_local(x_mu_flat, N, K, batch_size)
            log_p_z = p_z.log_prob(zs).sum(-1)
            log_p_x = log_likelihood(x, zs, mus, cs, N, D, K, num_samples, batch_size)
            log_increment_weights[m] = log_p_x + log_p_z - log_q_z
        else:
            Nks = zs.sum(-2) ## S * B * K
            x_z_flat = torch.cat((data_flat, zs), -1)
            cs_pis, cs, log_q_c = enc_type(x_z_flat, Nks, K, D, num_samples, batch_size)
            log_p_c = p_c.log_prob(cs).sum(-1)
            q_mean, q_sigma, mus, log_q_eta = enc_global(x_z_flat, cs, N, K, D, num_samples, batch_size)
            mus_flat = mus.view(num_samples, batch_size, K*D).unsqueeze(2).repeat(1, 1, N, 1)
            cs_flat = cs.view(num_samples, batch_size, K*K).unsqueeze(2).repeat(1, 1, N, 1)
            x_mu_flat = torch.cat((data_flat, mus_flat, cs_flat), -1)
            zs_pi, zs, log_q_z = enc_local(x_mu_flat, N, K, batch_size)
            log_p_z = p_z.log_prob(zs).sum(-1)
            log_p_x = log_likelihood(x, zs, mus, cs, N, D, K, num_samples, batch_size)
            log_p_eta = Normal(prior_mean, prior_sigma).log_prob(mus).sum(-1).sum(-1)
            log_increment_weights[m] = log_p_x + log_p_z + log_p_c + log_p_eta -log_q_z - log_q_eta - log_q_c
            
    increment_weights = torch.exp(log_increment_weights - logsumexp(log_increment_weights, 1).unsqueeze(1).repeat(1, num_samples, 1)).detach()
    ## EUBO and ELBO
    eubos = torch.mul(increment_weights, log_increment_weights).sum(1)
    eubo = eubos.mean(0).mean()
    
    elbos = log_increment_weights.mean(1)
    elbo = elbos.mean(0).mean()  
    ess = (1./ (increment_weights ** 2).sum(1)).mean(0).mean()

    eubo_last = eubos[-1].mean()
    elbo_last = elbos[-1].mean()
    
    return eubo, elbo, ess, eubo_last, elbo_last


In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 1e-2)     
        
enc_global = Encoder_global()
enc_global = enc_global.cuda()
enc_local = Encoder_local()
enc_local = enc_local.cuda()
enc_type = Encoder_type().cuda()
optimizer = torch.optim.Adam(list(enc_type.parameters())+list(enc_local.parameters())+list(enc_global.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))

In [None]:
EUBOs = []
ELBOs = []
ESSs = []
EUBOs_ag = []
ELBOs_ag = []

flog = open('results/log-' + PATH + '.txt', 'w+')
flog.write('EUBO, ELBO, ESS, eubo_ag, elbo_ag\n')
flog.close()

num_batches = int((Xs.shape[0] / BATCH_SIZE))
for epoch in range(NUM_EPOCHS):
    indices = torch.randperm(num_seqs)
    for step in range(num_batches):
        time_start = time.time()
        optimizer.zero_grad()
        batch_indices = indices[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        batch_Xs = Xs[batch_indices]
        batch_Xs = shuffler(batch_Xs, N, D, BATCH_SIZE)
        batch_Xs = batch_Xs.cuda()
        eubo, elbo, ess, eubo_ag, elbo_ag = ag(batch_Xs, N, K, D, NUM_SAMPLES, STEPS, BATCH_SIZE)
        eubo.backward()
        optimizer.step()
        EUBOs.append(eubo.item())
        ELBOs.append(elbo.item())
        ESSs.append(ess.item())
        EUBOs_ag.append(eubo_ag.item())
        ELBOs_ag.append(elbo_ag.item())
        flog = open('results/log-' + PATH + '.txt', 'a+')
        flog.write(str(eubo.item()) + ', ' + str(elbo.item()) + ', ' + str(ess.item()) + ', ' + str(eubo_ag.item()) + ', ' + str(elbo_ag.item()) + '\n')
        flog.close()
        time_end = time.time()
        print('epoch=%d, step=%d, EUBO=%f, ELBO=%f, ESS=%.3f (%ds)' % (epoch, step, eubo, elbo, ess, time_end - time_start))

In [None]:
# fig = plt.figure(figsize=(12,8))
# ax = fig.add_subplot(1,1,1)
# ax.plot(EUBOs, 'r', label='average EUBO')
# ax.plot(ELBOs, 'b', label='average EUBO')
# ax.plot(EUBOs_ag, 'g', label='last EUBO')
# ax.plot(ELBOs_ag, 'orange', label='last EUBO')
# ax.legend(fontsize=18)
# ax.set_xlabel('gradient steps')
# plt.savefig('results/train' + PATH + '%dsamples-%dsteps-%datchsize.svg' % (NUM_SAMPLES, STEPS, BATCH_SIZE))

In [None]:
flog = open('results/log-' + PATH + '.txt', 'w+')
flog.write('EUBO, ELBO, ESS, eubo_ag, elbo_ag\n')
for i in range(len(EUBOs)):
    flog.write(str(EUBOs[i]) + ', ' + str(ELBOs[i]) + ', ' + str(1.0) + ', ' + str(EUBOs_ag[i]) + ', ' + str(ELBOs_ag[i]) + '\n')
flog.close()

In [None]:
# torch.save(enc_global.state_dict(), 'models/global-enc-' + PATH)
# torch.save(enc_local.state_dict(), 'models/local-enc' + PATH)

In [None]:
def test(x, N, K, D, num_samples, steps, batch_size):
    data_flat = x.repeat(num_samples, 1, 1, 1) ## S * B * N * D
    log_increment_weights = torch.zeros((steps, num_samples, batch_size))
    p_z = rcat(temperature=torch.Tensor([0.66]).cuda(), probs=0.25*torch.ones(4).cuda())
    for m in range(steps):
        if m == 0:
            mus, log_p_mu = inti_global(prior_mean, prior_sigma, num_samples)
            mus_flat = mus.view(NUM_SAMPLES, BATCH_SIZE, K*D).unsqueeze(2).repeat(1, 1, N, 1)
            x_mu_flat = torch.cat((data_flat, mus_flat), -1)
            zs_pi, zs, log_q_z = enc_local(x_mu_flat, N, K, batch_size)
            log_p_z = p_z.log_prob(zs).sum(-1)
            log_p_x = log_likelihood(x, zs, mus, N, D, K, num_samples, batch_size)
            log_increment_weights[m] = log_p_x + log_p_z - log_q_z
        else:
            x_z_flat = torch.cat((data_flat, zs), -1)
            q_mean, q_sigma, mus, log_q_eta = enc_global(x_z_flat, N, K, D, num_samples, batch_size)
            mus_flat = mus.view(num_samples, batch_size, K*D).unsqueeze(2).repeat(1, 1, N, 1)
            x_mu_flat = torch.cat((data_flat, mus_flat), -1)
            zs_pi, zs, log_q_z = enc_local(x_mu_flat, N, K, batch_size)
            log_p_z = p_z.log_prob(zs).sum(-1)
            log_p_x = log_likelihood(x, zs, mus, N, D, K, num_samples, batch_size)
            log_p_eta = Normal(prior_mean, prior_sigma).log_prob(mus).sum(-1).sum(-1)
            log_increment_weights[m] = log_p_x + log_p_z + log_p_eta -log_q_z - log_q_eta
    increment_weights = torch.exp(log_increment_weights - logsumexp(log_increment_weights, 1).unsqueeze(1).repeat(1, num_samples, 1)).detach()
    final_weights = increment_weights[-1]
    
    return 

In [None]:
BATCH_SIZE = 35
prior_mean = torch.zeros((BATCH_SIZE, K, D)).cuda()
prior_sigma = torch.ones((BATCH_SIZE, K, D)).cuda()
indices = torch.randperm(num_seqs)
batch_indices = indices[0*BATCH_SIZE : (0+1)*BATCH_SIZE]
batch_Xs = Xs[batch_indices]
batch_Xs = shuffler(batch_Xs, N, D, BATCH_SIZE)
batch_Mus = Mus[batch_indices]
x = batch_Xs.cuda()
num_samples = 1
steps = 500
STEPS = steps
batch_size = BATCH_SIZE
data_flat = x.repeat(num_samples, 1, 1, 1) ## S * B * N * D
for m in range(steps):
    if m == 0:
        mus, log_p_mu = inti_global(prior_mean, prior_sigma, num_samples)
        mus_flat = mus.view(num_samples, BATCH_SIZE, K*D).unsqueeze(2).repeat(1, 1, N, 1)
        x_mu_flat = torch.cat((data_flat, mus_flat), -1)
        zs_pi, zs, log_q_z = enc_local(x_mu_flat, N, K, BATCH_SIZE)
    else:
        x_z_flat = torch.cat((data_flat, zs), -1)
        q_mean, q_sigma, mus, log_q_eta = enc_global(x_z_flat, N, K, D, num_samples, BATCH_SIZE)
        mus_flat = mus.view(num_samples, BATCH_SIZE, K*D).unsqueeze(2).repeat(1, 1, N, 1)
        x_mu_flat = torch.cat((data_flat, mus_flat), -1)
        zs_pi, zs, log_q_z = enc_local(x_mu_flat, N, K, BATCH_SIZE)

In [None]:
def permute_single(xb, zb, mub, N, K, D):
    first_ll = torch.zeros((K,K))
    tt = zb.argmax(-1)
    for p in range(K):
        ind_tt = (tt == p).nonzero()[:, 0]
        xk = xb[ind_tt] - mub[p]
        for k in range(K):
            if k == 0:
                u, log_q = enc_circle(xk)
                mu_na, log_p_x = dec_circle(u, xk)
            elif k == 1:
                u, log_q = enc_square(xk)
                mu_na, log_p_x = dec_square(u, xk)
            elif k == 2:
                u, log_q = enc_cross(xk)
                mu_na, log_p_x = dec_cross(u, xk)
            else:
                u, log_q = enc_triangle(xk)
                mu_na, log_p_x = dec_triangle(u, xk)    
            first_ll[p, k] = log_p_x.sum(-1)
    perm_ll = []
    perm_ind = []
    perm_temp = [0,1,2,3]
    for i1, q in enumerate(perm_temp):
        ind1 = [a for a in perm_temp if a != q]
        for i2, r in enumerate(ind1):
            ind2 = [a for a in ind1 if a != r]
            for i3, s in enumerate(ind2):
                ind3 = [a for a in ind2 if a != s]
                for i3, t in enumerate(ind3):
                    perm1 = np.arange(0, 4, 1)
                    perm2 = np.array([q, r, s, t])
                    perm_ll.append(first_ll[perm1, perm2].sum().item())
                    perm_ind.append(perm2)  
    new_Z = torch.zeros((N, K))
    best_ind = np.array(perm_ll).argmax()
    assignment = perm_ind[best_ind]
    mu_perm_ind = np.argsort(assignment)
    new_mu = mub[mu_perm_ind]
    for k2 in range(K):
        ind_prev = (tt == k2).nonzero()[:, 0]
        onehot = torch.zeros(K)
        onehot[assignment[k2]] = 1
        new_Z[ind_prev] = onehot
    return new_mu, new_Z, assignment

def permute(X, Z, mus, N, K, D, BATCH_SIZE):
    new_Zs = torch.zeros((BATCH_SIZE, N, K))
    new_mus = torch.zeros((BATCH_SIZE, K, D))   
    for b in range(BATCH_SIZE):
        xb = X[b]
        zb = Z[b]
        mub = mus[b]
        new_mu, new_z = permute_single(xb, zb, mub, N, K, D)
        new_Zs[b] = new_z
        new_mus[b] = new_mu
    return new_mus, new_Zs

In [None]:
new_mus, zs = permute(x, zs[0], mus[0], N, K, D, BATCH_SIZE)

In [None]:

# new_mu, new_z, aa = permute_single(xb, zb, mub, N, K, D)

In [None]:
bb = zb.cpu().data.numpy().argmax(-1)
ind = (bb == p).nonzero()[0]
xbk = xb.cpu().data.numpy()[ind]
plt.scatter(xbk[:,0], xbk[:,1])

In [None]:
tt = zb.argmax(-1)
ind_tt = (tt == 0).nonzero()[:, 0]

xk = xb[ind_tt] - mub[0]
for k in range(K):
    if k == 0:
        u, log_q = enc_circle(xk)
        mu_na, log_p_x = dec_circle(u, xk)
    elif k == 1:
        u, log_q = enc_square(xk)
        mu_na, log_p_x = dec_square(u, xk)
    elif k == 2:
        u, log_q = enc_cross(xk)
        mu_na, log_p_x = dec_cross(u, xk)
    else:
        u, log_q = enc_triangle(xk)
        mu_na, log_p_x = dec_triangle(u, xk)    
    print(log_p_x.sum(-1))

In [None]:
p =3
tt = zb.argmax(-1)
ind_tt = (tt == p).nonzero()[:, 0]
mu_meanb = q_mean[0,0]
mu_true = batch_Mus[0].cuda()
xk = xb[ind_tt] - mu_true[2]
for k in range(K):
    if k == 0:
        u, log_q = enc_circle(xk)
        mu_na, log_p_x = dec_circle(u, xk)
    elif k == 1:
        u, log_q = enc_square(xk)
        mu_na, log_p_x = dec_square(u, xk)
    elif k == 2:
        u, log_q = enc_cross(xk)
        mu_na, log_p_x = dec_cross(u, xk)
    else:
        u, log_q = enc_triangle(xk)
        mu_na, log_p_x = dec_triangle(u, xk)    
    print(log_p_x.sum())

In [None]:
xb = x[0]
zb = zs[0][0]
mub = mus[0][0]

In [None]:
# q_mean[0,0]
mu_meanb

In [None]:
mu_true

In [None]:
def plot_final_samples(x, z, mus, q_mean, q_sigma, batch_size, PATH):
    colors = ['red', 'blue', 'green', 'gold']
    ## order [cirlcr, square, cross,  triangle]
    fig = plt.figure(figsize=(25,50))
    for b in range(batch_size):
        ax = fig.add_subplot(int(batch_size / 5), 5, b+1)
        xb = x[b].cpu().data.numpy()
        zb = z[b].cpu().data.numpy()
        mu = mus[b].cpu().data.numpy()
        labels = zb.argmax(-1)
        mu_mean = q_mean[b].cpu().data.numpy()
        mu_sigma = q_sigma[b].cpu().data.numpy()
        for k in range(K):

            ind = (labels == k).nonzero()[0]
            xbk = xb[ind]
            ax.scatter(xbk[:, 0], xbk[:, 1], c=colors[k])
            
            plot_cov_ellipse(cov=np.diag(mu_sigma[k]), pos=mu_mean[k], nstd=2, ax=ax, alpha=0.2, color=colors[k])

        ax.set_ylim([-4, 4])
        ax.set_xlim([-4, 4])
    plt.savefig('results/modes' + PATH + '%dSTEPS.svg' % STEPS)
plot_final_samples(batch_Xs, zs[0], mus[0], q_mean[0], q_sigma[0], batch_size, PATH)

In [None]:
steps = 5
BATCH_SIZE = 1
indices = torch.randperm(num_seqs)
batch_indices = indices[0*BATCH_SIZE : (0+1)*BATCH_SIZE]
batch_Xs = Xs[batch_indices]
batch_Xs = shuffler(batch_Xs, N, D, BATCH_SIZE)
x = batch_Xs.cuda()
num_samples = 1

prior_mean = torch.zeros((BATCH_SIZE, K, D)).cuda()
prior_sigma = torch.ones((BATCH_SIZE, K, D)).cuda()

colors = ['red', 'blue', 'green', 'gold']
data_flat = x.repeat(num_samples, 1, 1, 1) ## S * B * N * D
for m in range(steps):
    if m == 0:
        mus, log_p_mu = inti_global(prior_mean, prior_sigma, num_samples)
        mus_flat = mus.view(num_samples, BATCH_SIZE, K*D).unsqueeze(2).repeat(1, 1, N, 1)
        x_mu_flat = torch.cat((data_flat, mus_flat), -1)
        zs_pi, zs, log_q_z = enc_local(x_mu_flat, N, K, BATCH_SIZE)
    else:
        x_z_flat = torch.cat((data_flat, zs), -1)
        q_mean, q_sigma, mus, log_q_eta = enc_global(x_z_flat, N, K, D, num_samples, BATCH_SIZE)
        mus_flat = mus.view(num_samples, BATCH_SIZE, K*D).unsqueeze(2).repeat(1, 1, N, 1)
        x_mu_flat = torch.cat((data_flat, mus_flat), -1)
        zs_pi, zs, log_q_z = enc_local(x_mu_flat, N, K, BATCH_SIZE)
    
    
    fig = plt.figure(figsize=(10,50))

    ax = fig.add_subplot(STEPS, 1, m+1)
    xb = x[0].cpu().data.numpy()
    zb = zs[0,0].cpu().data.numpy()
    labels = zb.argmax(-1)
    for k in range(K):
        ind = (labels == k).nonzero()[0]
        xbk = xb[ind]
        ax.scatter(xbk[:, 0], xbk[:, 1], c=colors[k])
    ax.set_ylim([-4, 4])
    ax.set_xlim([-4, 4])

In [None]:
data_flat.shape


In [None]:
colors = ['red', 'blue', 'green', 'gold']
for p in range(BATCH_SIZE):
    fig = plt.figure(figsize=(10,5))
    ax1 = fig.add_subplot(1,2,1)
    ax2 = fig.add_subplot(1,2,2)
    ax1.set_xlim([-4,4])
    ax1.set_ylim([-4,4])
    ax2.set_xlim([-4,4])
    ax2.set_ylim([-4,4])
    x = batch_Xs[p]
    z = batch_Zs[p]
    mu = batch_Mus[p]
    labels = z.argmax(-1)
    for k in range(K):
        ax1.scatter(mu[k, 0].data.numpy(), mu[k, 1].data.numpy(), c=colors[k])
        ind = (labels == k).nonzero()[:, 0]
        xk = x[ind]
        zk = z[ind] 
        ax2.scatter(xk[:,0].data.numpy(), xk[:,1].data.numpy(), c=colors[k]) 

In [None]:
fig = plt.figure(figsize=(12,5))
ax1 = fig.add_subplot(1,2,1)
ax2 = fig.add_subplot(1,2,2)
# ax1.scatter(x[:, 0].data.numpy(), x[:, 1].data.numpy())
labels = z.argmax(-1)
for k in range(K):  
    ind = (labels == k).nonzero()[:, 0]
    xk = x[ind]
    zk = z[ind]
    ax1.scatter(xk[:,0].data.numpy(), xk[:,1].data.numpy(), c=colors[k])
#     ax2.plot(zk.data.numpy(), c=colors[k])

ax2.set_aspect('equal')
ax1.set_aspect('equal')

In [None]:

mus = init_means()
Z, log_pis = E_step(Xs, mus, N, K, D)

labels = Z.argmax(-1)
mus = torch.zeros((K, D))
for k in range(K):
    labels_k = labels == k
    if labels_k.sum().item() == 0:
        continue
    else:
        mus[k] = Xs[labels_k].mean(0)

In [None]:
mus

In [None]:
labels = Z.argmax(-1)

In [None]:
batch_Zs[2]