In [45]:
from torch.utils.data import Dataset
import torch
import numpy as np
import pickle
import zlib
from collections import MutableMapping

class CompressedDict(MutableMapping):

    def __init__(self):
        self.store = dict()

    def __getitem__(self, key):
        return pickle.loads(zlib.decompress(self.store[key]))

    def __setitem__(self, key, value):
        self.store[key] = zlib.compress(pickle.dumps(value))

    def __delitem__(self, key):
        del self.store[key]

    def __iter__(self):
        return iter(self.store)

    def __len__(self):
        return len(self.store)

    def __keytransform__(self, key):
        return key

class PrefDB(Dataset):
    """
    A circular database of preferences about pairs of segments.
    For each preference, we store the preference itself
    (mu in the paper) and the two segments the preference refers to.
    Segments are stored with deduplication - so that if multiple
    preferences refer to the same segment, the segment is only stored once.
    """

    def __init__(self, maxlen):
        self.segments = CompressedDict()
        self.seg_refs = {}
        self.prefs = []
        self.maxlen = maxlen

    def append(self, s1, s2, pref):
        k1 = hash(np.array(s1).tobytes())
        k2 = hash(np.array(s2).tobytes())

        for k, s in zip([k1, k2], [s1, s2]):
            if k not in self.segments.keys():
                self.segments[k] = s
                self.seg_refs[k] = 1
            else:
                self.seg_refs[k] += 1

        tup = (k1, k2, pref)
        self.prefs.append(tup)

        if len(self.prefs) > self.maxlen:
            self.del_first()

    def del_first(self):
        self.del_pref(0)

    def del_pref(self, n):
        if n >= len(self.prefs):
            raise IndexError("Preference {} doesn't exist".format(n))
        k1, k2, _ = self.prefs[n]
        for k in [k1, k2]:
            if self.seg_refs[k] == 1:
                del self.segments[k]
                del self.seg_refs[k]
            else:
                self.seg_refs[k] -= 1
        del self.prefs[n]

    def __len__(self):
        return len(self.prefs)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        s1s = s2s = prefs = []
        
        k1, k2, pref = self.prefs[idx]

        return self.segments[k1], self.segments[k2], pref

In [46]:
db = PrefDB(maxlen=100)
for i in range(100):
    db.append(np.zeros((32, 84, 84, 3)), np.zeros((32, 84, 84, 3)), 0)

In [49]:
val_loader = torch.utils.data.DataLoader(
    db,
    batch_size=8,
    shuffle=False,
    num_workers=8
)

for s1s, s2s, prefs in val_loader:
    #print(np.array(s1s).shape)
    print('b')

b
b
b
b
b
b
b
b
b
b
b
b
b


In [53]:
prefs.shape

torch.Size([4])

In [54]:
prefs    
#start_steps = self.n_steps
#start_time = time.time()

#for batch in db
#    train_step(batch, prefs_train)
#    self.n_steps += 1
#        if self.n_steps and self.n_steps % val_interval == 0:
#            self.val_step(prefs_val)
#            
#        end_time = time.time()
#        end_steps = self.n_steps
#        rate = (end_steps - start_steps) / (end_time - start_time)
#        #log rate (steps per second)

tensor([0, 0, 0, 0])

In [13]:
len(db)

100

In [2]:
a = [1,2,3]
a[[1,2]]

TypeError: list indices must be integers or slices, not list