In [1]:
import torch
import math

In [2]:
class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code onlyl works with end >= start"
        self.start = start 
        self.end = end

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            iter_start = self.start
            iter_end = self.end
        else:
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        return iter(range(iter_start, iter_end))


worker_init_fn is used to do something at the begining of each worker.

In [3]:
def worker_init_fn(id):
    print(sys.modules["__main__"].__dict__)

In [6]:
if __name__ == "__main__":
    ds = MyIterableDataset(start=0, end=7)
    print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))

RuntimeError: DataLoader worker (pid(s) 42508, 15056) exited unexpectedly

But the same code can work with script

In [10]:
# see the script
with open('ptdl.py', 'r') as text:
    textfile = text.read()
    print(textfile)

import torch
import math
import sys

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code onlyl works with end >= start"
        self.start = start 
        self.end = end

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            iter_start = self.start
            iter_end = self.end
        else:
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        return iter(range(iter_start, iter_end))

def worker_init_fn(id):
    print(sys.modules["__main__"].__dict__)
        
if __name__ == "__main__":
    ds = MyIterableDataset(start=0, end=7)
    print(list(torch.utils.data.DataLoader

In [7]:
# run multiple worker with script
from subprocess import Popen, PIPE
process = Popen(["python", "ptdl.py"])