-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.hpp
41 lines (41 loc) · 1.27 KB
/
main.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#pragma once
#include "n.hpp"
struct tExpTuple {
std::vector<float> states, n_states;
std::vector<float> actions;
std::vector<float> rewards;
std::vector<bool> dones;
std::vector<float> adv;
std::vector<float> values;
std::vector<float> returns;
float scale;
int maxStep;
int position;
tExpTuple(int size, int state, int action);
};
struct Trainer {
int inputSize, outputSize, type;
Net net, value;
// float px, py, vx, vy;
float exp;
// std::vector<float> v_label;
std::vector<float> scale;
// std::vector<float> w;
// std::vector<float> state;
std::vector<std::shared_ptr<tExpTuple>> tuples;
void initTrainer(int s, int a, int type = 0);
void shutDown();
const std::vector<float> &preUpdate(const std::vector<float> &input);
int postUpdate(int id, float shape, const std::vector<float> &o_input,
const std::vector<float> &n_input,
const std::vector<float> &act, bool done);
void gae(std::vector<float> adv);
void save(const char *model);
void load(const char *model);
};
class cMathUtil {
public:
static double RandDoubleNorm(double mean, double stdev);
static double EvalGaussian(double mean, double covar, double sample);
static double EvalGaussianLogp(double mean, double covar, double sample);
};