In [1]:
import sys

import gym
import safety_gym

import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import numpy as np
from garage.envs import GymEnv

sys.path.append('../robust_rewards_from_preferences')
import envs.custom_safety_envs
del sys.path[-1]

from replay_buffer import ReplayBuffer
from garage.trainer import Trainer
from dreamer import Dreamer
from utils import RandomPolicy
from garage.experiment.experiment import ExperimentContext
from garage.sampler import RaySampler, LocalSampler, DefaultWorker
import gym.envs.atari
from garage.sampler.worker_factory import WorkerFactory
import threading

import torch

from ruamel.yaml import YAML
from dotmap import DotMap

In [2]:
yaml = YAML()
with open('./config.yaml', 'r') as f:
    CONFIG = DotMap(yaml.load(f))

In [3]:
from garage import EpisodeBatch, TimeStepBatch

In [4]:
from garage.envs.wrappers import ClipReward, EpisodicLife,  FireReset, Grayscale,  MaxAndSkip, Noop,  Resize, StackFrames

In [5]:
env = gym.envs.atari.AtariEnv('breakout', obs_type='image', frameskip=1, repeat_action_probability=0.25, full_action_space=False)
env = Noop(env, noop_max=30)
env = MaxAndSkip(env, skip=4)
# env = EpisodicLife(env)
if CONFIG.image.color_channels == 1:
    env = Grayscale(env)
env = Resize(env, CONFIG.image.height, CONFIG.image.height)
max_episode_length = 108000 / 4
env = GymEnv(env, max_episode_length=max_episode_length, is_image=True)

In [6]:
buf = ReplayBuffer(env.spec)
random_policy = RandomPolicy(env.spec)
sampler = RaySampler(agents=random_policy, envs=env, max_episode_length=env.spec.max_episode_length, n_workers=2)
dreamer = Dreamer(env.spec, sampler=sampler, rssm_model=None, actor=None, critic=None, buf=buf)
ctxt = ExperimentContext(snapshot_dir='./snapshot_dir', snapshot_mode='gap_overwrite', snapshot_gap=50)
trainer = Trainer(ctxt)
trainer.setup(algo=dreamer, env=env)

2021-06-06 18:34:16,105	INFO services.py:1267 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


In [84]:
dreamer._initialize_dataset(trainer)



In [8]:
# import torch
# import torch.nn.functional as F
# from torch import nn
# from torch import distributions
# from torch.distributions import kl_divergence, Independent
# import torch.nn.functional as F

# def categorical_kl(probs_a, probs_b):
#     return torch.sum(probs_a * torch.log(probs_a / probs_b), dim=[-1, -2])

# def kl_loss(posterior, prior):
#     lhs = categorical_kl(posterior.probs.detach(), prior.probs)
#     rhs = categorical_kl(posterior.probs, prior.probs.detach())
#     kl_loss = CONFIG.rssm.alpha * lhs + (1 - CONFIG.rssm.alpha) * rhs
   
#     assert torch.isclose(lhs, rhs).all()
    
#     expected = kl_divergence(
#         Independent(posterior, 1), 
#         Independent(prior, 1),)
    
#     assert torch.isclose(lhs, expected, atol=1e-5).all(), lhs - expected
    
#     return kl_loss

# class WorldModel(torch.nn.Module):
    
#     def __init__(self, env_spec):
#         super().__init__()
#         self.env_spec = env_spec
        
#         self.rssm = RSSM(action_size=self.env_spec.action_space.n)
#         self.image_encoder = ImageEncoder()
#         self.image_decoder = ImageDecoder()
        
#         feat_size = (
#             CONFIG.rssm.stoch_state_classes * CONFIG.rssm.stoch_state_size
#             + CONFIG.rssm.det_state_size
#         )
        
#         self.reward_predictor = MLP(
#             input_shape=feat_size,
#             units=CONFIG.reward_head.units)
        
#         self.discount_predictor = MLP(
#             input_shape=feat_size,
#             units=CONFIG.discount_head.units)
        
