In [2]:
!pip install -q -U keras==3.11.1
!pip install -q -U grain

In [1]:
import keras
import grain
import numpy as np

# A simple model for demonstration
class SimpleModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.dense = keras.layers.Dense(1, activation="relu")

    def call(self, x):
        return self.dense(x)

# Create a custom data source using PyGrain
class SimpleDataSource(grain.sources.RandomAccessDataSource):
    def __init__(self, num_samples=100):
        super().__init__()
        self.data = np.arange(num_samples, dtype=np.float32).reshape(-1, 1)
        self.labels = self.data * 2 + 1

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

# Create a PyGrain DataLoader
data_source = SimpleDataSource()
dataloader = grain.DataLoader(
    data_source=data_source,
    sampler=grain.samplers.SequentialSampler(len(data_source)),
    operations=[grain.transforms.Batch(batch_size=10, drop_remainder=False)]
)

# Instantiate and compile the Keras model
model = SimpleModel()
model.compile(
    optimizer=keras.optimizers.SGD(learning_rate=0.1),
    loss=keras.losses.MeanSquaredError(),
    metrics=[keras.metrics.MeanAbsoluteError()]
)

# Train the model using the PyGrain DataLoader
print("Training with PyGrain DataLoader...")
model.fit(dataloader, epochs=2)

print("Evaluation with PyGrain DataLoader...")
model.evaluate(dataloader)


Training with PyGrain DataLoader...
Epoch 1/2
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 6ms/step - loss: 14819.5654 - mean_absolute_error: 109.1741
Epoch 2/2




[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 13333.0000 - mean_absolute_error: 100.0000
Evaluation with PyGrain DataLoader...
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 13333.0000 - mean_absolute_error: 100.0000


[13333.0, 100.0]