In [3]:
import pickle
import numpy as np
import torch
import pandas as pd

In [4]:
metadata_file = "/notebooks/data/metadata_parquet/features_metadata_slim.parquet"
metadata_df = pd.read_parquet(metadata_file)
chip_ids = metadata_df[metadata_df.split == "train"].chip_id.unique().astype(np.str_)

In [7]:
chip_lst = chip_ids.tolist()

In [9]:
class NumpySerializedList():
    def __init__(self, lst: list):
        def _serialize(data):
            buffer = pickle.dumps(data, protocol=-1)
            return np.frombuffer(buffer, dtype=np.uint8)

        print(
            "Serializing {} elements to byte tensors and concatenating them all ...".format(
                len(lst)
            )
        )
        self._lst = [_serialize(x) for x in lst]
        self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64)
        self._addr = np.cumsum(self._addr)
        self._lst = np.concatenate(self._lst)
        print("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2))

    def __len__(self):
        return len(self._addr)

    def __getitem__(self, idx):
        start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
        end_addr = self._addr[idx].item()
        bytes = memoryview(self._lst[start_addr:end_addr])
        return pickle.loads(bytes)


class TorchSerializedList(NumpySerializedList):
    def __init__(self, lst: list):
        super().__init__(lst)
        self._addr = torch.from_numpy(self._addr)
        self._lst = torch.from_numpy(self._lst)

    def __getitem__(self, idx):
        start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
        end_addr = self._addr[idx].item()
        bytes = memoryview(self._lst[start_addr:end_addr].numpy())
        return pickle.loads(bytes)


In [10]:
ts_lst = TorchSerializedList(chip_lst)

Serializing 8689 elements to byte tensors and concatenating them all ...
Serialized dataset takes 0.19 MiB


In [12]:
print(ts_lst[0])

0003d2eb
