Skip to content

Commit

Permalink
fix(helper): add real progressbar for training (#136)
Browse files Browse the repository at this point in the history
* fix(helper): add real progressbar for training

* fix(helper): add real progressbar for training

* fix(helper): add real progressbar for training
  • Loading branch information
hanxiao committed Oct 16, 2021
1 parent 6b8eca8 commit 5a25a72
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 10 deletions.
6 changes: 2 additions & 4 deletions finetuner/tuner/base.py
@@ -1,4 +1,5 @@
import abc
import warnings
from typing import (
Optional,
Union,
Expand All @@ -7,8 +8,6 @@
Dict,
)

from jina.logging.logger import JinaLogger

from ..helper import AnyDNN, AnyDataLoader, AnyOptimizer, DocumentArrayLike


Expand Down Expand Up @@ -46,7 +45,6 @@ def __init__(
):
self._embed_model = embed_model
self._head_layer = head_layer
self.logger = JinaLogger(self.__class__.__name__)

def _get_optimizer_kwargs(self, optimizer: str, custom_kwargs: Optional[Dict]):
"""Merges user-provided optimizer kwargs with default ones."""
Expand Down Expand Up @@ -74,7 +72,7 @@ def _get_optimizer_kwargs(self, optimizer: str, custom_kwargs: Optional[Dict]):
custom_kwargs = custom_kwargs or {}
extra_args = set(custom_kwargs.keys()) - set(opt_kwargs.keys())
if extra_args:
self.logger.warning(
warnings.warn(
f'The following arguments are not valid for the optimizer {optimizer}:'
f' {extra_args}'
)
Expand Down
9 changes: 7 additions & 2 deletions finetuner/tuner/keras/__init__.py
Expand Up @@ -83,10 +83,14 @@ def _train(self, data, optimizer, description: str):

log_generator = LogGenerator('T', losses, metrics)

train_data_len = 0
with ProgressBar(
description, message_on_done=log_generator, final_line_feed=False
description,
message_on_done=log_generator,
final_line_feed=False,
total_length=train_data_len,
) as p:

train_data_len = 0
for inputs, label in data:
with tf.GradientTape() as tape:
outputs = self.wrapped_model(inputs, training=True)
Expand All @@ -102,6 +106,7 @@ def _train(self, data, optimizer, description: str):
metrics.append(metric.numpy())

p.update(message=log_generator())
train_data_len += 1

return losses, metrics

Expand Down
9 changes: 7 additions & 2 deletions finetuner/tuner/paddle/__init__.py
Expand Up @@ -99,10 +99,14 @@ def _train(self, data, optimizer: Optimizer, description: str):
metrics = []

log_generator = LogGenerator('T', losses, metrics)

train_data_len = 0
with ProgressBar(
description, message_on_done=log_generator, final_line_feed=False
description,
message_on_done=log_generator,
final_line_feed=False,
total_length=train_data_len,
) as p:
train_data_len = 0
for inputs, label in data:
# forward step
outputs = self.wrapped_model(*inputs)
Expand All @@ -118,6 +122,7 @@ def _train(self, data, optimizer: Optimizer, description: str):
metrics.append(metric.numpy())

p.update(message=log_generator())
train_data_len += 1
return losses, metrics

def fit(
Expand Down
9 changes: 7 additions & 2 deletions finetuner/tuner/pytorch/__init__.py
Expand Up @@ -106,10 +106,14 @@ def _train(self, data, optimizer: Optimizer, description: str):
metrics = []

log_generator = LogGenerator('T', losses, metrics)

train_data_len = 0
with ProgressBar(
description, message_on_done=log_generator, final_line_feed=False
description,
message_on_done=log_generator,
final_line_feed=False,
total_length=train_data_len,
) as p:
train_data_len = 0
for inputs, label in data:
# forward step
inputs = [inpt.to(self.device) for inpt in inputs]
Expand All @@ -128,6 +132,7 @@ def _train(self, data, optimizer: Optimizer, description: str):
metrics.append(metric.cpu().numpy())

p.update(message=log_generator())
train_data_len += 1
return losses, metrics

def fit(
Expand Down

0 comments on commit 5a25a72

Please sign in to comment.