In [24]:
import os
import torch

import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.distributions.multivariate_normal import MultivariateNormal

import gym
from pendulum.rl import load_from_path
from stitching.rl import PolicyCollection

In [2]:
collection = PolicyCollection()

policy_path = '../pendulum/trained_agents/pendulum_00/policy.pkl'
policy, env = load_from_path(policy_path)
base_path = '../pendulum/trained_agents/'
for folder in sorted(os.listdir(base_path)):
    path = os.path.join(base_path, folder, 'policy.pkl')
    collection.append(*load_from_path(path))

collection.reset()
dataset = []
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
for _ in range(10):
    for x in collection.step():
        dataset.append(x)

In [26]:
class QFunction(torch.nn.Module):
    
    def __init__(self, collection, z_size):
        super(QFunction, self).__init__()
        x_size = env.observation_space.shape[0]
        u_size = env.action_space.shape[0]
        self.fc1 = torch.nn.Linear(x_size + z_size, 400)
        self.fc2 = torch.nn.Linear(400 + u_size, 300)
        self.fc3 = torch.nn.Linear(300, 1)
    
    def forward(self, x, u, z):
        xz = torch.cat([x, z], dim=1)
        a1 = F.relu(self.fc1(xz))
        a1u = torch.cat([a1, u], dim=1)
        a2 = F.relu(self.fc2(a1u))
        y = self.fc3(a2)
        return y
    
class LatentWrapped(torch.nn.Module):
    
    def __init__(self, collection, q, z_size):
        super(LatentWrapped, self).__init__()
        self.z_size = z_size
        n = len(collection.envs)
        self.register_buffer('mean', torch.randn(n, z_size))
        self.register_buffer('std_logits', torch.randn(n, z_size))
        self.q = q
        
    def forward(self, i, x, u):
        eps = torch.randn(x.size(0), self.z_size)
        z = self.mean[i] + eps * torch.exp(self.std_logits[i])
        return self.q(x, u, z)
        
    
z_size = 4
q = QFunction(collection, z_size)
inds, o, u, r, o_, u_ = next(iter(dataloader))
z = torch.randn(o.size(0), z_size)
wrapped = LatentWrapped(collection, q, 4)
wrapped(inds, o, u)

ValueError: covariance_matrix must be at least two-dimensional, with optional leading batch dimensions

[
    0     0     0     0     0     0     0     0     0     1
    0     0     0     1     0     0     0     0     0     0
    0     0     0     0     0     1     0     0     0     0
    0     0     0     0     0     1     0     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     0     0     1     0
    1     0     0     0     0     0     0     0     0     0
    0     0     0     0     0     0     0     0     0     1
    0     0     0     0     0     1     0     0     0     0
    1     0     0     0     0     0     0     0     0     0
    0     0     0     0     0     0     1     0     0     0
    0     1     0     0     0     0     0     0     0     0
    0     0     0     0     0     0     0     0     1     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     0     0     0     1     0     0
    0     0     0     0     0     0     1     0     0     0
    0     0     0     0     1     0   

In [14]:
next(iter(dataloader))

[
     1     0     0     0     0     0     0     0     0     0
 [torch.FloatTensor of size 1x10], 
  2.5830  0.5264
 [torch.FloatTensor of size 1x2], 
 -1.8400
 [torch.FloatTensor of size 1x1], 
 -6.8538
 [torch.FloatTensor of size 1x1], 
  2.5830  0.5264
 [torch.FloatTensor of size 1x2], 
 1.00000e-02 *
   8.0000
 [torch.FloatTensor of size 1x1]]