Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Update rts/game_MC
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuandong Tian committed Oct 12, 2017
1 parent 06364c0 commit 7cad273
Show file tree
Hide file tree
Showing 16 changed files with 143 additions and 46 deletions.
2 changes: 1 addition & 1 deletion rts/backend/comm_ai.h
Expand Up @@ -9,7 +9,7 @@

#pragma once
#include <concurrentqueue.h>
#include "ws_server.h"
#include "vendor/ws_server.h"
#include "ai.h"
#include "raw2cmd.h"

Expand Down
46 changes: 31 additions & 15 deletions rts/backend/main_loop.cc
Expand Up @@ -17,6 +17,7 @@
#include "engine/ai.h"
#include "elf/game_base.h"
#include "ai.h"
#include "../game_MC/mcts.h"
#include "comm_ai.h"

#include <iostream>
Expand All @@ -43,24 +44,31 @@ bool add_players(const string &args, int frame_skip, RTSGame *game) {
vector<string> params = split(player, '=');
int tick_start = (params.size() == 1 ? 0 : std::stoi(params[1]));
game->AddBot(new TCPAI("tcpai", tick_start, 8000), 1);
}
/*else if (player.find("mcts") == 0) {
game->GetState().AppendPlayer("tcpai");
} else if (player.find("mcts") == 0) {
vector<string> params = split(player, '=');
int mcts_thread = std::stoi(params[1]);
int mcts_rollout_per_thread = std::stoi(params[2]);

mcts::TSOptions opt;
opt.num_threads = std::stoi(params[1]);
opt.num_rollout_per_thread = std::stoi(params[2]);

/*
vector<string> prerun_cmds;
if (params.size() >= 4) {
prerun_cmds = split(params[3], '-');
}
bots.push_back(new MCTSAI(INVALID, frame_skip, nullptr, mcts_thread, mcts_rollout_per_thread, false, &prerun_cmds));
mcts = true;
}*/
else if (player.find("spectator") == 0) {
*/
game->AddBot(new MCTSRTSAI(opt), frame_skip);
game->GetState().AppendPlayer("mcts_ai");
// mcts = true;
} else if (player.find("spectator") == 0) {
vector<string> params = split(player, '=');
int tick_start = (params.size() == 1 ? 0 : std::stoi(params[1]));
game->AddSpectator(new TCPAI("spectator", tick_start, 8000));
} else if (player == "dummy") {
game->AddBot(new AI("dummy"), frame_skip);
game->GetState().AppendPlayer("dummy");
}
else if (player == "dummy") game->AddBot(new AI("dummy"), frame_skip);
/*
else if (player == "flag_simple") {
//if (mcts) bots[0]->SetFactory([&](int r) -> AI* { return new FlagSimpleAI(INVALID, r, nullptr, nullptr);});
Expand All @@ -72,6 +80,7 @@ bool add_players(const string &args, int frame_skip, RTSGame *game) {
AI *ai = AIFactory<AI>::CreateAI(player, "");
if (ai != nullptr) {
game->AddBot(ai, frame_skip);
game->GetState().AppendPlayer(player);
} else {
cout << "Unknown player! " << player << endl;
return false;
Expand Down Expand Up @@ -147,7 +156,7 @@ RTSGameOptions ai_vs_ai2(const Parser &parser, string *players) {

return options;
}
/*

RTSGameOptions ai_vs_mcts(const Parser &parser, string *players) {
RTSGameOptions options = GetOptions(parser);
int mcts_threads = parser.GetItem<int>("mcts_threads");
Expand All @@ -163,7 +172,7 @@ RTSGameOptions ai_vs_mcts(const Parser &parser, string *players) {

return options;
}
*/

