In [1]:
%matplotlib inline
%run ../../path_import.py
import numpy as np
import matplotlib.pyplot as plt
from plots import *
from utils import *
from training import *
from model_mu_rad import *
import time
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [None]:
Data = torch.from_numpy(np.load('../rings_varying_radius/obs.npy')).float()
## Data Parameters
N = 300
K = 3
D = 2

## Model Parameters
MCMC_SIZE = 10
SAMPLE_SIZE = 10
NUM_HIDDEN_GLOBAL = 8
NUM_HIDDEN_LOCAL = 64
STAT_SIZE = 8
NUM_LATENTS =  D
## Training Parameters
BATCH_SIZE = 20
NUM_EPOCHS = 200
LEARNING_RATE = 1e-3
CUDA = torch.cuda.is_available()
PATH = 'ag-mu-rad-idw-v1'
DEVICE = torch.device('cuda:0')

In [None]:
enc_mu_rad, enc_z, optimizer = initialize(NUM_HIDDEN_GLOBAL, STAT_SIZE, NUM_HIDDEN_LOCAL, K, D, CUDA, DEVICE, LEARNING_RATE)

In [None]:
train_rad(Eubo_rad, enc_mu_rad, enc_z, optimizer, Data, K, NUM_EPOCHS, MCMC_SIZE, SAMPLE_SIZE, BATCH_SIZE, PATH, CUDA, DEVICE)

In [None]:
BATCH_SIZE_TEST = 50
def sample_single_batch(num_seqs, Xs, sample_size, batch_size, CUDA, device):
    indices = torch.randperm(num_seqs)
    batch_indices = indices[0*batch_size : (0+1)*batch_size]
    obs = Xs[batch_indices]
    obs = shuffler(obs).repeat(sample_size, 1, 1, 1)
    if CUDA:
        obs = obs.cuda().to(device)
    return obs

def test(enc_mu_rad, enc_z, obs, N, K, D, mcmc_size, sample_size, batch_size, noise_sigma, device):
    """
    initialize z
    learn both mean and radius

    """
    p_init_z = cat(enc_z.prior_pi)
    states = p_init_z.sample((sample_size, batch_size, N,))
    log_p_z = p_init_z.log_prob(states)## S * B * N
    log_q_z = p_init_z.log_prob(states)

    for m in range(mcmc_size):
        if m != 0:
            states = resample_states(states, weights_local)
        q_eta, p_eta = enc_mu_rad(obs, states, K, sample_size, batch_size)
        log_q_eta = q_eta['means'].log_prob.sum(-1) + q_eta['radius'].log_prob.sum(-1)
        log_p_eta = p_eta['means'].log_prob.sum(-1) + p_eta['radius'].log_prob.sum(-1)
        obs_mu = q_eta['means'].value
        obs_rad = q_eta['radius'].value
        log_obs_k = True_Log_likelihood(obs, states, obs_mu, obs_rad, K, D, noise_sigma, device, cluster_flag=True)
        log_weights_global = log_obs_k + log_p_eta - log_q_eta
        weights_global = F.softmax(log_weights_global, 0).detach()
        ## resample mu
        obs_mu, obs_rad = resample_eta(obs_mu, obs_rad, weights_global)
        ## update z -- cluster assignments
        q_z, p_z = enc_z(obs, obs_mu, obs_rad, N, sample_size, batch_size, noise_sigma, device)
        log_p_z = p_z['zs'].log_prob
        log_q_z = q_z['zs'].log_prob ## S * B * N
        states = q_z['zs'].value
        log_obs_n = True_Log_likelihood(obs, states, obs_mu, obs_rad, K, D, noise_sigma, device, cluster_flag=False)
        log_weights_local = log_obs_n + log_p_z - log_q_z
        weights_local = F.softmax(log_weights_local, 0).detach()

    return q_eta, q_z

def plot_samples(obs, q_eta, q_z, K, batch_size, PATH):
    colors = ['r', 'b', 'g']
    fig = plt.figure(figsize=(25,50))
    xs = obs[0].cpu()
    mu_mu = q_eta['means'].dist.loc[0].cpu().data.numpy()
    mu_sigma = q_eta['means'].dist.scale[0].cpu().data.numpy()
    zs = q_z['zs'].dist.probs[0].cpu().data.numpy()
    for b in range(batch_size):
        ax = fig.add_subplot(int(batch_size / 5), 5, b+1)
        x = xs[b]
        z = zs[b]
        mu_mu_b = mu_mu[b]
        mu_sigma_b = mu_sigma[b]
        assignments = z.argmax(-1)
        for k in range(K):
            cov_k = np.diag(mu_sigma_b[k]**2)
            xk = x[np.where(assignments == k)]
            ax.scatter(xk[:, 0], xk[:, 1], c=colors[k], alpha=0.2)
            plot_cov_ellipse(cov=cov_k, pos=mu_mu_b[k], nstd=2, ax=ax, alpha=1.0, color=colors[k])
        ax.set_ylim([-8, 8])
        ax.set_xlim([-8, 8])
    plt.savefig('results/modes-' + PATH + '.svg')
    
obs = sample_single_batch(NUM_SEQS, Xs, SAMPLE_SIZE, BATCH_SIZE_TEST, CUDA, DEVICE)
q_eta, q_z = test(enc_mu_rad, enc_z, obs, N, K, D, 100, SAMPLE_SIZE, BATCH_SIZE_TEST, noise_sigma=0.05, device=DEVICE)
%time plot_samples(obs, q_eta, q_z, K, BATCH_SIZE_TEST, PATH)