## LSTM Distilled with Keras

In [11]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GRU, Dense
from tabulate import tabulate
import matplotlib.pyplot as plt

In [12]:
# Set seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Generate our sample data
temperature = np.array([18.2, 19.5, 20.1, 22.4, 23.8, 25.0, 23.2, 21.5, 19.8, 17.5])
humidity = np.array([65.2, 62.8, 58.5, 55.0, 45.2, 42.1, 48.5, 52.3, 60.5, 67.8])
wind_speed = np.array([5.2, 6.8, 8.5, 10.2, 12.5, 14.8, 13.2, 11.5, 9.2, 6.5])
X = np.column_stack((temperature, humidity, wind_speed))

# Target: power consumption (kWh)
y = 2.5 * temperature - 0.5 * humidity + 1.2 * wind_speed + np.random.normal(0, 5, 10)

# Normalize the data
from sklearn.preprocessing import MinMaxScaler

# Create scalers for input and output data
X_scaler = MinMaxScaler(feature_range=(0, 1))
y_scaler = MinMaxScaler(feature_range=(0, 1))

# Fit and transform the data
X_normalized = X_scaler.fit_transform(X)
y_normalized = y_scaler.fit_transform(y.reshape(-1, 1)).flatten()

# Create sequences with lookback of 2 (predict the step after the sequence)
def create_sequences(X, y, lookback=2):
    X_seq, y_seq = [], []
    for i in range(len(X) - lookback):
        X_seq.append(X[i:i+lookback])
        y_seq.append(y[i+lookback])
    return np.array(X_seq), np.array(y_seq)

X_sequences, y_sequences = create_sequences(X_normalized, y_normalized, lookback=2)

print("=" * 80)
print("GRU IMPLEMENTATION IN KERAS WITH STATE EXTRACTION")
print("=" * 80)
print(f"Data shape: {X.shape} (10 days with 3 features each)")
print(f"Normalized sequence shape: {X_sequences.shape} (8 sequences with lookback=2)")
print("=" * 80)

# Define GRU dimensions
n_features = 3  # Temperature, humidity, wind speed
n_hidden = 4    # Number of GRU units (same as previous example)
lookback = 2    # Same as previous example

GRU IMPLEMENTATION IN KERAS WITH STATE EXTRACTION
Data shape: (10, 3) (10 days with 3 features each)
Normalized sequence shape: (8, 2, 3) (8 sequences with lookback=2)


In [13]:
# Create and compile the model
model = Sequential([
    tf.keras.layers.Input(shape=(lookback, n_features)),
    GRU(n_hidden, 
        return_sequences=False, stateful=False),
    Dense(1)
])

model.compile(optimizer='adam', loss='mse')

# Print model summary
print("\nModel Architecture:")
model.summary()


Model Architecture:


In [16]:
# Extract initial weights from the GRU layer
gru_layer = model.layers[0]
gru_weights = gru_layer.get_weights()

# In Keras GRU, weights are organized as:
# [kernel (input weights), recurrent_kernel (hidden state weights), bias]
kernel = gru_weights[0]         # Input weights (shape: [n_features, n_hidden*3])
recurrent_kernel = gru_weights[1]   # Hidden state weights (shape: [n_hidden, n_hidden*3])
bias = gru_weights[2]           # Bias terms (shape: [n_hidden*3])

# Keras concatenates all gate weights, so we need to slice them
# The order is: z (update), r (reset), h (candidate)
units = n_hidden
input_dim = n_features

# Extract weights for each gate (reshape to match our earlier format)
# Input weights
W_z_x = kernel[:, :units]
W_r_x = kernel[:, units:units*2]
W_h_x = kernel[:, units*2:units*3]

# Recurrent weights (from previous hidden state)
W_z_h = recurrent_kernel[:, :units]
W_r_h = recurrent_kernel[:, units:units*2]
W_h_h = recurrent_kernel[:, units*2:units*3]

# Bias terms
b_z = bias[:units]
b_r = bias[units:units*2]
b_h = bias[units*2:units*3]

