Skip to content

Commit

Permalink
typing: lib.gui.analysis.stats
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Oct 13, 2022
1 parent 8910ae5 commit 47867a0
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 52 deletions.
112 changes: 61 additions & 51 deletions lib/gui/analysis/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

from math import ceil
from threading import Event
from typing import List, Optional, Tuple, Union
from typing_extensions import Self
from typing import Any, cast, Dict, List, Optional, Tuple, Union

import numpy as np

Expand All @@ -33,12 +32,12 @@ class GlobalSession():
"""
def __init__(self) -> None:
logger.debug("Initializing %s", self.__class__.__name__)
self._state = None
self._state: Dict[str, Any] = {}
self._model_dir = ""
self._model_name = ""

self._tb_logs = None
self._summary = None
self._tb_logs: Optional[TensorBoardLogs] = None
self._summary: Optional[SessionsSummary] = None

self._is_training = False
self._is_querying = Event()
Expand All @@ -62,29 +61,31 @@ def model_filename(self) -> str:
return os.path.join(self._model_dir, self._model_name)

@property
def batch_sizes(self) -> dict:
def batch_sizes(self) -> Dict[int, int]:
""" dict: The batch sizes for each session_id for the model. """
if self._state is None:
if not self._state:
return {}
return {int(sess_id): sess["batchsize"]
for sess_id, sess in self._state.get("sessions", {}).items()}

@property
def full_summary(self) -> List[dict]:
""" list: List of dictionaries containing summary statistics for each session id. """
assert self._summary is not None
return self._summary.get_summary_stats()

@property
def logging_disabled(self) -> bool:
""" bool: ``True`` if logging is enabled for the currently training session otherwise
``False``. """
if self._state is None:
if not self._state:
return True
return self._state["sessions"][str(self.session_ids[-1])]["no_logs"]

@property
def session_ids(self) -> List[int]:
""" list: The sorted list of all existing session ids in the state file """
assert self._tb_logs is not None
return self._tb_logs.session_ids

def _load_state_file(self) -> None:
Expand All @@ -96,8 +97,8 @@ def _load_state_file(self) -> None:
logger.debug("Loaded state: %s", self._state)

def initialize_session(self,
model_folder: Optional[str],
model_name: Optional[str],
model_folder: str,
model_name: str,
is_training: bool = False) -> None:
""" Initialize a Session.
Expand All @@ -106,12 +107,14 @@ def initialize_session(self,
Parameters
----------
model_folder: str, optional
model_folder: str,
If loading a session manually (e.g. for the analysis tab), then the path to the model
folder must be provided. For training sessions, this should be left at ``None``
folder must be provided. For training sessions, this should be passed through from the
launcher
model_name: str, optional
If loading a session manually (e.g. for the analysis tab), then the model filename
must be provided. For training sessions, this should be left at ``None``
must be provided. For training sessions, this should be passed through from the
launcher
is_training: bool, optional
``True`` if the session is being initialized for a training session, otherwise
``False``. Default: ``False``
Expand All @@ -120,6 +123,7 @@ def initialize_session(self,

if self._model_dir == model_folder and self._model_name == model_name:
if is_training:
assert self._tb_logs is not None
self._tb_logs.set_training(is_training)
self._load_state_file()
self._is_training = True
Expand Down Expand Up @@ -157,7 +161,7 @@ def clear(self) -> None:

self._is_training = False

def get_loss(self, session_id: Optional[int]) -> dict:
def get_loss(self, session_id: Optional[int]) -> Dict[str, np.ndarray]:
""" Obtain the loss values for the given session_id.
Parameters
Expand All @@ -176,21 +180,24 @@ def get_loss(self, session_id: Optional[int]) -> dict:
if self._is_training:
self._is_querying.set()

assert self._tb_logs is not None
loss_dict = self._tb_logs.get_loss(session_id=session_id)
if session_id is None:
retval = {}
all_loss: Dict[str, List[float]] = {}
for key in sorted(loss_dict):
for loss_key, loss in loss_dict[key].items():
retval.setdefault(loss_key, []).extend(loss)
retval = {key: np.array(val, dtype="float32") for key, val in retval.items()}
all_loss.setdefault(loss_key, []).extend(loss)
retval: Dict[str, np.ndarray] = {key: np.array(val, dtype="float32")
for key, val in all_loss.items()}
else:
retval = loss_dict.get(session_id, {})

if self._is_training:
self._is_querying.clear()
return retval

def get_timestamps(self, session_id: Optional[int]) -> Union[dict, np.ndarray]:
def get_timestamps(self, session_id: Optional[int]) -> Union[Dict[int, np.ndarray],
np.ndarray]:
""" Obtain the time stamps keys for the given session_id.
Parameters
Expand All @@ -211,6 +218,7 @@ def get_timestamps(self, session_id: Optional[int]) -> Union[dict, np.ndarray]:
if self._is_training:
self._is_querying.set()

assert self._tb_logs is not None
retval = self._tb_logs.get_timestamps(session_id=session_id)
if session_id is not None:
retval = retval[session_id]
Expand Down Expand Up @@ -249,16 +257,17 @@ def get_loss_keys(self, session_id: Optional[int]) -> List[str]:
loss_keys = {int(sess_id): [name for name in session["loss_names"] if name != "total"]
for sess_id, session in self._state["sessions"].items()}
else:
assert self._tb_logs is not None
loss_keys = {sess_id: list(logs.keys())
for sess_id, logs
in self._tb_logs.get_loss(session_id=session_id).items()}

if session_id is None:
retval = list(set(loss_key
for session in loss_keys.values()
for loss_key in session))
retval: List[str] = list(set(loss_key
for session in loss_keys.values()
for loss_key in session))
else:
retval = loss_keys.get(session_id)
retval = loss_keys.get(session_id, [])
return retval


Expand All @@ -279,8 +288,8 @@ def __init__(self, session: GlobalSession) -> None:
self._session = session
self._state = session._state

self._time_stats = None
self._per_session_stats = None
self._time_stats: Dict[int, Dict[str, Union[float, int]]] = {}
self._per_session_stats: List[Dict[str, Any]] = []
logger.debug("Initialized %s", self.__class__.__name__)

def get_summary_stats(self) -> List[dict]:
Expand Down Expand Up @@ -315,20 +324,21 @@ def _get_time_stats(self) -> None:
If the main Session is currently training, then the training session ID is updated with the
latest stats.
"""
if self._time_stats is None:
if not self._time_stats:
logger.debug("Collating summary time stamps")

self._time_stats = {
sess_id: dict(start_time=np.min(timestamps) if np.any(timestamps) else 0,
end_time=np.max(timestamps) if np.any(timestamps) else 0,
iterations=timestamps.shape[0] if np.any(timestamps) else 0)
for sess_id, timestamps in self._session.get_timestamps(None).items()}
for sess_id, timestamps in cast(Dict[int, np.ndarray],
self._session.get_timestamps(None)).items()}

elif _SESSION.is_training:
logger.debug("Updating summary time stamps for training session")

session_id = _SESSION.session_ids[-1]
latest = self._session.get_timestamps(session_id)
latest = cast(np.ndarray, self._session.get_timestamps(session_id))

self._time_stats[session_id] = dict(
start_time=np.min(latest) if np.any(latest) else 0,
Expand All @@ -344,12 +354,12 @@ def _get_per_session_stats(self) -> None:
If a training session is running, then updates the training sessions stats only.
"""
if self._per_session_stats is None:
if not self._per_session_stats:
logger.debug("Collating per session stats")
compiled = []
for session_id in self._time_stats:
logger.debug("Compiling session ID: %s", session_id)
if self._state is None:
if not self._state:
logger.debug("Session state dict doesn't exist. Most likely task has been "
"terminated during compilation")
return
Expand Down Expand Up @@ -377,7 +387,7 @@ def _get_per_session_stats(self) -> None:
/ stats["elapsed"] if stats["elapsed"] > 0 else 0)
logger.debug("per_session_stats: %s", self._per_session_stats)

def _collate_stats(self, session_id: int) -> dict:
def _collate_stats(self, session_id: int) -> Dict[str, Union[int, float]]:
""" Collate the session summary statistics for the given session ID.
Parameters
Expand Down Expand Up @@ -406,14 +416,14 @@ def _collate_stats(self, session_id: int) -> dict:
logger.debug(retval)
return retval

def _total_stats(self) -> dict:
def _total_stats(self) -> Dict[str, Union[str, int, float]]:
""" Compile the Totals stats.
Totals are fully calculated each time as they will change on the basis of the training
session.
Returns
-------
dict:
dict
The Session name, start time, end time, elapsed time, rate, batch size and number of
iterations for all session ids within the loaded data.
"""
Expand Down Expand Up @@ -486,8 +496,8 @@ def _convert_time(cls, timestamp: float) -> Tuple[str, str, str]:
tuple
(`hours`, `minutes`, `seconds`) as strings
"""
hrs = int(timestamp // 3600)
hrs = f"{hrs:02d}" if hrs < 10 else str(hrs)
ihrs = int(timestamp // 3600)
hrs = f"{ihrs:02d}" if ihrs < 10 else str(ihrs)
mins = f"{(int(timestamp % 3600) // 60):02d}"
secs = f"{(int(timestamp % 3600) % 60):02d}"
return hrs, mins, secs
Expand Down Expand Up @@ -536,13 +546,13 @@ def __init__(self, session_id,
self._loss_keys = loss_keys if isinstance(loss_keys, list) else [loss_keys]
self._selections = selections if isinstance(selections, list) else [selections]
self._is_totals = session_id is None
self._args = dict(avg_samples=avg_samples,
smooth_amount=smooth_amount,
flatten_outliers=flatten_outliers)
self._args: Dict[str, Union[int, float]] = dict(avg_samples=avg_samples,
smooth_amount=smooth_amount,
flatten_outliers=flatten_outliers)
self._iterations = 0
self._limit = 0
self._start_iteration = 0
self._stats = {}
self._stats: Dict[str, np.ndarray] = {}
self.refresh()
logger.debug("Initialized %s", self.__class__.__name__)

Expand All @@ -557,11 +567,11 @@ def start_iteration(self) -> int:
return self._start_iteration

@property
def stats(self) -> dict:
def stats(self) -> Dict[str, np.ndarray]:
""" dict: The final calculated statistics """
return self._stats

def refresh(self) -> Optional[Self]:
def refresh(self) -> Optional["Calculations"]:
""" Refresh the stats """
logger.debug("Refreshing")
if not _SESSION.is_loaded:
Expand Down Expand Up @@ -658,11 +668,11 @@ def _get_raw(self) -> None:
if len(iterations) > 1:
# Crop all losses to the same number of items
if self._iterations == 0:
self.stats = {lossname: np.array([], dtype=loss.dtype)
for lossname, loss in self.stats.items()}
self._stats = {lossname: np.array([], dtype=loss.dtype)
for lossname, loss in self.stats.items()}
else:
self.stats = {lossname: loss[:self._iterations]
for lossname, loss in self.stats.items()}
self._stats = {lossname: loss[:self._iterations]
for lossname, loss in self.stats.items()}

else: # Rate calculation
data = self._calc_rate_total() if self._is_totals else self._calc_rate()
Expand Down Expand Up @@ -719,8 +729,8 @@ def _calc_rate(self) -> np.ndarray:
The training rate for each iteration of the selected session
"""
logger.debug("Calculating rate")
retval = (_SESSION.batch_sizes[self._session_id] * 2) / np.diff(_SESSION.get_timestamps(
self._session_id))
batch_size = _SESSION.batch_sizes[self._session_id] * 2
retval = batch_size / np.diff(cast(np.ndarray, _SESSION.get_timestamps(self._session_id)))
logger.debug("Calculated rate: Item_count: %s", len(retval))
return retval

Expand All @@ -740,8 +750,8 @@ def _calc_rate_total(cls) -> np.ndarray:
"""
logger.debug("Calculating totals rate")
batchsizes = _SESSION.batch_sizes
total_timestamps = _SESSION.get_timestamps(None)
rate = []
total_timestamps = cast(Dict[int, np.ndarray], _SESSION.get_timestamps(None))
rate: List[float] = []
for sess_id in sorted(total_timestamps.keys()):
batchsize = batchsizes[sess_id]
timestamps = total_timestamps[sess_id]
Expand Down Expand Up @@ -781,7 +791,7 @@ def _calc_avg(self, data: np.ndarray) -> np.ndarray:
The moving average for the given data
"""
logger.debug("Calculating Average. Data points: %s", len(data))
window = self._args["avg_samples"]
window = cast(int, self._args["avg_samples"])
pad = ceil(window / 2)
datapoints = data.shape[0]

Expand Down Expand Up @@ -968,8 +978,8 @@ def _ewma_vectorized(self,
out /= scaling_factors[-2::-1] # cumulative sums / scaling

if offset != 0:
offset = np.array(offset, copy=False).astype(self._dtype, copy=False)
out += offset * scaling_factors[1:]
noffset = np.array(offset, copy=False).astype(self._dtype, copy=False)
out += noffset * scaling_factors[1:]

def _ewma_vectorized_2d(self, data: np.ndarray, out: np.ndarray) -> None:
""" Calculates the exponential moving average over the last axis.
Expand Down
3 changes: 2 additions & 1 deletion lib/gui/control_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tkinter import colorchooser, ttk
from itertools import zip_longest
from functools import partial
from typing import Any, Dict

from _tkinter import Tcl_Obj, TclError

Expand All @@ -23,7 +24,7 @@
# We store Tooltips, ContextMenus and Commands globally when they are created
# Because we need to add them back to newly cloned widgets (they are not easily accessible from
# original config or are prone to getting destroyed when the original widget is destroyed)
_RECREATE_OBJECTS = dict(tooltips={}, commands={}, contextmenus={})
_RECREATE_OBJECTS: Dict[str, Dict[str, Any]] = dict(tooltips={}, commands={}, contextmenus={})


def _get_tooltip(widget, text=None, text_variable=None):
Expand Down

0 comments on commit 47867a0

Please sign in to comment.