In [1]:
import os
import sys
from collections import OrderedDict

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

In [2]:
class TensorShardsDataset(Dataset):
    """
     each file created has:
     {
         "key": tells about order in overall array 
         "start_idx": smallest index in overall array in the file
         "end_offset": offset/index till before which this chunk has elements
         "chunk": actual dump of the chunk
     }
     
     to fetch populated elements do dump["chunk"][:dump["end_offset"]]
     for ease of debugging, chunks are initialized with np.nan
    """
    def __init__(self,\
                 root_dir,\
                 data_shape,\
                 data=None, \
                 chunk_size=100000,\
                ):
        
        self.root_dir = root_dir
        self.data_shape = data_shape
        self.chunk_size = chunk_size
        
        self.cur_len = 0 # current length that has been populated
        self.cur_idx = 0 # index where next element will go
        
        self.map = OrderedDict()
        self.verbose_map = OrderedDict()
        
        self.chunk = np.zeros((self.chunk_size, *self.data_shape))
        self.chunk[:] = np.nan
        
        if data is not None:
            flush_it = False
            for i in range(len(data)):
                if i == len(data) - 1:
                    flush_it = True
                self.append(data[i], flush_it)
    
    def append(self, pt, flush_it=False):
        
        self.chunk[self.cur_idx] = pt
        self.cur_idx += 1
        self.cur_len += 1
        
        if self.cur_len % self.chunk_size == 0 and self.cur_idx == self.chunk_size:
            self.flush_chunk()
        
        elif flush_it:
            self.flush_chunk()
            
    def extend(self, pts, flush_it=False):
        """
        TODO: re-write this and do better
        """
        for i in range(pts.shape[0]):
            self.append(pts[i])
    
    def _save_flush(self, fn, to_dump):
        if not os.path.exists(self.root_dir):
            os.makedirs(self.root_dir, exist_ok=True)
            
        torch.save(to_dump, fn)
    
    def flush_chunk(self):
        key_to_dump_in = self.cur_len // self.chunk_size # which multiple of self.chunk_size is currently running
        if self.cur_len % self.chunk_size == 0:
            if self.cur_len == 0:
                raise RunTimeWarning("no data to flush!")
            else:
                key_to_dump_in -= 1
        
        key = key_to_dump_in
        start_idx = None
        end_offset = None
        
        if key_to_dump_in in self.map:
            fn = self.map[key_to_dump_in]
            load_dump = torch.load(fn)
            load_chunk = load_dump["chunk"]
            load_start = self.verbose_map[key_to_dump_in]["start_idx"]
            load_end_offset = self.verbose_map[key_to_dump_in]["end_offset"]

            assert load_start == key_to_dump_in * self.chunk_size
            assert load_end_offset <= self.cur_idx
            assert (load_chunk[:load_end] == self.chunk[:load_end]).all() == True, "loaded chunk not same as data to be dumped!"
            
            key = key_to_dump_in
            start_idx = self.verbose_map[key_to_dump_in]["start_idx"]
            end_offset = self.cur_idx
            
            self.verbose_map[key_to_dump_in]["end_offset"] = end_offset
        
        else:
            fn = os.path.join(self.root_dir, str(key_to_dump_in) + ".pth")
            start_idx = key_to_dump_in * self.chunk_size
            end_offset = self.cur_idx
            
            self.verbose_map[key_to_dump_in] = {
                "start_idx": start_idx,
                "end_offset": self.cur_idx,
                "name": fn,
            }
            
            self.map[key_to_dump_in] = fn
        
        to_dump = {
            "start_idx": start_idx,
            "end_offset": end_offset,
            "key": key_to_dump_in,
            "chunk": self.chunk
        }
#         print(to_dump)
        self._save_flush(fn, to_dump)
        if self.cur_idx == self.chunk_size: # chunk was full before dumping
            self.chunk[:] = np.nan # reset chunk
            self.cur_idx = 0 # reset current index so that new element is inserted at start        
                
    
    def __getitem__(self, idx):
#         print(idx)
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        key_to_load = idx // self.chunk_size
        idx_to_load = idx % self.chunk_size
        
        fn = self.map[key_to_load]
        
        loaded_chunk = torch.load(fn)
        return loaded_chunk["chunk"][idx_to_load]
    
    def __len__(self):
        return self.cur_len

In [9]:
class ContainerClass:
    
    def __init__(self):
        self.N = 122
        self.data = np.random.normal(size=(self.N, 50))
        self.tn_data = np.random.normal(size=(self.N, 2, 3))
        self.tn_set = TensorShardsDataset(
            root_dir="tmp",
            data_shape=(2, 3),
            data=self.tn_data,
            chunk_size=10,
        )
        
    def __len__(self):
        return self.N
    def __getitem__(self, idx):
        return {
            "pts": self.data[idx],
            "tn_pts": self.tn_set[idx] 
        }