print("+"*62)
print("Extracted GRU Weights:".center(62))
print("+"*62)
print(f"{'Update gate weights (W_z_x):':<35} {W_z_x.shape}")
print(f"{'Reset gate weights (W_r_x):':<35} {W_r_x.shape}")
print(f"{'Candidate weights (W_h_x):':<35} {W_h_x.shape}")
print(f"{'Recurrent weights shapes:':<35} {W_z_h.shape, W_r_h.shape, W_h_h.shape}")
print('+'*62)

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                    Extracted GRU Weights:                    
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Update gate weights (W_z_x):        (3, 4)
Reset gate weights (W_r_x):         (3, 4)
Candidate weights (W_h_x):          (3, 4)
Recurrent weights shapes:           ((4, 4), (4, 4), (4, 4))
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++


In [14]:
# Extract initial weights from the GRU layer
gru_layer = model.layers[0]
gru_weights = gru_layer.get_weights()

# In Keras GRU, weights are organized differently than our implementation expects
kernel = gru_weights[0]         # Input weights (shape: [n_features, n_hidden*3])
recurrent_kernel = gru_weights[1]   # Hidden state weights (shape: [n_hidden, n_hidden*3])
bias = gru_weights[2]           # Bias terms (shape: [n_hidden*3])

# The order in Keras GRU is: z (update), r (reset), h (candidate)
units = n_hidden

# Input weights for each part of the GRU
W_z_x = kernel[:, :units]                # Update gate input weights
W_r_x = kernel[:, units:units*2]         # Reset gate input weights
W_h_x = kernel[:, units*2:units*3]       # Candidate hidden state input weights

# Recurrent weights for each part of the GRU
W_z_h = recurrent_kernel[:, :units]      # Update gate recurrent weights
W_r_h = recurrent_kernel[:, units:units*2]  # Reset gate recurrent weights
W_h_h = recurrent_kernel[:, units*2:units*3]  # Candidate recurrent weights

# Bias terms
b_z = bias[:units]              # Update gate bias
b_r = bias[units:units*2]       # Reset gate bias
b_h = bias[units*2:units*3]     # Candidate hidden state bias

In [15]:
# Define activation functions 
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def tanh(x):
    return np.tanh(x)