#     def forward(self, observations, actions):
#         segs, batch, channels, height, width = observations.shape
#         flattened_observations = observations.reshape(segs*batch, channels, height, width)
#         embedded_observations = encoder(flattened_observations).reshape(segs, batch, -1)
#         out = self.rssm.observe(embedded_observations, actions)
#         out['reward_preds'] = self.reward_predictor(out['feats'])
#         out['discount_preds'] = self.discount_predictor(out['feats'])
#         return out
    
#     def reward_pred(self, logits):
#         dist = torch.distributions.Normal(loc=logits, scale=1)
    
#     def discount_pred(self, logits):
#         dist = torch.distributions.Bernoulli(logits=logits)
        
# class Actor(object):
#     pass

# class RSSM(torch.nn.Module):

#     def __init__(self,
#                  action_size,):
    
#         super().__init__()
#         self.action_size = action_size
        
#         self.embed_size = CONFIG.rssm.embed_size
#         self.stoch_state_classes = CONFIG.rssm.stoch_state_classes
#         self.stoch_state_size = CONFIG.rssm.stoch_state_size
#         self.det_state_size = CONFIG.rssm.det_state_size
#         self.act = eval(CONFIG.rssm.act)

#         self.register_parameter(
#             name='cell_initial_state',
#             param=nn.Parameter(torch.zeros(self.det_state_size))
#         )

#         self.cell = nn.GRUCell(input_size=self.det_state_size,
#                                hidden_size=self.det_state_size)

#         self._initialize_imagination_layers()
#         self._initialize_observation_layers()
        
#     def _initialize_imagination_layers(self):
#         self.embed_stoch_state_and_action = nn.Linear(
#             self.action_size + self.stoch_state_size * self.stoch_state_classes,
#             self.det_state_size)
        
#         self.imagine_out_1 = nn.Linear(self.det_state_size, self.det_state_size)
#         self.imagine_out_2 = nn.Linear(
#             self.det_state_size,
#             self.stoch_state_classes*self.stoch_state_size)
    
#     def _initialize_observation_layers(self):
#         self.observe_out_1 = nn.Linear(
#             self.det_state_size + self.embed_size,
#             self.det_state_size)
#         self.observe_out_2 = nn.Linear(
#             self.det_state_size,
#             self.stoch_state_classes*self.stoch_state_size)
    
#     def initial_state(self, batch_size):
#         state = {
#             'logits': torch.zeros(batch_size,
#                                   self.stoch_state_size,
#                                   self.stoch_state_classes),
#             'stoch': torch.zeros(batch_size,
#                                  self.stoch_state_size,
#                                   self.stoch_state_classes),
#             'deter': self.cell_initial_state.repeat([batch_size, 1])
#         }
#         return state
    
#     def step(self, prev_stoch, prev_deter, prev_action):
#         x = torch.cat((prev_stoch.flatten(start_dim=1), prev_action), dim=-1)
#         x = self.act(self.embed_stoch_state_and_action(x))
#         deter = self.cell(x, prev_deter)
#         return deter
    
#     def get_stoch(self, x):
#         logits = x.reshape(
#             *x.shape[:-1],
#             self.stoch_state_size,
#             self.stoch_state_classes)
#         dist = distributions.Categorical(logits=logits)
#         sample = F.one_hot(dist.sample(), num_classes=self.stoch_state_classes).type(torch.float)
#         sample += dist.probs - dist.probs.detach()  # Straight through gradients trick
#         return sample, dist
        
#     def imagine_step(self, prev_stoch, prev_deter, prev_action):
#         deter = self.step(prev_stoch, prev_deter, prev_action)
#         x = self.act(self.imagine_out_1(deter))
#         x = self.imagine_out_2(x)  
#         sample, dist = self.get_stoch(x)
#         prior = {'sample': sample, 'dist': dist}
#         return prior, deter
    
#     def observe_step(self, prev_stoch, prev_deter, prev_action, embed):
#         prior, deter = self.imagine_step(prev_stoch, prev_deter, prev_action)
#         x = torch.cat([deter, embed], dim=-1)
#         x = self.act(self.observe_out_1(x))
#         x = self.observe_out_2(x)
#         sample, dist = self.get_stoch(x)
#         posterior = {'sample': sample, 'dist': dist}
#         return posterior, prior, deter
    
