Skip to content

Commit

Permalink
Bugfix: Correct loss labels when graphing
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Mar 20, 2024
1 parent 1d3c59c commit 9ddc838
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 42 deletions.
15 changes: 8 additions & 7 deletions lib/gui/analysis/event_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations
import logging
import os
import re
import typing as T
import zlib

Expand All @@ -14,6 +15,7 @@
from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
errors_impl as tf_errors)

from lib.logger import parse_class_init
from lib.serializer import get_serializer

if T.TYPE_CHECKING:
Expand Down Expand Up @@ -46,7 +48,7 @@ class _LogFiles():
The folder that contains the Tensorboard log files
"""
def __init__(self, logs_folder: str) -> None:
logger.debug("Initializing: %s: (logs_folder: '%s')", self.__class__.__name__, logs_folder)
logger.debug(parse_class_init(locals()))
self._logs_folder = logs_folder
self._filenames = self._get_log_filenames()
logger.debug("Initialized: %s", self.__class__.__name__)
Expand Down Expand Up @@ -215,7 +217,7 @@ def add_live_data(self, timestamps: np.ndarray, loss: np.ndarray) -> None:
class _Cache():
""" Holds parsed Tensorflow log event data in a compressed cache in memory. """
def __init__(self) -> None:
logger.debug("Initializing: %s", self.__class__.__name__)
logger.debug(parse_class_init(locals()))
self._data: dict[int, _CacheData] = {}
self._carry_over: dict[int, EventData] = {}
self._loss_labels: list[str] = []
Expand Down Expand Up @@ -471,8 +473,7 @@ class TensorBoardLogs():
``True`` if the events are being read whilst Faceswap is training otherwise ``False``
"""
def __init__(self, logs_folder: str, is_training: bool) -> None:
logger.debug("Initializing: %s: (logs_folder: %s, is_training: %s)",
self.__class__.__name__, logs_folder, is_training)
logger.debug(parse_class_init(locals()))
self._is_training = False
self._training_iterator = None

Expand Down Expand Up @@ -631,12 +632,12 @@ class _EventParser(): # pylint:disable=too-few-public-methods
otherwise ``False``
"""
def __init__(self, iterator: Iterator[bytes], cache: _Cache, live_data: bool) -> None:
logger.debug("Initializing: %s: (iterator: %s, cache: %s, live_data: %s)",
self.__class__.__name__, iterator, cache, live_data)
logger.debug(parse_class_init(locals()))
self._live_data = live_data
self._cache = cache
self._iterator = self._get_latest_live(iterator) if live_data else iterator
self._loss_labels: list[str] = []
self._num_strip = re.compile(r"_\d+$")
logger.debug("Initialized: %s", self.__class__.__name__)

@classmethod
Expand Down Expand Up @@ -728,7 +729,7 @@ def _parse_outputs(self, event: event_pb2.Event) -> None:
if layer["name"] == layer_name)["config"]
layer_outputs = self._get_outputs(output_config)
for output in layer_outputs: # Drill into sub-model to get the actual output names
loss_name = output[0][0]
loss_name = self._num_strip.sub("", output[0][0]) # strip trailing numbers
if loss_name[-2:] not in ("_a", "_b"): # Rename losses to reflect the side output
new_name = f"{loss_name.replace('_both', '')}_{side}"
logger.debug("Renaming loss output from '%s' to '%s'", loss_name, new_name)
Expand Down
13 changes: 6 additions & 7 deletions lib/gui/analysis/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np

from lib.logger import parse_class_init
from lib.serializer import get_serializer

from .event_reader import TensorBoardLogs
Expand All @@ -30,7 +31,7 @@ class GlobalSession():
:attr:`lib.gui.analysis.Session`
"""
def __init__(self) -> None:
logger.debug("Initializing %s", self.__class__.__name__)
logger.debug(parse_class_init(locals()))
self._state: dict[str, T.Any] = {}
self._model_dir = ""
self._model_name = ""
Expand Down Expand Up @@ -289,7 +290,7 @@ class SessionsSummary(): # pylint:disable=too-few-public-methods
The loaded or currently training session
"""
def __init__(self, session: GlobalSession) -> None:
logger.debug("Initializing %s: (session: %s)", self.__class__.__name__, session)
logger.debug(parse_class_init(locals()))
self._session = session
self._state = session._state

Expand Down Expand Up @@ -539,11 +540,7 @@ def __init__(self, session_id,
avg_samples: int = 500,
smooth_amount: float = 0.90,
flatten_outliers: bool = False) -> None:
logger.debug("Initializing %s: (session_id: %s, display: %s, loss_keys: %s, "
"selections: %s, avg_samples: %s, smooth_amount: %s, flatten_outliers: %s)",
self.__class__.__name__, session_id, display, loss_keys, selections,
avg_samples, smooth_amount, flatten_outliers)

logger.debug(parse_class_init(locals()))
warnings.simplefilter("ignore", np.RankWarning)

self._session_id = session_id
Expand Down Expand Up @@ -872,6 +869,7 @@ class _ExponentialMovingAverage(): # pylint:disable=too-few-public-methods
Adapted from: https://stackoverflow.com/questions/42869495
"""
def __init__(self, data: np.ndarray, amount: float) -> None:
logger.debug(parse_class_init(locals()))
assert data.ndim == 1
amount = min(max(amount, 0.001), 0.999)