# Custom forward pass to replicate Keras's internal GRU computation
def gru_forward_pass(X_sequences, weights):
    kernel, recurrent_kernel, bias = weights
    n_sequences = X_sequences.shape[0]
    predictions = np.zeros(n_sequences)
    
    # Initialize states (GRU only has hidden state)
    h_t = np.zeros(n_hidden)
    
    all_states = []
    
    for seq_idx, sequence in enumerate(X_sequences):
        print(f"\n{'=' * 30} SEQUENCE {seq_idx+1} {'=' * 30}")
        print(f"Days {seq_idx+1}-{seq_idx+2} → Predicting Day {seq_idx+3}")
        
        sequence_states = []
        
        for t in range(len(sequence)):
            x_t = sequence[t]
            day_num = seq_idx + t + 1
            
            # Calculate gate inputs
            # In Keras, the computations are done with the weights organized differently:
            # Input contribution to gates

            z_input = np.dot(x_t, W_z_x)         # shape should be [n_hidden]
            r_input = np.dot(x_t, W_r_x)         # shape should be [n_hidden]
            h_input = np.dot(x_t, W_h_x)         # shape should be [n_hidden]

            
            
            
            # Hidden state contribution to gates
            z_hidden = np.dot(h_t, W_z_h)        # shape should be [n_hidden]
            r_hidden = np.dot(h_t, W_r_h)        # shape should be [n_hidden]

            
            # Gate activations with bias - all these should have shape [n_hidden]
            z_t = sigmoid(z_input + z_hidden + b_z)
            r_t = sigmoid(r_input + r_hidden + b_r)
            
            # Calculate reset hidden state
            reset_h_t = r_t * h_t
            
            # Calculate candidate hidden state
            # Note: For the candidate, the recurrent calculation is done after applying the reset gate
            h_hidden = np.dot(reset_h_t, W_h_h)
            h_tilde = tanh(h_input + h_hidden + b_h)
            
            # Update hidden state
            h_t_prev = h_t.copy()
            h_t = (1 - z_t) * h_t + z_t * h_tilde
            
            # Store states
            sequence_states.append({
                'day': day_num,
                'input': x_t,
                'update_gate': z_t,
                'reset_gate': r_t,
                'candidate_hidden': h_tilde,
                'hidden_state': h_t.copy()
            })
            
            # Print gate values and states for this time step
            print(f"\n--- Time step {t+1} (Day {day_num}) ---")
            print("Input Features:")
            input_table = [
                ["Temperature", f"{x_t[0]:.2f}"],
                ["Humidity", f"{x_t[1]:.2f}"],
                ["Wind Speed", f"{x_t[2]:.2f}"]
            ]
            print(tabulate(input_table, headers=["Feature", "Value"], tablefmt="grid"))
            
            print("\nGate Values:")
            gates_table = [
                ["Update Gate (z_t)", f"{z_t}"],
                ["Reset Gate (r_t)", f"{r_t}"],
                ["Candidate Hidden State (h_tilde)", f"{h_tilde}"]
            ]
            print(tabulate(gates_table, headers=["Gate", "Values"], tablefmt="grid"))
            
            print("\nState Updates:")
            state_table = [
                ["Previous Hidden State", f"{h_t_prev}"],
                ["New Hidden State", f"{h_t}"],
                ["Change in Hidden State", f"{h_t - h_t_prev}"]
            ]
            print(tabulate(state_table, headers=["State", "Values"], tablefmt="grid"))
        
        # Predict using the final hidden state (replicating Dense layer)
        W_y = model.layers[1].get_weights()[0]
        b_y = model.layers[1].get_weights()[1]
        y_pred = np.dot(h_t, W_y) + b_y
        predictions[seq_idx] = y_pred[0]
        
        # Print prediction vs actual
        # Convert back to original scale for clearer comparison
        original_pred = y_scaler.inverse_transform([[y_pred[0]]])[0][0]
        original_actual = y_scaler.inverse_transform([[y_sequences[seq_idx]]])[0][0]
        
        print(f"\n--- Prediction for Day {seq_idx+3} ---")
        pred_table = [
            ["Predicted Power Consumption (normalized)", f"{y_pred[0]:.4f}"],
            ["Actual Power Consumption (normalized)", f"{y_sequences[seq_idx]:.4f}"],
            ["Prediction Error (normalized)", f"{y_pred[0] - y_sequences[seq_idx]:.4f}"],
            ["Predicted Power Consumption (kWh)", f"{original_pred:.2f}"],
            ["Actual Power Consumption (kWh)", f"{original_actual:.2f}"],
            ["Prediction Error (kWh)", f"{original_pred - original_actual:.2f}"]
        ]
        print(tabulate(pred_table, headers=["Metric", "Value"], tablefmt="grid"))
        
        all_states.append(sequence_states)
        
        print(f"\nFinal state after sequence {seq_idx+1}:")
        final_state_table = [
            ["Hidden State (h_t)", f"{h_t}"]
        ]
        print(tabulate(final_state_table, headers=["State", "Values"], tablefmt="grid"))
        
        if seq_idx < len(X_sequences) - 1:
            print("\n→ This state will be used as initial state for the next sequence")
    
    return predictions, all_states

# Run our custom forward pass
print("\n" + "=" * 80)
print("CUSTOM FORWARD PASS USING EXTRACTED WEIGHTS")
print("=" * 80)
predictions, all_states = gru_forward_pass(X_sequences, gru_weights)

# Compare with Keras model predictions
keras_predictions = model.predict(X_sequences, verbose=0)

print("\n" + "=" * 80)
print("COMPARISON WITH KERAS MODEL PREDICTIONS")
print("=" * 80)
comparison_table = []
for i in range(len(predictions)):
    # Original scale values
    orig_pred = y_scaler.inverse_transform([[predictions[i]]])[0][0]
    orig_keras = y_scaler.inverse_transform([[keras_predictions[i][0]]])[0][0]
    orig_actual = y_scaler.inverse_transform([[y_sequences[i]]])[0][0]
    
    comparison_table.append([
        i+3,  # Day number (after lookback)
        f"{predictions[i]:.4f}",
        f"{keras_predictions[i][0]:.4f}",
        f"{y_sequences[i]:.4f}",
        f"{orig_pred:.2f}",
        f"{orig_keras:.2f}",
        f"{orig_actual:.2f}"
    ])

print(tabulate(comparison_table, 
               headers=["Day", "Our Pred (norm)", "Keras Pred (norm)", "Actual (norm)", 
                        "Our Pred (kWh)", "Keras Pred (kWh)", "Actual (kWh)"], 
               tablefmt="grid"))

print("\nNote: Our predictions and Keras predictions might differ slightly due to")
print("implementation details and numerical precision differences.")