#     def imagine(self):
#         pass
    
#     def observe(self, embedded_observations, actions):
#         segs, steps, embedding_size = embedded_observations.shape
#         assert segs == actions.shape[0]
#         assert steps == actions.shape[1]
        
#         # Change from SEGS x STEPS x N -> STEPS x SEGS x N
#         # This facilitates 
#         embedded_observations = torch.swapaxes(embedded_observations, 0, 1)
#         actions = torch.swapaxes(actions, 0, 1)
        
#         initial = self.initial_state(batch_size=segs)
#         stoch, deter = initial['stoch'], initial['deter']
        
#         posteriors = []
#         priors = []
#         deters = []
#         feats = []
#         kl_losses = []
        
#         for embed, action in zip(embedded_observations, actions):
#             posterior, prior, deter = self.observe_step(stoch, deter, action, embed)
#             stoch = posterior['sample']
            
#             posteriors.append(posterior)
#             priors.append(prior)
#             deters.append(deter)
#             feats.append(torch.cat([stoch.flatten(start_dim=1), deter], dim=-1))
#             kl_losses.append(kl_loss(posterior['dist'], prior['dist']))

#         out = {
#             'posteriors': posteriors,
#             'priors': priors,
#             'deters': torch.swapaxes(torch.stack(deters), 0, 1),
#             'feats': torch.swapaxes(torch.stack(feats), 0, 1),
#             'kl_losses': torch.swapaxes(torch.stack(kl_losses), 0, 1)
#         }
         
#         return out
    
# class ImageEncoder(torch.nn.Module):
#     def __init__(self):
#         super().__init__()
        
#         Activation = eval(CONFIG.image_encoder.Activation)
#         self.N = N = CONFIG.image_encoder.N
#         self.color_channels = CONFIG.image.color_channels
        
#         self.model = nn.Sequential(
#             nn.Conv2d(self.color_channels, N*1, 4, 2),
#             Activation(),
#             nn.Conv2d(32, N*2, 4, 2),
#             Activation(),
#             nn.Conv2d(64, N*4, 4, 2),
#             Activation(),
#             nn.Conv2d(128, N*8, 4, 2),
#             Activation(),
#         )

#     def forward(self, img):
#         x = self.model(img)
#         return torch.flatten(x, start_dim=1)

# class ImageDecoder(torch.nn.Module):

#     def __init__(self):
#         super().__init__()
        
#         Activation = eval(CONFIG.image_decoder.Activation)
#         self.N = N = CONFIG.image_decoder.N
#         self.shape = [
#             CONFIG.image.height,
#             CONFIG.image.width,
#             CONFIG.image.color_channels
#         ]
#         feat_shape = (
#             CONFIG.rssm.stoch_state_classes * CONFIG.rssm.stoch_state_size
#             + CONFIG.rssm.det_state_size)
        
#         self.dense = nn.Linear(feat_shape, N*32)
#         self.deconvolve = nn.Sequential(
#             nn.ConvTranspose2d(N*32, N*4, 5, 2),
#             Activation(),
#             nn.ConvTranspose2d(N*4, N*2, 5, 2),
#             Activation(),
#             nn.ConvTranspose2d(N*2, N, 6, 2),
#             Activation(),
#             nn.ConvTranspose2d(N, self.shape[-1], 6, 2),
#         )

#     def forward(self, embed):
#         batch_shape = embed.shape[0]
#         x = self.dense(embed).reshape(-1, self.N*32, 1, 1)
#         x = self.deconvolve(x)
#         norm = distributions.Normal(loc=x, scale=1)
#         dist = distributions.Independent(norm, len(self.shape))
#         assert len(dist.batch_shape) == 1
#         assert dist.batch_shape[0] == batch_shape
#         return dist

# class MLP(torch.nn.Module):
    
#     def __init__(self, input_shape, units, Activation=torch.nn.ELU):
#         super().__init__()
        
