In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import gym
import custom_gym

In [2]:
env = gym.make('CustomPendulum-v0')

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [3]:
def get_episodedata(env, gamma):

    episodedata = []
    ob = env.reset()
    while True:
        ac = env.action_space.sample()
        nextob,rew,done,_ = env.step(ac)
        episodedata.append(np.hstack([ob,ac,rew,nextob]))
        ob = nextob.copy()
        if done:
            return None
        if gamma<np.random.rand():
            break
    return np.array(episodedata), len(ob), len(ac)

def preprocess_episodedata(episodedata, s_dim, a_dim):
    sa = episodedata[:,:(s_dim+a_dim)]
    ds = episodedata[:,:s_dim] - episodedata[:,-s_dim:]
    return np.hstack([sa, ds])


In [4]:
episodedata, s_dim, a_dim = get_episodedata(env, 0.99)
sads = preprocess_episodedata(episodedata, s_dim, a_dim)
sads = torch.from_numpy(sads.astype(np.float32)).clone()

In [11]:
class VI(torch.nn.Module):
    def __init__(self, s_dim, a_dim, z_dim):
        super(VI, self).__init__()
        self.s_dim = s_dim
        self.a_dim = a_dim
        self.z_dim = z_dim
        
        self.nu = 1e-2
        
#         self.x_dim = x_dim
#         self.z_dim = z_dim
#         self.y_dim = y_dim
#         self.q_mu = torch.nn.Parameter(torch.zeros(z_dim))
#         self.q_logvar = torch.nn.Parameter(torch.zeros(z_dim))
#         self.q_mu_test = torch.nn.Parameter(torch.zeros(z_dim))
#         self.q_logvar_test = torch.nn.Parameter(torch.zeros(z_dim))
        self.prior_mu = torch.nn.Parameter(torch.zeros(z_dim))
        self.prior_logvar = torch.nn.Parameter(torch.zeros(z_dim))
        self.likelihood_logvar = torch.nn.Parameter(torch.zeros(s_dim))
        self.transition_net = torch.nn.Sequential(
                                        torch.nn.Linear(s_dim+a_dim+z_dim, 64),
                                        torch.nn.ReLU(),
                                        torch.nn.Linear(64, 64),
                                        torch.nn.ReLU(),
                                        torch.nn.Linear(64, s_dim),
                                    )
        self.permutation_variant_net_part1 = torch.nn.Sequential(
                                        torch.nn.Linear(s_dim+a_dim+s_dim, 32),
                                        torch.nn.ReLU(),
                                        torch.nn.Linear(32, 32)
                                        )
        self.permutation_variant_net_part2 = torch.nn.Sequential(
                                        torch.nn.Linear(32, 32),
                                        torch.nn.ReLU(),
                                        torch.nn.Linear(32, 2*z_dim)
                                        )        
#         self.train_mode()
        
    def permutation_invariant_net(self, sads):
        h = self.permutation_variant_net_part1(sads)
        q_mu_q_logvar = self.permutation_variant_net_part2(h.mean(0))
        return q_mu_q_logvar[:self.z_dim], q_mu_q_logvar[self.z_dim:]
        
        
#     def train_mode(self):
#         self.q_mu.requires_grad=True
#         self.q_logvar.requires_grad=True
#         self.prior_mu.requires_grad=False
#         self.prior_logvar.requires_grad=True
#         self.likelihood_logvar.requires_grad=True
#         for param in self.lvr.parameters():
#             param.requires_grad = True
# #         self.q_mu_test.requires_grad=False
# #         self.q_logvar_test.requires_grad=False

# #     def test_mode(self):
# #         self.q_mu.requires_grad=False
# #         self.q_logvar.requires_grad=False
# #         self.prior_mu.requires_grad=False
# #         self.prior_logvar.requires_grad=False
# #         self.likelihood_logvar.requires_grad=False
# #         for param in self.lvr.parameters():
# #             param.requires_grad = False
# #         self.q_mu_test.requires_grad=True
# #         self.q_logvar_test.requires_grad=True            
    
    def gaussian_likelihood_loss(self, y, mu, logvar):
        return 0.5 * torch.sum(((y-mu)**2) * torch.exp(-logvar) + logvar)
        # 分散行列Varが対角成分var_iの対角行列の場合には、log(det|Var|) = log(prod_i var_i) = sum_i log(var_i) 
    
    def kld(self, mu1, logvar1, mu2, logvar2):
        # kld(p1|p2) = E_{z~p1}[ log p1(z) - log p2(z) ]
        tmp1 = 0.5 * (logvar2 - logvar1) # log (sigma2/sigma1)
        tmp2 = 0.5 * (torch.exp(logvar1)+(mu1-mu2)**2) / torch.exp(logvar2) # (sigma1^2+(mu1-mu2)^2)/(2*sigma2^2)
        return torch.mean(tmp1 + tmp2)

    def compute_loss(self, sads):
        q_mu, q_logvar = self.permutation_invariant_net(sads)

        # reparametrization trick
        eps = torch.randn_like(q_logvar)
        std = torch.exp(0.5 * q_logvar)
        z = (eps*std+q_mu) * torch.ones(sads.shape[0],self.z_dim)
        saz = torch.cat((sads[:,:(self.s_dim+self.a_dim)],z),dim=-1)
        ds_pred = self.transition_net(saz)
        ds = sads[:,-self.s_dim:]
        
        loss = 0
        loss += self.gaussian_likelihood_loss(ds, ds_pred, self.likelihood_logvar) # approx of E_{z~q}[ - log p(y|x,z) ]
        loss += self.nu * self.kld(q_mu, q_logvar, self.prior_mu, self.prior_logvar) # nu * E_{z~q}[ log q(z) - log p(z) ]
        return loss


In [12]:
net = VI(3,1,2)