Skip to content

Commit

Permalink
Update statistics collector & refactor of training
Browse files Browse the repository at this point in the history
  • Loading branch information
Parzival-05 committed May 23, 2024
1 parent db44f50 commit 29a9358
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 43 deletions.
100 changes: 77 additions & 23 deletions AIAgent/epochs_statistics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import multiprocessing as mp
from dataclasses import dataclass
from pathlib import Path
from statistics import mean
Expand Down Expand Up @@ -32,9 +31,28 @@ class StatsWithTable:
df: pd.DataFrame


class StatisticsCollector:
_lock = mp.Lock()
@dataclass
class Status:
"""
status : RUNNING_STATUS | FAILED_STATUS | FINISHED_STATUS
"""

status: str
epoch: EpochNumber

RUNNING_STATUS = "running"
FAILED_STATUS = "failed"
FINISHED_STATUS = "finished"

def __str__(self) -> str:
status = self.status
result: str = f"status={self.status}"
if status == self.FAILED_STATUS:
result += f", on epoch = {self.epoch}"
return result


class StatisticsCollector:
def __init__(
self,
SVM_count: int,
Expand All @@ -43,49 +61,86 @@ def __init__(
self._SVM_count: int = SVM_count
self._file = file

self._SVMS_info: dict[SVMName, TrainingParams] = {}
self._SVMS_info: dict[SVMName, TrainingParams | None] = {}
self._sessions_info: dict[EpochNumber, dict[SVMName, StatsWithTable]] = {}
self._status: dict[SVMName, Status] = {}

self._running: SVMName | None = None

def register_new_training_session(self, SVM_name: SVMName):
self._running = SVM_name
self._SVMS_info[SVM_name] = None
self._SVMS_info = sort_dict(self._SVMS_info)
self._update_file()

def register_training_session(
def start_training_session(
self,
SVM_name: SVMName,
batch_size: int,
lr: float,
num_hops_1: int,
num_hops_2: int,
num_of_state_features: int,
epochs: int,
):
SVM_name = self._running
self._status[SVM_name] = Status(Status.RUNNING_STATUS, 0)

self._SVMS_info[SVM_name] = TrainingParams(
batch_size, lr, num_hops_1, num_hops_2, num_of_state_features, epochs
batch_size,
lr,
num_hops_1,
num_hops_2,
num_of_state_features,
epochs,
)
self._SVMS_info = sort_dict(self._SVMS_info)
self._update_file()

def fail(self):
SVM_name = self._running
self._status[SVM_name].status = Status.FAILED_STATUS
self._running = None
self._update_file()

def finish(self):
SVM_name = self._running
self._status[SVM_name].status = Status.FINISHED_STATUS
self._running = None
self._update_file()

def update_results(
self,
epoch: EpochNumber,
SVM_name: SVMName,
average_result: float,
map2results_list: list[Map2Result],
):
SVM_name = self._running
epoch = self._status[SVM_name].epoch

results = self._sessions_info.get(epoch, {})
results[SVM_name] = StatsWithTable(
average_result, convert_to_df(SVM_name, map2results_list)
)
self._sessions_info[epoch] = sort_dict(results)
self._status[SVM_name].epoch = epoch + 1
self._update_file()

def _get_SVMS_info(self) -> str:
svm_info_line = lambda svm_info: (
f"{svm_info[0]} : "
f"batch_size={svm_info[1].batch_size}, "
f"lr={svm_info[1].lr}, "
f"num_hops_1={svm_info[1].num_hops_1}, "
f"num_hops_2={svm_info[1].num_hops_2}, "
f"num_of_state_features={svm_info[1].num_of_state_features}, "
f"epochs={svm_info[1].epochs}\n"
)
def _get_training_info(self) -> str:
def svm_info_line(svm_info):
svm_name, training_params = svm_info[0], svm_info[1]
status: Status | None = self._status.get(svm_name, None)
if status:
svm_info_line = (
f"{svm_name} : "
f"{str(status)}, "
f"batch_size={training_params.batch_size}, "
f"lr={training_params.lr}, "
f"num_hops_1={training_params.num_hops_1}, "
f"num_hops_2={training_params.num_hops_2}, "
f"num_of_state_features={training_params.num_of_state_features}, "
f"epochs={training_params.epochs}\n"
)
else:
svm_info_line = ""
return svm_info_line

return "".join(list(map(svm_info_line, self._SVMS_info.items())))

Expand All @@ -112,9 +167,8 @@ def _get_epochs_results(self) -> str:
return epochs_results

def _update_file(self):
with self._lock:
SVMS_info = self._get_SVMS_info()
epochs_results = self._get_epochs_results()
SVMS_info = self._get_training_info()
epochs_results = self._get_epochs_results()
with open(self._file, "w") as f:
f.write(SVMS_info)
f.write(epochs_results)
Expand Down
10 changes: 2 additions & 8 deletions AIAgent/ml/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@ def play_game_task(task):


def validate_coverage(
svm_info: SVMInfo,
statistics_collector: StatisticsCollector,
model: torch.nn.Module,
epoch: int,
dataset: TrainingDataset,
progress_bar_colour: str = "#ed95ce",
server_count: int = 1,
):
"""
Evaluate model using symbolic execution engine. It runs in parallel.
Expand All @@ -44,16 +43,13 @@ def validate_coverage(
----------
model : torch.nn.Module
Model to evaluate
epoch : int
Epoch number to write result.
dataset : TrainingDataset
Dataset object for validation.
progress_bar_colour : str
Your favorite colour for progress bar.
"""
wrapper = TrainingModelWrapper(model)
tasks = [([game_map], dataset, wrapper) for game_map in dataset.maps]
server_count = svm_info.count

with ThreadPool(server_count) as p:
all_results = []
Expand Down Expand Up @@ -85,9 +81,7 @@ def validate_coverage(
)
)
)
statistics_collector.update_results(
epoch, svm_info.name, average_result, all_results
)
statistics_collector.update_results(average_result, all_results)
return average_result


Expand Down
25 changes: 13 additions & 12 deletions AIAgent/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,17 @@ def load_weights(model: torch.nn.Module):
sampler=sampler, direction=TrainingConfig.STUDY_DIRECTION
)
run_name = f"{logfile_base_name}_{svm_info.name}"
statistics_collector.register_new_training_session(svm_info.name)

