From 56b662ea74a930fd6c367222071ec17f3f27dedb Mon Sep 17 00:00:00 2001 From: Filip Sondej <26285707+filyp@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:15:47 +0100 Subject: [PATCH 01/12] bump transformers to 4.51.3 make UnlearnTrainer implementation more future proof make other necessary changes to be compatible with new transformers version --- docs/components.md | 2 +- requirements.txt | 5 +- src/train.py | 2 +- src/trainer/base.py | 2 +- src/trainer/unlearn/base.py | 110 ++--------------------------- src/trainer/unlearn/ceu.py | 4 +- src/trainer/unlearn/dpo.py | 4 +- src/trainer/unlearn/grad_ascent.py | 4 +- src/trainer/unlearn/grad_diff.py | 4 +- src/trainer/unlearn/npo.py | 4 +- src/trainer/unlearn/pdu.py | 4 +- src/trainer/unlearn/rmu.py | 4 +- src/trainer/unlearn/satimp.py | 4 +- src/trainer/unlearn/simnpo.py | 4 +- src/trainer/unlearn/undial.py | 4 +- src/trainer/unlearn/wga.py | 4 +- 16 files changed, 43 insertions(+), 122 deletions(-) diff --git a/docs/components.md b/docs/components.md index a0ba85a64..9f4031439 100644 --- a/docs/components.md +++ b/docs/components.md @@ -38,7 +38,7 @@ class GradDiff(UnlearnTrainer): def __init__(self, gamma, alpha, ...): ... - def compute_loss(self, model, inputs, return_outputs=False): + def compute_unlearn_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): ... ``` diff --git a/requirements.txt b/requirements.txt index 2f39c76e2..58a47633b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ -huggingface-hub==0.29.1 -transformers==4.45.1 +huggingface-hub==0.36.0 +transformers==4.51.3 +hf-xet==1.2.0 numpy==2.2.3 hydra-core==1.3 hydra_colorlog==1.2.0 diff --git a/src/train.py b/src/train.py index a2f81c8d4..72bc95b82 100644 --- a/src/train.py +++ b/src/train.py @@ -49,7 +49,7 @@ def main(cfg: DictConfig): trainer_cfg=trainer_cfg, model=model, train_dataset=data.get("train", None), - eval_dataset=data.get("eval", None), + eval_dataset=data.get("eval", "dummy"), # None would trigger Trainer exception tokenizer=tokenizer, data_collator=collator, evaluators=evaluators, diff --git a/src/trainer/base.py b/src/trainer/base.py index 05f36a2a4..3d30a7b87 100644 --- a/src/trainer/base.py +++ b/src/trainer/base.py @@ -52,7 +52,7 @@ def evaluate( ) return eval_metrics - if eval_dataset is None: + if eval_dataset is None or eval_dataset == "dummy": return {} # Run the default HF Trainer evaluate method when eval dataset is provided return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix) diff --git a/src/trainer/unlearn/base.py b/src/trainer/unlearn/base.py index 683698da2..3e358130a 100644 --- a/src/trainer/unlearn/base.py +++ b/src/trainer/unlearn/base.py @@ -75,110 +75,8 @@ def _prepare_deepspeed(self, model): model.eval() return model - def prediction_step( - self, - model: nn.Module, - inputs: Dict[str, Union[torch.Tensor, Any]], - prediction_loss_only: bool, - ignore_keys: Optional[List[str]] = None, - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: - """ - The only change to this function is calling the Trainer's compute_loss, as it's often overridden by unlearning methods, and we want to maintain the Trainer's evaluation setup. - """ - has_labels = ( - False - if len(self.label_names) == 0 - else all(inputs.get(k) is not None for k in self.label_names) - ) - # For CLIP-like models capable of returning loss values. - # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` - # is `True` in `model.forward`. - return_loss = inputs.get("return_loss", None) - if return_loss is None: - return_loss = self.can_return_loss - loss_without_labels = ( - True if len(self.label_names) == 0 and return_loss else False - ) - - inputs = self._prepare_inputs(inputs) - if ignore_keys is None: - if hasattr(self.model, "config"): - ignore_keys = getattr( - self.model.config, "keys_to_ignore_at_inference", [] - ) - else: - ignore_keys = [] - - # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. - if has_labels or loss_without_labels: - labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) - if len(labels) == 1: - labels = labels[0] + def compute_loss(self, model, inputs, **kwargs): + if model.training: + return self.compute_unlearn_loss(model, inputs, **kwargs) else: - labels = None - - with torch.no_grad(): - if is_sagemaker_mp_enabled(): - raw_outputs = smp_forward_only(model, inputs) - if has_labels or loss_without_labels: - if isinstance(raw_outputs, dict): - loss_mb = raw_outputs["loss"] - logits_mb = tuple( - v - for k, v in raw_outputs.items() - if k not in ignore_keys + ["loss"] - ) - else: - loss_mb = raw_outputs[0] - logits_mb = raw_outputs[1:] - - loss = loss_mb.reduce_mean().detach().cpu() - logits = smp_nested_concat(logits_mb) - else: - loss = None - if isinstance(raw_outputs, dict): - logits_mb = tuple( - v for k, v in raw_outputs.items() if k not in ignore_keys - ) - else: - logits_mb = raw_outputs - logits = smp_nested_concat(logits_mb) - else: - if has_labels or loss_without_labels: - with self.compute_loss_context_manager(): - ### Call compute_loss of super class since overridden compute_loss is not be applicable to eval_dataset. - loss, outputs = super().compute_loss( - model, inputs, return_outputs=True - ) - loss = loss.mean().detach() - - if isinstance(outputs, dict): - logits = tuple( - v - for k, v in outputs.items() - if k not in ignore_keys + ["loss"] - ) - else: - logits = outputs[1:] - else: - loss = None - with self.compute_loss_context_manager(): - outputs = model(**inputs) - if isinstance(outputs, dict): - logits = tuple( - v for k, v in outputs.items() if k not in ignore_keys - ) - else: - logits = outputs - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[self.args.past_index - 1] - - if prediction_loss_only: - return (loss, None, None) - - logits = nested_detach(logits) - if len(logits) == 1: - logits = logits[0] - - return (loss, logits, labels) + return super().compute_loss(model, inputs, **kwargs) diff --git a/src/trainer/unlearn/ceu.py b/src/trainer/unlearn/ceu.py index 33da99c3c..a32caaf06 100644 --- a/src/trainer/unlearn/ceu.py +++ b/src/trainer/unlearn/ceu.py @@ -86,7 +86,9 @@ def __init__(self, ignore_first_n_answer_tokens=1, *args, **kwargs): super().__init__(*args, **kwargs) self.ignore_first_n_answer_tokens = ignore_first_n_answer_tokens - def compute_loss(self, model, inputs, return_outputs=False): + def compute_unlearn_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): forget_inputs = inputs["forget"] loss, outputs = compute_batch_ceu( model, diff --git a/src/trainer/unlearn/dpo.py b/src/trainer/unlearn/dpo.py index b64b474b4..921220662 100644 --- a/src/trainer/unlearn/dpo.py +++ b/src/trainer/unlearn/dpo.py @@ -9,7 +9,9 @@ def __init__(self, beta=1.0, *args, **kwargs): if self.ref_model is None: self.ref_model = self._prepare_ref_model(self.model) - def compute_loss(self, model, inputs, return_outputs=False): + def compute_unlearn_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): forget_inputs = inputs["forget"]["original"] alternate_inputs = inputs["forget"]["alternate"] diff --git a/src/trainer/unlearn/grad_ascent.py b/src/trainer/unlearn/grad_ascent.py index eda8b4812..6ab1944a9 100644 --- a/src/trainer/unlearn/grad_ascent.py +++ b/src/trainer/unlearn/grad_ascent.py @@ -2,7 +2,9 @@ class GradAscent(UnlearnTrainer): - def compute_loss(self, model, inputs, return_outputs=False): + def compute_unlearn_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): forget_inputs = inputs["forget"] forget_inputs = { "input_ids": forget_inputs["input_ids"], diff --git a/src/trainer/unlearn/grad_diff.py b/src/trainer/unlearn/grad_diff.py index bfecc19a2..d51a672d9 100644 --- a/src/trainer/unlearn/grad_diff.py +++ b/src/trainer/unlearn/grad_diff.py @@ -38,7 +38,9 @@ def compute_retain_loss(self, model, retain_inputs): ) return retain_loss - def compute_loss(self, model, inputs, return_outputs=False): + def compute_unlearn_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): forget_inputs = inputs["forget"] forget_inputs = { "input_ids": forget_inputs["input_ids"], diff --git a/src/trainer/unlearn/npo.py b/src/trainer/unlearn/npo.py index 7c782d968..12851e418 100644 --- a/src/trainer/unlearn/npo.py +++ b/src/trainer/unlearn/npo.py @@ -9,7 +9,9 @@ def __init__(self, beta=1.0, *args, **kwargs): if self.ref_model is None: self.ref_model = self._prepare_ref_model(self.model) - def compute_loss(self, model, inputs, return_outputs=False): + def compute_unlearn_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): forget_inputs = inputs["forget"] forget_loss, forget_outputs = compute_dpo_loss( diff --git a/src/trainer/unlearn/pdu.py b/src/trainer/unlearn/pdu.py index e79bcc58b..e86c7b69f 100644 --- a/src/trainer/unlearn/pdu.py +++ b/src/trainer/unlearn/pdu.py @@ -102,7 +102,9 @@ def post_epoch_dual_param_update(self): ) self.log({"retain_preference": self.preferences[1]}) - def compute_loss(self, model, inputs, return_outputs=False): + def compute_unlearn_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): forget_inputs = inputs["forget"] forget_inputs = { "input_ids": forget_inputs["input_ids"], diff --git a/src/trainer/unlearn/rmu.py b/src/trainer/unlearn/rmu.py index d990d3a38..7be48ce26 100644 --- a/src/trainer/unlearn/rmu.py +++ b/src/trainer/unlearn/rmu.py @@ -136,7 +136,9 @@ def compute_retain_loss(self, model, retain_inputs): retain_loss = super().compute_retain_loss(model, retain_inputs) return retain_loss - def compute_loss(self, model, inputs, return_outputs=False): + def compute_unlearn_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): forget_inputs = inputs["forget"] forget_inputs = { "input_ids": forget_inputs["input_ids"], diff --git a/src/trainer/unlearn/satimp.py b/src/trainer/unlearn/satimp.py index b664390cd..102afa584 100644 --- a/src/trainer/unlearn/satimp.py +++ b/src/trainer/unlearn/satimp.py @@ -14,7 +14,9 @@ def __init__( if self.ref_model is None: self.ref_model = self._prepare_ref_model(self.model) - def compute_loss(self, model, inputs, return_outputs=False): + def compute_unlearn_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): forget_inputs = inputs["forget"] forget_inputs = { "input_ids": forget_inputs["input_ids"], diff --git a/src/trainer/unlearn/simnpo.py b/src/trainer/unlearn/simnpo.py index cb4f7f99c..9885f01d9 100644 --- a/src/trainer/unlearn/simnpo.py +++ b/src/trainer/unlearn/simnpo.py @@ -10,7 +10,9 @@ def __init__(self, delta=0.0, beta=1.0, *args, **kwargs): self.delta = delta self.beta = beta - def compute_loss(self, model, inputs, return_outputs=False): + def compute_unlearn_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): forget_inputs = inputs["forget"] forget_labels = forget_inputs["labels"] diff --git a/src/trainer/unlearn/undial.py b/src/trainer/unlearn/undial.py index e32147b30..e1b21c02a 100644 --- a/src/trainer/unlearn/undial.py +++ b/src/trainer/unlearn/undial.py @@ -9,7 +9,9 @@ def __init__(self, beta=1.0, *args, **kwargs): if self.ref_model is None: self.ref_model = self._prepare_ref_model(self.model) - def compute_loss(self, model, inputs, return_outputs=False): + def compute_unlearn_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): forget_inputs = inputs["forget"] forget_loss, forget_outputs = compute_undial_loss( model, self.ref_model, forget_inputs, self.beta diff --git a/src/trainer/unlearn/wga.py b/src/trainer/unlearn/wga.py index 08c4bf402..4dcf00796 100644 --- a/src/trainer/unlearn/wga.py +++ b/src/trainer/unlearn/wga.py @@ -11,7 +11,9 @@ def __init__(self, beta=1.0, gamma=1.0, alpha=1.0, *args, **kwargs): if self.ref_model is None: self.ref_model = self._prepare_ref_model(self.model) - def compute_loss(self, model, inputs, return_outputs=False): + def compute_unlearn_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): forget_inputs = inputs["forget"] forget_inputs = { "input_ids": forget_inputs["input_ids"], From 31546d3a9389eefda7b430b551be334c9a5f4ee1 Mon Sep 17 00:00:00 2001 From: Filip Sondej <26285707+filyp@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:21:14 +0100 Subject: [PATCH 02/12] simplify installation fix leaderboard docs wider gitignore --- .gitignore | 12 ++++++++++++ README.md | 2 +- community/leaderboard.md | 2 +- requirements.txt | 1 + setup.py | 5 +---- src/trainer/unlearn/base.py | 31 ++----------------------------- 6 files changed, 18 insertions(+), 35 deletions(-) diff --git a/.gitignore b/.gitignore index b526ea4c7..2999bc88c 100644 --- a/.gitignore +++ b/.gitignore @@ -181,3 +181,15 @@ cython_debug/ .idea/ .vscode/ + +multirun/ +CLAUDE.md +slurm-*.out +job.sh +profile.prof +.aim/ +*.lprof + +tmp_comm/ +**/*.pkl +community/benchmarks/wmdp_low_mi/plots/*.pdf diff --git a/README.md b/README.md index 8bde09ce7..cd1afcad0 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,7 @@ We provide several variants for each of the components in the unlearning pipelin # Environment setup conda create -n unlearning python=3.11 conda activate unlearning -pip install .[lm_eval] +pip install . pip install --no-build-isolation flash-attn==2.6.3 # Data setup diff --git a/community/leaderboard.md b/community/leaderboard.md index ffb515ab7..ab9aad8f5 100644 --- a/community/leaderboard.md +++ b/community/leaderboard.md @@ -9,7 +9,7 @@ We encourage the community to develop new methods, optimize them for specific be To implement a new method, refer to our [contributing guide](../docs/contributing.md). > [!NOTE] -> The [results.md](../docs/results.md) file is maintained for reproducibility purposes. However, we encourage contributors to update the leaderboard table instead of the reproducibility table. We will continue refining and tuning baseline methods to keep the leaderboard up to date. +> The [results.md](../docs/repro.md) file is maintained for reproducibility purposes. However, we encourage contributors to update the leaderboard table instead of the reproducibility table. We will continue refining and tuning baseline methods to keep the leaderboard up to date. ### TOFU unlearning on the `Llama-2-7b-hf-chat` architecture diff --git a/requirements.txt b/requirements.txt index 58a47633b..65208ab3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ huggingface-hub==0.36.0 transformers==4.51.3 hf-xet==1.2.0 +lm-eval==0.4.8 numpy==2.2.3 hydra-core==1.3 hydra_colorlog==1.2.0 diff --git a/setup.py b/setup.py index 6a5c99c7e..5e2bced38 100644 --- a/setup.py +++ b/setup.py @@ -17,13 +17,10 @@ packages=find_packages(), install_requires=requirements, # Uses requirements.txt extras_require={ - "lm-eval": [ - "lm-eval==0.4.8", - ], # Install using `pip install .[lm-eval]` "dev": [ "pre-commit==4.0.1", "ruff==0.6.9", - ], # Install using `pip install .[dev]` + ], # Install using `pip install ".[dev]"` }, python_requires=">=3.11", ) diff --git a/src/trainer/unlearn/base.py b/src/trainer/unlearn/base.py index 3e358130a..c0411bdce 100644 --- a/src/trainer/unlearn/base.py +++ b/src/trainer/unlearn/base.py @@ -1,35 +1,8 @@ -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from torch import nn from copy import deepcopy -from packaging import version -from trainer.base import FinetuneTrainer - -from transformers.trainer_pt_utils import ( - nested_detach, -) - -from transformers.utils import ( - is_sagemaker_mp_enabled, -) +from accelerate.utils import is_deepspeed_available -from accelerate.utils import ( - is_deepspeed_available, -) - -if is_sagemaker_mp_enabled(): - from smdistributed.modelparallel import __version__ as SMP_VERSION - - IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") - - from transformers.trainer_pt_utils import ( - smp_forward_only, - smp_nested_concat, - ) -else: - IS_SAGEMAKER_MP_POST_1_10 = False +from trainer.base import FinetuneTrainer if is_deepspeed_available(): import deepspeed From 9987ee3e6ad8963bd15d9354ba7c278b6ba3879d Mon Sep 17 00:00:00 2001 From: Filip Sondej <26285707+filyp@users.noreply.github.com> Date: Mon, 26 Jan 2026 20:19:32 +0100 Subject: [PATCH 03/12] make FinetuneTrainer more readable --- src/trainer/base.py | 56 +++++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 30 deletions(-) diff --git a/src/trainer/base.py b/src/trainer/base.py index 3d30a7b87..2fada5ad6 100644 --- a/src/trainer/base.py +++ b/src/trainer/base.py @@ -1,13 +1,12 @@ # Modified from https://github.com/huggingface/transformers/blob/v4.45.1/src/transformers/trainer.py -from typing import Dict, List, Optional, Union - -import os import logging -from transformers import Trainer +import os +from typing import Any, Dict, List, Optional, Union + from torch.utils.data import Dataset +from transformers import Trainer from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR -from typing import Any logger = logging.getLogger(__name__) @@ -26,31 +25,28 @@ def evaluate( trial: Dict[str, Any] = None, ) -> Dict[str, float]: # Run a custom evaluator and save results - if self.evaluators: - if self.accelerator.is_local_main_process: - eval_metrics = {} - if self.accelerator.num_processes == 1: - run_dir = self._get_output_dir(trial=trial) - checkpoint_folder = ( - f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" - ) - output_dir = os.path.join(run_dir, checkpoint_folder, "evals") - os.makedirs(output_dir, exist_ok=True) - eval_metrics = {} - for _, evaluator in self.evaluators.items(): - eval_args = { - "output_dir": output_dir, - "template_args": self.template_args, - "model": self.model, - "tokenizer": self.tokenizer, - } - eval_metrics.update(evaluator.evaluate(**eval_args)) - self.log(eval_metrics) - else: - logger.warning( - "Custom evaluator can be run with this Trainer only when a single accelerator process is running." - ) - return eval_metrics + if self.evaluators and self.accelerator.is_local_main_process: + if self.accelerator.num_processes != 1: + logger.warning( + "Custom evaluator can be run with this Trainer only when a single accelerator process is running." + ) + return {} + + run_dir = self._get_output_dir(trial=trial) + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + output_dir = os.path.join(run_dir, checkpoint_folder, "evals") + os.makedirs(output_dir, exist_ok=True) + eval_metrics = {} + for _, evaluator in self.evaluators.items(): + eval_args = { + "output_dir": output_dir, + "template_args": self.template_args, + "model": self.model, + "tokenizer": self.tokenizer, + } + eval_metrics.update(evaluator.evaluate(**eval_args)) + self.log(eval_metrics) + return eval_metrics if eval_dataset is None or eval_dataset == "dummy": return {} From 4b438483c9048a760c6ae5aca1ce9acd144dad7e Mon Sep 17 00:00:00 2001 From: Filip Sondej <26285707+filyp@users.noreply.github.com> Date: Thu, 29 Jan 2026 10:46:09 +0100 Subject: [PATCH 04/12] use trainer.processing_class instead of trainer.tokenizer, to support transformers==5 --- src/train.py | 2 +- src/trainer/__init__.py | 4 ++-- src/trainer/base.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/train.py b/src/train.py index 72bc95b82..290342e99 100644 --- a/src/train.py +++ b/src/train.py @@ -50,7 +50,7 @@ def main(cfg: DictConfig): model=model, train_dataset=data.get("train", None), eval_dataset=data.get("eval", "dummy"), # None would trigger Trainer exception - tokenizer=tokenizer, + processing_class=tokenizer, data_collator=collator, evaluators=evaluators, template_args=template_args, diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py index 0bb7aa3b5..447b2d2dc 100644 --- a/src/trainer/__init__.py +++ b/src/trainer/__init__.py @@ -50,7 +50,7 @@ def load_trainer( model, train_dataset=None, eval_dataset=None, - tokenizer=None, + processing_class=None, data_collator=None, evaluators=None, template_args=None, @@ -70,7 +70,7 @@ def load_trainer( model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, - tokenizer=tokenizer, + processing_class=processing_class, data_collator=data_collator, args=trainer_args, evaluators=evaluators, diff --git a/src/trainer/base.py b/src/trainer/base.py index 2fada5ad6..f424bb0fc 100644 --- a/src/trainer/base.py +++ b/src/trainer/base.py @@ -42,7 +42,7 @@ def evaluate( "output_dir": output_dir, "template_args": self.template_args, "model": self.model, - "tokenizer": self.tokenizer, + "tokenizer": self.processing_class, } eval_metrics.update(evaluator.evaluate(**eval_args)) self.log(eval_metrics) From e180d7fd0749081ef943159bcac10d7506144abb Mon Sep 17 00:00:00 2001 From: Filip Sondej <26285707+filyp@users.noreply.github.com> Date: Wed, 11 Feb 2026 14:01:21 +0100 Subject: [PATCH 05/12] fix ruff linter errors --- .gitignore | 1 + src/evals/metrics/utils.py | 6 +++--- src/train.py | 2 +- src/trainer/unlearn/pdu.py | 15 ++++++++------- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 2999bc88c..f3948ebfa 100644 --- a/.gitignore +++ b/.gitignore @@ -193,3 +193,4 @@ profile.prof tmp_comm/ **/*.pkl community/benchmarks/wmdp_low_mi/plots/*.pdf +**/tmp.py diff --git a/src/evals/metrics/utils.py b/src/evals/metrics/utils.py index 71684b7de..7e27d2f5c 100644 --- a/src/evals/metrics/utils.py +++ b/src/evals/metrics/utils.py @@ -63,9 +63,9 @@ def run_batchwise_evals(model, dataloader, batch_eval_fn, batch_eval_fn_args, ev model=model, batch=mini_batch, **batch_eval_fn_args ) indexwise_batch_evals = dict(zip(data_indices, batch_evals)) - assert not ( - evals[intra_item_idx].keys() & indexwise_batch_evals.keys() - ), "Data indices repeated while iterating dataloader" + assert not (evals[intra_item_idx].keys() & indexwise_batch_evals.keys()), ( + "Data indices repeated while iterating dataloader" + ) evals[intra_item_idx] |= indexwise_batch_evals # evals looks like {iidx0: {idx453: {prob: 0.1, loss: 1}}, # iidx1: {idx453: {prob: 0.2, loss: 2}}} diff --git a/src/train.py b/src/train.py index 290342e99..b88fcdcc9 100644 --- a/src/train.py +++ b/src/train.py @@ -49,7 +49,7 @@ def main(cfg: DictConfig): trainer_cfg=trainer_cfg, model=model, train_dataset=data.get("train", None), - eval_dataset=data.get("eval", "dummy"), # None would trigger Trainer exception + eval_dataset=data.get("eval", "dummy"), # None would trigger Trainer exception processing_class=tokenizer, data_collator=collator, evaluators=evaluators, diff --git a/src/trainer/unlearn/pdu.py b/src/trainer/unlearn/pdu.py index e86c7b69f..bbd9db1e7 100644 --- a/src/trainer/unlearn/pdu.py +++ b/src/trainer/unlearn/pdu.py @@ -1,7 +1,8 @@ import torch -from trainer.unlearn.grad_diff import GradDiff from transformers import TrainerCallback +from trainer.unlearn.grad_diff import GradDiff + class PDU(GradDiff): def __init__( @@ -38,9 +39,9 @@ def enable_updates(self): self.can_update = True def final_loss_value(self, losses): - assert len(losses) == len( - self.preferences - ), f"Expected {len(self.preferences)} losses, but got {len(losses)} losses." + assert len(losses) == len(self.preferences), ( + f"Expected {len(self.preferences)} losses, but got {len(losses)} losses." + ) # Shift the retain_loss for the primal dual method. # If no primal-dual method is used, gradient-based methods will not suffer @@ -73,9 +74,9 @@ def final_loss_value(self, losses): @torch.no_grad() def post_epoch_dual_param_update(self): - assert ( - self.primal_dual - ), "Dual parameter update requires primal dual to be enabled" + assert self.primal_dual, ( + "Dual parameter update requires primal dual to be enabled" + ) # Get the training dataloader dataloader = self.get_train_dataloader() From 1a3b7c815ad086c14b35b292c4844ae0e2ff3bb5 Mon Sep 17 00:00:00 2001 From: Filip Sondej <26285707+filyp@users.noreply.github.com> Date: Wed, 18 Feb 2026 18:20:33 +0100 Subject: [PATCH 06/12] format for ruff==0.6.6 --- src/evals/metrics/utils.py | 6 +++--- src/trainer/unlearn/pdu.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/evals/metrics/utils.py b/src/evals/metrics/utils.py index 7e27d2f5c..71684b7de 100644 --- a/src/evals/metrics/utils.py +++ b/src/evals/metrics/utils.py @@ -63,9 +63,9 @@ def run_batchwise_evals(model, dataloader, batch_eval_fn, batch_eval_fn_args, ev model=model, batch=mini_batch, **batch_eval_fn_args ) indexwise_batch_evals = dict(zip(data_indices, batch_evals)) - assert not (evals[intra_item_idx].keys() & indexwise_batch_evals.keys()), ( - "Data indices repeated while iterating dataloader" - ) + assert not ( + evals[intra_item_idx].keys() & indexwise_batch_evals.keys() + ), "Data indices repeated while iterating dataloader" evals[intra_item_idx] |= indexwise_batch_evals # evals looks like {iidx0: {idx453: {prob: 0.1, loss: 1}}, # iidx1: {idx453: {prob: 0.2, loss: 2}}} diff --git a/src/trainer/unlearn/pdu.py b/src/trainer/unlearn/pdu.py index bbd9db1e7..fd6bd45d0 100644 --- a/src/trainer/unlearn/pdu.py +++ b/src/trainer/unlearn/pdu.py @@ -39,9 +39,9 @@ def enable_updates(self): self.can_update = True def final_loss_value(self, losses): - assert len(losses) == len(self.preferences), ( - f"Expected {len(self.preferences)} losses, but got {len(losses)} losses." - ) + assert len(losses) == len( + self.preferences + ), f"Expected {len(self.preferences)} losses, but got {len(losses)} losses." # Shift the retain_loss for the primal dual method. # If no primal-dual method is used, gradient-based methods will not suffer @@ -74,9 +74,9 @@ def final_loss_value(self, losses): @torch.no_grad() def post_epoch_dual_param_update(self): - assert self.primal_dual, ( - "Dual parameter update requires primal dual to be enabled" - ) + assert ( + self.primal_dual + ), "Dual parameter update requires primal dual to be enabled" # Get the training dataloader dataloader = self.get_train_dataloader() From 367fee937430da1e9c04ad740659eeb9142c9589 Mon Sep 17 00:00:00 2001 From: Filip Sondej <26285707+filyp@users.noreply.github.com> Date: Sun, 22 Feb 2026 11:33:48 +0100 Subject: [PATCH 07/12] preserve compute_loss backwards compatibility --- .gitignore | 2 +- requirements.txt | 1 + runners/modal_runner.py | 45 ++++++++++++++++++++++++++++++ src/trainer/unlearn/base.py | 24 ++++++++++++---- src/trainer/unlearn/ceu.py | 2 +- src/trainer/unlearn/dpo.py | 2 +- src/trainer/unlearn/grad_ascent.py | 2 +- src/trainer/unlearn/grad_diff.py | 2 +- src/trainer/unlearn/npo.py | 2 +- src/trainer/unlearn/pdu.py | 2 +- src/trainer/unlearn/rmu.py | 2 +- src/trainer/unlearn/satimp.py | 2 +- src/trainer/unlearn/simnpo.py | 2 +- src/trainer/unlearn/undial.py | 2 +- src/trainer/unlearn/wga.py | 2 +- 15 files changed, 76 insertions(+), 18 deletions(-) create mode 100644 runners/modal_runner.py diff --git a/.gitignore b/.gitignore index f3948ebfa..3899a36d4 100644 --- a/.gitignore +++ b/.gitignore @@ -193,4 +193,4 @@ profile.prof tmp_comm/ **/*.pkl community/benchmarks/wmdp_low_mi/plots/*.pdf -**/tmp.py +**/tmp.py \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 65208ab3e..bc6f5b725 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ scipy==1.14.1 tensorboard==2.18.0 scikit-learn==1.5.2 deepspeed==0.15.4 +wandb==0.21.4 diff --git a/runners/modal_runner.py b/runners/modal_runner.py new file mode 100644 index 000000000..e8dc02b09 --- /dev/null +++ b/runners/modal_runner.py @@ -0,0 +1,45 @@ +# Run from repo root: +# modal run runners/modal_runner.py "python3 src/train.py --config-name=unlearn.yaml experiment=unlearn/wmdp_low_mi/default trainer=CIR task_name=test" + +import subprocess + +import modal + +image = ( + modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.11") + .apt_install("git") + .pip_install_from_requirements("requirements.txt") + .pip_install("flash-attn==2.6.3", extra_options="--no-build-isolation") + # if we move to torch>2.5, we need to use pre-built wheels from here, because the build is painfully slow + # also to support B200, flash-attn==2.6.3 is too old, we'd need to bump to e.g. 2.8.3 + # .pip_install( + # "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.12/flash_attn-2.8.3+cu128torch2.10-cp311-cp311-linux_x86_64.whl" + # ) + .add_local_dir("data", remote_path="/root/code/data") + .add_local_dir(".cache/load_hf", remote_path="/root/code/.cache/load_hf") + .add_local_dir("configs", remote_path="/root/code/configs") + .add_local_dir("src", remote_path="/root/code/src") +) + +app = modal.App("open-unlearning", image=image) + + +@app.function( + gpu="L40S", # 48GB + # gpu="A100-80GB", # if needing 80GB + # gpu="H100", + # gpu="H200", + # gpu="B200", + timeout=1 * 3600, + secrets=[modal.Secret.from_dotenv()], +) +def run_training(args: str): + cmd = f"cd /root/code && HF_HUB_DOWNLOAD_TIMEOUT=60 PYTHONUNBUFFERED=1 {args}" + + print(f"Running: {cmd}") + subprocess.run(cmd, shell=True, executable="/bin/bash", check=True) + + +@app.local_entrypoint() +def main(args: str): + run_training.remote(args) \ No newline at end of file diff --git a/src/trainer/unlearn/base.py b/src/trainer/unlearn/base.py index c0411bdce..d6713c597 100644 --- a/src/trainer/unlearn/base.py +++ b/src/trainer/unlearn/base.py @@ -9,6 +9,24 @@ class UnlearnTrainer(FinetuneTrainer): + def prediction_step(self, *args, **kwargs): + """Use standard loss during evaluation, not the unlearn loss. + + Subclasses override compute_loss() with custom unlearn logic, but that + should only run during training. During evaluation (via prediction_step), + we temporarily swap in the base FinetuneTrainer.compute_loss so the + standard cross-entropy loss is used instead. This avoids having to copy + the full prediction_step implementation from transformers, making it + robust across library versions. + """ + custom_compute_loss = self.compute_loss + # __get__ binds the unbound class method to this instance, so it receives `self` + self.compute_loss = FinetuneTrainer.compute_loss.__get__(self, type(self)) + try: + return super().prediction_step(*args, **kwargs) + finally: + self.compute_loss = custom_compute_loss + # Adapted from Huggingface DPO Trainer: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 def _prepare_deepspeed(self, model): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 @@ -47,9 +65,3 @@ def _prepare_deepspeed(self, model): model, *_ = deepspeed.initialize(model=model, config=config_kwargs) model.eval() return model - - def compute_loss(self, model, inputs, **kwargs): - if model.training: - return self.compute_unlearn_loss(model, inputs, **kwargs) - else: - return super().compute_loss(model, inputs, **kwargs) diff --git a/src/trainer/unlearn/ceu.py b/src/trainer/unlearn/ceu.py index a32caaf06..90b069d47 100644 --- a/src/trainer/unlearn/ceu.py +++ b/src/trainer/unlearn/ceu.py @@ -86,7 +86,7 @@ def __init__(self, ignore_first_n_answer_tokens=1, *args, **kwargs): super().__init__(*args, **kwargs) self.ignore_first_n_answer_tokens = ignore_first_n_answer_tokens - def compute_unlearn_loss( + def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None ): forget_inputs = inputs["forget"] diff --git a/src/trainer/unlearn/dpo.py b/src/trainer/unlearn/dpo.py index 921220662..81ab39c35 100644 --- a/src/trainer/unlearn/dpo.py +++ b/src/trainer/unlearn/dpo.py @@ -9,7 +9,7 @@ def __init__(self, beta=1.0, *args, **kwargs): if self.ref_model is None: self.ref_model = self._prepare_ref_model(self.model) - def compute_unlearn_loss( + def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None ): forget_inputs = inputs["forget"]["original"] diff --git a/src/trainer/unlearn/grad_ascent.py b/src/trainer/unlearn/grad_ascent.py index 6ab1944a9..ccb7dabab 100644 --- a/src/trainer/unlearn/grad_ascent.py +++ b/src/trainer/unlearn/grad_ascent.py @@ -2,7 +2,7 @@ class GradAscent(UnlearnTrainer): - def compute_unlearn_loss( + def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None ): forget_inputs = inputs["forget"] diff --git a/src/trainer/unlearn/grad_diff.py b/src/trainer/unlearn/grad_diff.py index d51a672d9..53e3a5b0e 100644 --- a/src/trainer/unlearn/grad_diff.py +++ b/src/trainer/unlearn/grad_diff.py @@ -38,7 +38,7 @@ def compute_retain_loss(self, model, retain_inputs): ) return retain_loss - def compute_unlearn_loss( + def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None ): forget_inputs = inputs["forget"] diff --git a/src/trainer/unlearn/npo.py b/src/trainer/unlearn/npo.py index 12851e418..3996f8ecf 100644 --- a/src/trainer/unlearn/npo.py +++ b/src/trainer/unlearn/npo.py @@ -9,7 +9,7 @@ def __init__(self, beta=1.0, *args, **kwargs): if self.ref_model is None: self.ref_model = self._prepare_ref_model(self.model) - def compute_unlearn_loss( + def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None ): forget_inputs = inputs["forget"] diff --git a/src/trainer/unlearn/pdu.py b/src/trainer/unlearn/pdu.py index fd6bd45d0..773d2159f 100644 --- a/src/trainer/unlearn/pdu.py +++ b/src/trainer/unlearn/pdu.py @@ -103,7 +103,7 @@ def post_epoch_dual_param_update(self): ) self.log({"retain_preference": self.preferences[1]}) - def compute_unlearn_loss( + def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None ): forget_inputs = inputs["forget"] diff --git a/src/trainer/unlearn/rmu.py b/src/trainer/unlearn/rmu.py index 7be48ce26..77bce2897 100644 --- a/src/trainer/unlearn/rmu.py +++ b/src/trainer/unlearn/rmu.py @@ -136,7 +136,7 @@ def compute_retain_loss(self, model, retain_inputs): retain_loss = super().compute_retain_loss(model, retain_inputs) return retain_loss - def compute_unlearn_loss( + def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None ): forget_inputs = inputs["forget"] diff --git a/src/trainer/unlearn/satimp.py b/src/trainer/unlearn/satimp.py index 102afa584..f42d4acbb 100644 --- a/src/trainer/unlearn/satimp.py +++ b/src/trainer/unlearn/satimp.py @@ -14,7 +14,7 @@ def __init__( if self.ref_model is None: self.ref_model = self._prepare_ref_model(self.model) - def compute_unlearn_loss( + def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None ): forget_inputs = inputs["forget"] diff --git a/src/trainer/unlearn/simnpo.py b/src/trainer/unlearn/simnpo.py index 9885f01d9..dc96ee29d 100644 --- a/src/trainer/unlearn/simnpo.py +++ b/src/trainer/unlearn/simnpo.py @@ -10,7 +10,7 @@ def __init__(self, delta=0.0, beta=1.0, *args, **kwargs): self.delta = delta self.beta = beta - def compute_unlearn_loss( + def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None ): forget_inputs = inputs["forget"] diff --git a/src/trainer/unlearn/undial.py b/src/trainer/unlearn/undial.py index e1b21c02a..d7c1d77bb 100644 --- a/src/trainer/unlearn/undial.py +++ b/src/trainer/unlearn/undial.py @@ -9,7 +9,7 @@ def __init__(self, beta=1.0, *args, **kwargs): if self.ref_model is None: self.ref_model = self._prepare_ref_model(self.model) - def compute_unlearn_loss( + def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None ): forget_inputs = inputs["forget"] diff --git a/src/trainer/unlearn/wga.py b/src/trainer/unlearn/wga.py index 4dcf00796..f0b19cbae 100644 --- a/src/trainer/unlearn/wga.py +++ b/src/trainer/unlearn/wga.py @@ -11,7 +11,7 @@ def __init__(self, beta=1.0, gamma=1.0, alpha=1.0, *args, **kwargs): if self.ref_model is None: self.ref_model = self._prepare_ref_model(self.model) - def compute_unlearn_loss( + def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None ): forget_inputs = inputs["forget"] From d6f1f1a4d119445143e576f9b42452374b3ba37d Mon Sep 17 00:00:00 2001 From: Filip Sondej <26285707+filyp@users.noreply.github.com> Date: Sun, 22 Feb 2026 11:43:39 +0100 Subject: [PATCH 08/12] revert gitignore --- .gitignore | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index 3899a36d4..e0254fa5d 100644 --- a/.gitignore +++ b/.gitignore @@ -180,17 +180,4 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ -.vscode/ - -multirun/ -CLAUDE.md -slurm-*.out -job.sh -profile.prof -.aim/ -*.lprof - -tmp_comm/ -**/*.pkl -community/benchmarks/wmdp_low_mi/plots/*.pdf -**/tmp.py \ No newline at end of file +.vscode/ \ No newline at end of file From f1d25726fae15248a0a9401f5c70576643ff4713 Mon Sep 17 00:00:00 2001 From: Filip Sondej <26285707+filyp@users.noreply.github.com> Date: Sun, 22 Feb 2026 14:59:34 +0100 Subject: [PATCH 09/12] hide the trainer fix that allows for evaluating when eval_dataset=None --- docs/components.md | 2 +- src/train.py | 2 +- src/trainer/base.py | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/components.md b/docs/components.md index 9f4031439..107827634 100644 --- a/docs/components.md +++ b/docs/components.md @@ -38,7 +38,7 @@ class GradDiff(UnlearnTrainer): def __init__(self, gamma, alpha, ...): ... - def compute_unlearn_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): ... ``` diff --git a/src/train.py b/src/train.py index b88fcdcc9..9cc6abcb5 100644 --- a/src/train.py +++ b/src/train.py @@ -49,7 +49,7 @@ def main(cfg: DictConfig): trainer_cfg=trainer_cfg, model=model, train_dataset=data.get("train", None), - eval_dataset=data.get("eval", "dummy"), # None would trigger Trainer exception + eval_dataset=data.get("eval", None), processing_class=tokenizer, data_collator=collator, evaluators=evaluators, diff --git a/src/trainer/base.py b/src/trainer/base.py index f424bb0fc..d5e9f0f0d 100644 --- a/src/trainer/base.py +++ b/src/trainer/base.py @@ -15,6 +15,10 @@ class FinetuneTrainer(Trainer): def __init__(self, evaluators=None, template_args=None, *args, **kwargs): self.evaluators = evaluators self.template_args = template_args + # When using custom evaluators without an eval dataset, pass a dummy value + # to prevent Trainer from raising on eval_dataset=None when eval_strategy is set + if kwargs.get("eval_dataset") is None and evaluators: + kwargs["eval_dataset"] = "dummy" super().__init__(*args, **kwargs) def evaluate( From d3dfad67028f6c1fd00462b1d3dc2cf50197eac3 Mon Sep 17 00:00:00 2001 From: Filip Sondej <26285707+filyp@users.noreply.github.com> Date: Fri, 6 Mar 2026 19:07:40 +0100 Subject: [PATCH 10/12] apply changes from the review --- README.md | 2 +- requirements.txt | 1 - runners/modal_runner.py | 45 ---------- setup.py | 3 + src/trainer/base.py | 10 ++- src/trainer/unlearn/base.py | 159 +++++++++++++++++++++++++++++++----- 6 files changed, 149 insertions(+), 71 deletions(-) delete mode 100644 runners/modal_runner.py diff --git a/README.md b/README.md index cd1afcad0..f66a325ee 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,7 @@ We provide several variants for each of the components in the unlearning pipelin # Environment setup conda create -n unlearning python=3.11 conda activate unlearning -pip install . +pip install ".[lm_eval]" pip install --no-build-isolation flash-attn==2.6.3 # Data setup diff --git a/requirements.txt b/requirements.txt index bc6f5b725..5186a418d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ huggingface-hub==0.36.0 transformers==4.51.3 hf-xet==1.2.0 -lm-eval==0.4.8 numpy==2.2.3 hydra-core==1.3 hydra_colorlog==1.2.0 diff --git a/runners/modal_runner.py b/runners/modal_runner.py deleted file mode 100644 index e8dc02b09..000000000 --- a/runners/modal_runner.py +++ /dev/null @@ -1,45 +0,0 @@ -# Run from repo root: -# modal run runners/modal_runner.py "python3 src/train.py --config-name=unlearn.yaml experiment=unlearn/wmdp_low_mi/default trainer=CIR task_name=test" - -import subprocess - -import modal - -image = ( - modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.11") - .apt_install("git") - .pip_install_from_requirements("requirements.txt") - .pip_install("flash-attn==2.6.3", extra_options="--no-build-isolation") - # if we move to torch>2.5, we need to use pre-built wheels from here, because the build is painfully slow - # also to support B200, flash-attn==2.6.3 is too old, we'd need to bump to e.g. 2.8.3 - # .pip_install( - # "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.12/flash_attn-2.8.3+cu128torch2.10-cp311-cp311-linux_x86_64.whl" - # ) - .add_local_dir("data", remote_path="/root/code/data") - .add_local_dir(".cache/load_hf", remote_path="/root/code/.cache/load_hf") - .add_local_dir("configs", remote_path="/root/code/configs") - .add_local_dir("src", remote_path="/root/code/src") -) - -app = modal.App("open-unlearning", image=image) - - -@app.function( - gpu="L40S", # 48GB - # gpu="A100-80GB", # if needing 80GB - # gpu="H100", - # gpu="H200", - # gpu="B200", - timeout=1 * 3600, - secrets=[modal.Secret.from_dotenv()], -) -def run_training(args: str): - cmd = f"cd /root/code && HF_HUB_DOWNLOAD_TIMEOUT=60 PYTHONUNBUFFERED=1 {args}" - - print(f"Running: {cmd}") - subprocess.run(cmd, shell=True, executable="/bin/bash", check=True) - - -@app.local_entrypoint() -def main(args: str): - run_training.remote(args) \ No newline at end of file diff --git a/setup.py b/setup.py index 5e2bced38..0972a9701 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,9 @@ packages=find_packages(), install_requires=requirements, # Uses requirements.txt extras_require={ + "lm-eval": [ + "lm-eval==0.4.11", + ], # Install using `pip install ".[lm-eval]"` "dev": [ "pre-commit==4.0.1", "ruff==0.6.9", diff --git a/src/trainer/base.py b/src/trainer/base.py index d5e9f0f0d..b10492e2e 100644 --- a/src/trainer/base.py +++ b/src/trainer/base.py @@ -10,15 +10,17 @@ logger = logging.getLogger(__name__) +# When using custom evaluators without an eval dataset, pass a dummy value +# to prevent Trainer from raising on eval_dataset=None when eval_strategy is set +_EVAL_PLACEHOLDER = "_EVAL_PLACEHOLDER" + class FinetuneTrainer(Trainer): def __init__(self, evaluators=None, template_args=None, *args, **kwargs): self.evaluators = evaluators self.template_args = template_args - # When using custom evaluators without an eval dataset, pass a dummy value - # to prevent Trainer from raising on eval_dataset=None when eval_strategy is set if kwargs.get("eval_dataset") is None and evaluators: - kwargs["eval_dataset"] = "dummy" + kwargs["eval_dataset"] = _EVAL_PLACEHOLDER super().__init__(*args, **kwargs) def evaluate( @@ -52,7 +54,7 @@ def evaluate( self.log(eval_metrics) return eval_metrics - if eval_dataset is None or eval_dataset == "dummy": + if eval_dataset is None or eval_dataset == _EVAL_PLACEHOLDER: return {} # Run the default HF Trainer evaluate method when eval dataset is provided return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix) diff --git a/src/trainer/unlearn/base.py b/src/trainer/unlearn/base.py index d6713c597..753ff94e1 100644 --- a/src/trainer/unlearn/base.py +++ b/src/trainer/unlearn/base.py @@ -1,32 +1,41 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn from copy import deepcopy +from packaging import version +from trainer.base import FinetuneTrainer -from accelerate.utils import is_deepspeed_available +from transformers.trainer_pt_utils import ( + nested_detach, +) -from trainer.base import FinetuneTrainer + +from transformers.utils import ( + is_sagemaker_mp_enabled, +) + +from accelerate.utils import ( + is_deepspeed_available, +) + +if is_sagemaker_mp_enabled(): + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") + + from transformers.trainer_pt_utils import ( + smp_forward_only, + smp_nested_concat, + ) +else: + IS_SAGEMAKER_MP_POST_1_10 = False if is_deepspeed_available(): import deepspeed class UnlearnTrainer(FinetuneTrainer): - def prediction_step(self, *args, **kwargs): - """Use standard loss during evaluation, not the unlearn loss. - - Subclasses override compute_loss() with custom unlearn logic, but that - should only run during training. During evaluation (via prediction_step), - we temporarily swap in the base FinetuneTrainer.compute_loss so the - standard cross-entropy loss is used instead. This avoids having to copy - the full prediction_step implementation from transformers, making it - robust across library versions. - """ - custom_compute_loss = self.compute_loss - # __get__ binds the unbound class method to this instance, so it receives `self` - self.compute_loss = FinetuneTrainer.compute_loss.__get__(self, type(self)) - try: - return super().prediction_step(*args, **kwargs) - finally: - self.compute_loss = custom_compute_loss - # Adapted from Huggingface DPO Trainer: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 def _prepare_deepspeed(self, model): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 @@ -65,3 +74,113 @@ def _prepare_deepspeed(self, model): model, *_ = deepspeed.initialize(model=model, config=config_kwargs) model.eval() return model + + def prediction_step( + self, + model: nn.Module, + inputs: dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[list[str]] = None, + ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + The only change to this function is calling the Trainer's compute_loss, as it's often overridden by unlearning methods, and we want to maintain the Trainer's evaluation setup. + """ + has_labels = ( + False + if len(self.label_names) == 0 + else all(inputs.get(k) is not None for k in self.label_names) + ) + # For CLIP-like models capable of returning loss values. + # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` + # is `True` in `model.forward`. + return_loss = inputs.get("return_loss", None) + if return_loss is None: + return_loss = self.can_return_loss + loss_without_labels = ( + True if len(self.label_names) == 0 and return_loss else False + ) + + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr( + self.model.config, + "keys_to_ignore_at_inference", + ["past_key_values"], + ) + else: + ignore_keys = [] + + # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. + if has_labels or loss_without_labels: + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + with torch.no_grad(): + if is_sagemaker_mp_enabled(): + raw_outputs = smp_forward_only(model, inputs) + if has_labels or loss_without_labels: + if isinstance(raw_outputs, dict): + loss_mb = raw_outputs["loss"] + logits_mb = tuple( + v + for k, v in raw_outputs.items() + if k not in ignore_keys + ["loss"] + ) + else: + loss_mb = raw_outputs[0] + logits_mb = raw_outputs[1:] + + loss = loss_mb.reduce_mean().detach().cpu() + logits = smp_nested_concat(logits_mb) + else: + loss = None + if isinstance(raw_outputs, dict): + logits_mb = tuple( + v for k, v in raw_outputs.items() if k not in ignore_keys + ) + else: + logits_mb = raw_outputs + logits = smp_nested_concat(logits_mb) + else: + if has_labels or loss_without_labels: + with self.compute_loss_context_manager(): + ### Call compute_loss of super class since overridden compute_loss is not applicable to eval_dataset. + loss, outputs = super().compute_loss( + model, inputs, return_outputs=True + ) + loss = loss.detach().mean() + + if isinstance(outputs, dict): + logits = tuple( + v + for k, v in outputs.items() + if k not in ignore_keys + ["loss"] + ) + else: + logits = outputs[1:] + else: + loss = None + with self.compute_loss_context_manager(): + outputs = model(**inputs) + if isinstance(outputs, dict): + logits = tuple( + v for k, v in outputs.items() if k not in ignore_keys + ) + else: + logits = outputs + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index - 1] + + if prediction_loss_only: + return (loss, None, None) + + logits = nested_detach(logits) + if len(logits) == 1: + logits = logits[0] + + return (loss, logits, labels) From faf942f5be8381fac04ab46e3907608c77b35257 Mon Sep 17 00:00:00 2001 From: Filip Sondej <26285707+filyp@users.noreply.github.com> Date: Fri, 6 Mar 2026 19:17:13 +0100 Subject: [PATCH 11/12] typo fix --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0972a9701..fca95d9d6 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ packages=find_packages(), install_requires=requirements, # Uses requirements.txt extras_require={ - "lm-eval": [ + "lm-eval": [ "lm-eval==0.4.11", ], # Install using `pip install ".[lm-eval]"` "dev": [ From 2ff502949e9fb1fa7f010e2682219100969dc676 Mon Sep 17 00:00:00 2001 From: Vineeth Date: Fri, 6 Mar 2026 23:01:21 -0800 Subject: [PATCH 12/12] fix: lint --- src/trainer/unlearn/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/trainer/unlearn/base.py b/src/trainer/unlearn/base.py index 753ff94e1..26ab0ac7f 100644 --- a/src/trainer/unlearn/base.py +++ b/src/trainer/unlearn/base.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch from torch import nn