In [22]:
import os

os.environ["KERAS_BACKEND"] = "torch"

import torch 
import keras 
import numpy as np

To write a custom training loop, we need the following ingredients:

- A model to train, of course.
- An optimizer. You could either use a keras.optimizers optimizer, or a native PyTorch optimizer from torch.optim.
- A loss function. You could either use a keras.losses loss, or a native PyTorch loss from torch.nn.
- A dataset. You could use any format: a tf.data.Dataset, a PyTorch DataLoader, a Python generator, etc.
- Let's line them up. We'll use torch-native objects in each case – except, of course, for the Keras model.

In [36]:
def __init__(self, model, pool):
        super(Model, self).__init__()

# Let's consider a simple MNIST model
def get_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = keras.layers.Dense(64, activation="relu")(inputs)
    x2 = keras.layers.Dense(64, activation="relu")(x1)
    outputs = keras.layers.Dense(10, name="predictions")(x2)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


# Create load up the MNIST dataset and put it in a torch DataLoader
# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Create torch Datasets
train_dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(x_train), torch.from_numpy(y_train)
)
val_dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(x_val), torch.from_numpy(y_val)
)

# Create DataLoaders for the Datasets
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False
)

### Next, here's our PyTorch optimizer and our PyTorch loss function:

In [42]:
model = get_model()
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)