#         self.net = nn.Sequential()
        
#         for i, unit in enumerate(units):
#             self.net.add_module(f"linear_{i}", nn.Linear(input_shape, unit))
#             self.net.add_module(f"activation_{i}", Activation())
#             input_shape = unit
    
#         self.net.add_module("out_layer", nn.Linear(input_shape, 1))
            
#     def forward(self, features):
#         return self.net(features)


In [64]:
import torch
import torch.nn.functional as F
from torch import nn
from torch import distributions
from torch.distributions import kl_divergence, Independent
import torch.nn.functional as F

from ruamel.yaml import YAML
from dotmap import DotMap
yaml = YAML()
with open('./config.yaml', 'r') as f:
    CONFIG = DotMap(yaml.load(f))


def categorical_kl(probs_a, probs_b):
    return torch.sum(probs_a * torch.log(probs_a / probs_b), dim=[-1, -2])


def kl_loss(posterior, prior):
    lhs = categorical_kl(posterior.probs.detach(), prior.probs)
    rhs = categorical_kl(posterior.probs, prior.probs.detach())
    kl_loss = CONFIG.rssm.alpha * lhs + (1 - CONFIG.rssm.alpha) * rhs

    assert torch.isclose(lhs, rhs).all()

    expected = kl_divergence(
        Independent(posterior, 1),
        Independent(prior, 1),)

    assert torch.isclose(lhs, expected, atol=1e-5).all(), lhs - expected

    return kl_loss


class WorldModel(torch.nn.Module):

    def __init__(self, env_spec):
        super().__init__()
        self.env_spec = env_spec

        self.rssm = RSSM(action_size=self.env_spec.action_space.n)
        self.image_encoder = ImageEncoder()
        self.image_decoder = ImageDecoder()

        self.feat_size = (
            CONFIG.rssm.stoch_state_classes * CONFIG.rssm.stoch_state_size
            + CONFIG.rssm.det_state_size
        )

        self.reward_predictor = MLP(
            input_shape=self.feat_size,
            units=CONFIG.reward_head.units,
            dist='mse')

        self.discount_predictor = MLP(
            input_shape=self.feat_size,
            units=CONFIG.discount_head.units,
            dist='bernoulli')

    def forward(self, observations, actions):
        segs, steps, channels, height, width = observations.shape
        flattened_observations = observations.reshape(
            segs*steps, channels, height, width)
        embedded_observations = self.image_encoder(
            flattened_observations).reshape(segs, steps, -1)
        out = self.rssm.observe(embedded_observations, actions)
        out['reward_dist'] = self.reward_predictor(out['feats'])
        out['discount_dist'] = self.discount_predictor(out['feats'])
        flattened_feats = out['feats'].reshape(segs*steps, self.feat_size)
        mean = self.image_decoder(flattened_feats).reshape(
            segs, steps, channels, height, width)
        norm = distributions.Normal(loc=mean, scale=1)
        image_recon_dist = distributions.Independent(norm, 3)
        assert image_recon_dist.batch_shape == (segs, steps)
        out['image_recon_dist'] = image_recon_dist
        return out

    def loss(self, out, observation_batch, reward_batch, discount_batch):
        kl_loss = out['kl_losses'].mean()
        reward_loss = -out['reward_dist'].log_prob(reward_batch).mean()
        discount_loss = -out['discount_dist'].log_prob(discount_batch).mean()
        recon_loss = -out['image_recon_dist'].log_prob(observation_batch).mean()
        loss = reward_loss + discount_loss + recon_loss + CONFIG.rssm.beta * kl_loss
        return loss


class Actor(object):
    pass


