In [1]:
from dcem import dcem
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
B = 10
init_mu = torch.ones(B, 1)
init_sigma = torch.ones(B, 1) * 10
n_sample = 20
n_elite = 10
n_iter = 10
device='cpu'

def f(z):
    energies = torch.zeros(B, n_sample)
    for b in range(B):
        energies[b] = torch.sin((z[b, :, 0]-b) - 3.14159/2) + (z[b, :, 0]-b)**2
    return energies
    
min_z = dcem(f=f, nx=1, n_batch=B, init_mu=init_mu, init_sigma=init_sigma, n_sample=n_sample, n_elite=n_elite, n_iter=n_iter, device=device)
min_z

tensor([[0.2956],
        [1.0160],
        [1.2554],
        [2.9635],
        [3.9341],
        [5.1605],
        [5.9382],
        [6.9756],
        [7.9605],
        [7.8976]])

In [12]:
B = 4
n_context = 2
d_x = 3
n_sample = 2
n_elite = 10
n_iter = 10
z_dim = 1

batch_X = torch.randn(B, n_context, d_x)
batch_Y = torch.randn(B, n_context, d_x)

model = nn.Linear(d_x + z_dim, d_x)

init_mu = torch.ones(B, 1)
init_sigma = torch.ones(B, 1) * 10
energies_all = None
batch_idx = 0
sample_idx = 0

def f(z):
    global energies_all
    x_rep = batch_X.repeat(1, n_sample, 1)
    y_rep = batch_Y.repeat(1, n_sample, 1)
#     z_rep = z.repeat(1, n_context, 1)
    z_rep = z.tile((1, n_context)).view(B, n_sample * n_context, -1)
    x = torch.cat([x_rep, z_rep], dim=-1)
    print('x', x[batch_idx])
    s_hat = model(x)
    energies_all = F.mse_loss(y_rep, s_hat, reduction='none')
    print('all', energies_all.shape)
    print(energies_all)
    energies_red = energies_all.sum(dim=-1)
    print('red', energies_red.shape)
    print(energies_red)
    eview = energies_red.view(B, n_sample, n_context)
    print('eview', eview.shape)
    print(eview)
    energies = eview.sum(dim=-1)
    return energies

# min_z = dcem(f=f, nx=1, n_batch=B, init_mu=init_mu, init_sigma=init_sigma, n_sample=n_sample, n_elite=n_elite, n_iter=n_iter, device=device)
zs = torch.randn(B, n_sample, 1)
energies = f(zs)

z = zs[batch_idx, sample_idx]
energy = 0.
for i in range(n_context):
    x = torch.cat([batch_X[batch_idx, i], z])
    print('xs', x)
    s_hat = model(x)
    e = F.mse_loss(batch_Y[batch_idx, i], s_hat, reduction='sum')
    print('e', e)
    energy += e
print('energy batch', energies[batch_idx, sample_idx])
print('energy single', energy)

x tensor([[ 1.3250, -0.5254,  0.9028,  0.6246],
        [-1.2589,  0.1180, -1.4028,  0.6246],
        [ 1.3250, -0.5254,  0.9028,  0.7514],
        [-1.2589,  0.1180, -1.4028,  0.7514]])
all torch.Size([4, 4, 3])
tensor([[[2.0323e+00, 6.2391e-03, 1.0195e-01],
         [5.5355e+00, 1.2224e+00, 1.2101e+00],
         [1.9030e+00, 2.4212e-03, 1.1597e-01],
         [5.7546e+00, 1.1575e+00, 1.2573e+00]],

        [[1.1015e+00, 3.4208e-02, 3.4445e+00],
         [1.2385e-03, 1.8592e-01, 1.9519e+00],
         [5.3571e-01, 1.5222e-01, 2.9224e+00],
         [1.2448e-01, 4.0498e-01, 1.5641e+00]],

        [[1.0964e-01, 2.7969e+00, 7.9195e-01],
         [4.5584e+00, 2.0869e-01, 2.3013e-01],
         [2.8613e-02, 2.4579e+00, 6.6463e-01],
         [3.8930e+00, 1.2404e-01, 1.6406e-01]],

        [[1.8898e-02, 5.5376e+00, 1.3319e+00],
         [3.7848e-01, 2.2548e+00, 2.8547e-01],
         [3.8144e-02, 5.7149e+00, 1.2711e+00],
         [4.5299e-01, 2.1440e+00, 3.1467e-01]]], grad_fn=<MseLossBackward>)


In [None]:
print(batch_X.shape)

In [None]:
energies_all

In [None]:
energies

In [11]:
z = torch.randn(B, n_sample, 1)
print(z.repeat(1, n_context, 1).shape)
print(z.repeat(1, n_context, 1))
z_rep = z.tile((1, n_context)).view(-1, n_sample * n_context, z_dim)
print(z_rep.shape)
print(z_rep)

torch.Size([4, 4, 1])
tensor([[[ 1.8364],
         [ 0.5219],
         [ 1.8364],
         [ 0.5219]],

        [[-0.5998],
         [-0.3751],
         [-0.5998],
         [-0.3751]],

        [[-0.8130],
         [ 1.1510],
         [-0.8130],
         [ 1.1510]],

        [[ 1.1416],
         [-0.1184],
         [ 1.1416],
         [-0.1184]]])
torch.Size([4, 4, 1])
tensor([[[ 1.8364],
         [ 1.8364],
         [ 0.5219],
         [ 0.5219]],

        [[-0.5998],
         [-0.5998],
         [-0.3751],
         [-0.3751]],

        [[-0.8130],
         [-0.8130],
         [ 1.1510],
         [ 1.1510]],

        [[ 1.1416],
         [ 1.1416],
         [-0.1184],
         [-0.1184]]])
