Skip to content

Commit

Permalink
Import gym as optional package to build docs successfully (Lightning-…
Browse files Browse the repository at this point in the history
…Universe#458)

* Import gym as optional package

* Fix import

* Apply isort
  • Loading branch information
akihironitta authored and chris-clem committed Dec 17, 2020
1 parent add8bdf commit f058a3f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
31 changes: 19 additions & 12 deletions pl_bolts/models/rl/common/gym_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,28 @@
"""
import collections

import gym
import gym.spaces
import numpy as np
import torch

from pl_bolts.utils import _OPENCV_AVAILABLE
from pl_bolts.utils import _GYM_AVAILABLE, _OPENCV_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _GYM_AVAILABLE:
import gym.spaces
from gym import ObservationWrapper, Wrapper
from gym import make as gym_make
else: # pragma: no-cover
warn_missing_pkg('gym')
Wrapper = object
ObservationWrapper = object

if _OPENCV_AVAILABLE:
import cv2
else:
warn_missing_pkg('cv2', pypi_name='opencv-python') # pragma: no-cover


class ToTensor(gym.Wrapper):
class ToTensor(Wrapper):
"""For environments where the user need to press FIRE for the game to start."""

def __init__(self, env=None):
Expand All @@ -34,7 +41,7 @@ def reset(self):
return torch.tensor(self.env.reset())


class FireResetEnv(gym.Wrapper):
class FireResetEnv(Wrapper):
"""For environments where the user need to press FIRE for the game to start."""

def __init__(self, env=None):
Expand All @@ -58,7 +65,7 @@ def reset(self):
return obs


class MaxAndSkipEnv(gym.Wrapper):
class MaxAndSkipEnv(Wrapper):
"""Return only every `skip`-th frame"""

def __init__(self, env=None, skip=4):
Expand Down Expand Up @@ -88,7 +95,7 @@ def reset(self):
return obs


class ProcessFrame84(gym.ObservationWrapper):
class ProcessFrame84(ObservationWrapper):
"""preprocessing images from env"""

def __init__(self, env=None):
Expand Down Expand Up @@ -121,7 +128,7 @@ def process(frame):
return x_t.astype(np.uint8)


class ImageToPyTorch(gym.ObservationWrapper):
class ImageToPyTorch(ObservationWrapper):
"""converts image to pytorch format"""

def __init__(self, env):
Expand All @@ -142,15 +149,15 @@ def observation(observation):
return np.moveaxis(observation, 2, 0)


class ScaledFloatFrame(gym.ObservationWrapper):
class ScaledFloatFrame(ObservationWrapper):
"""scales the pixels"""

@staticmethod
def observation(obs):
return np.array(obs).astype(np.float32) / 255.0


class BufferWrapper(gym.ObservationWrapper):
class BufferWrapper(ObservationWrapper):
""""Wrapper for image stacking"""

def __init__(self, env, n_steps, dtype=np.float32):
Expand All @@ -176,7 +183,7 @@ def observation(self, observation):
return self.buffer


class DataAugmentation(gym.ObservationWrapper):
class DataAugmentation(ObservationWrapper):
"""
Carries out basic data augmentation on the env observations
- ToTensor
Expand All @@ -197,7 +204,7 @@ def observation(self, obs):

def make_environment(env_name):
"""Convert environment with wrappers"""
env = gym.make(env_name)
env = gym_make(env_name)
env = MaxAndSkipEnv(env)
env = FireResetEnv(env)
env = ProcessFrame84(env)
Expand Down
6 changes: 4 additions & 2 deletions pl_bolts/models/rl/dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset
from pl_bolts.losses.rl import dqn_loss
from pl_bolts.models.rl.common.agents import ValueAgent
from pl_bolts.models.rl.common.gym_wrappers import make_environment
from pl_bolts.models.rl.common.memory import MultiStepBuffer
from pl_bolts.models.rl.common.networks import CNN
from pl_bolts.utils import _GYM_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _GYM_AVAILABLE:
from pl_bolts.models.rl.common.gym_wrappers import gym, make_environment
from gym import Env
else:
warn_missing_pkg('gym') # pragma: no-cover
Env = object


class DQN(pl.LightningModule):
Expand Down Expand Up @@ -336,7 +338,7 @@ def test_dataloader(self) -> DataLoader:
return self._dataloader()

@staticmethod
def make_environment(env_name: str, seed: Optional[int] = None) -> gym.Env:
def make_environment(env_name: str, seed: Optional[int] = None) -> Env:
"""
Initialise gym environment
Expand Down

0 comments on commit f058a3f

Please sign in to comment.