In [7]:
import os
import sys
from collections import OrderedDict

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

In [2]:
class TensorShardsDataset(Dataset):
    """
     each file created has:
     {
         "key": tells about order in overall array 
         "start_idx": smallest index in overall array in the file
         "end_offset": offset/index till before which this chunk has elements
         "chunk": actual dump of the chunk
     }
     
     to fetch populated elements do dump["chunk"][:dump["end_offset"]]
     for ease of debugging, chunks are initialized with np.nan
    """
    def __init__(self,\
                 root_dir,\
                 data_shape,\
                 data=None, \
                 chunk_size=100000,\
                ):
        
        self.root_dir = root_dir
        self.data_shape = data_shape
        self.chunk_size = chunk_size
        
        self.cur_len = 0 # current length that has been populated
        self.cur_idx = 0 # index where next element will go
        
        self.map = OrderedDict()
        self.verbose_map = OrderedDict()
        
        self.chunk = np.zeros((self.chunk_size, *self.data_shape))
        self.chunk[:] = np.nan
        
        if data is not None:
            flush_it = False
            for i in range(len(data)):
                if i == len(data) - 1:
                    flush_it = True
                self.append(data[i], flush_it)
    
    def append(self, pt, flush_it=False):
        
        self.chunk[self.cur_idx] = pt
        self.cur_idx += 1
        self.cur_len += 1
        
        if self.cur_len % self.chunk_size == 0 and self.cur_idx == self.chunk_size:
            self.flush_chunk()
        
        elif flush_it:
            self.flush_chunk()
            
    def extend(self, pts, flush_it=False):
        """
        TODO: re-write this and do better
        """
        for i in range(pts.shape[0]):
            self.append(pts[i])
    
    def _save_flush(self, fn, to_dump):
        if not os.path.exists(self.root_dir):
            os.makedirs(self.root_dir, exist_ok=True)
            
        torch.save(to_dump, fn)
    
    def flush_chunk(self):
        key_to_dump_in = self.cur_len // self.chunk_size # which multiple of self.chunk_size is currently running
        if self.cur_len % self.chunk_size == 0:
            if self.cur_len == 0:
                raise RunTimeWarning("no data to flush!")
            else:
                key_to_dump_in -= 1
        
        key = key_to_dump_in
        start_idx = None
        end_offset = None
        
        if key_to_dump_in in self.map:
            fn = self.map[key_to_dump_in]
            load_dump = torch.load(fn)
            load_chunk = load_dump["chunk"]
            load_start = self.verbose_map[key_to_dump_in]["start_idx"]
            load_end_offset = self.verbose_map[key_to_dump_in]["end_offset"]

            assert load_start == key_to_dump_in * self.chunk_size
            assert load_end_offset <= self.cur_idx
            assert (load_chunk[:load_end] == self.chunk[:load_end]).all() == True, "loaded chunk not same as data to be dumped!"
            
            key = key_to_dump_in
            start_idx = self.verbose_map[key_to_dump_in]["start_idx"]
            end_offset = self.cur_idx
            
            self.verbose_map[key_to_dump_in]["end_offset"] = end_offset
        
        else:
            fn = os.path.join(self.root_dir, str(key_to_dump_in) + ".pth")
            start_idx = key_to_dump_in * self.chunk_size
            end_offset = self.cur_idx
            
            self.verbose_map[key_to_dump_in] = {
                "start_idx": start_idx,
                "end_offset": self.cur_idx,
                "name": fn,
            }
            
            self.map[key_to_dump_in] = fn
        
        to_dump = {
            "start_idx": start_idx,
            "end_offset": end_offset,
            "key": key_to_dump_in,
            "chunk": self.chunk
        }
#         print(to_dump)
        self._save_flush(fn, to_dump)
        if self.cur_idx == self.chunk_size: # chunk was full before dumping
            self.chunk[:] = np.nan # reset chunk
            self.cur_idx = 0 # reset current index so that new element is inserted at start        
                
    
    def __getitem__(self, idx):
#         print(idx)
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        key_to_load = idx // self.chunk_size
        idx_to_load = idx % self.chunk_size
        
        fn = self.map[key_to_load]
        
        loaded_chunk = torch.load(fn)
        return loaded_chunk["chunk"][idx_to_load]
    
    def __len__(self):
        return self.cur_len

In [3]:
data = np.random.normal(size=(101, 2, 3))
test_set = TensorShardsDataset(
    root_dir="tmp",
    data_shape=(2, 3),
    data=data,
    chunk_size=10,
)

