Skip to content

Commit

Permalink
Merge branch 'master' into max_sequence_length
Browse files Browse the repository at this point in the history
  • Loading branch information
arnavgarg1 committed Aug 26, 2023
2 parents cdf6f22 + feec8a6 commit 9fd27b6
Show file tree
Hide file tree
Showing 20 changed files with 304 additions and 321 deletions.
1 change: 1 addition & 0 deletions ludwig/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
RECALL = "recall"
SPECIFICITY = "specificity"
PREDICTIONS = "predictions"
RESPONSE = "RESPONSE"
TOP_K = "top_k"
TOP_K_PREDICTIONS = "top_k_predictions"
PROBABILITY = "probability"
Expand Down
18 changes: 14 additions & 4 deletions ludwig/features/base_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from ludwig.types import FeatureConfigDict, FeatureMetadataDict, PreprocessingConfigDict, TrainingSetMetadataDict
from ludwig.utils import output_feature_utils
from ludwig.utils.calibration import CalibrationModule
from ludwig.utils.metric_utils import get_scalar_from_ludwig_metric
from ludwig.utils.torch_utils import LudwigModule
from ludwig.utils.types import DataFrame, TorchscriptPreprocessingInput

Expand Down Expand Up @@ -380,9 +379,20 @@ def get_metrics(self):
metric_vals = {}
for metric_name, metric_fn in self._metric_functions.items():
try:
metric_vals[metric_name] = get_scalar_from_ludwig_metric(metric_fn)
except Exception:
logger.exception(f"Caught exception computing metric: {metric_name}.")
computed_metric = metric_fn.compute()
except Exception as e:
logger.exception(f"Caught exception computing metric: {metric_name} with error: {e}.")
continue

# Metrics from torchmetrics can be a straightforward tensor.
if isinstance(computed_metric, Tensor):
metric_vals[metric_name] = computed_metric.detach().cpu().numpy().item()
else:
# Metrics from torchmetrics can be a dict of tensors.
# For example, ROUGE is returned as a dictionary of tensors.
# Unpack.
for sub_metric_name, metric in computed_metric.items():
metric_vals[sub_metric_name] = metric.detach().cpu().numpy().item()
return metric_vals

def reset_metrics(self):
Expand Down
53 changes: 52 additions & 1 deletion ludwig/features/text_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
# ==============================================================================
import logging
from functools import partial
from typing import Dict, Union
from typing import Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor
from transformers import PreTrainedTokenizer

from ludwig.constants import (
COLUMN,
IGNORE_INDEX_TOKEN_ID,
LAST_PREDICTIONS,
LENGTHS,
NAME,
Expand All @@ -29,6 +32,7 @@
PROBABILITIES,
PROBABILITY,
PROC_COLUMN,
RESPONSE,
TEXT,
)
from ludwig.features.base_feature import BaseFeatureMixin, OutputFeature
Expand All @@ -39,6 +43,7 @@
SequenceInputFeature,
SequenceOutputFeature,
)
from ludwig.modules.metric_registry import get_metric_tensor_input
from ludwig.schema.features.text_feature import TextInputFeatureConfig, TextOutputFeatureConfig
from ludwig.types import FeatureMetadataDict, PreprocessingConfigDict, TrainingSetMetadataDict
from ludwig.utils.math_utils import softmax
Expand All @@ -54,6 +59,23 @@
logger = logging.getLogger(__name__)


def get_decoded_targets_and_predictions(
targets: Tensor,
predictions: Dict[str, Tensor],
tokenizer: PreTrainedTokenizer,
) -> Tuple[List[str], List[str]]:
"""Returns the decoded targets and predictions, accounting for IGNORE_INDEX_TOKEN_ID."""
sanitized_targets = torch.where(targets != IGNORE_INDEX_TOKEN_ID, targets, tokenizer.pad_token_id)
sanitized_predictions = torch.where(
predictions[PREDICTIONS] != IGNORE_INDEX_TOKEN_ID,
predictions[PREDICTIONS],
tokenizer.pad_token_id,
)
decoded_targets = tokenizer.batch_decode(sanitized_targets, skip_special_tokens=True)
decoded_predictions = tokenizer.batch_decode(sanitized_predictions, skip_special_tokens=True)
return decoded_targets, decoded_predictions


