Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gh 2120 imporved tensorboard logging #2164

Merged
39 changes: 35 additions & 4 deletions flair/trainers/trainer.py
Expand Up @@ -6,8 +6,8 @@
import datetime
import sys
import inspect
import os
import warnings
import os
import torch
from torch.optim.sgd import SGD
from torch.utils.data.dataset import ConcatDataset
Expand Down Expand Up @@ -46,6 +46,8 @@ def __init__(
optimizer: torch.optim.Optimizer = SGD,
epoch: int = 0,
use_tensorboard: bool = False,
tensorboard_log_dir = None,
metrics_for_tensorboard = []
):
"""
Initialize a model trainer
Expand All @@ -54,12 +56,16 @@ def __init__(
:param optimizer: The optimizer to use (typically SGD or Adam)
:param epoch: The starting epoch (normally 0 but could be higher if you continue training model)
:param use_tensorboard: If True, writes out tensorboard information
:param tensorboard_log_dir: Directory into which tensorboard log files will be written
:param metrics_for_tensorboard: List of tuples that specify which metrics (in addition to the main_score) shall be plotted in tensorboard, could be [("macro avg", 'f1-score'), ("macro avg", 'precision')] for example
"""
self.model: flair.nn.Model = model
self.corpus: Corpus = corpus
self.optimizer: torch.optim.Optimizer = optimizer
self.epoch: int = epoch
self.use_tensorboard: bool = use_tensorboard
self.tensorboard_log_dir = tensorboard_log_dir
self.metrics_for_tensorboard = metrics_for_tensorboard

def initialize_best_dev_score(self,log_dev):
"""
Expand Down Expand Up @@ -179,8 +185,9 @@ def train(
eval_on_train_fraction=0.0,
eval_on_train_shuffle=False,
save_model_each_k_epochs: int = 0,
save_best_checkpoints=False,
classification_main_metric=("micro avg", 'f1-score'),
tensorboard_comment='',
save_best_checkpoints=False,
**kwargs,
) -> dict:
"""
Expand Down Expand Up @@ -217,8 +224,9 @@ def train(
:param save_model_each_k_epochs: Each k epochs, a model state will be written out. If set to '5', a model will
be saved each 5 epochs. Default is 0 which means no model saving.
:param save_model_epoch_step: Each save_model_epoch_step'th epoch the thus far trained model will be saved
:param save_best_checkpoints: If True, in addition to saving the best model also the corresponding checkpoint is saved
:param classification_main_metric: Type of metric to use for best model tracking and learning rate scheduling (if dev data is available, otherwise loss will be used), currently only applicable for text_classification_model
:param tensorboard_comment: Comment to use for tensorboard logging
:param save_best_checkpoints: If True, in addition to saving the best model also the corresponding checkpoint is saved
:param kwargs: Other arguments for the Optimizer
:return:
"""
Expand All @@ -231,7 +239,12 @@ def train(
if self.use_tensorboard:
try:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

if self.tensorboard_log_dir is not None and not os.path.exists(self.tensorboard_log_dir):
os.mkdir(self.tensorboard_log_dir)
writer = SummaryWriter(log_dir=self.tensorboard_log_dir, comment=tensorboard_comment)
log.info(f"tensorboard logging path is {self.tensorboard_log_dir}")

except:
log_line(log)
log.warning(
Expand Down Expand Up @@ -432,6 +445,8 @@ def train(
)

previous_learning_rate = learning_rate
if self.use_tensorboard:
writer.add_scalar("learning_rate", learning_rate, self.epoch)

# stop training if learning rate becomes too small
if (not isinstance(lr_scheduler, OneCycleLR)) and learning_rate < min_learning_rate:
Expand Down Expand Up @@ -567,6 +582,12 @@ def train(
log.info(
f"TRAIN_SPLIT : loss {train_part_loss} - score {round(train_part_eval_result.main_score, 4)}"
)
if self.use_tensorboard:
for (metric_class_avg_type, metric_type) in self.metrics_for_tensorboard:
writer.add_scalar(
f"train_{metric_class_avg_type}_{metric_type}", train_part_eval_result.classification_report[metric_class_avg_type][metric_type], self.epoch
)


if log_dev:
dev_eval_result, dev_loss = self.model.evaluate(
Expand Down Expand Up @@ -596,6 +617,11 @@ def train(
writer.add_scalar(
"dev_score", dev_eval_result.main_score, self.epoch
)
for (metric_class_avg_type, metric_type) in self.metrics_for_tensorboard:
writer.add_scalar(
f"dev_{metric_class_avg_type}_{metric_type}",
dev_eval_result.classification_report[metric_class_avg_type][metric_type], self.epoch
)

if log_test:
test_eval_result, test_loss = self.model.evaluate(
Expand All @@ -619,6 +645,11 @@ def train(
writer.add_scalar(
"test_score", test_eval_result.main_score, self.epoch
)
for (metric_class_avg_type, metric_type) in self.metrics_for_tensorboard:
writer.add_scalar(
f"test_{metric_class_avg_type}_{metric_type}",
test_eval_result.classification_report[metric_class_avg_type][metric_type], self.epoch
)

# determine learning rate annealing through scheduler. Use auxiliary metric for AnnealOnPlateau
if log_dev and isinstance(lr_scheduler, AnnealOnPlateau):
Expand Down