Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

uv #11

Merged
merged 3 commits into from
Feb 21, 2024
Merged

uv #11

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 9 additions & 45 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,54 +10,18 @@ concurrency:
cancel-in-progress: true

env:
FORCE_COLOR: "1"
PYTHONUNBUFFERED: "1"
FORCE_COLOR: 3

jobs:
test:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout (GitHub)
uses: actions/checkout@v3

- name: pdm cache
uses: actions/cache@v3
with:
path: .cache/pdm
key: ${{ runner.os }}-pdm-${{ hashFiles('pdm.lock') }}
restore-keys: |
${{ runner.os }}-pdm-
- name: TFDS cache
uses: actions/cache@v3
with:
path: .tensorflow_datasets
key: ${{ runner.os }}-tfds-${{ hashFiles('pdm.lock') }}
restore-keys: |
${{ runner.os }}-tfds-

- name: Login to GitHub Container Registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Build and run Dev Container task
uses: devcontainers/ci@v0.3
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
# Change this to point to your image name
imageName: ghcr.io/ethanluoyc/corax
cacheFrom: ghcr.io/ethanluoyc/corax
# Change this to be your CI task/script
runCmd: |
# Add multiple commands to run if needed
export TFDS_DATA_DIR=$PWD/.tensorflow_datasets
mkdir -p $TFDS_DATA_DIR
pdm config cache_dir .cache/pdm
pdm sync -G:all
pdm lint
pdm test
pdm run python projects/baselines/baselines/iql/train_test.py
python-version: '3.10'
cache: 'pip' # caching pip dependencies
- name: Run nox
run: |
python -m pip install nox
nox -v
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.10.13
3 changes: 2 additions & 1 deletion corax/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep):
self._actor.observe(action, next_timestep)

def _has_data_for_training(self):
if self._iterator.ready(): # type: ignore
assert self._replay_tables is not None and self._iterator is not None
if self._iterator.ready():
return True
for table, batch_size in zip(
self._replay_tables,
Expand Down
6 changes: 3 additions & 3 deletions corax/agents/jax/decision_transformer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def add_return_to_go(episode):
return episode

def _pad_along_axis(x, padded_size, axis=0, value=0):
pad_width = padded_size - tf.shape(x)[axis]
pad_width = padded_size - tf.shape(x)[axis] # type: ignore
if pad_width <= 0:
return x
padding = [(0, 0)] * len(x.shape.as_list())
Expand All @@ -72,10 +72,10 @@ def pad_steps(steps, max_len):
padded_discounts = _pad_along_axis(steps["discount"], max_len, 0, 2)
padded_timesteps = _pad_along_axis(steps["timestep"], max_len, 0, 0)
mask = _pad_along_axis(
tf.ones(tf.shape(steps["reward"])[0], dtype=bool),
tf.ones(tf.shape(steps["reward"])[0], dtype=bool), # type: ignore
max_len,
0,
False, # type: ignore
False,
)
return {
"observation": padded_obs,
Expand Down
1 change: 1 addition & 0 deletions corax/datasets/reverb.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _make_dataset(unused_idx: tf.Tensor) -> tf.data.Dataset:
datasets, weights=tables.values()
)
else:
assert len(datasets) == 1
dataset = datasets[0]

# Post-process each element if a post-processing function is passed, e.g.
Expand Down
6 changes: 3 additions & 3 deletions corax/jax/running_statistics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from absl.testing import absltest
import jax
from jax.config import config as jax_config # type: ignore
from jax import config as jax_config
import jax.numpy as jnp
import numpy as np
import tree
Expand All @@ -31,7 +31,7 @@
update_and_validate = functools.partial(running_statistics.update, validate_shapes=True)


class TestNestedSpec(NamedTuple):
class _TestNestedSpec(NamedTuple):
# Note: the fields are intentionally in reverse order to test ordering.
a: specs.Array
b: specs.Array
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_pmap_update_nested(self):
tree.map_structure(lambda x: self.assert_allclose(x, jnp.ones_like(x)), std)

def test_different_structure_normalize(self):
spec = TestNestedSpec(
spec = _TestNestedSpec(
a=specs.Array((5,), jnp.float32), b=specs.Array((2,), jnp.float32)
)
state = running_statistics.init_state(spec)
Expand Down
2 changes: 1 addition & 1 deletion corax/utils/counting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def wait(self):
"""Waits on the barrier until all threads have called this method."""
with self._cond:
self._count += 1
self._cond.notifyAll()
self._cond.notify_all()
while self._count < self._num_threads:
self._cond.wait()

Expand Down
2 changes: 1 addition & 1 deletion corax/utils/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def close(self):
def tensor_to_numpy(value: Any):
if hasattr(value, "numpy"):
return value.numpy() # tf.Tensor (TF2).
if hasattr(value, "device_buffer"):
if hasattr(value, "addressable_data"):
return np.asarray(value) # jnp.DeviceArray.
return value

Expand Down
13 changes: 13 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import nox


@nox.session
def test(session):
session.install("-r", "requirements/test.txt", "jax[cpu]", "-e", ".[tf,jax]")
session.run("pytest", "-n", "auto", "corax/")


@nox.session
def lint(session):
session.install("pre-commit")
session.run("pre-commit", "run", "--all-files")
Loading