{'start_idx': 0, 'end_offset': 10, 'key': 0, 'chunk': array([[[-1.37770676, -0.42279999,  0.06963574],
        [ 0.45002766, -2.97989563,  0.28480054]],

       [[-0.28745642,  0.3332029 , -0.26697032],
        [ 1.37322702, -0.16109079,  1.35398784]],

       [[-0.0996866 ,  0.23663271, -0.48155911],
        [ 0.93266403,  1.43077288, -0.65260788]],

       [[-0.22634149,  0.60865913,  0.42467253],
        [ 0.18055532,  0.73460021,  1.13110561]],

       [[ 0.15098057,  1.34213653, -0.14480202],
        [ 0.46101742, -1.84446932,  0.71265579]],

       [[-0.35663476, -0.33590892, -0.41213884],
        [-1.58784269,  0.5789129 ,  0.39834458]],

       [[-1.91958516,  0.05969337, -1.15695642],
        [-2.38057782, -0.7161313 ,  0.62203563]],

       [[-0.60944326, -2.26893017, -0.1106346 ],
        [ 0.86649593,  0.23623879, -0.5861697 ]],

       [[-0.44633381,  0.20437342, -1.48677622],
        [-0.15376237, -0.07433862, -0.81822552]],

       [[-0.16384853,  1.06359005, -0.41346689

In [6]:
for i in range(len(test_set)):
    print(i, test_set[i])

0
0 [[-1.37770676 -0.42279999  0.06963574]
 [ 0.45002766 -2.97989563  0.28480054]]
1
1 [[-0.28745642  0.3332029  -0.26697032]
 [ 1.37322702 -0.16109079  1.35398784]]
2
2 [[-0.0996866   0.23663271 -0.48155911]
 [ 0.93266403  1.43077288 -0.65260788]]
3
3 [[-0.22634149  0.60865913  0.42467253]
 [ 0.18055532  0.73460021  1.13110561]]
4
4 [[ 0.15098057  1.34213653 -0.14480202]
 [ 0.46101742 -1.84446932  0.71265579]]
5
5 [[-0.35663476 -0.33590892 -0.41213884]
 [-1.58784269  0.5789129   0.39834458]]
6
6 [[-1.91958516  0.05969337 -1.15695642]
 [-2.38057782 -0.7161313   0.62203563]]
7
7 [[-0.60944326 -2.26893017 -0.1106346 ]
 [ 0.86649593  0.23623879 -0.5861697 ]]
8
8 [[-0.44633381  0.20437342 -1.48677622]
 [-0.15376237 -0.07433862 -0.81822552]]
9
9 [[-0.16384853  1.06359005 -0.41346689]
 [ 1.6122202  -2.11645463  0.64319296]]
10
10 [[ 1.38451335 -0.01086801 -0.88662465]
 [-0.58045868 -1.1576138  -1.40557465]]
11
11 [[-0.37167058  0.4061569  -0.93147056]
 [ 0.16408076 -0.93713933 -0.92440569]]


In [8]:
dl = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=8)
for (i, batch) in enumerate(dl):
    print(i, batch, batch.shape)

32
64
33
0
34
65
1
35
66
2
36
67
3
68
37
4
69
5
70
6
71
7
38
72
8
39
73
9
40
74
10
41
75
11
76
42
77
43
12
78
44
13
14
45
79
15
46
80
16
47
81
17
48
82
18
49
83
19
50
84
20
51
85
21
52
22
86
53
23
54
87
55
88
24
89
25
56
90
26
57
91
27
58
92
28
59
29
93
60
30
61
94
31
62
95
63
96
97
98
99
100
0 tensor([[[-1.3777, -0.4228,  0.0696],
         [ 0.4500, -2.9799,  0.2848]],

        [[-0.2875,  0.3332, -0.2670],
         [ 1.3732, -0.1611,  1.3540]],

        [[-0.0997,  0.2366, -0.4816],
         [ 0.9327,  1.4308, -0.6526]],

        [[-0.2263,  0.6087,  0.4247],
         [ 0.1806,  0.7346,  1.1311]],

        [[ 0.1510,  1.3421, -0.1448],
         [ 0.4610, -1.8445,  0.7127]],

        [[-0.3566, -0.3359, -0.4121],
         [-1.5878,  0.5789,  0.3983]],

        [[-1.9196,  0.0597, -1.1570],
         [-2.3806, -0.7161,  0.6220]],

        [[-0.6094, -2.2689, -0.1106],
         [ 0.8665,  0.2362, -0.5862]],

        [[-0.4463,  0.2044, -1.4868],
         [-0.1538, -0.0743, -0.8182]],

  

In [None]:
# ! rm -rf tmp

In [7]:
[2 for i in [2]]

[2]