Training and Validation Loss

In [None]:
# Training function with validation loss tracking
def train_pinn(model, X_train, Y_train, X_val, Y_val, epochs=20000, learning_rate=1e-5):
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    train_loss_history = []
    val_loss_history = []

    for epoch in range(epochs):
        with tf.GradientTape() as tape:
            predictions = model(X_train)
            train_loss = tf.reduce_mean(tf.square(predictions - Y_train))

        gradients = tape.gradient(train_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # Compute validation loss
        val_predictions = model(X_val)
        val_loss = tf.reduce_mean(tf.square(val_predictions - Y_val))

        train_loss_history.append(train_loss.numpy())
        val_loss_history.append(val_loss.numpy())

        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Train Loss: {train_loss.numpy()}, Val Loss: {val_loss.numpy()}")

    return train_loss_history, val_loss_history

# Parameters and initial conditions
params = (0.03, 0.01, 0.03, 0.03, 0.01, 0.095, 0.90, 0.03)
initial_conditions = [100, 5, 1, 0.0]
t = np.linspace(0, 50, 500)

# Generate and normalize data
synthetic_data = generate_data(params, initial_conditions, t)
synthetic_data, y_min, y_max = normalize(synthetic_data)
t, t_min, t_max = normalize(t.reshape(-1, 1))

# Split into training (80%) and validation (20%)
split_idx = int(0.8 * len(t))
X_train, X_val = t[:split_idx], t[split_idx:]
Y_train, Y_val = synthetic_data[:split_idx], synthetic_data[split_idx:]

# Convert to tensors
X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
Y_train = tf.convert_to_tensor(Y_train, dtype=tf.float32)
X_val = tf.convert_to_tensor(X_val, dtype=tf.float32)
Y_val = tf.convert_to_tensor(Y_val, dtype=tf.float32)

# Initialize and train the PINN
model = PINN()
train_loss_history, val_loss_history = train_pinn(model, X_train, Y_train, X_val, Y_val, epochs=20000, learning_rate=1e-5)

# Plot training & validation loss
plt.figure(figsize=(8, 5))
plt.plot(train_loss_history, label="Training Loss", color="b")
plt.plot(val_loss_history, label="Validation Loss", color="r")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training & Validation Loss")
plt.legend()
plt.grid(True)