-
Notifications
You must be signed in to change notification settings - Fork 893
/
model.h
180 lines (162 loc) · 6.09 KB
/
model.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
// Copyright 2021 DeepMind Technologies Limited
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OPEN_SPIEL_ALGORITHMS_ALPHA_ZERO_TORCH_MODEL_H_
#define OPEN_SPIEL_ALGORITHMS_ALPHA_ZERO_TORCH_MODEL_H_
#include <torch/torch.h>
#include <iostream>
#include <string>
#include <vector>
namespace open_spiel {
namespace algorithms {
namespace torch_az {
struct ResInputBlockConfig {
int input_channels;
int input_height;
int input_width;
int filters;
int kernel_size;
int padding;
};
struct ResTorsoBlockConfig {
int input_channels;
int filters;
int kernel_size;
int padding;
int layer;
};
struct ResOutputBlockConfig {
int input_channels;
int value_filters;
int policy_filters;
int kernel_size;
int padding;
int value_linear_in_features;
int value_linear_out_features;
int policy_linear_in_features;
int policy_linear_out_features;
int value_observation_size;
int policy_observation_size;
};
// Information for the model. This should be enough for any type of model
// (residual, convultional, or MLP). It needs to be saved/loaded to/from
// a file so the input and output stream operators are overload.
struct ModelConfig {
std::vector<int> observation_tensor_shape;
int number_of_actions;
int nn_depth;
int nn_width;
double learning_rate;
double weight_decay;
};
std::istream& operator>>(std::istream& stream, ModelConfig& config);
std::ostream& operator<<(std::ostream& stream, const ModelConfig& config);
// A block of the residual model's network that handles the input. It consists
// of one convolutional layer (CONV) and one batch normalization (BN) layer, and
// the output is passed through a rectified linear unit function (RELU).
//
// Illustration:
// [Input Tensor] --> CONV --> BN --> RELU
//
// There is only one input block per model.
class ResInputBlockImpl : public torch::nn::Module {
public:
ResInputBlockImpl(const ResInputBlockConfig& config);
torch::Tensor forward(torch::Tensor x);
private:
int channels_;
int height_;
int width_;
torch::nn::Conv2d conv_;
torch::nn::BatchNorm2d batch_norm_;
};
TORCH_MODULE(ResInputBlock);
// A block of the residual model's network that makes up the 'torso'. It
// consists of two convolutional layers (CONV) and two batchnormalization layers
// (BN). The activation function is rectified linear unit (RELU). The input to
// the layer is added to the output before the final activation function.
//
// Illustration:
// [Input Tensor] --> CONV --> BN --> RELU --> CONV --> BN --> + --> RELU
// \___________________________________________________/
//
// Unlike the input and output blocks, one can specify how many of these torso
// blocks they want in their model.
class ResTorsoBlockImpl : public torch::nn::Module {
public:
ResTorsoBlockImpl(const ResTorsoBlockConfig& config, int layer);
torch::Tensor forward(torch::Tensor x);
private:
torch::nn::Conv2d conv1_;
torch::nn::Conv2d conv2_;
torch::nn::BatchNorm2d batch_norm1_;
torch::nn::BatchNorm2d batch_norm2_;
};
TORCH_MODULE(ResTorsoBlock);
// A block of the residual model's network that creates the output. It consists
// of a value and policy head. The value head takes the input through one
// convoluational layer (CONV), one batch normalization layers (BN), and two
// linear layers (LIN). The output activation function is tanh (TANH), the
// rectified linear activation function (RELU) is within. The policy head
// consists of one convolutional layer, batch normalization layer, and linear
// layer. There is no softmax activation function in this layer. The softmax
// on the output is applied in the forward function of the residual model.
// This design was chosen because the loss function of the residual model
// requires the policy logits, not the policy distribution. By providing the
// policy logits as output, the residual model can either apply the softmax
// activation function, or calculate the loss using Torch's log softmax
// function.
//
// Illustration:
// --> CONV --> BN --> RELU --> LIN --> RELU --> LIN --> TANH
// [Input Tensor] --
// --> CONV --> BN --> RELU --> LIN (no SOFTMAX here)
//
// There is only one output block per model.
class ResOutputBlockImpl : public torch::nn::Module {
public:
ResOutputBlockImpl(const ResOutputBlockConfig& config);
std::vector<torch::Tensor> forward(torch::Tensor x, torch::Tensor mask);
private:
torch::nn::Conv2d value_conv_;
torch::nn::BatchNorm2d value_batch_norm_;
torch::nn::Linear value_linear1_;
torch::nn::Linear value_linear2_;
int value_observation_size_;
torch::nn::Conv2d policy_conv_;
torch::nn::BatchNorm2d policy_batch_norm_;
torch::nn::Linear policy_linear_;
int policy_observation_size_;
};
TORCH_MODULE(ResOutputBlock);
// The model class that interacts with the VPNet. The ResInputBlock,
// ResTorsoBlock, and ResOutputBlock are not to be used by the VPNet directly.
class ResModelImpl : public torch::nn::Module {
public:
ResModelImpl(const ModelConfig& config, const std::string& device);
std::vector<torch::Tensor> forward(torch::Tensor x, torch::Tensor mask);
std::vector<torch::Tensor> losses(torch::Tensor inputs, torch::Tensor masks,
torch::Tensor policy_targets,
torch::Tensor value_targets);
private:
std::vector<torch::Tensor> forward_(torch::Tensor x, torch::Tensor mask);
torch::nn::ModuleList layers_;
torch::Device device_;
int num_torso_blocks_;
double weight_decay_;
};
TORCH_MODULE(ResModel);
} // namespace torch_az
} // namespace algorithms
} // namespace open_spiel
#endif // OPEN_SPIEL_ALGORITHMS_ALPHA_ZERO_TORCH_MODEL_H_