-
Notifications
You must be signed in to change notification settings - Fork 4
/
replay_buffer_services.py
158 lines (122 loc) · 7.18 KB
/
replay_buffer_services.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import grpc
import tensorflow as tf
from concurrent import futures
from protos import replay_buffer_pb2
from protos import replay_buffer_pb2_grpc
# For type annotations
from typing import Iterable, Tuple, Dict, Any
from utils import CommandLineParser
from replay_buffer import ReplayBuffer
from game_services import history_to_protobuf, history_from_protobuf
from exceptions import MuZeroError
from config import MuZeroConfig
from game import GameHistory
from muzero_types import ObservationBatch, ActionBatch, ValueBatch, PolicyBatch
class ReplayBufferServer(replay_buffer_pb2_grpc.ReplayBufferServicer):
"""
A server for replay buffers, exposing their functionality through a gRPC API.
"""
def __init__(self, config: MuZeroConfig) -> None:
self.replay_buffer = ReplayBuffer(config=config)
def NumGames(self, request: replay_buffer_pb2.Empty, context) -> replay_buffer_pb2.NumGamesResponse:
return replay_buffer_pb2.NumGamesResponse(num_games=self.replay_buffer.num_games())
def SaveHistory(self, request: replay_buffer_pb2.GameHistory, context) -> replay_buffer_pb2.NumGamesResponse:
self.replay_buffer.save_history(history_from_protobuf(request))
return replay_buffer_pb2.NumGamesResponse(num_games=1)
def SaveMultipleHistory(self,
request_iterator: Iterable[replay_buffer_pb2.GameHistory],
context
) -> replay_buffer_pb2.NumGamesResponse:
num_games = 0
for message in request_iterator:
self.replay_buffer.save_history(history_from_protobuf(message))
num_games += 1
return replay_buffer_pb2.NumGamesResponse(num_games=num_games)
def SampleBatch(self,
request: replay_buffer_pb2.MiniBatchRequest,
context
) -> Iterable[replay_buffer_pb2.MiniBatchResponse]:
dataset = self.replay_buffer.as_dataset(batch_size=request.batch_size)
for inputs, outputs in dataset:
(batch_observations, batch_actions) = inputs
(batch_target_rewards, batch_target_values, batch_target_policies) = outputs
response = replay_buffer_pb2.MiniBatchResponse()
response.batch_observations.CopyFrom(tf.make_tensor_proto(batch_observations))
response.batch_actions.CopyFrom(tf.make_tensor_proto(batch_actions))
response.batch_target_rewards.CopyFrom(tf.make_tensor_proto(batch_target_rewards))
response.batch_target_values.CopyFrom(tf.make_tensor_proto(batch_target_values))
response.batch_target_policies.CopyFrom(tf.make_tensor_proto(batch_target_policies))
yield response
def Stats(self, request: replay_buffer_pb2.StatsRequest, context) -> replay_buffer_pb2.StatsResponse:
if request.detailed:
return replay_buffer_pb2.StatsResponse(metrics=self.replay_buffer.detailed_stats())
else:
return replay_buffer_pb2.StatsResponse(metrics=self.replay_buffer.stats())
def BackupBuffer(self, request: replay_buffer_pb2.Empty, context) -> Iterable[replay_buffer_pb2.GameHistory]:
for history in self.replay_buffer.buffer:
yield history_to_protobuf(history)
class RemoteReplayBuffer:
"""
Connects to a replay buffer server and interacts with it.
Behaves exactly like ReplayBuffer, but is agnostic about how the server deals with the actual buffer.
"""
def __init__(self, ip_port: str) -> None:
channel = grpc.insecure_channel(ip_port)
self.remote_replay_buffer = replay_buffer_pb2_grpc.ReplayBufferStub(channel)
def num_games(self) -> int:
response = self.remote_replay_buffer.NumGames(replay_buffer_pb2.Empty())
return response.num_games
def save_games(self, filepath: str) -> None:
response_iterator = self.remote_replay_buffer.BackupBuffer(replay_buffer_pb2.Empty())
message = replay_buffer_pb2.GameHistoryList()
message.histories.extend(response_iterator)
with open(filepath, 'wb') as protobuf_file:
protobuf_file.write(message.SerializeToString())
def load_games(self, filepath: str) -> None:
message = replay_buffer_pb2.GameHistoryList()
with open(filepath, 'rb') as buffer_file:
message.ParseFromString(buffer_file.read())
self.remote_replay_buffer.SaveMultipleHistory(iter(message.histories))
def save_history(self, game_history: GameHistory) -> None:
request = history_to_protobuf(game_history)
response = self.remote_replay_buffer.SaveHistory(request)
if not response.num_games:
raise MuZeroError(message='Could not save game history!')
def as_dataset(self,
batch_size: int
) -> Iterable[Tuple[Tuple[ObservationBatch, ActionBatch], Tuple[ValueBatch, ValueBatch, PolicyBatch]]]:
request = replay_buffer_pb2.MiniBatchRequest(batch_size=batch_size)
response_iterator = self.remote_replay_buffer.SampleBatch(request)
for response in response_iterator:
batch_observations = ObservationBatch(tf.constant(tf.make_ndarray(response.batch_observations)))
batch_actions = ActionBatch(tf.constant(tf.make_ndarray(response.batch_actions)))
inputs = (batch_observations, batch_actions)
batch_target_rewards = ValueBatch(tf.constant(tf.make_ndarray(response.batch_target_rewards)))
batch_target_values = ValueBatch(tf.constant(tf.make_ndarray(response.batch_target_values)))
batch_target_policies = PolicyBatch(tf.constant(tf.make_ndarray(response.batch_target_policies)))
outputs = (batch_target_rewards, batch_target_values, batch_target_policies)
yield inputs, outputs
def stats(self) -> Dict[str, Any]:
return self.remote_replay_buffer.Stats(replay_buffer_pb2.StatsRequest(detailed=False)).metrics
def detailed_stats(self) -> Dict[str, Any]:
return self.remote_replay_buffer.Stats(replay_buffer_pb2.StatsRequest(detailed=True)).metrics
def main():
parser = CommandLineParser(name='MuProver Replay Buffer Server', game=True, port=True, threads=True)
# parser.add_argument('--backup_dir', type=str, metavar='PATH',
# help='Directory where game backups are stored')
# parser.add_argument('--load', type=str, metavar='PATH',
# help='Filename for .pbuf with games to load at startup.')
args = parser.parse_args()
# if args.backup_dir and not os.path.isdir(args.backup_dir):
# parser.error('--backup_dir {} does not point to a valid directory!'.format(args.backup_dir))
# if args.load and not os.path.isfile(args.load):
# parser.error(f'--load {args.load} does not point to a valid .pbuf file!')
grpc_server = grpc.server(futures.ThreadPoolExecutor(max_workers=args.threads))
servicer = ReplayBufferServer(config=args.config)
replay_buffer_pb2_grpc.add_ReplayBufferServicer_to_server(servicer, grpc_server)
grpc_server.add_insecure_port(f'[::]:{args.port}')
print(f'Starting replay buffer server, listening on port {args.port}...')
grpc_server.start()
grpc_server.wait_for_termination()
if __name__ == '__main__':
main()