# Visualize the state evolution
days = []
hidden_state_norms = []

for seq_idx, sequence_states in enumerate(all_states):
    for state in sequence_states:
        days.append(state['day'])
        hidden_state_norms.append(np.linalg.norm(state['hidden_state']))

plt.figure(figsize=(12, 6))
plt.plot(days, hidden_state_norms, 's-', label='Hidden State Norm')
plt.xlabel('Day')
plt.ylabel('State Norm (Magnitude)')
plt.title('Evolution of GRU Hidden State Over Time')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig('gru_state_evolution.png')
plt.close()

print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)
print("The code has successfully:")
print("1. Created a GRU model in Keras with the same architecture")
print("2. Extracted the weights from the Keras model")
print("3. Implemented a custom forward pass to show the internal calculations")
print("4. Compared our predictions with the Keras model predictions")
print("5. Generated a visualization of how the hidden state evolves over time")
print("\nA plot of state evolution has been saved as 'gru_state_evolution.png'")


CUSTOM FORWARD PASS USING EXTRACTED WEIGHTS

Days 1-2 → Predicting Day 3


ValueError: operands could not be broadcast together with shapes (4,) (2,12) 

In [19]:
# Custom forward pass to replicate Keras's internal GRU computation
def gru_forward_pass(X_sequences, weights):
    kernel, recurrent_kernel, bias = weights
    n_sequences = X_sequences.shape[0]
    predictions = np.zeros(n_sequences)
    
    # Initialize hidden state
    h_t = np.zeros(n_hidden)
    
    all_states = []
    
    for seq_idx, sequence in enumerate(X_sequences):
        print(f"\n{'=' * 30} SEQUENCE {seq_idx+1} {'=' * 30}")
        print(f"Days {seq_idx+1}-{seq_idx+2} → Predicting Day {seq_idx+3}")
        
        sequence_states = []
        
        for t in range(len(sequence)):
            x_t = sequence[t]
            day_num = seq_idx + t + 1
            
            # Calculate gate inputs
            # Important: Make sure dimensions are correct
            # x_t shape is [n_features], W_z_x shape is [n_features, n_hidden]
            # Result of np.dot(x_t, W_z_x) should be [n_hidden]
            
            # Input contribution to gates
            z_input = np.dot(x_t, W_z_x)  # Should be shape [n_hidden]
            r_input = np.dot(x_t, W_r_x)  # Should be shape [n_hidden]
            h_input = np.dot(x_t, W_h_x)  # Should be shape [n_hidden]
            
            # Hidden state contribution to gates
            # h_t shape is [n_hidden], W_z_h shape is [n_hidden, n_hidden]
            # Result should be [n_hidden]
            z_hidden = np.dot(h_t, W_z_h)  # Should be shape [n_hidden]
            r_hidden = np.dot(h_t, W_r_h)  # Should be shape [n_hidden]
            
            # Gate activations with bias
            # All of these terms should be shape [n_hidden]
            z_t = sigmoid(z_input + z_hidden + b_z)
            r_t = sigmoid(r_input + r_hidden + b_r)
            
            # Reset hidden state
            reset_h_t = r_t * h_t  # Element-wise multiplication, shape [n_hidden]
            
            # Candidate hidden state
            h_hidden = np.dot(reset_h_t, W_h_h)  # Should be shape [n_hidden]
            h_tilde = tanh(h_input + h_hidden + b_h)
            
            # Update hidden state
            h_t_prev = h_t.copy()
            h_t = (1 - z_t) * h_t + z_t * h_tilde  # Element-wise operations
            
            # Store states
            sequence_states.append({
                'day': day_num,
                'input': x_t,
                'update_gate': z_t,
                'reset_gate': r_t,
                'candidate_hidden': h_tilde,
                'hidden_state': h_t.copy()
            })
            
            # Print information...
            # [rest of the printing code]
            
        # Make prediction using the final hidden state
        W_y = model.layers[1].get_weights()[0]
        b_y = model.layers[1].get_weights()[1]
        y_pred = np.dot(h_t, W_y) + b_y
        predictions[seq_idx] = y_pred[0]
        
        # [rest of the code]
    
    return predictions, all_states

In [18]:
predictions, all_states = gru_forward_pass(X_sequences, gru_weights)


Days 1-2 → Predicting Day 3


ValueError: operands could not be broadcast together with shapes (4,) (2,12) 

