Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple validation datasets unsupported with dataloader_persistent_workers=True #30527

Open
2 of 4 tasks
bastienlc opened this issue Apr 28, 2024 · 1 comment · May be fixed by #30627
Open
2 of 4 tasks

Multiple validation datasets unsupported with dataloader_persistent_workers=True #30527

bastienlc opened this issue Apr 28, 2024 · 1 comment · May be fixed by #30627
Labels

Comments

@bastienlc
Copy link

System Info

  • transformers version: 4.40.1
  • Platform: Linux-6.8.0-76060800daily20240311-generic-x86_64-with-glibc2.35
  • Python version: 3.11.8
  • Huggingface_hub version: 0.22.2
  • Safetensors version: 0.4.3
  • Accelerate version: 0.29.3
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: True
  • Using distributed or parallel set-up in script?: False

Who can help?

@muellerzr @pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from transformers import Trainer, TrainingArguments

DIM = 2


class DummyDataset(Dataset):
    def __init__(self, size=10000, label=0):
        self.size = size
        self.data = torch.rand(size, DIM)
        self.labels = torch.full((size,), label)

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return {"input_ids": self.data[idx], "labels": self.labels[idx]}


class DummyModel(torch.nn.Module):
    def __init__(self):
        super(DummyModel, self).__init__()
        self.linear = torch.nn.Linear(DIM, 2)

    def forward(self, input_ids, labels=None):
        outputs = self.linear(input_ids)
        loss = F.cross_entropy(outputs, labels)
        return {"logits": outputs, "loss": loss}


if __name__ == "__main__":
    model = DummyModel()
    train_dataset = DummyDataset(label=0)
    good_validation_dataset = DummyDataset(label=0)
    bad_validation_dataset = DummyDataset(label=1)

    training_args = TrainingArguments(
        output_dir="./outputs",
        learning_rate=0.01,
        num_train_epochs=5,
        per_device_train_batch_size=128,
        per_device_eval_batch_size=128,
        dataloader_num_workers=2,
        dataloader_persistent_workers=True,
        evaluation_strategy="epoch",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset={"good": good_validation_dataset, "bad": bad_validation_dataset},
    )

    trainer.train()

With dataloader_persistent_workers=True :

{'eval_good_loss': 0.04770788177847862, 'eval_good_runtime': 0.0951, 'eval_good_samples_per_second': 105140.955, 'eval_good_steps_per_second': 830.614, 'epoch': 1.0}                                                                                                             
{'eval_bad_loss': 0.04770788177847862, 'eval_bad_runtime': 0.1225, 'eval_bad_samples_per_second': 81619.03, 'eval_bad_steps_per_second': 644.79, 'epoch': 1.0}                                                                                                                    
{'eval_good_loss': 0.024791115894913673, 'eval_good_runtime': 0.0995, 'eval_good_samples_per_second': 100488.125, 'eval_good_steps_per_second': 793.856, 'epoch': 2.0}                                                                                                            
{'eval_bad_loss': 0.024791115894913673, 'eval_bad_runtime': 0.1183, 'eval_bad_samples_per_second': 84530.882, 'eval_bad_steps_per_second': 667.794, 'epoch': 2.0}                                                                                                                 
{'eval_good_loss': 0.017540939152240753, 'eval_good_runtime': 0.095, 'eval_good_samples_per_second': 105282.943, 'eval_good_steps_per_second': 831.735, 'epoch': 3.0}                                                                                                             
{'eval_bad_loss': 0.017540939152240753, 'eval_bad_runtime': 0.0814, 'eval_bad_samples_per_second': 122839.094, 'eval_bad_steps_per_second': 970.429, 'epoch': 3.0}                                                                                                                
{'eval_good_loss': 0.014589476399123669, 'eval_good_runtime': 0.1745, 'eval_good_samples_per_second': 57297.904, 'eval_good_steps_per_second': 452.653, 'epoch': 4.0}                                                                                                             
{'eval_bad_loss': 0.014589476399123669, 'eval_bad_runtime': 0.1389, 'eval_bad_samples_per_second': 71998.668, 'eval_bad_steps_per_second': 568.789, 'epoch': 4.0}                                                                                                                 
{'eval_good_loss': 0.01373046450316906, 'eval_good_runtime': 0.0833, 'eval_good_samples_per_second': 120031.709, 'eval_good_steps_per_second': 948.25, 'epoch': 5.0}                                                                                                              
{'eval_bad_loss': 0.01373046450316906, 'eval_bad_runtime': 0.0865, 'eval_bad_samples_per_second': 115601.295, 'eval_bad_steps_per_second': 913.25, 'epoch': 5.0}                                                                                                                  
{'train_runtime': 1.8571, 'train_samples_per_second': 26923.771, 'train_steps_per_second': 212.698, 'train_loss': 0.03968705527390106, 'epoch': 5.0}