Expand All @@ -880,6 +878,7 @@ def __init__(self, data: np.ndarray, amount: float) -> None:
self._dtype = "float32" if data.dtype == np.float32 else "float64"
self._row_size = self._get_max_row_size()
self._out = np.empty_like(data, dtype=self._dtype)
logger.debug("Initialized %s", self.__class__.__name__)

def __call__(self) -> np.ndarray:
""" Perform the exponential moving average calculation.
Expand Down
4 changes: 3 additions & 1 deletion lib/gui/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import tkinter as tk
from tkinter import ttk

from lib.logger import parse_class_init

from .display_analysis import Analysis
from .display_command import GraphDisplay, PreviewExtract, PreviewTrain
from .utils import get_config
Expand All @@ -31,7 +33,7 @@ class DisplayNotebook(ttk.Notebook): # pylint: disable=too-many-ancestors
"""

def __init__(self, parent):
logger.debug("Initializing %s", self.__class__.__name__)
logger.debug(parse_class_init(locals()))
super().__init__(parent)
parent.add(self)
tk_vars = get_config().tk_vars
Expand Down
18 changes: 9 additions & 9 deletions lib/gui/display_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import tkinter as tk
from tkinter import ttk

from lib.logger import parse_class_init

from .custom_widgets import Tooltip
from .display_page import DisplayPage
from .popup_session import SessionPopUp
Expand Down Expand Up @@ -36,8 +38,7 @@ class Analysis(DisplayPage): # pylint: disable=too-many-ancestors
The help text to display for the summary statistics page
"""
def __init__(self, parent, tab_name, helptext):
logger.debug("Initializing: %s: (parent, %s, tab_name: '%s', helptext: '%s')",
self.__class__.__name__, parent, tab_name, helptext)
logger.debug(parse_class_init(locals()))
super().__init__(parent, tab_name, helptext)
self._summary = None

Expand All @@ -62,10 +63,10 @@ def set_vars(self):
dict
The dictionary of variable names to tkinter variables
"""
return dict(selected_id=tk.StringVar(),
refresh_graph=get_config().tk_vars.refresh_graph,
is_training=get_config().tk_vars.is_training,
analysis_folder=get_config().tk_vars.analysis_folder)
return {"selected_id": tk.StringVar(),
"refresh_graph": get_config().tk_vars.refresh_graph,
"is_training": get_config().tk_vars.is_training,
"analysis_folder": get_config().tk_vars.analysis_folder}