class RSSM(torch.nn.Module):

    def __init__(self,
                 action_size,):
        super().__init__()
        self.action_size = action_size

        self.embed_size = CONFIG.rssm.embed_size
        self.stoch_state_classes = CONFIG.rssm.stoch_state_classes
        self.stoch_state_size = CONFIG.rssm.stoch_state_size
        self.det_state_size = CONFIG.rssm.det_state_size
        self.act = eval(CONFIG.rssm.act)

        self.register_parameter(
            name='cell_initial_state',
            param=nn.Parameter(torch.zeros(self.det_state_size))
        )

        self.cell = nn.GRUCell(input_size=self.det_state_size,
                               hidden_size=self.det_state_size)

        self._initialize_imagination_layers()
        self._initialize_observation_layers()

    def _initialize_imagination_layers(self):
        self.embed_stoch_state_and_action = nn.Linear(
            self.action_size + self.stoch_state_size * self.stoch_state_classes,
            self.det_state_size)

        self.imagine_out_1 = nn.Linear(self.det_state_size, self.det_state_size)
        self.imagine_out_2 = nn.Linear(
            self.det_state_size,
            self.stoch_state_classes*self.stoch_state_size)

    def _initialize_observation_layers(self):
        self.observe_out_1 = nn.Linear(
            self.det_state_size + self.embed_size,
            self.det_state_size)
        self.observe_out_2 = nn.Linear(
            self.det_state_size,
            self.stoch_state_classes*self.stoch_state_size)

    def initial_state(self, batch_size):
        state = {
            'logits': torch.zeros(batch_size,
                                  self.stoch_state_size,
                                  self.stoch_state_classes),
            'stoch': torch.zeros(batch_size,
                                 self.stoch_state_size,
                                 self.stoch_state_classes),
            'deter': self.cell_initial_state.repeat([batch_size, 1])
        }
        return state

    def step(self, prev_stoch, prev_deter, prev_action):
        x = torch.cat((prev_stoch.flatten(start_dim=1), prev_action), dim=-1)
        x = self.act(self.embed_stoch_state_and_action(x))
        deter = self.cell(x, prev_deter)
        return deter

    def get_stoch(self, x):
        logits = x.reshape(
            *x.shape[:-1],
            self.stoch_state_size,
            self.stoch_state_classes)
        dist = distributions.Categorical(logits=logits)
        sample = F.one_hot(dist.sample(), num_classes=self.stoch_state_classes).type(torch.float)
        sample += dist.probs - dist.probs.detach()  # Straight through gradients trick
        return sample, dist

    def imagine_step(self, prev_stoch, prev_deter, prev_action):
        deter = self.step(prev_stoch, prev_deter, prev_action)
        x = self.act(self.imagine_out_1(deter))
        x = self.imagine_out_2(x)
        sample, dist = self.get_stoch(x)
        prior = {'sample': sample, 'dist': dist}
        return prior, deter

    def observe_step(self, prev_stoch, prev_deter, prev_action, embed):
        prior, deter = self.imagine_step(prev_stoch, prev_deter, prev_action)
        x = torch.cat([deter, embed], dim=-1)
        x = self.act(self.observe_out_1(x))
        x = self.observe_out_2(x)
        sample, dist = self.get_stoch(x)
        posterior = {'sample': sample, 'dist': dist}
        return posterior, prior, deter

    def imagine(self):
        pass

    def observe(self, embedded_observations, actions):
        segs, steps, embedding_size = embedded_observations.shape
        assert segs == actions.shape[0]
        assert steps == actions.shape[1]

        # Change from SEGS x STEPS x N -> STEPS x SEGS x N
        # This facilitates 
        embedded_observations = torch.swapaxes(embedded_observations, 0, 1)
        actions = torch.swapaxes(actions, 0, 1)

        initial = self.initial_state(batch_size=segs)
        stoch, deter = initial['stoch'], initial['deter']

        posteriors = []
        priors = []
        deters = []
        feats = []
        kl_losses = []

        for embed, action in zip(embedded_observations, actions):
            posterior, prior, deter = self.observe_step(stoch, deter, action, embed)
            stoch = posterior['sample']

            posteriors.append(posterior)
            priors.append(prior)
            deters.append(deter)
            feats.append(torch.cat([stoch.flatten(start_dim=1), deter], dim=-1))
            kl_losses.append(kl_loss(posterior['dist'], prior['dist']))

        out = {
            'posteriors': posteriors,
            'priors': priors,
            'deters': torch.swapaxes(torch.stack(deters), 0, 1),
            'feats': torch.swapaxes(torch.stack(feats), 0, 1),
            'kl_losses': torch.swapaxes(torch.stack(kl_losses), 0, 1)
        }

        return out


class ImageEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()

        Activation = eval(CONFIG.image_encoder.Activation)
        self.N = N = CONFIG.image_encoder.N

        self.model = nn.Sequential(
            nn.Conv2d(1, N*1, 4, 2),
            Activation(),
            nn.Conv2d(32, N*2, 4, 2),
            Activation(),
            nn.Conv2d(64, N*4, 4, 2),
            Activation(),
            nn.Conv2d(128, N*8, 4, 2),
            Activation(),
        )

    def forward(self, img):
        x = self.model(img)
        return torch.flatten(x, start_dim=1)

class ImageDecoder(torch.nn.Module):

    def __init__(self):
        super().__init__()

        Activation = eval(CONFIG.image_decoder.Activation)
        self.N = N = CONFIG.image_decoder.N
        self.shape = [
            CONFIG.image.height,
            CONFIG.image.width,
            CONFIG.image.color_channels
        ]
        feat_shape = (
            CONFIG.rssm.stoch_state_classes * CONFIG.rssm.stoch_state_size
            + CONFIG.rssm.det_state_size)

        self.dense = nn.Linear(feat_shape, N*32)
        self.deconvolve = nn.Sequential(
            nn.ConvTranspose2d(N*32, N*4, 5, 2),
            Activation(),
            nn.ConvTranspose2d(N*4, N*2, 5, 2),
            Activation(),
            nn.ConvTranspose2d(N*2, N, 6, 2),
            Activation(),
            nn.ConvTranspose2d(N, self.shape[-1], 6, 2),
            nn.Sigmoid(),  # TODO: Check this
        )

    def forward(self, embed):
        batch_shape = embed.shape[0]
        x = self.dense(embed).reshape(-1, self.N*32, 1, 1)
        x = self.deconvolve(x)
        return x
        assert len(dist.batch_shape) == 1
        assert dist.batch_shape[0] == batch_shape
        return dist


class MLP(torch.nn.Module):

    def __init__(self, input_shape, units, dist='mse', Activation=torch.nn.ELU):
        super().__init__()
        self.dist = dist

        self.net = nn.Sequential()
        for i, unit in enumerate(units):
            self.net.add_module(f"linear_{i}", nn.Linear(input_shape, unit))
            self.net.add_module(f"activation_{i}", Activation())
            input_shape = unit
        self.net.add_module("out_layer", nn.Linear(input_shape, 1))

    def forward(self, features):
        logits = self.net(features).squeeze()
        if self.dist == 'mse':
            return torch.distributions.Normal(loc=logits, scale=1)
        elif self.dist == 'bernoulli':
            return torch.distributions.Bernoulli(logits=logits)



In [81]:
env.spec.action_space.sample()

3

In [65]:

rssm = RSSM(action_size=1)
encoder = ImageEncoder()
segs = dreamer.buffer.sample_segments(5)



In [82]:
observation_batch = torch.tensor([seg.next_observations for seg in segs]).type(torch.float)
action_batch = torch.tensor([env.spec.action_space.flatten_n(seg.actions) for seg in segs]).type(torch.float)
reward_batch = torch.tensor([seg.rewards for seg in segs]).type(torch.float)
discount_batch = 1 - torch.tensor([seg.terminals for seg in segs]).type(torch.float)

observation_batch = observation_batch.unsqueeze(2)
observation_batch = observation_batch / 255 - 0.5
# segs, batch, channels, height, width = observation_batch.shape
# flattened_observation_batch = observation_batch.reshape(segs*batch, channels, height, width)
# embedded_observations = encoder(flattened_observation_batch).reshape(segs, batch, -1)
# embedded_observations.shape

  discount_batch = 1 - torch.tensor([seg.terminals for seg in segs]).type(torch.float)


In [83]:
observation_batch.shape

torch.Size([5, 30, 1, 64, 64])