class TextFeatureMixin(BaseFeatureMixin):
@staticmethod
def type():
Expand Down Expand Up @@ -261,6 +283,35 @@ def get_output_dtype(cls):
def output_shape(self) -> torch.Size:
return torch.Size([self.decoder_obj.config.max_sequence_length])

def update_metrics(
self,
targets: Tensor,
predictions: Dict[str, Tensor],
tokenizer: Optional[PreTrainedTokenizer] = None,
) -> None:
"""Updates metrics with the given targets and predictions.
If decoded_targets and decoded_predictions are provided, as through LLM model types, then additional
response-based metrics like BLEU and ROUGE are also computed.
Args:
targets: Tensor with target values for this output feature.
predictions: Dict of tensors returned by predictions().
"""
if tokenizer is not None:
decoded_targets, decoded_predictions = get_decoded_targets_and_predictions(targets, predictions, tokenizer)
for metric_name, metric_fn in self._metric_functions.items():
prediction_key = get_metric_tensor_input(metric_name)
if prediction_key == RESPONSE:
if tokenizer is not None:
# RESPONSE metrics cannot be computed if decoded texts are not provided.
# Decoded texts are only provided using the LLM model type.
if decoded_targets is not None and decoded_predictions is not None:
metric_fn.update(decoded_predictions, decoded_targets)
else:
metric_fn = metric_fn.to(predictions[prediction_key].device)
metric_fn.update(predictions[prediction_key].detach(), targets)

