Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF2 fixes. #130

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions dopamine/discrete_domains/atari_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,14 @@ class has two main functions: `.__init__` and `.call`. When we create our
import tensorflow.compat.v1 as tf

import cv2
from tensorflow.contrib import layers as contrib_layers
from tensorflow.contrib import slim as contrib_slim
from tensorflow.compat.v1 import layers as contrib_layers

# Allow failure on this import (not in tf2). This means atari won't be
# available but other domains will.
try:
from tensorflow.contrib import slim as contrib_slim
except:
pass


NATURE_DQN_OBSERVATION_SHAPE = (84, 84) # Size of downscaled Atari 2600 frame.
Expand Down
4 changes: 2 additions & 2 deletions dopamine/replay_memory/circular_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import tensorflow.compat.v1 as tf

import gin.tf
from tensorflow.contrib import staging as contrib_staging
from tensorflow.python.ops import data_flow_ops

# Defines a type describing part of the tuple returned by the replay
# memory. Each element of the tuple is a tensor of shape [batch, ...] where
Expand Down Expand Up @@ -855,7 +855,7 @@ def _set_up_staging(self, transition):
transition_type = self.memory.get_transition_elements()

# Create the staging area in CPU.
prefetch_area = contrib_staging.StagingArea(
prefetch_area = data_flow_ops.StagingArea(
[shape_with_type.type for shape_with_type in transition_type])

# Store prefetch op for tests, but keep it private -- users should not be
Expand Down