With dataloader_persistent_workers=False :

{'eval_good_loss': 0.10046054422855377, 'eval_good_runtime': 0.1053, 'eval_good_samples_per_second': 95006.818, 'eval_good_steps_per_second': 750.554, 'epoch': 1.0}                                                                                                              
{'eval_bad_loss': 2.533043622970581, 'eval_bad_runtime': 0.0946, 'eval_bad_samples_per_second': 105667.808, 'eval_bad_steps_per_second': 834.776, 'epoch': 1.0}                                                                                                                   
{'eval_good_loss': 0.05101846158504486, 'eval_good_runtime': 0.161, 'eval_good_samples_per_second': 62102.692, 'eval_good_steps_per_second': 490.611, 'epoch': 2.0}                                                                                                               
{'eval_bad_loss': 3.2872579097747803, 'eval_bad_runtime': 0.1805, 'eval_bad_samples_per_second': 55403.336, 'eval_bad_steps_per_second': 437.686, 'epoch': 2.0}                                                                                                                   
{'eval_good_loss': 0.03576516732573509, 'eval_good_runtime': 0.1225, 'eval_good_samples_per_second': 81623.001, 'eval_good_steps_per_second': 644.822, 'epoch': 3.0}                                                                                                              
{'eval_bad_loss': 3.694115161895752, 'eval_bad_runtime': 0.1046, 'eval_bad_samples_per_second': 95635.471, 'eval_bad_steps_per_second': 755.52, 'epoch': 3.0}                                                                                                                     
{'eval_good_loss': 0.029605071991682053, 'eval_good_runtime': 0.0998, 'eval_good_samples_per_second': 100165.593, 'eval_good_steps_per_second': 791.308, 'epoch': 4.0}                                                                                                            
{'eval_bad_loss': 3.9129879474639893, 'eval_bad_runtime': 0.0825, 'eval_bad_samples_per_second': 121274.534, 'eval_bad_steps_per_second': 958.069, 'epoch': 4.0}                                                                                                                  
{'eval_good_loss': 0.027824044227600098, 'eval_good_runtime': 0.0903, 'eval_good_samples_per_second': 110771.994, 'eval_good_steps_per_second': 875.099, 'epoch': 5.0}                                                                                                            
{'eval_bad_loss': 3.9852359294891357, 'eval_bad_runtime': 0.1141, 'eval_bad_samples_per_second': 87625.956, 'eval_bad_steps_per_second': 692.245, 'epoch': 5.0}                                                                                                                   
{'train_runtime': 2.0821, 'train_samples_per_second': 24014.737, 'train_steps_per_second': 189.716, 'train_loss': 0.08233800960492484, 'epoch': 5.0}

Expected behavior

Hi there,

When using multiple validation datasets with transformers.Trainer and setting dataloader_persistent_workers=True in the transformers.TrainingArguments, all evaluations are done using the first validation dataset.

In the example above, the model only learns to predict the class 0, so we should have a big loss for the "bad" validation dataset and a small one for the "good" one.

This seems related to #28469 and #29538; which does not support passing a dictionary of evaluation datasets :

# def get_eval_dataloader in src/transformers/trainer.py
if hasattr(self, "_eval_dataloader") and self.args.dataloader_persistent_workers:
            return self.accelerator.prepare(self._eval_dataloader)

The evaluation dataloaders should probably also be stored in a dictionary, or the _eval_dataloader attribute should be suffixed with the eval_dataset_name.

I can look into opening a PR for this.

@muellerzr
Copy link
Contributor

@bastienlc feel free to open a PR to support this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants