Skip to content

Commit

Permalink
HuggingFace improvements (#649)
Browse files Browse the repository at this point in the history
* huggingface: log some parameters

* huggingface: Add `log_model`.

- If `None` (default) will not log any artifact.
- If `all` will call log_artifact with `output_dir` at each `on_save` call.
- If `last` will save the model `on_train_end` and call `log_artifact` with type=model and copy=True.

* examples: Add DVCLive-HuggingFace notebook

* Don't cherry-pick args

* huggingface: Conditional model name based on load_best_model_at_end

* huggingface: Keep model_file behavior

* Use `True` instead of `last`.
  • Loading branch information
daavoo committed Aug 16, 2023
1 parent 1aa3e05 commit f1b8e2a
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 20 deletions.
167 changes: 167 additions & 0 deletions examples/DVCLive-HuggingFace.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install accelerate datasets dvclive evaluate 'transformers[torch]' --upgrade"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!git init -q\n",
"!git config --local user.email \"you@example.com\"\n",
"!git config --local user.name \"Your Name\"\n",
"!dvc init -q\n",
"!git commit -m \"DVC init\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from transformers import AutoTokenizer\n",
"\n",
"dataset = load_dataset(\"imdb\")\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-cased\")\n",
"\n",
"def tokenize_function(examples):\n",
" return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n",
"\n",
"small_train_dataset = dataset[\"train\"].shuffle(seed=42).select(range(2000)).map(tokenize_function, batched=True)\n",
"small_eval_dataset = dataset[\"test\"].shuffle(seed=42).select(range(200)).map(tokenize_function, batched=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import evaluate\n",
"\n",
"metric = evaluate.load(\"f1\")\n",
"\n",
"def compute_metrics(eval_pred):\n",
" logits, labels = eval_pred\n",
" predictions = np.argmax(logits, axis=-1)\n",
" return metric.compute(predictions=predictions, references=labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tracking experiments with DVCLive"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dvclive.huggingface import DVCLiveCallback\n",
"from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
"\n",
"for epochs in (5, 10, 15):\n",
" model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-cased\", num_labels=2)\n",
" for param in model.base_model.parameters():\n",
" param.requires_grad = False\n",
"\n",
" training_args = TrainingArguments(\n",
" evaluation_strategy=\"epoch\", \n",
" learning_rate=3e-4,\n",
" logging_strategy=\"epoch\",\n",
" num_train_epochs=epochs,\n",
" output_dir=\"output\", \n",
" overwrite_output_dir=True,\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\",\n",
" save_strategy=\"epoch\",\n",
" weight_decay=0.01,\n",
" )\n",
"\n",
" trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=small_train_dataset,\n",
" eval_dataset=small_eval_dataset,\n",
" compute_metrics=compute_metrics,\n",
" callbacks=[DVCLiveCallback(report=\"notebook\", save_dvc_exp=True, log_model=\"last\")],\n",
" )\n",
" trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Comparing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import dvc.api\n",
"import pandas as pd\n",
"\n",
"columns = [\"Experiment\", \"epoch\", \"eval.f1\"]\n",
"\n",
"df = pd.DataFrame(dvc.api.exp_show(), columns=columns)\n",
"\n",
"df.dropna(inplace=True)\n",
"df.reset_index(drop=True, inplace=True)\n",
"df\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!dvc plots diff $(dvc exp list --names-only)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import HTML\n",
"HTML(filename='./dvc_plots/index.html')"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
50 changes: 43 additions & 7 deletions src/dvclive/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# ruff: noqa: ARG002
from typing import Optional
import logging
import os
from typing import Literal, Optional, Union

from transformers import (
TrainerCallback,
Expand All @@ -12,13 +14,35 @@
from dvclive import Live
from dvclive.utils import standardize_metric_name

logger = logging.getLogger("dvclive")


class DVCLiveCallback(TrainerCallback):
def __init__(self, model_file=None, live: Optional[Live] = None, **kwargs):
def __init__(
self,
live: Optional[Live] = None,
log_model: Optional[Union[Literal["all"], bool]] = None,
**kwargs,
):
super().__init__()
self.model_file = model_file
self._log_model = log_model
self.model_file = kwargs.pop("model_file", None)
if self.model_file:
logger.warning(
"model_file is deprecated and will be removed"
" in the next major version, use log_model instead"
)
self.live = live if live is not None else Live(**kwargs)

def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
self.live.log_params(args.to_dict())

def on_log(
self,
args: TrainingArguments,
Expand All @@ -31,6 +55,16 @@ def on_log(
self.live.log_metric(standardize_metric_name(key, __name__), value)
self.live.next_step()

def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if self._log_model == "all" and state.is_world_process_zero:
self.live.log_artifact(args.output_dir)

def on_epoch_end(
self,
args: TrainingArguments,
Expand All @@ -53,10 +87,12 @@ def on_train_end(
control: TrainerControl,
**kwargs,
):
if args.load_best_model_at_end:
trainer = Trainer(
if self._log_model is True and state.is_world_process_zero:
fake_trainer = Trainer(
args=args, model=kwargs.get("model"), tokenizer=kwargs.get("tokenizer")
)
trainer.save_model()
self.live.log_artifact(args.output_dir)
name = "best" if args.load_best_model_at_end else "last"
output_dir = os.path.join(args.output_dir, name)
fake_trainer.save_model(output_dir)
self.live.log_artifact(output_dir, name=name, type="model", copy=True)
self.live.end()
52 changes: 39 additions & 13 deletions tests/test_frameworks/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from dvclive import Live
from dvclive.plots.metric import Metric
from dvclive.serialize import load_yaml
from dvclive.utils import parse_metrics

try:
Expand Down Expand Up @@ -99,6 +100,7 @@ def args():
"foo",
evaluation_strategy="epoch",
num_train_epochs=2,
save_strategy="epoch",
)


Expand Down Expand Up @@ -131,14 +133,17 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker):
assert len(logs[os.path.join(scalars, "epoch.tsv")]) == 3
assert len(logs[os.path.join(scalars, "eval", "loss.tsv")]) == 2

params = load_yaml(live.params_file)
assert params["num_train_epochs"] == 2

def test_huggingface_model_file(tmp_dir, model, args, data, mocker):
model_path = tmp_dir / "model_hf"
model_save = mocker.spy(model, "save_pretrained")

live_callback = DVCLiveCallback(model_file=model_path)
@pytest.mark.parametrize("log_model", ["all", True, None])
@pytest.mark.parametrize("best", [True, False])
def test_huggingface_log_model(tmp_dir, model, args, data, mocker, log_model, best):
live_callback = DVCLiveCallback(log_model=log_model)
log_artifact = mocker.patch.object(live_callback.live, "log_artifact")

args.load_best_model_at_end = best
trainer = Trainer(
model,
args,
Expand All @@ -149,12 +154,21 @@ def test_huggingface_model_file(tmp_dir, model, args, data, mocker):
trainer.add_callback(live_callback)
trainer.train()

assert model_path.is_dir()
expected_call_count = {
"all": 2,
True: 1,
None: 0,
}
assert log_artifact.call_count == expected_call_count[log_model]

assert (model_path / "pytorch_model.bin").exists()
assert (model_path / "config.json").exists()
assert model_save.call_count == 2
log_artifact.assert_called_with(model_path)
if log_model == "last":
name = "best" if best else "last"
log_artifact.assert_called_with(
os.path.join(args.output_dir, name),
name=name,
type="model",
copy=True,
)


def test_huggingface_pass_logger():
Expand All @@ -164,11 +178,14 @@ def test_huggingface_pass_logger():
assert DVCLiveCallback(live=logger).live is logger


def test_huggingface_log_artifact(tmp_dir, model, args, data, mocker):
live_callback = DVCLiveCallback()
def test_huggingface_model_file(tmp_dir, model, args, data, mocker):
logger = mocker.patch("dvclive.huggingface.logger")

model_path = tmp_dir / "model_hf"

live_callback = DVCLiveCallback(model_file=model_path)
log_artifact = mocker.patch.object(live_callback.live, "log_artifact")

args.load_best_model_at_end = True
trainer = Trainer(
model,
args,
Expand All @@ -179,4 +196,13 @@ def test_huggingface_log_artifact(tmp_dir, model, args, data, mocker):
trainer.add_callback(live_callback)
trainer.train()

log_artifact.assert_called_with(trainer.args.output_dir)
assert model_path.is_dir()

assert (model_path / "pytorch_model.bin").exists()
assert (model_path / "config.json").exists()
log_artifact.assert_called_with(model_path)

logger.warning.assert_called_with(
"model_file is deprecated and will be removed"
" in the next major version, use log_model instead"
)

0 comments on commit f1b8e2a

Please sign in to comment.