Skip to content

Commit

Permalink
Gameserver update (#66)
Browse files Browse the repository at this point in the history
* Remove parallel training

* Rename output of training

* Add .mono to gitignore

* Add handling of errors while training

* Add `svm_info` to maps & refactor of training

* Edit specification of config

* Update statistics collector & refactor of training

* Replace thread pool with process pool

* Remove excess field from TrainingConfig

* Refactor of validation

* Refactor of epochs_statistics

* Add SingleSVMInfo class

* Remove excess import

* Add config validation & revised training process

* Add `pydantic` to project

* Edit run_training: make the parameter "server_count" explicitly set

* Minor fixes

* Remove excess `svms_count` field

* Add ignoring of extra fields in `SingleSVMInfo`

* Fix `EpochsInfo`: make it a dataclass

* Edit interface of model saver

* Create single dataset. Merge configs.

* Change optuna studies path.

* Small fixes from comments.

Co-authored-by: Max Nigmatulin <40598909+emnigma@users.noreply.github.com>

* Fix bug with SVMInfo adding.

* Refactor config. Add support of miltiple json files for single platform. One SVM for one platform.

* Remove comments.

* edit 'launch_servers': Add support of usvm

* Fix statistics collector

---------

Co-authored-by: Anya497 <chi.vinny0702@gmail.com>
Co-authored-by: Anya Chistyakova <57658770+Anya497@users.noreply.github.com>
Co-authored-by: Max Nigmatulin <40598909+emnigma@users.noreply.github.com>
  • Loading branch information
4 people committed Jul 4, 2024
1 parent baa65ea commit 5dc7ef5
Show file tree
Hide file tree
Showing 17 changed files with 690 additions and 442 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_and_run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
- name: Sanity check
working-directory: ./AIAgent
run: |
if ls *.pkl 1> /dev/null 2>&1; then
if ls report/optuna_studies/*.pkl 1> /dev/null 2>&1; then
echo "PKL files found."
else
echo "No PKL files found."
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,5 @@ cython_debug/
# MacOS specific
.DS_Store
*report/

.mono/
65 changes: 55 additions & 10 deletions AIAgent/common/classes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, TypeAlias

from common.game import GameMap
from dataclasses_json import dataclass_json
from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass as pydantic_dataclass
from connection.broker_conn.classes import SVMInfo
from common.game import GameMap
from ml.protocols import Named

PlatformName: TypeAlias = str
SVMInfoName: TypeAlias = str


@dataclass_json
@dataclass
Expand Down Expand Up @@ -52,12 +59,50 @@ class Map2Result:
AgentResultsOnGameMaps: TypeAlias = defaultdict[Named, list[Map2Result]]


@dataclass_json
@dataclass
class SVMInfo:
name: str
count: int
launch_command: str
min_port: int
max_port: int
server_working_dir: str
@pydantic_dataclass
class DatasetConfig:
dataset_base_path: Path # path to dir with explored dlls
dataset_description: Path # full paths to JSON-file with dataset description

@field_validator("dataset_base_path", "dataset_description", mode="before")
@classmethod
def transform(cls, input: str) -> Path:
return Path(input).resolve()


@pydantic_dataclass
class Platform:
name: PlatformName
DatasetConfigs: list[DatasetConfig]
SVMInfo: SVMInfo


@pydantic_dataclass
class OptunaConfig:
n_startup_trials: int # number of optuna's trials
n_trials: int # number of optuna's trials
n_jobs: int
study_direction: str


@pydantic_dataclass
class TrainingConfig:
dynamic_dataset: bool
train_percentage: float
threshold_coverage: int
load_to_cpu: bool
epochs: int
threshold_steps_number: Optional[int] = Field(default=None)


@pydantic_dataclass
class Config:
Platforms: list[Platform]
OptunaConfig: OptunaConfig
TrainingConfig: TrainingConfig
path_to_weights: Optional[Path] = Field(default=None)

@field_validator("path_to_weights", mode="before")
@classmethod
def transform(cls, input: Optional[str]) -> Optional[Path]:
return Path(input).absolute() if input is not None else None
9 changes: 6 additions & 3 deletions AIAgent/common/game.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass

from dataclasses_json import dataclass_json
from dataclasses import dataclass, field
from dataclasses_json import config, dataclass_json
from connection.broker_conn.classes import SingleSVMInfo


@dataclass_json
Expand Down Expand Up @@ -74,6 +74,9 @@ class GameMap:
NameOfObjectToCover: str
DefaultSearcher: str
MapName: str
SVMInfo: SingleSVMInfo = field(
default=None, metadata=config(exclude=lambda x: True)
)


@dataclass_json
Expand Down
11 changes: 0 additions & 11 deletions AIAgent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,6 @@ class WebsocketSourceLinks:
POST_WS = f"http://0.0.0.0:{BrokerConfig.BROKER_PORT}/post_ws"


class TrainingConfig:
DYNAMIC_DATASET = True
TRAIN_PERCENTAGE = 1
THRESHOLD_COVERAGE = 100
THRESHOLD_STEPS_NUMBER = None
LOAD_TO_CPU = False

OPTUNA_N_JOBS = 1
STUDY_DIRECTION = "maximize"


@dataclass(slots=True, frozen=True)
class DumpByTimeoutFeature:
enabled: bool
Expand Down
21 changes: 21 additions & 0 deletions AIAgent/connection/broker_conn/classes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from typing import Callable, TypeAlias
from pydantic.dataclasses import dataclass as pydantic_dataclass

from config import FeatureConfig
from connection.game_server_conn.unsafe_json import asdict
Expand All @@ -18,5 +19,25 @@ class ServerInstanceInfo:
pid: int | Undefined


@pydantic_dataclass(config=dict(extra="ignore"))
class SingleSVMInfo:
name: str
launch_command: str
min_port: int
max_port: int
server_working_dir: str

def to_dict(self): # GameMap class requires the to_dict method for all its fields
return self.__dict__


@pydantic_dataclass
class SVMInfo(SingleSVMInfo):
count: int

def create_single_svm_info(self) -> SingleSVMInfo:
return SingleSVMInfo(**self.__dict__)


def custom_encoder_if_disable_message_checks() -> Callable | None:
return asdict if FeatureConfig.DISABLE_MESSAGE_CHECKS else None
7 changes: 3 additions & 4 deletions AIAgent/connection/broker_conn/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from urllib.parse import urlencode

import httplib2
from common.classes import SVMInfo
from config import WebsocketSourceLinks
from connection.broker_conn.classes import ServerInstanceInfo
from connection.broker_conn.classes import ServerInstanceInfo, SingleSVMInfo


def acquire_instance(svm_info: SVMInfo) -> ServerInstanceInfo:
def acquire_instance(svm_info: SingleSVMInfo) -> ServerInstanceInfo:
response, content = httplib2.Http().request(
WebsocketSourceLinks.GET_WS + "?" + urlencode(SVMInfo.to_dict(svm_info))
WebsocketSourceLinks.GET_WS + "?" + urlencode(SingleSVMInfo.to_dict(svm_info))
)
if response.status != 200:
logging.error(f"{response.status} with {content=} on acquire_instance call")
Expand Down
5 changes: 2 additions & 3 deletions AIAgent/connection/broker_conn/socket_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from contextlib import contextmanager, suppress

import websocket
from common.classes import SVMInfo
from config import GameServerConnectorConfig
from connection.broker_conn.classes import ServerInstanceInfo
from connection.broker_conn.classes import ServerInstanceInfo, SingleSVMInfo
from connection.broker_conn.requests import acquire_instance, return_instance


Expand Down Expand Up @@ -38,7 +37,7 @@ def wait_for_connection(server_instance: ServerInstanceInfo):


@contextmanager
def game_server_socket_manager(svm_info: SVMInfo):
def game_server_socket_manager(svm_info: SingleSVMInfo):
server_instance = acquire_instance(svm_info)

socket = wait_for_connection(server_instance)
Expand Down
Loading

0 comments on commit 5dc7ef5

Please sign in to comment.