In [1]:
from src.data.datasets import Ex2VecOriginalDatasetShared, GLOBAL_SHARED_DATA, Ex2VecOriginalDatasetWrap

from torch.utils.data import DataLoader

import psutil

import os

import time

import torch

In [2]:
GLOBAL_SHARED_DATA['train'] = Ex2VecOriginalDatasetShared('sorted_data.parquet', 'train_dict.json', 'interactions.h5', sample_negative=999)

100%|█████████████████████████████████████████████████████████| 4892757/4892757 [00:47<00:00, 102172.58it/s]
100%|██████████████████████████████████████████████████████████| 4892757/4892757 [00:48<00:00, 99855.73it/s]


In [3]:
dataset = Ex2VecOriginalDatasetWrap(dataset_id='train')

In [4]:
def collate_fn(batch):
    # Remove None entries
    batch = [x for x in batch if x is not None]
    
    if not batch:
        return None  # Signal to skip this batch
    
    # Stack each field in the batch
    collated_batch = {}
    keys = batch[0].keys()
    for key in keys:
        collated_batch[key] = torch.stack([sample[key] for sample in batch])

    return collated_batch

def print_memory(label=""):
    process = psutil.Process(os.getpid())
    mem = process.memory_info().rss / (1024 ** 2)
    print(f"[{label}] Memory: {mem:.2f} MB")
    return mem

def print_total_memory(label=""):
    process = psutil.Process(os.getpid())
    mem = process.memory_info().rss
    total = mem

    # Add memory of all child processes (DataLoader workers)
    for child in process.children(recursive=True):
        try:
            total += child.memory_info().rss
        except psutil.NoSuchProcess:
            pass  # Process may have exited

    print(f"[{label}] Total memory incl. workers: {total / (1024 ** 2):.2f} MB")
    return total

def run_test(num_workers):
    print(f"\n== Running with {num_workers} workers ==")
    dataset = Ex2VecOriginalDatasetWrap(dataset_id='train')
    loader = DataLoader(dataset, batch_size=1, num_workers=num_workers, collate_fn=collate_fn)

    print_total_memory("Before loading")

    # Trigger worker start and one batch load
    for i, batch in enumerate(loader):
        if i > 1:
            break
        time.sleep(0.1)

    print_total_memory("After loading")
    del loader
    time.sleep(1)  # give OS time to clean up

In [5]:
for workers in [0, 1, 2, 4, 8, 10, 12, 14, 16]:
    run_test(workers)


== Running with 0 workers ==
[Before loading] Total memory incl. workers: 3269.45 MB
[After loading] Total memory incl. workers: 3273.58 MB

== Running with 1 workers ==
[Before loading] Total memory incl. workers: 3273.58 MB
[After loading] Total memory incl. workers: 3265.57 MB

== Running with 2 workers ==
[Before loading] Total memory incl. workers: 3265.57 MB
[After loading] Total memory incl. workers: 3265.59 MB

== Running with 4 workers ==
[Before loading] Total memory incl. workers: 3265.59 MB
[After loading] Total memory incl. workers: 3265.64 MB

== Running with 8 workers ==
[Before loading] Total memory incl. workers: 3265.64 MB
[After loading] Total memory incl. workers: 3265.73 MB

== Running with 10 workers ==
[Before loading] Total memory incl. workers: 3265.95 MB




[After loading] Total memory incl. workers: 3266.13 MB

== Running with 12 workers ==
[Before loading] Total memory incl. workers: 3266.13 MB




[After loading] Total memory incl. workers: 3266.19 MB

== Running with 14 workers ==
[Before loading] Total memory incl. workers: 3266.19 MB




[After loading] Total memory incl. workers: 3266.25 MB

== Running with 16 workers ==
[Before loading] Total memory incl. workers: 3266.27 MB




[After loading] Total memory incl. workers: 3266.34 MB