In [20]:
# Shape verification code
print("=" * 80)
print("SHAPE VERIFICATION FOR GRU FORWARD PASS")
print("=" * 80)

# Access the first time step of the first sequence for testing
x_t = X_sequences[0][0]  # First input of first sequence
h_t = np.zeros(n_hidden)  # Initial hidden state

# Weight shapes
print("Weight shapes:")
print(f"W_z_x shape: {W_z_x.shape}")
print(f"W_r_x shape: {W_r_x.shape}")
print(f"W_h_x shape: {W_h_x.shape}")
print(f"W_z_h shape: {W_z_h.shape}")
print(f"W_r_h shape: {W_r_h.shape}")
print(f"W_h_h shape: {W_h_h.shape}")
print(f"b_z shape: {b_z.shape}")
print(f"b_r shape: {b_r.shape}")
print(f"b_h shape: {b_h.shape}")

# Input shapes
print("\nInput shapes:")
print(f"x_t shape: {x_t.shape}")
print(f"h_t shape: {h_t.shape}")

# Compute intermediate values
z_input = np.dot(x_t, W_z_x)
r_input = np.dot(x_t, W_r_x)
h_input = np.dot(x_t, W_h_x)

z_hidden = np.dot(h_t, W_z_h)
r_hidden = np.dot(h_t, W_r_h)

# Print computed shapes
print("\nComputed intermediate shapes:")
print(f"z_input shape: {z_input.shape}")
print(f"r_input shape: {r_input.shape}")
print(f"h_input shape: {h_input.shape}")
print(f"z_hidden shape: {z_hidden.shape}")
print(f"r_hidden shape: {r_hidden.shape}")

# Test addition compatibility
try:
    # Test update gate addition
    update_sum = z_input + z_hidden + b_z
    print("\nAddition tests:")
    print(f"z_input + z_hidden + b_z shape: {update_sum.shape} ✓")
    
    # Test reset gate addition
    reset_sum = r_input + r_hidden + b_r
    print(f"r_input + r_hidden + b_r shape: {reset_sum.shape} ✓")
    
    # Test the rest of the operations
    r_t = sigmoid(reset_sum)
    reset_h_t = r_t * h_t
    h_hidden = np.dot(reset_h_t, W_h_h)
    h_sum = h_input + h_hidden + b_h
    
    print(f"reset_h_t shape: {reset_h_t.shape} ✓")
    print(f"h_hidden shape: {h_hidden.shape} ✓")
    print(f"h_input + h_hidden + b_h shape: {h_sum.shape} ✓")
    
    print("\nAll shape tests passed! The forward pass should work correctly.")
except Exception as e:
    print(f"\nShape mismatch error: {e}")
    print("Fix the dimensions before proceeding with the forward pass.")

SHAPE VERIFICATION FOR GRU FORWARD PASS
Weight shapes:
W_z_x shape: (3, 4)
W_r_x shape: (3, 4)
W_h_x shape: (3, 4)
W_z_h shape: (4, 4)
W_r_h shape: (4, 4)
W_h_h shape: (4, 4)
b_z shape: (2, 12)
b_r shape: (0, 12)
b_h shape: (0, 12)

Input shapes:
x_t shape: (3,)
h_t shape: (4,)

Computed intermediate shapes:
z_input shape: (4,)
r_input shape: (4,)
h_input shape: (4,)
z_hidden shape: (4,)
r_hidden shape: (4,)

Shape mismatch error: operands could not be broadcast together with shapes (4,) (2,12) 
Fix the dimensions before proceeding with the forward pass.


In [21]:
# Extract initial weights from the GRU layer
gru_layer = model.layers[0]
gru_weights = gru_layer.get_weights()

# In Keras GRU, weights are organized as:
# [kernel (input weights), recurrent_kernel (hidden state weights), bias]
kernel = gru_weights[0]         # Input weights (shape: [n_features, n_hidden*3])
recurrent_kernel = gru_weights[1]   # Hidden state weights (shape: [n_hidden, n_hidden*3])
bias = gru_weights[2]           # Bias terms (shape: [n_hidden*3])

# The order is: z (update), r (reset), h (candidate)
units = n_hidden

# Input weights for each part of the GRU
W_z_x = kernel[:, :units]
W_r_x = kernel[:, units:units*2]
W_h_x = kernel[:, units*2:units*3]