In [75]:
world_model = WorldModel(env.spec)

from torch import optim

optimizer = optim.Adam(world_model.parameters(), lr=0.0002)

In [85]:
optimizer.zero_grad()
out = world_model(observation_batch, action_batch)
loss = world_model.loss(out, observation_batch, reward_batch, discount_batch)
loss.backward()
optimizer.step()

In [86]:
dist = torch.distributions.Normal(loc=torch.tensor(1.).to('cuda'), scale=1)

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [None]:
decoder = ImageDecoder()

In [46]:
out['feats'].shape

NameError: name 'out' is not defined

In [94]:
flattened_feat_batch = out['feats'].reshape(segs*batch, -1)
dist = decoder(flattened_feat_batch)

In [95]:
torch.reshape(dist, (3, 50, 1, 64, 64))

TypeError: reshape(): argument 'input' (position 1) must be Tensor, not Independent

In [65]:
observation_batch.shape

torch.Size([5, 30, 1, 64, 64])

In [69]:
dist.batch_shape

torch.Size([150])

In [141]:
dist = torch.distributions.Normal(loc=out['reward_preds'].squeeze(), scale=1)

In [155]:
dist = torch.distributions.Bernoulli(logits=out['discount_preds'].squeeze())

In [159]:
out['kl_losses'].shape

torch.Size([5, 30])

In [72]:
out['reward_preds'].shape

torch.Size([5, 30, 1])

In [71]:
dist

Independent(Normal(loc: torch.Size([150, 1, 64, 64]), scale: torch.Size([150, 1, 64, 64])), 3)

In [150]:
env.reset()

(array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
 {})

In [151]:
env.step(env.action_space.sample())

EnvStep(env_spec=EnvSpec(input_space=Discrete(4), output_space=Box(64, 64), max_episode_length=27000.0), action=1, reward=0.0, observation=array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), env_info={'ale.lives': 5, 'TimeLimit.truncated': False, 'GymEnv.TimeLimitTerminated': False}, step_type=<StepType.FIRST: 0>)

In [149]:
discount_batch

tensor([[0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900,
         0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900,
         0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900,
         0.9900, 0.9900, 0.0000],
        [0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900,
         0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900,
         0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900,
         0.9900, 0.9900, 0.9900],
        [0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900,
         0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900,
         0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900,
         0.9900, 0.9900, 0.9900],
        [0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900,
         0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900, 0.9900,
       

In [142]:
dist.log_prob(reward_batch)

tensor([[-0.9189, -0.9190, -0.9190, -0.9189, -0.9190, -0.9190, -0.9190, -0.9189,
         -0.9190, -0.9190, -0.9189, -0.9193, -0.9190, -0.9190, -0.9189, -0.9189,
         -0.9189, -0.9190, -0.9190, -0.9195, -0.9189, -0.9193, -0.9190, -0.9189,
         -0.9189, -0.9190, -0.9189, -0.9190, -0.9190, -0.9190],
        [-0.9189, -0.9190, -0.9190, -0.9190, -0.9189, -0.9190, -0.9189, -0.9190,
         -0.9189, -0.9193, -0.9189, -0.9191, -0.9190, -0.9189, -0.9190, -0.9190,
         -0.9189, -0.9190, -0.9189, -0.9189, -0.9190, -0.9189, -0.9190, -0.9189,
         -0.9190, -1.4340, -0.9189, -0.9191, -0.9190, -0.9189],
        [-0.9190, -0.9190, -0.9189, -0.9190, -0.9192, -0.9190, -0.9189, -0.9189,
         -0.9190, -0.9190, -0.9190, -0.9190, -0.9194, -0.9189, -0.9189, -0.9189,
         -0.9190, -0.9190, -0.9190, -0.9190, -0.9189, -0.9191, -0.9190, -0.9191,
         -0.9190, -0.9189, -0.9190, -0.9189, -0.9189, -0.9190],
        [-0.9189, -0.9190, -0.9190, -0.9189, -0.9190, -0.9189, -0.9190, -0.9189

In [138]:
reward_batch.shape

torch.Size([5, 30])