Skip to content

Commit

Permalink
Softly deprecate the get_str=False flag.
Browse files Browse the repository at this point in the history
Summary: We don't want to use print directly in stats.print() method. Instead this method will return the output string to the caller.

Reviewed By: shapovalov

Differential Revision: D45356240

fbshipit-source-id: 2cabe3cdfb9206bf09aa7b3cdd2263148a5ba145
  • Loading branch information
virendra-pathak authored and facebook-github-bot committed May 14, 2023
1 parent 297020a commit d08fe6d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 34 deletions.
4 changes: 2 additions & 2 deletions projects/implicitron_trainer/impl/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def load_stats(
list(log_vars),
plot_file=os.path.join(exp_dir, "train_stats.pdf"),
visdom_env=visdom_env_charts,
verbose=False,
visdom_server=self.visdom_server,
visdom_port=self.visdom_port,
)
Expand Down Expand Up @@ -382,7 +381,8 @@ def _training_or_validation_epoch(

# print textual status update
if it % self.metric_print_interval == 0 or last_iter:
stats.print(stat_set=trainmode, max_it=n_batches)
std_out = stats.get_status_string(stat_set=trainmode, max_it=n_batches)
logger.info(std_out)

# visualize results
if (
Expand Down
88 changes: 56 additions & 32 deletions pytorch3d/implicitron/tools/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import gzip
import json
import logging
import time
import warnings
from collections.abc import Iterable
Expand All @@ -17,6 +18,8 @@
from matplotlib import colors as mcolors
from pytorch3d.implicitron.tools.vis_utils import get_visdom_connection

logger = logging.getLogger(__name__)


class AverageMeter(object):
"""Computes and stores the average and current value"""
Expand Down Expand Up @@ -91,7 +94,9 @@ class Stats(object):
# stats.update() automatically parses the 'objective' and 'top1e' from
# the "output" dict and stores this into the db
stats.update(output)
stats.print() # prints the averages over given epoch
# prints the metric averages over given epoch
std_out = stats.get_status_string()
logger.info(str_out)
# stores the training plots into '/tmp/epoch_stats.pdf'
# and plots into a visdom server running at localhost (if running)
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
Expand All @@ -101,7 +106,6 @@ class Stats(object):
def __init__(
self,
log_vars,
verbose=False,
epoch=-1,
visdom_env="main",
do_plot=True,
Expand All @@ -110,7 +114,6 @@ def __init__(
visdom_port=8097,
):

self.verbose = verbose
self.log_vars = log_vars
self.visdom_env = visdom_env
self.visdom_server = visdom_server
Expand Down Expand Up @@ -156,32 +159,29 @@ def __exit__(self, type, value, traceback):
iserr = type is not None and issubclass(type, Exception)
iserr = iserr or (type is KeyboardInterrupt)
if iserr:
print("error inside 'with' block")
logger.error("error inside 'with' block")
return
if self.do_plot:
self.plot_stats(self.visdom_env)

def reset(self): # to be called after each epoch
stat_sets = list(self.stats.keys())
if self.verbose:
print("stats: epoch %d - reset" % self.epoch)
logger.debug(f"stats: epoch {self.epoch} - reset")
self.it = {k: -1 for k in stat_sets}
for stat_set in stat_sets:
for stat in self.stats[stat_set]:
self.stats[stat_set][stat].reset()

def hard_reset(self, epoch=-1): # to be called during object __init__
self.epoch = epoch
if self.verbose:
print("stats: epoch %d - hard reset" % self.epoch)
logger.debug(f"stats: epoch {self.epoch} - hard reset")
self.stats = {}

# reset
self.reset()

def new_epoch(self):
if self.verbose:
print("stats: new epoch %d" % (self.epoch + 1))
logger.debug(f"stats: new epoch {(self.epoch + 1)}")
self.epoch += 1
self.reset() # zero the stats + increase epoch counter

Expand All @@ -193,18 +193,17 @@ def gather_value(self, val):
val = float(val.sum())
return val

def add_log_vars(self, added_log_vars, verbose=True):
def add_log_vars(self, added_log_vars):
for add_log_var in added_log_vars:
if add_log_var not in self.stats:
if verbose:
print(f"Adding {add_log_var}")
logger.debug(f"Adding {add_log_var}")
self.log_vars.append(add_log_var)

def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):

if self.epoch == -1: # uninitialized
print(
"warning: epoch==-1 means uninitialized stats structure -> new_epoch() called"
logger.warning(
"epoch==-1 means uninitialized stats structure -> new_epoch() called"
)
self.new_epoch()

Expand Down Expand Up @@ -284,6 +283,12 @@ def print(
skip_nan=False,
stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
):
"""
stats.print() is deprecated. Please use get_status_string() instead.
example:
std_out = stats.get_status_string()
logger.info(str_out)
"""

epoch = self.epoch
stats = self.stats
Expand Down Expand Up @@ -311,8 +316,30 @@ def print(
if get_str:
return str_out
else:
warnings.warn(
"get_str=False is deprecated."
"Please enable this flag to get receive the output string.",
DeprecationWarning,
)
print(str_out)

def get_status_string(
self,
max_it=None,
stat_set="train",
vars_print=None,
skip_nan=False,
stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
):
return self.print(
max_it=max_it,
stat_set=stat_set,
vars_print=vars_print,
get_str=True,
skip_nan=skip_nan,
stat_format=stat_format,
)

def plot_stats(
self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None
):
Expand All @@ -329,16 +356,15 @@ def plot_stats(

stat_sets = list(self.stats.keys())

print(
"printing charts to visdom env '%s' (%s:%d)"
% (visdom_env, visdom_server, visdom_port)
logger.debug(
f"printing charts to visdom env '{visdom_env}' ({visdom_server}:{visdom_port})"
)

novisdom = False

viz = get_visdom_connection(server=visdom_server, port=visdom_port)
if viz is None or not viz.check_connection():
print("no visdom server! -> skipping visdom plots")
logger.info("no visdom server! -> skipping visdom plots")
novisdom = True

lines = []
Expand Down Expand Up @@ -385,7 +411,7 @@ def plot_stats(
)

if plot_file:
print("exporting stats to %s" % plot_file)
logger.info(f"plotting stats to {plot_file}")
ncol = 3
nrow = int(np.ceil(float(len(lines)) / ncol))
matplotlib.rcParams.update({"font.size": 5})
Expand Down Expand Up @@ -423,15 +449,15 @@ def plot_stats(
except PermissionError:
warnings.warn("Cant dump stats due to insufficient permissions!")

def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=True):
def synchronize_logged_vars(self, log_vars, default_val=float("NaN")):

stat_sets = list(self.stats.keys())

# remove the additional log_vars
for stat_set in stat_sets:
for stat in self.stats[stat_set].keys():
if stat not in log_vars:
print("additional stat %s:%s -> removing" % (stat_set, stat))
logger.warning(f"additional stat {stat_set}:{stat} -> removing")

self.stats[stat_set] = {
stat: v for stat, v in self.stats[stat_set].items() if stat in log_vars
Expand All @@ -442,21 +468,19 @@ def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=Tr
for stat_set in stat_sets:
for stat in log_vars:
if stat not in self.stats[stat_set]:
if verbose:
print(
"missing stat %s:%s -> filling with default values (%1.2f)"
% (stat_set, stat, default_val)
)
logger.info(
"missing stat %s:%s -> filling with default values (%1.2f)"
% (stat_set, stat, default_val)
)
elif len(self.stats[stat_set][stat].history) != self.epoch + 1:
h = self.stats[stat_set][stat].history
if len(h) == 0: # just never updated stat ... skip
continue
else:
if verbose:
print(
"incomplete stat %s:%s -> reseting with default values (%1.2f)"
% (stat_set, stat, default_val)
)
logger.info(
"incomplete stat %s:%s -> reseting with default values (%1.2f)"
% (stat_set, stat, default_val)
)
else:
continue

Expand Down

0 comments on commit d08fe6d

Please sign in to comment.