# Recurrent weights for each part of the GRU
W_z_h = recurrent_kernel[:, :units]
W_r_h = recurrent_kernel[:, units:units*2]
W_h_h = recurrent_kernel[:, units*2:units*3]

# Bias terms - THIS IS THE PART THAT NEEDS FIXING
# In Keras, for GRU, bias is organized as [b_z, b_r, b_h] with each having shape [units]
# and total bias has shape [units*3]
b_z = bias[:units]
b_r = bias[units:units*2]
b_h = bias[units*2:units*3]

# Print to verify the new shapes
print("Bias shapes after correction:")
print(f"b_z shape: {b_z.shape}")
print(f"b_r shape: {b_r.shape}")
print(f"b_h shape: {b_h.shape}")

Bias shapes after correction:
b_z shape: (2, 12)
b_r shape: (0, 12)
b_h shape: (0, 12)


In [22]:
# Print raw bias information
print("Raw bias information:")
print(f"Total bias shape: {bias.shape}")
print(f"Total bias content: {bias}")

# Try a different approach to extract biases correctly
total_units = n_hidden * 3  # Total units for all gates combined

# Check if bias is organized differently than expected
if bias.shape[0] == 2 and bias.shape[1] == 12:
    # If bias has shape (2, 12), it might be organized as:
    # First row: Input bias, Second row: Recurrent bias
    # Each row has [b_z, b_r, b_h] concatenated
    print("\nDetected special bias format (2, 12)...")
    
    # Extract from first row (input bias)
    input_bias = bias[0]
    b_z = input_bias[:n_hidden]
    b_r = input_bias[n_hidden:2*n_hidden]
    b_h = input_bias[2*n_hidden:3*n_hidden]
else:
    # Try flattening if it's not already a 1D array
    bias_flat = bias.flatten()
    if len(bias_flat) == total_units:
        print("\nUsing flattened bias...")
        b_z = bias_flat[:n_hidden]
        b_r = bias_flat[n_hidden:2*n_hidden]
        b_h = bias_flat[2*n_hidden:3*n_hidden]
    else:
        # As a fallback, create zeros
        print("\nUsing zero biases as fallback...")
        b_z = np.zeros(n_hidden)
        b_r = np.zeros(n_hidden)
        b_h = np.zeros(n_hidden)

print("\nCorrected bias shapes:")
print(f"b_z shape: {b_z.shape}")
print(f"b_r shape: {b_r.shape}")
print(f"b_h shape: {b_h.shape}")

