In [None]:
import numpy as np

import jax
import jax.numpy as jnp

from flax import linen as nn
from flax.metrics import tensorboard
from flax.training import train_state

from omegaconf import OmegaConf
import optax

from functools import partial

In [None]:
import models
from data import stickman

from train import train_and_evaluate
import tensorflow as tf
import logging

import hydra
from omegaconf import OmegaConf

In [None]:
import os
from hydra import initialize, initialize_config_module, initialize_config_dir, compose

In [None]:
overrides = [
    "train.batch_size=100"
]

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path="configs")
config = compose(config_name="default.yaml", overrides=overrides)

In [None]:
train_ds, test_ds = stickman.setup_data(config)

In [None]:
print(OmegaConf.to_yaml(config))

In [None]:

class Dataset():
    def __init__(self, x, y): 
        self.x = x
        self.y = y

    def __len__(self): 
        return len(self.x)

    def __getitem__(self, i): 
        return self.x[i],self.y[i]


class Sampler():
    def __init__(self, ds, bs, shuffle=False, seed=None):
        self.n = len(ds)
        self.bs = bs
        self.shuffle = shuffle
        self.rng = np.random.default_rng(seed=seed)
        
    def __iter__(self):
        self.idxs = self.rng.permutation(self.n) if self.shuffle else np.arange(self.n)
        for i in range(0, self.n, self.bs):
            yield self.idxs[i:i+self.bs]


class DataLoader():
    def __init__(self, ds, sampler):
        self.ds = ds
        self.sampler = sampler
        
    def __iter__(self):
        for s in self.sampler:
            yield self.collate([self.ds[i] for i in s])

    def collate(self, b):
        xs,ys = zip(*b)
        return np.stack(xs),np.stack(ys)



In [None]:
x_train, y_train = train_ds

In [None]:
ds = Dataset(x_train, y_train)

In [None]:
ds

In [None]:
s = Sampler(ds, 3, shuffle=True, seed=42)
dataloader = DataLoader(ds, s)

In [None]:
# This finally works

batch1 = None
batch2 = None
for idx, (xb, yb) in enumerate(dataloader):
    if idx == 0:
        batch1 = xb

for idx, (xb, yb) in enumerate(dataloader):
    if idx == 0:
        batch2 = xb     

In [None]:
for x,z in dataloader:
    print(x.shape)
    break

In [None]:
iterator = iter(dataloader)

In [None]:
while True:
    try:
        next(iterator)
    except StopIteration:
        print("exiting...")
        break
    

In [None]:
next(iterator)