In [1]:
import os
import sys
from collections import OrderedDict

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

In [2]:
class TensorFileDataset(Dataset):
    """
    Inspired by ImageFolder (https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html)
    and how ImageNet is stored by Pytorch to store large tensors in a directory like structure
    """
    
    def __init__(self,
                 root_dir,
                 total_len=100000,
                 per_dir_size=10000):
        self.root_dir = root_dir
        self.total_len = total_len
        self.per_dir_size = per_dir_size
        
        self.flist = list()
        
        self.make_dataset()
        
    
    def __len__(self):
        return self.total_len
        
    
    def __getitem__(self, idx):
        
        req_fn = self.flist[idx]
        if not os.path.exists(req_fn):
            raise RuntimeError("requested file {} does not exist!".format(req_fn))
            
        arr = torch.load(req_fn)
        return arr
    
    def make_dataset(self):
        num_dirs = (self.total_len // self.per_dir_size) + 1
        for i in range(num_dirs):
            cur_dir_idx = 0
            for j in range(self.per_dir_size):
                fn = str((i * self.per_dir_size) + j) + ".pth"
                full_path = os.path.join(self.root_dir, str(i), fn)
                self.flist.append(full_path)
                
            

In [26]:
def make_dummy_dataset():
    os.makedirs("/data/tmp", exist_ok=True)
    
    for i in range(10):
        cur_dir_idx = i
        os.makedirs("/data/tmp/" + str(i), exist_ok=True)
        for j in range(100):
            torch.save(torch.Tensor([i * 100 + j]), "/data/tmp/" + str(i) + "/" + str(i * 100 + j) + ".pth")
            

In [27]:
make_dummy_dataset()

In [28]:
# ! rm -rf /data/tmp

In [29]:
dset = TensorFileDataset(root_dir="/data/tmp", total_len=1000, per_dir_size=100)

In [30]:
dloader = DataLoader(dset, batch_size=8, shuffle=False, num_workers=4)

for (i, j) in enumerate(dloader):
    print(i, j)

0 tensor([[0.],
        [1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.]])
1 tensor([[ 8.],
        [ 9.],
        [10.],
        [11.],
        [12.],
        [13.],
        [14.],
        [15.]])
2 tensor([[16.],
        [17.],
        [18.],
        [19.],
        [20.],
        [21.],
        [22.],
        [23.]])
3 tensor([[24.],
        [25.],
        [26.],
        [27.],
        [28.],
        [29.],
        [30.],
        [31.]])
4 tensor([[32.],
        [33.],
        [34.],
        [35.],
        [36.],
        [37.],
        [38.],
        [39.]])
5 tensor([[40.],
        [41.],
        [42.],
        [43.],
        [44.],
        [45.],
        [46.],
        [47.]])
6 tensor([[48.],
        [49.],
        [50.],
        [51.],
        [52.],
        [53.],
        [54.],
        [55.]])
7 tensor([[56.],
        [57.],
        [58.],
        [59.],
        [60.],
        [61.],
        [62.],
        [63.]])
8 tensor([[64.],
       

108 tensor([[864.],
        [865.],
        [866.],
        [867.],
        [868.],
        [869.],
        [870.],
        [871.]])
109 tensor([[872.],
        [873.],
        [874.],
        [875.],
        [876.],
        [877.],
        [878.],
        [879.]])
110 tensor([[880.],
        [881.],
        [882.],
        [883.],
        [884.],
        [885.],
        [886.],
        [887.]])
111 tensor([[888.],
        [889.],
        [890.],
        [891.],
        [892.],
        [893.],
        [894.],
        [895.]])
112 tensor([[896.],
        [897.],
        [898.],
        [899.],
        [900.],
        [901.],
        [902.],
        [903.]])
113 tensor([[904.],
        [905.],
        [906.],
        [907.],
        [908.],
        [909.],
        [910.],
        [911.]])
114 tensor([[912.],
        [913.],
        [914.],
        [915.],
        [916.],
        [917.],
        [918.],
        [919.]])
115 tensor([[920.],
        [921.],
        [922.],
        [923.],
 