-
Notifications
You must be signed in to change notification settings - Fork 9
/
setup.py
145 lines (122 loc) · 3.42 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import setuptools
from importlib import util as import_util
import setuptools.command
import setuptools.command.build_py
import setuptools.command.develop
spec = import_util.spec_from_file_location('_metadata', 'muax/_metadata.py')
_metadata = import_util.module_from_spec(spec)
spec.loader.exec_module(_metadata)
muax_core_requirements = [
'mctx',
'dm-haiku',
'optax',
'gymnasium',
'lz4',
'tensorboardX'
]
tensorflow = [
'tensorflow==2.8.0',
'tensorflow_probability==0.15.0',
'tensorflow_datasets==4.6.0',
'dm-reverb==0.7.2',
'dm-launchpad==0.5.2',
]
acme_core_requirements = [
'dm-acme',
'absl-py',
'dm-env',
'dm-tree',
'numpy',
'pillow',
'typing-extensions',
]
acme_jax_requirements = [
'jax==0.4.3',
'jaxlib==0.4.3',
'chex==0.1.6',
'dm-haiku==0.0.10',
'flax',
'optax==0.1.7',
'rlax==0.1.6',
] + tensorflow + acme_core_requirements
acme_tf_requirements = [
'dm-sonnet',
'trfl',
] + tensorflow + acme_core_requirements
testing_requirements = [
'pytype==2023.12.8',
'pytest-xdist',
]
envs_requirements = [
'atari-py',
'bsuite',
'dm-control',
'gym==0.25.0',
'gym[atari]',
'pygame==2.1.0',
'rlds',
]
def generate_requirements_file(path=None):
"""Generates requirements.txt file with the Acme's dependencies.
Function from acme setup.py.
It is used by Launchpad GCP runtime to generate Acme requirements to be
installed inside the docker image. Acme itself is not installed from pypi,
but instead sources are copied over to reflect any local changes made to
the codebase.
Args:
path: path to the requirements.txt file to generate.
"""
if not path:
path = os.path.join(os.path.dirname(__file__), 'muax/requirements.txt')
with open(path, 'w') as f:
for package in set(muax_core_requirements
+ acme_core_requirements
+ acme_jax_requirements
+ acme_tf_requirements
+ envs_requirements):
f.write(f'{package}\n')
with open('README.md', 'r') as f:
long_description = f.read()
version = _metadata.__version__
class BuildPy(setuptools.command.build_py.build_py):
def run(self):
generate_requirements_file()
setuptools.command.build_py.build_py.run(self)
class Develop(setuptools.command.develop.develop):
def run(self):
generate_requirements_file()
setuptools.command.develop.develop.run(self)
cmdclass = {
'build_py': BuildPy,
'develop': Develop,
}
setuptools.setup(
name='muax',
version=version,
cmdclass=cmdclass,
author = 'bwfbowen',
description="A library that provides help for using MCTS RL with different frameworks.",
keywords='reinforcement-learning mcts python muzero machine learning',
long_description=long_description,
long_description_content_type='text/markdown',
packages=setuptools.find_packages(),
package_data={"": ['requirements.txt']},
include_package_data=True,
install_requires=muax_core_requirements,
extras_require={
'acme-jax': acme_jax_requirements,
'acme-tf': acme_tf_requirements,
'testing': testing_requirements,
'envs': envs_requirements,
},
classifiers=[
'Programming Language :: Python :: 3',
'License :: OSI Approved :: MIT License',
'Operating System :: OS Independent'
],
dependency_links=[
'https://storage.googleapis.com/jax-releases/jax_releases.html',
],
python_requires='>=3.9',
)