In [10]:
# data = np.random.normal(size=(101, 2, 3))
# test_set = TensorShardsDataset(
#     root_dir="tmp",
#     data_shape=(2, 3),
#     data=data,
#     chunk_size=10,
# )

test_set = ContainerClass()


In [11]:
! ls tmp

0.pth  10.pth  12.pth  3.pth  5.pth  7.pth  9.pth
1.pth  11.pth  2.pth   4.pth  6.pth  8.pth


In [12]:
for i in range(len(test_set)):
    print(i, test_set[i])

0 {'pts': array([-1.2058697 ,  0.83111627,  0.12447882,  1.10130095,  0.77687311,
       -0.38430274,  0.34382051,  0.7637175 ,  0.76659255, -1.52137733,
        0.74937945, -0.25339663, -1.34079111,  0.53298147,  0.19030117,
       -2.29962896, -0.8084318 , -1.06687583, -0.04722379, -0.38790042,
        1.55024982, -1.71594793, -1.14892766, -0.50490727, -1.24984217,
       -0.51917195,  1.07324671, -1.05217431,  0.01246013,  0.59036452,
       -1.24120882,  0.41423376,  0.9482089 , -0.35374797,  0.3082672 ,
       -0.03692701,  1.63436067,  0.44663464, -1.22290572,  0.65417282,
       -1.17305587, -1.28349406, -1.25147452, -1.89732914,  1.27682038,
       -0.61499589,  1.53424288, -0.75189719, -3.30385951, -0.3313477 ]), 'tn_pts': array([[ 0.51351704,  1.32086599,  0.56806533],
       [ 1.21996124,  0.68438071, -0.94026331]])}
1 {'pts': array([-2.20094916, -0.8122494 , -0.84914282, -0.05091408,  0.04878786,
       -0.33080339, -0.61788725, -0.44952252,  0.04439013,  1.07916118,
      

In [13]:
dl = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=8)
for (i, batch) in enumerate(dl):
    print(i, batch, batch["pts"].shape, batch["tn_pts"].shape)

0 {'pts': tensor([[-1.2059e+00,  8.3112e-01,  1.2448e-01,  ..., -7.5190e-01,
         -3.3039e+00, -3.3135e-01],
        [-2.2009e+00, -8.1225e-01, -8.4914e-01,  ..., -7.5008e-01,
         -3.9917e-01,  3.2667e+00],
        [-1.4958e+00, -1.3279e+00, -1.2832e-01,  ..., -1.6673e+00,
          9.1057e-01, -9.6577e-01],
        ...,
        [ 6.1330e-01, -1.8686e-01, -7.7286e-01,  ...,  3.1238e+00,
         -1.2068e+00, -1.7168e-02],
        [-1.4547e+00,  4.3747e-01,  1.2538e+00,  ..., -7.0354e-01,
         -9.8015e-01,  5.3921e-01],
        [ 3.5711e-01, -8.6248e-01,  2.3826e+00,  ..., -1.2423e-04,
         -6.8108e-01,  1.9616e-01]], dtype=torch.float64), 'tn_pts': tensor([[[ 0.5135,  1.3209,  0.5681],
         [ 1.2200,  0.6844, -0.9403]],

        [[ 1.3947,  0.2708,  1.7501],
         [-1.6482,  0.1413,  0.0391]],

        [[-0.2225, -0.3960, -0.4779],
         [-1.1121, -2.3064, -0.3800]],

        [[-0.7562, -0.7986,  0.2870],
         [-0.8061,  0.0878,  2.2982]],

        [[ 0.2

In [1]:
# ! rm -rf tmp

In [6]:
len(test_set)

101

In [8]:
np.zeros((2, *(3,)))

array([[0., 0., 0.],
       [0., 0., 0.]])

In [13]:
for i in range(0, 47, 5):
    x = list()
    for j in range(i, min(i+5, 47)):
        x.append(j)
    print(i, x)
    

0 [0, 1, 2, 3, 4]
5 [5, 6, 7, 8, 9]
10 [10, 11, 12, 13, 14]
15 [15, 16, 17, 18, 19]
20 [20, 21, 22, 23, 24]
25 [25, 26, 27, 28, 29]
30 [30, 31, 32, 33, 34]
35 [35, 36, 37, 38, 39]
40 [40, 41, 42, 43, 44]
45 [45, 46]
