In [None]:
%matplotlib inline
%run ../../import_envs.py
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

In [None]:
## Load dataset
K = 3
N_c, N_s, N_t = 70, 72, 72
N = N_c+N_s+N_t
data_path = data_dir + "ncmm/shapes_c=%d_s=%d_t=%d" % (N_c, N_s, N_t)

OB = torch.from_numpy(np.load(data_path + '/ob.npy')).float()
MU = torch.from_numpy(np.load(data_path + '/mu.npy')).float()
ANGLE = torch.from_numpy(np.load(data_path + '/angle.npy')).float()
# STATE = torch.from_numpy(np.load(data_path + '/state.npy')).float()
## Hyper Parameters
NUM_DATASETS = OB.shape[0]

## Train Parameters
NUM_EPOCHS = 500

D = 2
SAMPLE_SIZE = 10
BATCH_SIZE = 20
NUM_BATCHES = int((NUM_DATASETS / BATCH_SIZE))

CUDA = torch.cuda.is_available()
DEVICE = 1
RECON_SIGMA = torch.ones(1) * 0.2
# lg2pi = torch.log(torch.ones(1) * 2 * math.pi)
PATH = 'dec-%dpts-%dsamples' % (N_c+N_s+N_t, SAMPLE_SIZE)
## MOdel Parameters
NUM_HIDDEN_LOCAL = 64
NUM_HIDDEN = 32
NUM_NSS = 8
LEARNING_RATE = 1e-3
Train_Params = (NUM_EPOCHS, K, D, SAMPLE_SIZE, BATCH_SIZE, CUDA, DEVICE, PATH)

In [None]:
from decoder_shapes import *
from local_oneshot_state_shapes import *
from local_enc_angle import *
dec_x = Dec_x(D, NUM_HIDDEN, RECON_SIGMA, CUDA, DEVICE)
f_angle = Enc_angle(D, NUM_HIDDEN, CUDA, DEVICE)
f_state = Oneshot_state(K, D, NUM_HIDDEN_LOCAL, CUDA, DEVICE)

if CUDA:
    with torch.cuda.device(DEVICE):
        dec_x.cuda()
        f_angle.cuda()
        f_state.cuda()
        
optimizer =  torch.optim.Adam(list(f_angle.parameters())+list(f_state.parameters())+list(dec_x.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))

In [None]:
for epoch in range(NUM_EPOCHS):
    time_start = time.time()
    LOSS = 0.0
    indices = torch.randperm(NUM_DATASETS)
    for step in range(NUM_BATCHES):
        optimizer.zero_grad()
        batch_indices = indices[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        ob = OB[batch_indices]
#         state = STATE[batch_indices]
        mu = MU[batch_indices]
#         angle = ANGLE[batch_indices]
#         ob_embedding = torch.cat((ob, angle), -1)
#         ob_embedding = shuffler(ob_embedding).repeat(SAMPLE_SIZE, 1, 1, 1)
        ob = shuffler(ob).repeat(SAMPLE_SIZE, 1, 1, 1)
        mu = mu.repeat(SAMPLE_SIZE, 1, 1, 1)
        if CUDA:
            with torch.cuda.device(DEVICE):
                ob =ob.cuda()
                mu = mu.cuda()
        q_state, p_state = f_state.forward(ob, mu, K)
        log_p_state = p_state['states'].log_prob
        log_q_state = q_state['states'].log_prob
        state = q_state['states'].value ## S * B * N * K
        q_angle, p_angle = f_angle(ob, state, mu)
        log_q_angle = q_angle['angles'].log_prob.sum(-1)
        log_p_angle = p_angle['angles'].log_prob.sum(-1)
        angle = q_angle['angles'].value * 2 * math.pi
        p = dec_x(ob, state, angle, mu)
        ll = p['likelihood'].log_prob.sum(-1) ## S * B * N
        log_w =  ll.detach() + log_p_state + log_p_angle - log_q_state - log_q_angle
        w = F.softmax(log_w, 0).detach()
        loss_phi = (w * log_w).sum(0).mean()
        loss_theta = -ll.sum(-1).mean()
        ## gradient step
        loss_phi.backward(retain_graph=True)
        loss_theta.backward()
        optimizer.step()  
        LOSS += loss_theta.detach()
    if epoch % 1 == 0:
        time_end = time.time()
        print('epoch=%d, loss=%.4f (%ds)' % (epoch, LOSS / NUM_BATCHES, time_end - time_start))
        time_start = time.time()

In [None]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
from plots import plot_final_samples
    
indices = torch.randperm(NUM_DATASETS)
batch_indices = indices[0*BATCH_SIZE : (0+1)*BATCH_SIZE]
ob = OB[batch_indices]
mu = MU[batch_indices]

ob = shuffler(ob).repeat(SAMPLE_SIZE, 1, 1, 1)
mu = mu.repeat(SAMPLE_SIZE, 1, 1, 1)
if CUDA:
    with torch.cuda.device(DEVICE):
        ob =ob.cuda()
        mu = mu.cuda()
q_state, p_state = f_state.forward(ob, mu, K)
log_p_state = p_state['states'].log_prob
log_q_state = q_state['states'].log_prob
state = q_state['states'].value ## S * B * N * K
p = dec_x(ob, state, angle, mu)

E_state =  q_state['states'].dist.probs[0].cpu().data.numpy()
recon_mu = p['likelihood'].dist.loc[0].cpu().data.numpy()
plot_final_samples(recon_mu, mu[0].cpu().data.numpy(), E_state, K, PATH)

In [None]:
plot_final_samples(ob[0].cpu().data.numpy(), mu[0].cpu().data.numpy(), E_state, K, PATH)