In [1]:
import numpy as np
from types import SimpleNamespace
from datasets import load_dataset
from tinygrad import Tensor
from random import randint
from fastprogress.fastprogress import master_bar, progress_bar
from multiprocessing import Pool

In [2]:
dataset_train = load_dataset("danjacobellis/LSDIR_540", split="train")
dataset_valid = load_dataset("danjacobellis/LSDIR_540", split="validation")

Resolving data files:   0%|          | 0/89 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/89 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/85 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/89 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/89 [00:00<?, ?it/s]

In [3]:
config = SimpleNamespace()
config.epochs = 250
config.batch_size = 64
config.crop = 256
config.num_workers = 8

In [4]:
def random_crop(im,w,h):
    W,H=im.size
    assert max(w-W,h-H)<=0
    l,t=randint(0,W-w),randint(0,H-h)
    return im.crop((l,t,l+w,t+h))
    
def center_crop(im,w,h):
    W,H=im.size
    return im.crop(((W-w)//2,(H-h)//2,(W+w)//2,(H+h)//2))

def process_sample(index):
    index = int(index)  # Convert numpy.int64 to int
    sample = dataset_train[index]
    image = random_crop(sample['image'], config.crop, config.crop)
    arr = np.array(image)
    return arr

    
class ReplacementSampler:
    def __init__(self, total_samples_needed, batch_size):
        self.total_samples_needed = total_samples_needed
        self.batch_size = batch_size
        self.number_of_batches = int(np.ceil(self.total_samples_needed / self.batch_size))
        self.batches_generated = 0

    def __iter__(self):
        self.batches_generated = 0  # Reset for new iteration
        return self

    def __next__(self):
        if self.batches_generated >= self.number_of_batches:
            raise StopIteration
        indices = np.random.choice(
            self.total_samples_needed, size=self.batch_size, replace=True
        )
        self.batches_generated += 1
        return indices

    def __len__(self):
        return self.number_of_batches

In [None]:
mb = master_bar(range(config.epochs))
for i_epoch in mb:
    sampler = ReplacementSampler(
        total_samples_needed=dataset_train.num_rows,
        batch_size=config.batch_size
    )
    for i_batch, ind in enumerate(progress_bar(sampler, parent=mb)):
        with Pool(processes=config.num_workers) as pool:
            arrays = pool.map(process_sample, ind)
        batch = np.stack(arrays)
        x = Tensor(batch).permute(0, 3, 1, 2)