Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gameserver update #66

Merged
merged 30 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
36d07b6
Remove parallel training
Parzival-05 May 14, 2024
bdc59b9
Rename output of training
Parzival-05 May 21, 2024
f3fde83
Add .mono to gitignore
Parzival-05 May 21, 2024
fbc8db5
Add handling of errors while training
Parzival-05 May 21, 2024
ed70864
Add `svm_info` to maps & refactor of training
Parzival-05 May 21, 2024
db44f50
Edit specification of config
Parzival-05 May 21, 2024
29a9358
Update statistics collector & refactor of training
Parzival-05 May 22, 2024
d3b55ec
Replace thread pool with process pool
Parzival-05 May 23, 2024
d77cfdd
Remove excess field from TrainingConfig
Parzival-05 May 23, 2024
c655011
Refactor of validation
Parzival-05 May 28, 2024
3c256b5
Refactor of epochs_statistics
Parzival-05 May 28, 2024
c9908af
Add SingleSVMInfo class
Parzival-05 May 28, 2024
72871f5
Remove excess import
Parzival-05 May 28, 2024
4430d55
Add config validation & revised training process
Parzival-05 Jun 2, 2024
8074712
Add `pydantic` to project
Parzival-05 Jun 2, 2024
4d458e6
Edit run_training: make the parameter "server_count" explicitly set
Parzival-05 Jun 2, 2024
74473b9
Minor fixes
Parzival-05 Jun 2, 2024
ded4f31
Remove excess `svms_count` field
Parzival-05 Jun 3, 2024
e2b151d
Add ignoring of extra fields in `SingleSVMInfo`
Parzival-05 Jun 3, 2024
f49bb6c
Fix `EpochsInfo`: make it a dataclass
Parzival-05 Jun 3, 2024
e93c530
Edit interface of model saver
Parzival-05 Jun 3, 2024
cbae0f1
Create single dataset. Merge configs.
Anya497 Jun 26, 2024
01165a9
Change optuna studies path.
Anya497 Jun 27, 2024
7ae6edf
Small fixes from comments.
Anya497 Jun 27, 2024
cf25be5
Fix bug with SVMInfo adding.
Anya497 Jul 4, 2024
5da23c2
Refactor config. Add support of miltiple json files for single platfo…
Anya497 Jul 4, 2024
e8d94c7
Merge branch 'dev' into gameserver-upd
Anya497 Jul 4, 2024
80ea1f3
Remove comments.
Anya497 Jul 4, 2024
0ede7de
edit 'launch_servers': Add support of usvm
Parzival-05 Jul 4, 2024
9253f82
Fix statistics collector
Parzival-05 Jul 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_and_run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
- name: Sanity check
working-directory: ./AIAgent
run: |
if ls *.pkl 1> /dev/null 2>&1; then
if ls report/optuna_studies/*.pkl 1> /dev/null 2>&1; then
echo "PKL files found."
else
echo "No PKL files found."
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,5 @@ cython_debug/
# MacOS specific
.DS_Store
*report/

.mono/
65 changes: 55 additions & 10 deletions AIAgent/common/classes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, TypeAlias

from common.game import GameMap
from dataclasses_json import dataclass_json
from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass as pydantic_dataclass
from connection.broker_conn.classes import SVMInfo
from common.game import GameMap
from ml.protocols import Named

PlatformName: TypeAlias = str
SVMInfoName: TypeAlias = str


@dataclass_json
@dataclass
Expand Down Expand Up @@ -52,12 +59,50 @@ class Map2Result:
AgentResultsOnGameMaps: TypeAlias = defaultdict[Named, list[Map2Result]]


@dataclass_json
@dataclass
class SVMInfo:
name: str
count: int
launch_command: str
min_port: int
max_port: int
server_working_dir: str
@pydantic_dataclass
class DatasetConfig:
dataset_base_path: Path # path to dir with explored dlls
dataset_description: Path # full paths to JSON-file with dataset description

@field_validator("dataset_base_path", "dataset_description", mode="before")
@classmethod
def transform(cls, input: str) -> Path:
return Path(input).resolve()


@pydantic_dataclass
class Platform:
name: PlatformName
DatasetConfigs: list[DatasetConfig]
SVMInfo: SVMInfo


@pydantic_dataclass
class OptunaConfig:
n_startup_trials: int # number of optuna's trials
n_trials: int # number of optuna's trials
Anya497 marked this conversation as resolved.
Show resolved Hide resolved
n_jobs: int
study_direction: str


@pydantic_dataclass
class TrainingConfig:
dynamic_dataset: bool
train_percentage: float
threshold_coverage: int
load_to_cpu: bool
epochs: int
threshold_steps_number: Optional[int] = Field(default=None)


@pydantic_dataclass
class Config:
Platforms: list[Platform]
OptunaConfig: OptunaConfig
TrainingConfig: TrainingConfig
path_to_weights: Optional[Path] = Field(default=None)

@field_validator("path_to_weights", mode="before")
@classmethod
def transform(cls, input: Optional[str]) -> Optional[Path]:
return Path(input).absolute() if input is not None else None
9 changes: 6 additions & 3 deletions AIAgent/common/game.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass

from dataclasses_json import dataclass_json
from dataclasses import dataclass, field
from dataclasses_json import config, dataclass_json
from connection.broker_conn.classes import SingleSVMInfo


@dataclass_json
Expand Down Expand Up @@ -74,6 +74,9 @@ class GameMap:
NameOfObjectToCover: str
DefaultSearcher: str
MapName: str
SVMInfo: SingleSVMInfo = field(
default=None, metadata=config(exclude=lambda x: True)
)


@dataclass_json
Expand Down
11 changes: 0 additions & 11 deletions AIAgent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,6 @@ class WebsocketSourceLinks:
POST_WS = f"http://0.0.0.0:{BrokerConfig.BROKER_PORT}/post_ws"


class TrainingConfig:
DYNAMIC_DATASET = True
TRAIN_PERCENTAGE = 1
THRESHOLD_COVERAGE = 100
THRESHOLD_STEPS_NUMBER = None
LOAD_TO_CPU = False

OPTUNA_N_JOBS = 1
STUDY_DIRECTION = "maximize"


@dataclass(slots=True, frozen=True)
class DumpByTimeoutFeature:
enabled: bool
Expand Down
21 changes: 21 additions & 0 deletions AIAgent/connection/broker_conn/classes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from typing import Callable, TypeAlias
from pydantic.dataclasses import dataclass as pydantic_dataclass

from config import FeatureConfig
from connection.game_server_conn.unsafe_json import asdict
Expand All @@ -18,5 +19,25 @@ class ServerInstanceInfo:
pid: int | Undefined


@pydantic_dataclass(config=dict(extra="ignore"))
class SingleSVMInfo:
name: str
launch_command: str
min_port: int
max_port: int
server_working_dir: str

def to_dict(self): # GameMap class requires the to_dict method for all its fields
return self.__dict__


@pydantic_dataclass
class SVMInfo(SingleSVMInfo):
count: int

def create_single_svm_info(self) -> SingleSVMInfo:
return SingleSVMInfo(**self.__dict__)
emnigma marked this conversation as resolved.
Show resolved Hide resolved


def custom_encoder_if_disable_message_checks() -> Callable | None:
return asdict if FeatureConfig.DISABLE_MESSAGE_CHECKS else None
7 changes: 3 additions & 4 deletions AIAgent/connection/broker_conn/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from urllib.parse import urlencode

import httplib2
from common.classes import SVMInfo
from config import WebsocketSourceLinks
from connection.broker_conn.classes import ServerInstanceInfo
from connection.broker_conn.classes import ServerInstanceInfo, SingleSVMInfo


def acquire_instance(svm_info: SVMInfo) -> ServerInstanceInfo:
def acquire_instance(svm_info: SingleSVMInfo) -> ServerInstanceInfo:
response, content = httplib2.Http().request(
WebsocketSourceLinks.GET_WS + "?" + urlencode(SVMInfo.to_dict(svm_info))
WebsocketSourceLinks.GET_WS + "?" + urlencode(SingleSVMInfo.to_dict(svm_info))
)
if response.status != 200:
logging.error(f"{response.status} with {content=} on acquire_instance call")
Expand Down
5 changes: 2 additions & 3 deletions AIAgent/connection/broker_conn/socket_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from contextlib import contextmanager, suppress

import websocket
from common.classes import SVMInfo
from config import GameServerConnectorConfig
from connection.broker_conn.classes import ServerInstanceInfo
from connection.broker_conn.classes import ServerInstanceInfo, SingleSVMInfo
from connection.broker_conn.requests import acquire_instance, return_instance


Expand Down Expand Up @@ -38,7 +37,7 @@ def wait_for_connection(server_instance: ServerInstanceInfo):


@contextmanager
def game_server_socket_manager(svm_info: SVMInfo):
def game_server_socket_manager(svm_info: SingleSVMInfo):
server_instance = acquire_instance(svm_info)

socket = wait_for_connection(server_instance)
Expand Down
124 changes: 84 additions & 40 deletions AIAgent/epochs_statistics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import multiprocessing as mp
from dataclasses import dataclass
from multiprocessing.managers import BaseManager
from enum import Enum
from pathlib import Path
from statistics import mean
from typing import TypeAlias
from typing import Optional, TypeAlias

import natsort
import pandas as pd
Expand All @@ -24,7 +23,6 @@ class TrainingParams:
num_hops_1: int
num_hops_2: int
num_of_state_features: int
epochs: int


@dataclass
Expand All @@ -33,62 +31,116 @@ class StatsWithTable:
df: pd.DataFrame


class StatisticsCollector:
_lock = mp.Lock()
class SVMStatus(Enum):
RUNNING = "running"
FAILED = "failed"
FINISHED = "finished"


@dataclass
class Status:
"""
emnigma marked this conversation as resolved.
Show resolved Hide resolved
status : SVMStatus.RUNNING | SVMStatus.FAILED | SVMStatus.FINISHED
"""

status: SVMStatus
epoch: int

def __str__(self) -> str:
result: str = f"status={self.status.value}"
if self.status == SVMStatus.FAILED:
result += f", on epoch = {self.epoch}"
return result


class StatisticsCollector:
def __init__(
self,
SVM_count: int,
file: Path,
):
self._SVM_count: int = SVM_count
self._file = file

self._SVMS_info: dict[SVMName, TrainingParams] = {}
self._svms_info: dict[SVMName, Optional[TrainingParams]] = {}
self._epochs: dict[SVMName, Optional[EpochNumber]] = {}
self._sessions_info: dict[EpochNumber, dict[SVMName, StatsWithTable]] = {}
self._status: dict[SVMName, Status] = {}

self._running: Optional[SVMName] = None

def register_training_session(
def register_new_training_session(self, svm_name: SVMName):
self._running = svm_name
self._svms_info[svm_name] = None
self._epochs[svm_name] = None
self._svms_info = sort_dict(self._svms_info)
self._update_file()

def start_training_session(
self,
SVM_name: SVMName,
batch_size: int,
lr: float,
num_hops_1: int,
num_hops_2: int,
num_of_state_features: int,
epochs: int,
):
self._SVMS_info[SVM_name] = TrainingParams(
batch_size, lr, num_hops_1, num_hops_2, num_of_state_features, epochs
svm_name = self._running
self._epochs[svm_name] = epochs
self._status[svm_name] = Status(SVMStatus.RUNNING, 0)

self._svms_info[svm_name] = TrainingParams(
batch_size, lr, num_hops_1, num_hops_2, num_of_state_features
)
self._SVMS_info = sort_dict(self._SVMS_info)
self._update_file()

def fail(self):
svm_name = self._running
self._status[svm_name].status = SVMStatus.FAILED
self._running = None
self._update_file()

def finish(self):
svm_name = self._running
self._status[svm_name].status = SVMStatus.FINISHED
self._running = None
self._update_file()

def update_results(
self,
epoch: EpochNumber,
SVM_name: SVMName,
average_result: float,
map2results_list: list[Map2Result],
):
svm_name = self._running
epoch = self._status[svm_name].epoch

results = self._sessions_info.get(epoch, {})
results[SVM_name] = StatsWithTable(
average_result, convert_to_df(SVM_name, map2results_list)
results[svm_name] = StatsWithTable(
average_result, convert_to_df(svm_name, map2results_list)
)
self._sessions_info[epoch] = sort_dict(results)
self._status[svm_name].epoch += 1
self._update_file()

def _get_SVMS_info(self) -> str:
svm_info_line = lambda svm_info: (
f"{svm_info[0]} : "
f"batch_size={svm_info[1].batch_size}, "
f"lr={svm_info[1].lr}, "
f"num_hops_1={svm_info[1].num_hops_1}, "
f"num_hops_2={svm_info[1].num_hops_2}, "
f"num_of_state_features={svm_info[1].num_of_state_features}, "
f"epochs={svm_info[1].epochs}\n"
)
def _get_training_info(self) -> str:
def svm_info_line(svm_info):
svm_name, training_params = svm_info[0], svm_info[1]
epochs = self._epochs[svm_name]
status: Optional[Status] = self._status.get(svm_name, None)
if status is None:
return ""

svm_info_line = (
f"{svm_name} : "
f"{str(status)}, "
f"epochs={epochs}, "
f"batch_size={training_params.batch_size}, "
f"lr={training_params.lr}, "
f"num_hops_1={training_params.num_hops_1}, "
f"num_hops_2={training_params.num_hops_2}, "
f"num_of_state_features={training_params.num_of_state_features}\n"
)
return svm_info_line

return "".join(list(map(svm_info_line, self._SVMS_info.items())))
return "".join(list(map(svm_info_line, self._svms_info.items())))

def _get_epochs_results(self) -> str:
epochs_results = str()
Expand All @@ -113,11 +165,10 @@ def _get_epochs_results(self) -> str:
return epochs_results

def _update_file(self):
with self._lock:
SVMS_info = self._get_SVMS_info()
epochs_results = self._get_epochs_results()
svms_info = self._get_training_info()
epochs_results = self._get_epochs_results()
with open(self._file, "w") as f:
f.write(SVMS_info)
f.write(svms_info)
f.write(epochs_results)


Expand All @@ -133,10 +184,3 @@ def convert_to_df(svm_name: SVMName, map2result_list: list[Map2Result]) -> pd.Da
df = pd.DataFrame(results, columns=["Game result"], index=maps).T

return df


class StatisticsManager(BaseManager):
pass


StatisticsManager.register("StatisticsCollector", StatisticsCollector)
Loading
Loading