def on_tab_select(self):
""" Callback for when the analysis tab is selected.
Expand Down Expand Up @@ -299,7 +300,7 @@ class _Options(): # pylint:disable=too-few-public-methods
The Analysis Display Tab that holds the options buttons
"""
def __init__(self, parent):
logger.debug("Initializing: %s (parent: %s)", self.__class__.__name__, parent)
logger.debug(parse_class_init(locals()))
self._parent = parent
self._buttons = self._add_buttons()
self._add_training_callback()
Expand Down Expand Up @@ -380,8 +381,7 @@ class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors
The help text to display for the summary statistics page
"""
def __init__(self, parent, selected_id, helptext):
logger.debug("Initializing: %s: (parent, %s, selected_id: %s, helptext: '%s')",
self.__class__.__name__, parent, selected_id, helptext)
logger.debug(parse_class_init(locals()))
super().__init__(parent)
self._selected_id = selected_id

Expand Down
12 changes: 7 additions & 5 deletions lib/gui/display_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from tkinter import ttk

from lib.logger import parse_class_init
from lib.training.preview_tk import PreviewTk

from .display_graph import TrainingGraph
Expand All @@ -28,8 +29,7 @@
class PreviewExtract(DisplayOptionalPage): # pylint: disable=too-many-ancestors
""" Tab to display output preview images for extract and convert """
def __init__(self, *args, **kwargs) -> None:
logger.debug("Initializing %s (args: %s, kwargs: %s)",
self.__class__.__name__, args, kwargs)
logger.debug(parse_class_init(locals()))
self._preview = get_images().preview_extract
super().__init__(*args, **kwargs)
logger.debug("Initialized %s", self.__class__.__name__)
Expand Down Expand Up @@ -83,8 +83,7 @@ def save_items(self) -> None:
class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors
""" Training preview image(s) """
def __init__(self, *args, **kwargs) -> None:
logger.debug("Initializing %s (args: %s, kwargs: %s)",
self.__class__.__name__, args, kwargs)
logger.debug(parse_class_init(locals()))
self._preview = get_images().preview_train
self._display: PreviewTk | None = None
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -172,9 +171,11 @@ def __init__(self,
helptext: str,
wait_time: int,
command: str | None = None) -> None:
logger.debug(parse_class_init(locals()))
self._trace_vars: dict[T.Literal["smoothgraph", "display_iterations"],
tuple[tk.BooleanVar, str]] = {}
super().__init__(parent, tab_name, helptext, wait_time, command)
logger.debug("Initialized %s", self.__class__.__name__)

def set_vars(self) -> None:
""" Add graphing specific variables to the default variables.
Expand Down Expand Up @@ -212,7 +213,8 @@ def on_tab_select(self) -> None:
Pull latest data and run the tab's update code when the tab is selected.
"""
logger.debug("Callback received for '%s' tab", self.tabname)
logger.debug("Callback received for '%s' tab (display_item: %s)",
self.tabname, self.display_item)
if self.display_item is not None:
get_config().tk_vars.refresh_graph.set(True)
self._update_page()
Expand Down
11 changes: 8 additions & 3 deletions lib/gui/display_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
NavigationToolbar2Tk)
from matplotlib.backend_bases import NavigationToolbar2

from lib.logger import parse_class_init

from .custom_widgets import Tooltip
from .utils import get_config, get_images, LongRunningTask

Expand All @@ -40,7 +42,6 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
The data label for the y-axis
"""
def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None:
logger.debug("Initializing %s", self.__class__.__name__)
super().__init__(parent)
matplotlib.use("TkAgg") # Can't be at module level as breaks Github CI
style.use("ggplot")
Expand All @@ -58,7 +59,6 @@ def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None:

self._initiate_graph()
self._update_plot(initiate=True)
logger.debug("Initialized %s", self.__class__.__name__)

@property
def calcs(self):
Expand Down Expand Up @@ -335,10 +335,12 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors
"""

def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None:
logger.debug(parse_class_init(locals()))
super().__init__(parent, data, ylabel)
self._thread: LongRunningTask | None = None # Thread for LongRunningTask
self._displayed_keys: list[str] = []
self._add_callback()
logger.debug("Initialized %s", self.__class__.__name__)

def _add_callback(self) -> None:
""" Add the variable trace to update graph on refresh button press or save iteration. """
Expand Down Expand Up @@ -427,8 +429,10 @@ class SessionGraph(GraphBase): # pylint: disable=too-many-ancestors
Should be one of ``"log"`` or ``"linear"``
"""
def __init__(self, parent: ttk.Frame, data, ylabel: str, scale: str) -> None:
logger.debug(parse_class_init(locals()))
super().__init__(parent, data, ylabel)
self._scale = scale
logger.debug("Initialized %s", self.__class__.__name__)