Raw bias information:
Total bias shape: (2, 12)
Total bias content: [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

Detected special bias format (2, 12)...

Corrected bias shapes:
b_z shape: (4,)
b_r shape: (4,)
b_h shape: (4,)


In [23]:
# Define activation functions (same as previous example)
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def tanh(x):
    return np.tanh(x)

# Custom forward pass to replicate Keras's internal GRU computation
def gru_forward_pass(X_sequences, weights):
    kernel, recurrent_kernel, bias = weights
    n_sequences = X_sequences.shape[0]
    predictions = np.zeros(n_sequences)
    
    # Extract weights properly
    W_z_x = kernel[:, :n_hidden]
    W_r_x = kernel[:, n_hidden:n_hidden*2]
    W_h_x = kernel[:, n_hidden*2:n_hidden*3]
    
    W_z_h = recurrent_kernel[:, :n_hidden]
    W_r_h = recurrent_kernel[:, n_hidden:n_hidden*2]
    W_h_h = recurrent_kernel[:, n_hidden*2:n_hidden*3]
    
    # Extract biases properly from the first row (input bias)
    input_bias = bias[0]
    b_z = input_bias[:n_hidden]
    b_r = input_bias[n_hidden:2*n_hidden]
    b_h = input_bias[2*n_hidden:3*n_hidden]
    
    # Initialize hidden state
    h_t = np.zeros(n_hidden)
    
    all_states = []
    
    for seq_idx, sequence in enumerate(X_sequences):
        print(f"\n{'=' * 30} SEQUENCE {seq_idx+1} {'=' * 30}")
        print(f"Days {seq_idx+1}-{seq_idx+2} → Predicting Day {seq_idx+3}")
        
        sequence_states = []
        
        for t in range(len(sequence)):
            x_t = sequence[t]
            day_num = seq_idx + t + 1
            
            # Calculate gate inputs
            # Input contribution to gates
            z_input = np.dot(x_t, W_z_x)
            r_input = np.dot(x_t, W_r_x)
            h_input = np.dot(x_t, W_h_x)
            
            # Hidden state contribution to gates
            z_hidden = np.dot(h_t, W_z_h)
            r_hidden = np.dot(h_t, W_r_h)
            
            # Gate activations with bias
            z_t = sigmoid(z_input + z_hidden + b_z)
            r_t = sigmoid(r_input + r_hidden + b_r)
            
            # Calculate reset hidden state
            reset_h_t = r_t * h_t
            
            # Calculate candidate hidden state
            h_hidden = np.dot(reset_h_t, W_h_h)
            h_tilde = tanh(h_input + h_hidden + b_h)
            
            # Update hidden state
            h_t_prev = h_t.copy()
            h_t = (1 - z_t) * h_t + z_t * h_tilde
            
            # Store states
            sequence_states.append({
                'day': day_num,
                'input': x_t,
                'update_gate': z_t,
                'reset_gate': r_t,
                'candidate_hidden': h_tilde,
                'hidden_state': h_t.copy()
            })
            
            # Print gate values and states for this time step
            print(f"\n--- Time step {t+1} (Day {day_num}) ---")
            print("Input Features:")
            input_table = [
                ["Temperature", f"{x_t[0]:.2f}"],
                ["Humidity", f"{x_t[1]:.2f}"],
                ["Wind Speed", f"{x_t[2]:.2f}"]
            ]
            print(tabulate(input_table, headers=["Feature", "Value"], tablefmt="grid"))
            
            print("\nGate Values:")
            gates_table = [
                ["Update Gate (z_t)", f"{z_t}"],
                ["Reset Gate (r_t)", f"{r_t}"],
                ["Candidate Hidden State (h_tilde)", f"{h_tilde}"]
            ]
            print(tabulate(gates_table, headers=["Gate", "Values"], tablefmt="grid"))
            
            print("\nState Updates:")
            state_table = [
                ["Previous Hidden State", f"{h_t_prev}"],
                ["New Hidden State", f"{h_t}"],
                ["Change in Hidden State", f"{h_t - h_t_prev}"]
            ]
            print(tabulate(state_table, headers=["State", "Values"], tablefmt="grid"))
        
        # Predict using the final hidden state (replicating Dense layer)
        W_y = model.layers[1].get_weights()[0]
        b_y = model.layers[1].get_weights()[1]
        y_pred = np.dot(h_t, W_y) + b_y
        predictions[seq_idx] = y_pred[0]
        
        # Print prediction vs actual
        # Convert back to original scale for clearer comparison
        original_pred = y_scaler.inverse_transform([[y_pred[0]]])[0][0]
        original_actual = y_scaler.inverse_transform([[y_sequences[seq_idx]]])[0][0]
        
        print(f"\n--- Prediction for Day {seq_idx+3} ---")
        pred_table = [
            ["Predicted Power Consumption (normalized)", f"{y_pred[0]:.4f}"],
            ["Actual Power Consumption (normalized)", f"{y_sequences[seq_idx]:.4f}"],
            ["Prediction Error (normalized)", f"{y_pred[0] - y_sequences[seq_idx]:.4f}"],
            ["Predicted Power Consumption (kWh)", f"{original_pred:.2f}"],
            ["Actual Power Consumption (kWh)", f"{original_actual:.2f}"],
            ["Prediction Error (kWh)", f"{original_pred - original_actual:.2f}"]
        ]
        print(tabulate(pred_table, headers=["Metric", "Value"], tablefmt="grid"))
        
        all_states.append(sequence_states)
        
        print(f"\nFinal state after sequence {seq_idx+1}:")
        final_state_table = [
            ["Hidden State (h_t)", f"{h_t}"]
        ]
        print(tabulate(final_state_table, headers=["State", "Values"], tablefmt="grid"))
        
        if seq_idx < len(X_sequences) - 1:
            print("\n→ This state will be used as initial state for the next sequence")
    
    return predictions, all_states

# Run our custom forward pass
print("\n" + "=" * 80)
print("CUSTOM FORWARD PASS USING EXTRACTED WEIGHTS")
print("=" * 80)
predictions, all_states = gru_forward_pass(X_sequences, gru_weights)

# Compare with Keras model predictions
keras_predictions = model.predict(X_sequences, verbose=0)

print("\n" + "=" * 80)
print("COMPARISON WITH KERAS MODEL PREDICTIONS")
print("=" * 80)
comparison_table = []
for i in range(len(predictions)):
    # Original scale values
    orig_pred = y_scaler.inverse_transform([[predictions[i]]])[0][0]
    orig_keras = y_scaler.inverse_transform([[keras_predictions[i][0]]])[0][0]
    orig_actual = y_scaler.inverse_transform([[y_sequences[i]]])[0][0]
    
    comparison_table.append([
        i+3,  # Day number (after lookback)
        f"{predictions[i]:.4f}",
        f"{keras_predictions[i][0]:.4f}",
        f"{y_sequences[i]:.4f}",
        f"{orig_pred:.2f}",
        f"{orig_keras:.2f}",
        f"{orig_actual:.2f}"
    ])

print(tabulate(comparison_table, 
               headers=["Day", "Our Pred (norm)", "Keras Pred (norm)", "Actual (norm)", 
                        "Our Pred (kWh)", "Keras Pred (kWh)", "Actual (kWh)"], 
               tablefmt="grid"))

print("\nNote: Our predictions and Keras predictions might differ slightly due to")
print("implementation details and numerical precision differences.")

# Visualize the state evolution
days = []
hidden_state_norms = []

for seq_idx, sequence_states in enumerate(all_states):
    for state in sequence_states:
        days.append(state['day'])
        hidden_state_norms.append(np.linalg.norm(state['hidden_state']))

plt.figure(figsize=(12, 6))
plt.plot(days, hidden_state_norms, 's-', label='Hidden State Norm')
plt.xlabel('Day')
plt.ylabel('State Norm (Magnitude)')
plt.title('Evolution of GRU Hidden State Over Time')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig('gru_state_evolution.png')
plt.close()

print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)
print("The code has successfully:")
print("1. Created a GRU model in Keras with the same architecture")
print("2. Extracted the weights from the Keras model")
print("3. Implemented a custom forward pass to show the internal calculations")
print("4. Compared our predictions with the Keras model predictions")
print("5. Generated a visualization of how the hidden state evolves over time")
print("\nA plot of state evolution has been saved as 'gru_state_evolution.png'")

# Add a comparison of the two approaches visualization
# Plot the predictions in original scale
plt.figure(figsize=(12, 6))
days = list(range(3, 3 + len(predictions)))
orig_preds = [y_scaler.inverse_transform([[p]])[0][0] for p in predictions]
orig_keras = [y_scaler.inverse_transform([[k[0]]])[0][0] for k in keras_predictions]
orig_actual = [y_scaler.inverse_transform([[a]])[0][0] for a in y_sequences]

plt.plot(days, orig_preds, 'o-', label='Our GRU Implementation', linewidth=2)
plt.plot(days, orig_keras, 's--', label='Keras GRU Implementation', linewidth=2)
plt.plot(days, orig_actual, 'D-.', label='Actual Values', linewidth=2)
plt.xlabel('Day')
plt.ylabel('Power Consumption (kWh)')
plt.title('Comparison of GRU Predictions vs Actual Values')
plt.grid(True, alpha=0.3)
plt.legend()
plt.xticks(days)
plt.tight_layout()
plt.savefig('gru_prediction_comparison.png')
plt.close()

print("A comparison of prediction methods has been saved as 'gru_prediction_comparison.png'")


CUSTOM FORWARD PASS USING EXTRACTED WEIGHTS

Days 1-2 → Predicting Day 3

--- Time step 1 (Day 1) ---
Input Features:
+-------------+---------+
| Feature     |   Value |
| Temperature |    0.09 |
+-------------+---------+
| Humidity    |    0.9  |
+-------------+---------+
| Wind Speed  |    0    |
+-------------+---------+

Gate Values:
+----------------------------------+---------------------------------------------------+
| Gate                             | Values                                            |
| Update Gate (z_t)                | [0.56598987 0.47466015 0.4248398  0.6267074 ]     |
+----------------------------------+---------------------------------------------------+
| Reset Gate (r_t)                 | [0.62577714 0.49682496 0.51366921 0.57677263]     |
+----------------------------------+---------------------------------------------------+
| Candidate Hidden State (h_tilde) | [ 0.45328354 -0.22043176  0.15508197  0.42716322] |
+----------------------------------+