This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 286
/
trainable_ai.cc
124 lines (108 loc) · 3.81 KB
/
trainable_ai.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include "trainable_ai.h"
#include "engine/game_env.h"
#include "engine/unit.h"
#include "engine/cmd_interface.h"
#include "python_options.h"
#include "state_feature.h"
/*
static inline int sampling(const std::vector<float> &v, std::mt19937 *gen) {
std::vector<float> accu(v.size() + 1);
std::uniform_real_distribution<> dis(0, 1);
float rd = dis(*gen);
accu[0] = 0;
for (size_t i = 1; i < accu.size(); i++) {
accu[i] = v[i - 1] + accu[i - 1];
if (rd < accu[i]) {
return i - 1;
}
}
return v.size() - 1;
}
*/
bool TrainedAI::GameEnd(const State &s) {
AIWithComm::GameEnd(s);
for (auto &v : _recent_states.v()) {
v.clear();
}
return true;
}
void TrainedAI::extract(const State &s, Data *data) {
GameState *game = &data->newest();
MCSaveInfo(s, id(), game);
game->name = name();
if (_recent_states.maxlen() == 1) {
MCExtract(s, id(), _respect_fow, &game->s);
// std::cout << "(1) size_s = " << game->s.size() << std::endl << std::flush;
} else {
std::vector<float> &state = _recent_states.GetRoom();
MCExtract(s, id(), _respect_fow, &state);
const size_t maxlen = _recent_states.maxlen();
game->s.resize(maxlen * state.size());
// std::cout << "(" << maxlen << ") size_s = " << game->s.size() << std::endl << std::flush;
std::fill(game->s.begin(), game->s.end(), 0.0);
// Then put all states to game->s.
for (size_t i = 0; i < maxlen; ++i) {
const auto &curr_s = _recent_states.get_from_push(i);
if (! curr_s.empty()) {
assert(curr_s.size() == state.size());
std::copy(curr_s.begin(), curr_s.end(), &game->s[i * curr_s.size()]);
}
}
}
}
#define ACTION_GLOBAL 0
#define ACTION_UNIT_CMD 1
#define ACTION_REGIONAL 2
bool TrainedAI::handle_response(const State &s, const Data &data, RTSMCAction *a) {
a->Init(id(), name());
// if (_receiver == nullptr) return false;
const auto &env = s.env();
// Get the current action from the queue.
const auto &m = env.GetMap();
const GameState& gs = data.newest();
switch(gs.action_type) {
case ACTION_GLOBAL:
// action
a->SetState9(gs.a);
break;
case ACTION_UNIT_CMD:
{
// Use gs.unit_cmds
// std::vector<CmdInput> unit_cmds(gs.unit_cmds);
// Use data
std::vector<CmdInput> unit_cmds;
for (int i = 0; i < gs.n_max_cmd; ++i) {
unit_cmds.emplace_back(_XY(gs.uloc[i], m), _XY(gs.tloc[i], m), gs.ct[i], gs.bt[i]);
}
std::for_each(unit_cmds.begin(), unit_cmds.end(), [&](CmdInput &ci) { ci.ApplyEnv(env); });
a->SetUnitCmds(unit_cmds);
}
/*
case ACTION_REGIONAL:
{
if (_receiver->GetUseCmdComment()) {
string s;
for (size_t i = 0; i < gs.a_region.size(); ++i) {
for (size_t j = 0; j < gs.a_region[0].size(); ++j) {
int a = -1;
for (size_t k = 0; k < gs.a_region[0][0].size(); ++k) {
if (gs.a_region[k][i][j] == 1) {
a = k;
break;
}
}
s += to_string(a) + ",";
}
s += "; ";
}
SendComment(s);
}
}
// Regional actions.
return _mc_rule_actor.ActWithMap(env,reply.action_regions, &a->state_string(), &a->cmds());
*/
default:
throw std::range_error("action_type not valid! " + to_string(gs.action_type));
}
return true;
}