Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow trainer to work with mutliple learning rates #2641

Merged
merged 7 commits into from
Feb 25, 2022
108 changes: 70 additions & 38 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
from inspect import signature
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Type, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast

import torch
from torch.optim.sgd import SGD
Expand Down Expand Up @@ -87,7 +87,7 @@ def train(
scheduler=AnnealOnPlateau,
anneal_factor: float = 0.5,
patience: int = 3,
min_learning_rate: float = 0.0001,
min_learning_rate: Union[float, List[float]] = 0.0001,
initial_extra_patience: int = 0,
optimizer: Union[torch.optim.Optimizer, Type[torch.optim.Optimizer]] = SGD,
cycle_momentum: bool = False,
Expand Down Expand Up @@ -137,7 +137,7 @@ def train(
:param anneal_factor: The factor by which the learning rate is annealed
:param patience: Patience is the number of epochs with no improvement the Trainer waits # noqa: E501
until annealing the learning rate
:param min_learning_rate: If the learning rate falls below this threshold, training terminates # noqa: E501
:param min_learning_rate: If the (in multi lr case: all) learning rate falls below this threshold, training terminates # noqa: E501
:param warmup_fraction: Fraction of warmup steps if the scheduler is LinearSchedulerWithWarmup # noqa: E501
:param train_with_dev: If True, the data from dev split is added to the training data # noqa: E501
:param train_with_test: If True, the data from test split is added to the training data # noqa: E501
Expand Down Expand Up @@ -226,10 +226,19 @@ def train(
eval_batch_size = mini_batch_size
if mini_batch_chunk_size is None:
mini_batch_chunk_size = mini_batch_size
if learning_rate < min_learning_rate:
min_learning_rate = learning_rate / 10

initial_learning_rate = learning_rate
if inspect.isclass(optimizer):
# if optimizer is class, trainer will create a single parameter group
initial_learning_rate = [learning_rate]
else:
initial_learning_rate = [group["lr"] for group in optimizer.param_groups]

if not isinstance(min_learning_rate, list):
min_learning_rate = [min_learning_rate] * len(initial_learning_rate)

for i, lr in enumerate(initial_learning_rate):
if lr < min_learning_rate[i]:
min_learning_rate[i] = lr / 10

base_path = Path(base_path)
base_path.mkdir(exist_ok=True, parents=True)
Expand Down Expand Up @@ -260,7 +269,7 @@ def train(
weight_extractor = WeightExtractor(base_path)

# if optimizer class is passed, instantiate:
if not isinstance(optimizer, torch.optim.Optimizer):
if inspect.isclass(optimizer):
kwargs["lr"] = learning_rate
optimizer = optimizer(self.model.parameters(), **kwargs)

Expand All @@ -269,6 +278,9 @@ def train(

optimizer = torchcontrib.optim.SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=learning_rate)

# from here on, use list of learning rates
current_learning_rate: List = [group["lr"] for group in optimizer.param_groups]

if use_amp:
self.model, optimizer = amp.initialize(self.model, optimizer, opt_level=amp_opt_level)

Expand All @@ -291,7 +303,7 @@ def train(
if scheduler == OneCycleLR:
scheduler = OneCycleLR(
optimizer,
max_lr=learning_rate,
max_lr=current_learning_rate,
steps_per_epoch=dataset_size // mini_batch_size + 1,
epochs=max_epochs - epoch,
# if we load a checkpoint, we have already trained for epoch
Expand Down Expand Up @@ -367,13 +379,15 @@ def train(
else:
log_handler = None

lr_info = ",".join([f"{lr:.6f}" for lr in current_learning_rate])

log_line(log)
log.info(f'Model: "{self.model}"')
log_line(log)
log.info(f'Corpus: "{self.corpus}"')
log_line(log)
log.info("Parameters:")
log.info(f' - learning_rate: "{learning_rate}"')
log.info(f' - learning_rate: "{lr_info}"')
log.info(f' - mini_batch_size: "{mini_batch_size}"')
log.info(f' - patience: "{patience}"')
log.info(f' - anneal_factor: "{anneal_factor}"')
Expand All @@ -388,11 +402,9 @@ def train(
log_line(log)
log.info(f"Embeddings storage mode: {embeddings_storage_mode}")

previous_learning_rate = learning_rate
momentum = 0
for group in optimizer.param_groups:
if "momentum" in group:
momentum = group["momentum"]
previous_learning_rate = current_learning_rate

momentum = [group["momentum"] if "momentum" in group else 0 for group in optimizer.param_groups]

for epoch in range(epoch + 1, max_epochs + 1):
log_line(log)
Expand All @@ -410,16 +422,17 @@ def train(
train_part = torch.utils.data.dataset.Subset(self.corpus.train, train_part_indices)

# get new learning rate
for group in optimizer.param_groups:
learning_rate = group["lr"]
current_learning_rate = [group["lr"] for group in optimizer.param_groups]

lr_changed = any([lr != prev_lr for lr, prev_lr in zip(current_learning_rate, previous_learning_rate)])

if learning_rate != previous_learning_rate and batch_growth_annealing:
if lr_changed and batch_growth_annealing:
mini_batch_size *= 2

# reload last best model if annealing with restarts is enabled
if (
(anneal_with_restarts or anneal_with_prestarts)
and learning_rate != previous_learning_rate
and lr_changed
and os.path.exists(base_path / "best-model.pt")
):
if anneal_with_restarts:
Expand All @@ -429,15 +442,18 @@ def train(
log.info("resetting to pre-best model")
self.model.load_state_dict(self.model.load(base_path / "pre-best-model.pt").state_dict())

previous_learning_rate = learning_rate
previous_learning_rate = current_learning_rate
if use_tensorboard:
writer.add_scalar("learning_rate", learning_rate, epoch)
if len(current_learning_rate) == 1:
writer.add_scalar("learning_rate", current_learning_rate[0], epoch)
else:
for i, lr in enumerate(current_learning_rate):
writer.add_scalar(f"learning_rate_{i}", lr, epoch)

all_lrs_too_small = all([lr < min_lr for lr, min_lr in zip(current_learning_rate, min_learning_rate)])

# stop training if learning rate becomes too small
if (
not isinstance(scheduler, (OneCycleLR, LinearSchedulerWithWarmup))
and learning_rate < min_learning_rate
):
if not isinstance(scheduler, (OneCycleLR, LinearSchedulerWithWarmup)) and all_lrs_too_small:
log_line(log)
log.info("learning rate too small - quitting training!")
log_line(log)
Expand Down Expand Up @@ -509,25 +525,31 @@ def train(
if isinstance(scheduler, (OneCycleLR, LinearSchedulerWithWarmup)):
scheduler.step()
# get new learning rate
for group in optimizer.param_groups:
learning_rate = group["lr"]
if "momentum" in group:
momentum = group["momentum"]
if "betas" in group:
momentum, _ = group["betas"]
current_learning_rate = [group["lr"] for group in optimizer.param_groups]

momentum = [
group["betas"][0] if "betas" in group else group.get("momentum", 0)
for group in optimizer.param_groups
]

seen_batches += 1

batch_time += time.time() - start_time
if seen_batches % modulo == 0:
momentum_info = f" - momentum: {momentum:.4f}" if cycle_momentum else ""
momentum_info = ""
if cycle_momentum:
momentum_info = " - momentum:" + ",".join([f"{m:.4f}" for m in momentum])

lr_info = ",".join([f"{lr:.6f}" for lr in current_learning_rate])

intermittent_loss = train_loss / average_over if average_over > 0 else train_loss / seen_batches

log.info(
f"epoch {epoch} - iter {seen_batches}/"
f"{total_number_of_batches} - loss "
f"{intermittent_loss:.8f} - samples/sec:"
f" {mini_batch_size * modulo / batch_time:.2f}"
f" - lr: {learning_rate:.6f}{momentum_info}"
f" - lr: {lr_info}{momentum_info}"
)
batch_time = 0.0
iteration = epoch * total_number_of_batches + batch_no
Expand All @@ -545,7 +567,7 @@ def train(
self.model.save(base_path / model_name, checkpoint=save_optimizer_state)

log_line(log)
log.info(f"EPOCH {epoch} done: loss {train_loss:.4f}" f" - lr {learning_rate:.7f}")
log.info(f"EPOCH {epoch} done: loss {train_loss:.4f} - lr {lr_info}")

if use_tensorboard:
writer.add_scalar("train_loss", train_loss, epoch)
Expand Down Expand Up @@ -708,11 +730,19 @@ def train(
bad_epochs = scheduler.num_bad_epochs
except AttributeError:
bad_epochs = 0
for group in optimizer.param_groups:
new_learning_rate = group["lr"]
if new_learning_rate != previous_learning_rate:

new_learning_rate = [group["lr"] for group in optimizer.param_groups]

if any([new_lr != prev_lr for new_lr, prev_lr in zip(new_learning_rate, previous_learning_rate)]):
bad_epochs = patience + 1
if previous_learning_rate == initial_learning_rate:

# lr unchanged
if all(
[
prev_lr == initial_lr
for prev_lr, initial_lr in zip(previous_learning_rate, initial_learning_rate)
]
):
bad_epochs += initial_extra_patience

# log bad epochs
Expand Down Expand Up @@ -741,10 +771,12 @@ def train(
if log_test:
f.write("\tTEST_LOSS\tTEST_" + "\tTEST_".join(test_eval_result.log_header.split("\t")))

lr_info = ",".join([f"{lr:.4f}" for lr in current_learning_rate])

f.write(
f"\n{epoch}\t{datetime.datetime.now():%H:%M:%S}"
f"\t{bad_epochs}"
f"\t{learning_rate:.4f}\t{train_loss}"
f"\t{lr_info}\t{train_loss}"
)
f.write(result_line)

Expand Down