-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When the HF datasets is used in conjunction with PyTorch Dataloader, the RSS memory of the process keeps on increasing when it should stay constant.
Steps to reproduce the bug
Run and observe the output of this snippet which logs RSS memory.
import psutil
import os
from transformers import BertTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
BATCH_SIZE = 32
NUM_TRIES = 10
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def transform(x):
x.update(tokenizer(x["text"], return_tensors="pt", max_length=64, padding="max_length", truncation=True))
x.pop("text")
x.pop("label")
return x
dataset = load_dataset("imdb", split="train")
dataset.set_transform(transform)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
count = 0
while count < NUM_TRIES:
for idx, batch in enumerate(train_loader):
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(count, idx, mem_after - mem_before)
count += 1Expected results
Memory should not increase after initial setup and loading of the dataset
Actual results
Memory continuously increases as can be seen in the log.
Environment info
datasetsversion: 2.3.2- Platform: Linux-4.19.0-21-cloud-amd64-x86_64-with-glibc2.10
- Python version: 3.8.13
- PyArrow version: 7.0.0
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working