Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HuggingFace improvements #649

Merged
merged 7 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions examples/DVCLive-HuggingFace.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install accelerate datasets dvclive evaluate 'transformers[torch]' --upgrade"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!git init -q\n",
"!git config --local user.email \"you@example.com\"\n",
"!git config --local user.name \"Your Name\"\n",
"!dvc init -q\n",
"!git commit -m \"DVC init\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from transformers import AutoTokenizer\n",
"\n",
"dataset = load_dataset(\"imdb\")\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-cased\")\n",
"\n",
"def tokenize_function(examples):\n",
" return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n",
"\n",
"small_train_dataset = dataset[\"train\"].shuffle(seed=42).select(range(2000)).map(tokenize_function, batched=True)\n",
"small_eval_dataset = dataset[\"test\"].shuffle(seed=42).select(range(200)).map(tokenize_function, batched=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import evaluate\n",
"\n",
"metric = evaluate.load(\"f1\")\n",
"\n",
"def compute_metrics(eval_pred):\n",
" logits, labels = eval_pred\n",
" predictions = np.argmax(logits, axis=-1)\n",
" return metric.compute(predictions=predictions, references=labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tracking experiments with DVCLive"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dvclive.huggingface import DVCLiveCallback\n",
"from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
"\n",
"for epochs in (5, 10, 15):\n",
" model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-cased\", num_labels=2)\n",
" for param in model.base_model.parameters():\n",
" param.requires_grad = False\n",
"\n",
" training_args = TrainingArguments(\n",
" evaluation_strategy=\"epoch\", \n",
" learning_rate=3e-4,\n",
" logging_strategy=\"epoch\",\n",
" num_train_epochs=epochs,\n",
" output_dir=\"output\", \n",
" overwrite_output_dir=True,\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\",\n",
" save_strategy=\"epoch\",\n",
" weight_decay=0.01,\n",
" )\n",
"\n",
" trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=small_train_dataset,\n",
" eval_dataset=small_eval_dataset,\n",
" compute_metrics=compute_metrics,\n",
" callbacks=[DVCLiveCallback(report=\"notebook\", save_dvc_exp=True, log_model=\"last\")],\n",
" )\n",
" trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Comparing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import dvc.api\n",
"import pandas as pd\n",
"\n",
"columns = [\"Experiment\", \"epoch\", \"eval.f1\"]\n",
"\n",
"df = pd.DataFrame(dvc.api.exp_show(), columns=columns)\n",
"\n",
"df.dropna(inplace=True)\n",
"df.reset_index(drop=True, inplace=True)\n",
"df\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!dvc plots diff $(dvc exp list --names-only)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import HTML\n",
"HTML(filename='./dvc_plots/index.html')"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
49 changes: 34 additions & 15 deletions src/dvclive/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# ruff: noqa: ARG002
from typing import Optional
import os
from typing import Literal, Optional