@staticmethod
def update_config_with_metadata(feature_config, feature_metadata, *args, **kwargs):
feature_config.decoder.vocab_size = feature_metadata["vocab_size"]
Expand Down
15 changes: 9 additions & 6 deletions ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def build_outputs(
"""Builds and returns output feature."""
# TODO: only single task currently
if len(output_feature_configs) > 1:
raise ValueError("Only single task currently supported")
raise ValueError("The LLM model type only supports a single output feature.")

output_feature_config = output_feature_configs[0]
output_feature_config.input_size = input_size
Expand Down Expand Up @@ -485,9 +485,9 @@ def update_metrics(self, targets, predictions):
_targets, _predictions = realign_target_and_prediction_tensors_for_inference(
targets, predictions, of_name, self.tokenizer
)
of_obj.update_metrics(_targets[of_name], _predictions[of_name])
continue
of_obj.update_metrics(targets[of_name], predictions[of_name])
of_obj.update_metrics(_targets[of_name], _predictions[of_name], self.tokenizer)
else:
of_obj.update_metrics(targets[of_name], predictions[of_name])

# HACK (Tim): get the device of the targets to transfer self.eval_loss_metric to the same device
target_device = list(targets.values())[0].device
Expand All @@ -506,7 +506,10 @@ def update_metrics_finetune_llm(self, targets, predictions):
# to match the prediction length and depends on how much of the target tensor was included in the
# forward pass.
_targets = self._update_target_tensor_for_finetuning(_targets, _predictions, of_name)
of_obj.update_metrics(_targets[of_name], _predictions[of_name])
if isinstance(of_obj, TextOutputFeature):
of_obj.update_metrics(_targets[of_name], _predictions[of_name], self.tokenizer)
else:
of_obj.update_metrics(_targets[of_name], _predictions[of_name])
continue

of_obj.update_metrics(_targets[of_name], _predictions[of_name])
Expand Down Expand Up @@ -644,7 +647,7 @@ def get_args(self):

def _update_target_tensor_for_finetuning(
self, targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], of_name: str
):
) -> Dict[str, torch.Tensor]:
"""Update target tensor for fine-tuning.
This method removes left padding from target tensors, adds a pad token to the end of the target tensors,
Expand Down
71 changes: 12 additions & 59 deletions ludwig/modules/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
import math
from typing import Any, Callable, Dict
from typing import Any, Dict

from torch.optim import Optimizer
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR, ReduceLROnPlateau, SequentialLR
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau

from ludwig.constants import MINIMIZE, TRAINING, VALIDATION
from ludwig.modules.metric_registry import get_metric_objective
Expand Down Expand Up @@ -166,41 +166,22 @@ def get_schedule_with_warmup(
step_info: StepInfo,
) -> LambdaLR:
"""Creates a learning rate scheduler that updates each training step."""
schedulers = []
decay_fn = decay_registry[config.decay]

# Warmup scheduler
if step_info.num_warmup_steps > 0:
warmup_scheduler = LambdaLR(
optimizer,
lambda current_step: float(current_step) / float(max(1, step_info.num_warmup_steps)),
last_epoch=-1,
)
schedulers.append(warmup_scheduler)
def lr_lambda(current_step: int):
if current_step < step_info.num_warmup_steps:
return float(current_step) / float(max(1, step_info.num_warmup_steps))
return decay_fn(current_step, step_info.num_training_steps, step_info.num_warmup_steps, config)

# Decay scheduler
decay = config.decay
decay_scheduler = decay_registry[decay](config, optimizer, step_info)
schedulers.append(decay_scheduler)

if len(schedulers) == 1:
# Only one scheduler, no need to wrap in a SequentialLR
return schedulers[0]

# Return a SequentialLR that applies the warmup and decay schedulers in order
# with the warmup scheduler only applied for the first num_warmup_steps steps.
return SequentialLR(optimizer, schedulers=schedulers, milestones=[step_info.num_warmup_steps], last_epoch=-1)
return LambdaLR(optimizer, lr_lambda, last_epoch=-1)


def no_decay(current_step: int, num_training_steps: int, num_warmup_steps: int, config: LRSchedulerConfig):
return 1.0


def linear_decay(current_step: int, num_training_steps: int, num_warmup_steps: int, config: LRSchedulerConfig):
return max(
0.0,
float(num_training_steps - num_warmup_steps - current_step)
/ float(max(1, num_training_steps - num_warmup_steps)),
)
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))


def exponential_decay(current_step: int, num_training_steps: int, num_warmup_steps: int, config: LRSchedulerConfig):
Expand All @@ -213,36 +194,8 @@ def exponential_decay(current_step: int, num_training_steps: int, num_warmup_ste
return math.pow(decay_rate, exponent)


def wrap_decay_fn(decay_fn: Callable) -> Callable:
def init_fn(config: LRSchedulerConfig, optimizer: Optimizer, step_info: StepInfo) -> LambdaLR:
return LambdaLR(
optimizer,
lambda current_step: decay_fn(
current_step, step_info.num_training_steps, step_info.num_warmup_steps, config
),
last_epoch=-1,
)

return init_fn


def init_cosine_decay(
config: LRSchedulerConfig,
optimizer: Optimizer,
step_info: StepInfo,
) -> CosineAnnealingWarmRestarts:
return CosineAnnealingWarmRestarts(
optimizer,
T_0=config.t_0 or step_info.steps_per_checkpoint,
T_mult=config.t_mult or 1,
eta_min=config.eta_min or 0,
last_epoch=-1,
)


decay_registry = {
None: wrap_decay_fn(no_decay),
"linear": wrap_decay_fn(linear_decay),
"exponential": wrap_decay_fn(exponential_decay),
"cosine": init_cosine_decay,
None: no_decay,
"linear": linear_decay,
"exponential": exponential_decay,
}
21 changes: 21 additions & 0 deletions ludwig/modules/metric_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
)
from torchmetrics.functional.regression.r2 import _r2_score_compute, _r2_score_update
from torchmetrics.metric import jit_distributed_available
from torchmetrics.text import BLEUScore, WordErrorRate
from torchmetrics.text.perplexity import Perplexity
from torchmetrics.text.rouge import ROUGEScore

from ludwig.constants import (
ACCURACY,
Expand Down Expand Up @@ -62,6 +64,7 @@
PROBABILITIES,
R2,
RECALL,
RESPONSE,
ROC_AUC,
ROOT_MEAN_SQUARED_ERROR,
ROOT_MEAN_SQUARED_PERCENTAGE_ERROR,
Expand Down Expand Up @@ -389,6 +392,24 @@ def get_current_value(self, preds: Tensor, target: Tensor):
return torch.exp(shifted_loss)


@register_metric("bleu", [TEXT], MAXIMIZE, RESPONSE)
class BLEUScoreMetric(BLEUScore, LudwigMetric):
def __init__(self, **kwargs):
super().__init__()


@register_metric("rouge", [TEXT], MAXIMIZE, RESPONSE)
class ROUGEScoreMetric(ROUGEScore, LudwigMetric):
def __init__(self, **kwargs):
super().__init__()


@register_metric("word_error_rate", [TEXT], MINIMIZE, RESPONSE)
class WordErrorRateMetric(WordErrorRate, LudwigMetric):
def __init__(self, **kwargs):
super().__init__()


@register_metric("char_error_rate", [SEQUENCE, TEXT], MINIMIZE, PREDICTIONS)
class CharErrorRateMetric(CharErrorRate, LudwigMetric):
def __init__(self, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions ludwig/modules/metric_registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, List, Literal, TYPE_CHECKING, Union

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import LOGITS, MAXIMIZE, MINIMIZE, PREDICTIONS, PROBABILITIES
from ludwig.constants import LOGITS, MAXIMIZE, MINIMIZE, PREDICTIONS, PROBABILITIES, RESPONSE
from ludwig.utils.registry import Registry

if TYPE_CHECKING:
Expand Down Expand Up @@ -84,5 +84,5 @@ def get_metric_objective(metric_name: str) -> Literal[MINIMIZE, MAXIMIZE]:


@DeveloperAPI
def get_metric_tensor_input(metric_name: str) -> Literal[PREDICTIONS, PROBABILITIES, LOGITS]:
def get_metric_tensor_input(metric_name: str) -> Literal[PREDICTIONS, PROBABILITIES, LOGITS, RESPONSE]:
return metric_tensor_input_registry[metric_name]
28 changes: 1 addition & 27 deletions ludwig/schema/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LRSchedulerConfig(schema_utils.BaseMarshmallowConfig, ABC):
"""Configuration for learning rate scheduler parameters."""

decay: str = schema_utils.StringOptions(
options=["linear", "exponential", "cosine"],
options=["linear", "exponential"],
default=None,
allow_none=True,
description="Turn on decay of the learning rate.",
Expand Down Expand Up @@ -99,32 +99,6 @@ class LRSchedulerConfig(schema_utils.BaseMarshmallowConfig, ABC):
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["reduce_eval_split"],
)

# Parameters for CosineAnnealingWarmRestarts scheduler

t_0: int = schema_utils.PositiveInteger(
default=None,
allow_none=True,
description="Number of steps before the first restart for cosine annealing decay. If not specified, it"
" will be set to `steps_per_checkpoint`.",
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["t_0"],
)

t_mult: int = schema_utils.PositiveInteger(
default=1,
description="Period multiplier after each restart for cosine annealing decay. Defaults to 1, i.e.,"
" restart every `t_0` steps. If set to a larger value, the period between restarts increases by that"
" multiplier. For e.g., if t_mult is 2, then the periods would be: t_0, 2*t_0, 2^2*t_0, 2^3*t_0, etc.",
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["t_mult"],
)

eta_min: float = schema_utils.FloatRange(
default=0,
min=0,
max=1,
description="Minimum learning rate allowed for cosine annealing decay. Default: 0.",
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["eta_min"],
)


# TODO(travis): too much boilerplate here, we should find a way to abstract all this and only require specifying the
# minimal amount needed for the new config object.
Expand Down
Loading

0 comments on commit 9fd27b6

Please sign in to comment.