def build(self) -> None:
""" Build the session graph """
Expand Down Expand Up @@ -494,7 +498,7 @@ def __init__(self, # pylint: disable=super-init-not-called
window: ttk.Frame,
*,
pack_toolbar: bool = True) -> None:

logger.debug(parse_class_init(locals()))
# Avoid using self.window (prefer self.canvas.get_tk_widget().master),
# so that Tool implementations can reuse the methods.

Expand Down Expand Up @@ -528,6 +532,7 @@ def __init__(self, # pylint: disable=super-init-not-called
NavigationToolbar2.__init__(self, canvas) # pylint:disable=non-parent-init-called
if pack_toolbar:
self.pack(side=tk.BOTTOM, fill=tk.X)
logger.debug("Initialized %s", self.__class__.__name__)

@staticmethod
def _Button(frame: ttk.Frame, # pylint:disable=arguments-differ,arguments-renamed
Expand Down
10 changes: 2 additions & 8 deletions lib/gui/display_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ class DisplayPage(ttk.Frame): # pylint: disable=too-many-ancestors
""" Parent frame holder for each tab.
Defines uniform structure for each tab to inherit from """
def __init__(self, parent, tab_name, helptext):
logger.debug("Initializing %s: (tab_name: '%s', helptext: %s)",
self.__class__.__name__, tab_name, helptext)
ttk.Frame.__init__(self, parent)
super().__init__(parent)

self._parent = parent
self.running_task = parent.running_task
Expand All @@ -42,8 +40,6 @@ def __init__(self, parent, tab_name, helptext):
self.pack(fill=tk.BOTH, side=tk.TOP, anchor=tk.NW)
parent.add(self, text=self.tabname.title())

logger.debug("Initialized %s", self.__class__.__name__,)

@property
def _tab_is_active(self):
""" bool: ``True`` if the tab currently has focus otherwise ``False`` """
Expand Down Expand Up @@ -167,9 +163,7 @@ class DisplayOptionalPage(DisplayPage): # pylint: disable=too-many-ancestors
""" Parent Context Sensitive Display Tab """

def __init__(self, parent, tab_name, helptext, wait_time, command=None):
logger.debug("%s: OptionalPage args: (wait_time: %s, command: %s)",
self.__class__.__name__, wait_time, command)
DisplayPage.__init__(self, parent, tab_name, helptext)
super().__init__(parent, tab_name, helptext)

self._waittime = wait_time
self.command = command
Expand Down
29 changes: 27 additions & 2 deletions lib/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from datetime import datetime

import numpy as np


# TODO - Remove this monkey patch when TF autograph fixed to handle newer logging lib
def _patched_format(self, record):
Expand Down Expand Up @@ -544,6 +546,28 @@ def crash_log() -> str:
return filename


def _process_value(value: T.Any) -> T.Any:
""" Process the values from a local dict and return in a loggable format
Parameters
----------
value: Any
The dictionary value
Returns
-------
Any
The original or ammended value
"""
if isinstance(value, str):
return f'"{value}"'
if isinstance(value, np.ndarray) and np.prod(value.shape) > 10:
return f'[type: "{type(value).__name__}" shape: {value.shape}, dtype: "{value.dtype}"]'
if isinstance(value, (list, tuple, set)) and len(value) > 10:
return f'[type: "{type(value).__name__}" len: {len(value)}'
return value


def parse_class_init(locals_dict: dict[str, T.Any]) -> str:
""" Parse a locals dict from a class and return in a format suitable for logging
Parameters
Expand All @@ -555,10 +579,11 @@ def parse_class_init(locals_dict: dict[str, T.Any]) -> str:
str
The locals information suitable for logging
"""
delimit = {k: f"'{v}'" if isinstance(v, str) else v
delimit = {k: _process_value(v)
for k, v in locals_dict.items() if k != "self"}
dsp = ", ".join(f"{k}: {v}" for k, v in delimit.items())
return f"Initializing {locals_dict['self'].__class__.__name__} ({dsp})"
dsp = f" ({dsp})" if dsp else ""
return f"Initializing {locals_dict['self'].__class__.__name__}{dsp}"


_OLD_FACTORY = logging.getLogRecordFactory()
Expand Down

0 comments on commit 9ddc838

Please sign in to comment.