RTSGameOptions flag_ai_vs_ai(const Parser &parser, string *players) {
RTSGameOptions options = GetOptions(parser);
*players = "flag_simple,flag_simple,dummy";
Expand Down Expand Up @@ -240,7 +249,9 @@ void replay_mcts(const Parser &parser) {
cout << "Loading game " << endl;
RTSGame game(options);
game.AddBot(new MCTSAI(INVALID, frame_skip, nullptr, mcts_threads, mcts_rollout_per_thread, mcts_verbose, &prerun_cmds));
game.GetState().AppendPlayer("MCTSAI");
game.AddBot(new SimpleAI(INVALID, frame_skip, nullptr, nullptr));
game.GetState().AppendPlayer("SimpleAI");
cout << "Starting main loop " << endl;
PlayerId winner = game.MainLoop();
Expand Down Expand Up @@ -283,7 +294,10 @@ void replay_rollout(const Parser &parser) {
cout << "Loading game " << endl;
RTSGame game(options);
game.AddBot(new MCTS_ROLLOUT_AI(0, frame_skip, nullptr, selected_moves));
game.GetState().AppendPlayer("MCTS_ROLLOUT_AI");
game.AddBot(new SimpleAI(1, frame_skip, nullptr, nullptr));
game.GetState().AppendPlayer("SimpleAI");
cout << "Starting main loop " << endl;
PlayerId winner = game.MainLoop();
Expand Down Expand Up @@ -319,7 +333,7 @@ int main(int argc, char *argv[]) {
const map<string, function<RTSGameOptions (const Parser &, string *)> > func_mapping = {
{ "selfplay", ai_vs_ai },
{ "selfplay2", ai_vs_ai2 },
//{ "mcts", ai_vs_mcts },
{ "mcts", ai_vs_mcts },

{ "replay", replay },
{ "replay_cmd", replay_cmd },
Expand All @@ -337,7 +351,7 @@ int main(int argc, char *argv[]) {
};

GameDef::GlobalInit();

CmdLineUtils::CmdLineParser parser("playstyle --save_replay --load_replay --vis_after[-1] --save_snapshot_prefix --load_snapshot_prefix --seed[0] \
--load_snapshot_length --max_tick[30000] --binary_io[1] --games[16] --frame_skip[1] --tick_prompt_n_step[2000] --cmd_verbose[0] --peek_ticks --cmd_dumper_prefix \
--output_file[cout] --mcts_threads[16] --mcts_rollout_per_thread[100] --threads[64] --load_binary_string --mcts_verbose --mcts_prerun_cmds --handicap_level[0]");
Expand Down Expand Up @@ -395,12 +409,14 @@ int main(int argc, char *argv[]) {
else options.seed = seed0 + i * 241;

RTSStateExtend state(options);
RTSGame game(state);
RTSGame game(&state);
//game.AddBot(new SimpleAI(INVALID, frame_skip, nullptr));
//
//game.AddBot(new SimpleAI(INVALID, frame_skip, nullptr));
game.AddBot(AIFactory<AI>::CreateAI("simple", ""), frame_skip);
game.AddBot(AIFactory<AI>::CreateAI("simple", ""), frame_skip);
state.AppendPlayer("simple1");
state.AppendPlayer("simple2");

state.SetGlobalStats(&gstats);
bool infinite = (games == 0);
Expand All @@ -415,7 +431,7 @@ int main(int argc, char *argv[]) {
std::cout << gstats.PrintInfo() << std::endl;
} else {
RTSStateExtend state(options);
RTSGame game(state);
RTSGame game(&state);
cout << "Players: " << players << endl;
add_players(players, frame_skip, &game);
cout << "Finish adding players" << endl;
Expand Down
6 changes: 2 additions & 4 deletions rts/engine/game_state.cc
Expand Up @@ -9,13 +9,11 @@ RTSState::RTSState() {
_env.ClearAllPlayers();
}

void RTSState::OnAddPlayer(const string &name, int player_id) {
(void)player_id;
void RTSState::AppendPlayer(const string &name) {
_env.AddPlayer(name, PV_KNOW_ALL);
}

void RTSState::OnRemovePlayer(int player_id) {
(void)player_id;
void RTSState::RemoveLastPlayer() {
_env.RemovePlayer();
}

Expand Down
10 changes: 8 additions & 2 deletions rts/engine/game_state.h
Expand Up @@ -27,6 +27,10 @@ class RTSState {
}

// Copy construct.
RTSState(const RTSState &s) {
_env.InitGameDef();
*this = s;
}
RTSState &operator=(const RTSState &s) {
string str;
s.Save(&str);
Expand Down Expand Up @@ -76,8 +80,10 @@ class RTSState {

virtual bool Reset();

virtual void OnAddPlayer(const std::string &name, int player_id);
virtual void OnRemovePlayer(int player_id);
virtual ~RTSState() { }

void AppendPlayer(const std::string &name);
void RemoveLastPlayer();

private:
GameEnv _env;
Expand Down
4 changes: 4 additions & 0 deletions rts/engine/gamedef.h
Expand Up @@ -105,6 +105,10 @@ class GameDef {
bool CheckAddUnit(RTSMap* _map, UnitType type, const PointF& p) const;

const UnitTemplate &unit(UnitType t) const {
if (t < 0 || t >= (int)_units.size()) {
cout << "UnitType " << t << " is not found!" << endl;
throw std::range_error("Unit type is not found!");
}
return _units[t];
}

Expand Down
2 changes: 1 addition & 1 deletion rts/engine/wrapper_template.h
Expand Up @@ -51,7 +51,7 @@ class WrapperT {
// Note that all the bots created here will be owned by game.
// Note that AddBot() will set its receiver. So there is no need to specify it here.
RTSStateExtend s(op);
RTSGame game(s);
RTSGame game(&s);
wrapper.OnGameInit(&game, more_params);

s.SetGlobalStats(&_gstats);
Expand Down
1 change: 0 additions & 1 deletion rts/game_MC/ai.cc
@@ -1 +0,0 @@

17 changes: 17 additions & 0 deletions rts/game_MC/ai.h
Expand Up @@ -13,6 +13,7 @@
#include "engine/cmd_interface.h"
#include "engine/game_state.h"
#include "elf/ai.h"
#include "elf/mcts.h"
#include "game_action.h"
#include "python_options.h"
#define NUM_RES_SLOT 5
Expand All @@ -23,3 +24,19 @@ using Data = typename AIComm::Data;

using AIWithComm = elf::AIWithCommT<RTSState, RTSMCAction, AIComm>;
using AI = elf::AI_T<RTSState, RTSMCAction>;

struct ReducedState {
vector<float> state;
int action;
};

struct ReducedPred : public mcts::NodeResponseT<int> {
void SetPiAndV(const vector<float>& new_pi, float new_v) {
this->pi.resize(new_pi.size());
for (size_t i = 0; i < new_pi.size(); ++i) {
this->pi[i].first = i;
this->pi[i].second = new_pi[i];
}
this->value = new_v;
}
};
7 changes: 5 additions & 2 deletions rts/game_MC/game_action.cc
@@ -1,7 +1,7 @@
#include "game_action.h"

bool RTSMCAction::Send(const GameEnv &env, CmdReceiver &receiver) {
// Apply command.
// Apply command.
MCRuleActor rule_actor;
rule_actor.SetPlayerId(_player_id);
rule_actor.SetReceiver(&receiver);
Expand All @@ -17,7 +17,10 @@ bool RTSMCAction::Send(const GameEnv &env, CmdReceiver &receiver) {
}

switch(_type) {
case STATE9:
case STATE9:
if (_action < 0 || _action >= (int) state.size()) {
cout << "RTSMCAction: action invalid! action = " << _action << " / " << state.size() << endl;
}
state[_action] = 1;
break;
case SIMPLE:
Expand Down
1 change: 0 additions & 1 deletion rts/game_MC/python_options.h
Expand Up @@ -50,7 +50,6 @@ struct GameState {
float V;
std::vector<float> pi;

//
std::vector<int64_t> uloc, tloc;
std::vector<int64_t> bt, ct;

Expand Down
37 changes: 37 additions & 0 deletions rts/game_MC/rule_ai.h
Expand Up @@ -32,4 +32,41 @@ class HitAndRunAI : public AI {
}
};

class RandomAI : public AI {
public:
RandomAI(int seed) : AI("random"), rng_(seed) { }

// SERIALIZER_DERIVED(HitAndRunAI, AIBase, _state);

private:
std::mt19937 rng_;
bool Act(const RTSState &, RTSMCAction *action, const atomic_bool *) override {
action->Init(id(), name());
action->SetState9(rng_() % NUM_AISTATE);
return true;
}
};

class FixedAI : public AI {
public:
FixedAI(const AIOptions &opt, int seed) : AI(opt.name), rng_(seed) { }

void SpecifyNextAction(int a) {
specified_action_ = a;
}

// SERIALIZER_DERIVED(HitAndRunAI, AIBase, _state);

private:
int specified_action_ = -1;
std::mt19937 rng_;

bool Act(const RTSState &, RTSMCAction *action, const atomic_bool *) override {
action->Init(id(), name());
int a = specified_action_;
if (a < 0) a = rng_() % NUM_AISTATE;
action->SetState9(a);
return true;
}
};

13 changes: 13 additions & 0 deletions rts/game_MC/state_extract.cc
Expand Up @@ -10,6 +10,19 @@ static inline void accu_value(int idx, float val, std::map<int, std::pair<int, f
}
}

void MCSaveInfo(const RTSState &s, PlayerId player_id, GameState *gs) {
gs->tick = s.GetTick();
gs->winner = s.env().GetWinnerId();
gs->terminal = s.env().GetTermination() ? 1 : 0;

gs->last_r = 0.0;
int winner = s.env().GetWinnerId();
if (winner != INVALID) {
if (winner == player_id) gs->last_r = 1.0;
else gs->last_r = -1.0;
// cout << "player_id: " << player_id << " " << (gs->last_r > 0 ? "Won" : "Lose") << endl;
}
}

void MCExtract(const RTSState &s, PlayerId player_id, bool respect_fow, std::vector<float> *state) {
const GameEnv &env = s.env();
Expand Down
2 changes: 2 additions & 0 deletions rts/game_MC/state_feature.h
@@ -1,11 +1,13 @@
#pragma once

#include "engine/game_state.h"
#include "python_options.h"

#define _OFFSET(_c, _x, _y, _m) (((_c) * _m.GetYSize() + (_y)) * _m.GetXSize() + (_x))
#define _XY(loc, m) ((loc) % m.GetXSize()), ((loc) / m.GetXSize())

#define NUM_RES_SLOT 5

void MCExtract(const RTSState &s, PlayerId player_id, bool respect_fow, std::vector<float> *state);
void MCSaveInfo(const RTSState &s, PlayerId player_id, GameState *gs);

17 changes: 3 additions & 14 deletions rts/game_MC/trainable_ai.cc
Expand Up @@ -33,12 +33,9 @@ bool TrainedAI::GameEnd(const State &s) {
}

void TrainedAI::extract(const State &s, Data *data) {
const GameEnv &env = s.env();

GameState *game = &data->newest();
game->tick = s.receiver().GetTick();
game->winner = env.GetWinnerId();
game->terminal = env.GetTermination() ? 1 : 0;

MCSaveInfo(s, id(), game);
game->name = name();

if (_recent_states.maxlen() == 1) {
Expand All @@ -61,21 +58,13 @@ void TrainedAI::extract(const State &s, Data *data) {
}
}
}

game->last_r = 0.0;
int winner = s.env().GetWinnerId();

if (winner != INVALID) {
if (winner == id()) game->last_r = 1.0;
else game->last_r = -1.0;
}
}

#define ACTION_GLOBAL 0
#define ACTION_UNIT_CMD 1
#define ACTION_REGIONAL 2

bool TrainedAI::handle_response(const State &s, const Data &data, RTSMCAction *a) {
bool TrainedAI::handle_response(const State &s, const Data &data, RTSMCAction *a) {
a->Init(id(), name());

// if (_receiver == nullptr) return false;
Expand Down

0 comments on commit 7cad273

Please sign in to comment.