objective_partial = partial(
objective,
svm_info=svm_info,
statistics_collector=statistics_collector,
dataset=dataset,
dynamic_dataset=TrainingConfig.DYNAMIC_DATASET,
model_init=model_init,
epochs=num_epochs,
run_name=run_name,
server_count=svm_info.count,
)
try:
study.optimize(
Expand All @@ -127,6 +128,7 @@ def load_weights(model: torch.nn.Module):
)
except RuntimeError: # TODO: Replace it with a self-created exception
logging.error(f"Fail to train with {svm_info.name}")
statistics_collector.fail()
return
joblib.dump(
study,
Expand All @@ -136,13 +138,13 @@ def load_weights(model: torch.nn.Module):

def objective(
trial: optuna.Trial,
svm_info: SVMInfo,
statistics_collector: StatisticsCollector,
dataset: TrainingDataset,
dynamic_dataset: bool,
model_init: Callable,
epochs: int,
run_name: str,
server_count: int = 1,
):
config = TrialSettings(
lr=0.0003, # trial.suggest_float("lr", 1e-7, 1e-3),
Expand All @@ -169,14 +171,13 @@ def objective(

optimizer = config.optimizer(model.parameters(), lr=config.lr)
criterion = config.loss()
statistics_collector.register_training_session(
svm_info.name,
config.batch_size,
config.lr,
config.num_hops_1,
config.num_hops_2,
config.num_of_state_features,
config.epochs,
statistics_collector.start_training_session(
batch_size=config.batch_size,
lr=config.lr,
num_hops_1=config.num_hops_1,
num_hops_2=config.num_hops_2,
num_of_state_features=config.num_of_state_features,
epochs=config.epochs,
)

run_name = (
Expand Down Expand Up @@ -205,14 +206,14 @@ def objective(
model.eval()
dataset.switch_to("val")
result = validate_coverage(
svm_info=svm_info,
statistics_collector=statistics_collector,
model=model,
epoch=epoch,
dataset=dataset,
server_count=server_count,
)
if dynamic_dataset:
dataset.update_meta_data()
statistics_collector.finish()
return result


Expand Down

0 comments on commit 29a9358

Please sign in to comment.