In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import os
os.chdir("../..")


In [8]:
import numpy as np
import matplotlib.pyplot as plt

from numpy.typing import NDArray
from typing import Callable


In [6]:
from qml.tools.random import XRandomGenerator
from qml.tools.typing import Vector


In [14]:
nq = 2
ng = 3
nx = 1
ny = 1

# dataset
train_db_size = 10
validate_db_size = 10
# loader
batch_size = 4


In [20]:
class Dataset:

    def __init__(self, xs: list[Vector] | NDArray, ys: list[Vector] | NDArray):
        self._xs: NDArray = np.asarray(xs)
        self._ys: NDArray = np.asarray(ys)
    
    def __len__(self):
        return len(self._xs)
    
    @property
    def size(self) -> int:
        return len(self)
    
    @property
    def xs(self) -> NDArray:
        return self._xs.copy()
    
    @property
    def ys(self) -> NDArray:
        return self._ys.copy()
    
    @property
    def data(self) -> tuple[NDArray, NDArray]:
        return self.xs, self.ys

def generate_dataset(num_data: int, func: Callable, seed: int = None) -> Dataset:
    rng = XRandomGenerator(seed)
    
    xs = rng.uniform(-1, 1, num_data)
    ys = func(xs)
    return Dataset(xs, ys)


In [25]:
class DLIter:

    def __init__(self, xs: NDArray, ys: NDArray, indices: NDArray):
        self._xs = xs
        self._ys = ys
        self._indices = indices
        self._iter_counter = 0
    
    def __next__(self):
        if self._iter_counter >= len(self._indices):
            raise StopIteration()
        idx = self._indices[self._iter_counter]
        bxs = self._xs[idx]
        bys = self._ys[idx]
        self._iter_counter += 1
        return bxs, bys
    
    def __len__(self):
        return len(self._indices)


class DataLoader:

    def __init__(self, xs: NDArray, ys: NDArray, batch_size: int, shuffle: bool = True, seed: int = None):
        assert len(xs) == len(ys)
        assert batch_size > 0
        self._xs = xs
        self._ys = ys
        self.size = len(xs)
        self._batch_size = batch_size
        self._shuffle = shuffle
        self.rng = XRandomGenerator(seed)
    
    def __iter__(self):
        idx = np.arange(self.size)
        if self._shuffle:
            idx = self.rng.permutation(idx)
        idx = [
            idx[i * self._batch_size:(i + 1) * self._batch_size]
            for i in range(int(np.ceil(self.size / self._batch_size)))
        ]
        return DLIter(self._xs, self._ys, idx)
    
    @classmethod
    def from_dataset(cls, dataset: Dataset, batch_size:int, shuffle: bool = False) -> "DataLoader":
        return cls(
            dataset.xs, dataset.ys,
            batch_size, shuffle=shuffle
        )


In [26]:
func = lambda x: np.sin(2 * x)
dataset = generate_dataset(train_db_size, func)
dataset.xs


array([ 0.9603206 ,  0.56766996, -0.08822895,  0.89218968,  0.99447596,
       -0.31054606,  0.22458331, -0.02883838, -0.13680131, -0.64172513])

In [27]:
dataloader = DataLoader.from_dataset(dataset, batch_size)


In [28]:
for batch in dataloader:
    print(batch)


(array([ 0.9603206 ,  0.56766996, -0.08822895,  0.89218968]), array([ 0.93942589,  0.90667761, -0.17554359,  0.97727772]))
(array([ 0.99447596, -0.31054606,  0.22458331, -0.02883838]), array([ 0.91383946, -0.58192367,  0.43421496, -0.05764478]))
(array([-0.13680131, -0.64172513]), array([-0.27020181, -0.9589994 ]))