from transformers import (
TrainerCallback,
Expand All @@ -14,11 +15,33 @@


class DVCLiveCallback(TrainerCallback):
def __init__(self, model_file=None, live: Optional[Live] = None, **kwargs):
def __init__(
self,
live: Optional[Live] = None,
log_model: Optional[Literal["all", "last"]] = None,
**kwargs,
):
super().__init__()
self.model_file = model_file
self._log_model = log_model
self.live = live if live is not None else Live(**kwargs)

def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
for key, value in args.to_dict().items():
if key in (
"num_train_epochs",
"weight_decay",
"max_grad_norm",
"warmup_ratio",
"warmup_steps",
):
self.live.log_param(key, value)
daavoo marked this conversation as resolved.
Show resolved Hide resolved

def on_log(
self,
args: TrainingArguments,
Expand All @@ -31,20 +54,15 @@ def on_log(
self.live.log_metric(standardize_metric_name(key, __name__), value)
self.live.next_step()

def on_epoch_end(
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if self.model_file:
model = kwargs["model"]
model.save_pretrained(self.model_file)
tokenizer = kwargs.get("tokenizer")
if tokenizer:
tokenizer.save_pretrained(self.model_file)
self.live.log_artifact(self.model_file)
if self._log_model == "all" and state.is_world_process_zero:
self.live.log_artifact(args.output_dir)

def on_train_end(
self,
Expand All @@ -53,10 +71,11 @@ def on_train_end(
control: TrainerControl,
**kwargs,
):
if args.load_best_model_at_end:
trainer = Trainer(
if self._log_model == "last" and state.is_world_process_zero:
dberenbaum marked this conversation as resolved.
Show resolved Hide resolved
fake_trainer = Trainer(
args=args, model=kwargs.get("model"), tokenizer=kwargs.get("tokenizer")
)
trainer.save_model()
self.live.log_artifact(args.output_dir)
output_dir = os.path.join(args.output_dir, "last")
fake_trainer.save_model(output_dir)
self.live.log_artifact(output_dir, type="model", copy=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cross-framework consistency isn't our highest priority, but should we agree on some common principles for the final artifact, like naming and whether to copy it?

Copy link
Contributor Author

@daavoo daavoo Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like for all the integrations to have just 2 options:

  • all/checkpoints: resuming scenarios.
    Log the entire checkpoint folder

  • best: model registry
    Log the best checkpoint on end with copy=True, name="best", type="model"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine with me. Do you want to update the lightning logger to use copy=True? AFAIK the rest is consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will update this and lightning to use that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, are you also suggesting to change the behavior of log_model=True in lightning to track only the copied best artifact and not the whole directory? That's fine, just want to make sure I understand what you mean.

For HF, how should we handle the last/best checkpoint? If args.load_best_model_at_end, we could add name=best? WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, are you also suggesting to change the behavior of log_model=True in lightning to track only the copied best artifact and not the whole directory

I think I would suggest dropping the boolean value.

For HF, how should we handle the last/best checkpoint? If args.load_best_model_at_end, we could add name=best? WDYT?

Yes, makes sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would suggest dropping the boolean value.

I worry doing that and/or not saving the checkpoints dir breaks consistency with mlflow/wandb/etc. in lightning for the sake of consistency across dvclive. I would probably err on the side of sticking with consistency for lightning over consistency for dvclive where they conflict, but we can always make this a follow-up PR if it is taking this off track.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or HF, how should we handle the last/best checkpoint? If args.load_best_model_at_end, we could add name=best? WDYT?

Updated with this behavior

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, dropped last option in favor of True

self.live.end()
46 changes: 19 additions & 27 deletions tests/test_frameworks/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from dvclive import Live
from dvclive.plots.metric import Metric
from dvclive.serialize import load_yaml
from dvclive.utils import parse_metrics

try:
Expand Down Expand Up @@ -99,6 +100,7 @@ def args():
"foo",
evaluation_strategy="epoch",
num_train_epochs=2,
save_strategy="epoch",
)


Expand Down Expand Up @@ -131,12 +133,13 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker):
assert len(logs[os.path.join(scalars, "epoch.tsv")]) == 3
assert len(logs[os.path.join(scalars, "eval", "loss.tsv")]) == 2

params = load_yaml(live.params_file)
assert params["num_train_epochs"] == 2

def test_huggingface_model_file(tmp_dir, model, args, data, mocker):
model_path = tmp_dir / "model_hf"
model_save = mocker.spy(model, "save_pretrained")

live_callback = DVCLiveCallback(model_file=model_path)
@pytest.mark.parametrize("log_model", ["all", "last", None])
def test_huggingface_log_model(tmp_dir, model, args, data, mocker, log_model):
live_callback = DVCLiveCallback(log_model=log_model)
log_artifact = mocker.patch.object(live_callback.live, "log_artifact")

trainer = Trainer(
Expand All @@ -149,34 +152,23 @@ def test_huggingface_model_file(tmp_dir, model, args, data, mocker):
trainer.add_callback(live_callback)
trainer.train()

assert model_path.is_dir()
expected_call_count = {
"all": 2,
"last": 1,
None: 0,
}
assert log_artifact.call_count == expected_call_count[log_model]

assert (model_path / "pytorch_model.bin").exists()
assert (model_path / "config.json").exists()
assert model_save.call_count == 2
log_artifact.assert_called_with(model_path)
if log_model == "last":
log_artifact.assert_called_with(
os.path.join(args.output_dir, "last"),
type="model",
copy=True,
)


def test_huggingface_pass_logger():
logger = Live("train_logs")

assert DVCLiveCallback().live is not logger
assert DVCLiveCallback(live=logger).live is logger


def test_huggingface_log_artifact(tmp_dir, model, args, data, mocker):
live_callback = DVCLiveCallback()
log_artifact = mocker.patch.object(live_callback.live, "log_artifact")

args.load_best_model_at_end = True
trainer = Trainer(
model,
args,
train_dataset=data[0],
eval_dataset=data[1],
compute_metrics=compute_metrics,
)
trainer.add_callback(live_callback)
trainer.train()

log_artifact.assert_called_with(trainer.args.output_dir)