<a href="https://colab.research.google.com/github/Zak-Rey/CNN/blob/main/Customizing_fit().ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
import keras

In [2]:
class CustomModel(keras.Model):
  def train_step(self, data):
    x, y = data

    with tf.GradientTape() as tape:
      y_pred = self(x, training = True)
      loss = self.compute_loss(y = y, y_pred = y_pred)

    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))

    for metric in self.metrics:
      if metric.name == "loss":
        metric.update_state(loss)
      else:
        metric.update_state(y, y_pred)

    return {m.name: m.result() for m in self.metrics}


In [4]:
import numpy as np

#construct and compile an instant of CustomModel
inputs = keras.Input(shape = (32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y , epochs = 3)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.src.callbacks.History at 0x7905ad74ee60>

In [6]:
class CustomModel(keras.Model):
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.loss_tracker = keras.metrics.Mean(name="loss")
    self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")

  def train_step(self, data):
    x, y = data

    with tf.GradientTape() as tape:
      y_pred = self(x, training = True)
      loss = keras.losses.mean_squared_error(y, y_pred)

    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)

    self.optimizer.apply_gradients(zip(gradients, trainable_vars))

    self.loss_tracker.update_state(loss)
    self.mae_metric.update_state(y, y_pred)
    return {"loss": self.loss_tracker.result(), "mae": self.mae_metric.result()}

  @property
  def metrics(self):
    return[self.loss_tracker, self.mae_metric]

#Instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)

model.compile(optimizer="adam")

x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs = 5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x7905ad9fedd0>