## LSTM Distilled with Keras

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from tabulate import tabulate
import matplotlib.pyplot as plt

In [2]:
# Set seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Generate our sample data
temperature = np.array([18.2, 19.5, 20.1, 22.4, 23.8, 25.0, 23.2, 21.5, 19.8, 17.5])
humidity = np.array([65.2, 62.8, 58.5, 55.0, 45.2, 42.1, 48.5, 52.3, 60.5, 67.8])
wind_speed = np.array([5.2, 6.8, 8.5, 10.2, 12.5, 14.8, 13.2, 11.5, 9.2, 6.5])
X = np.column_stack((temperature, humidity, wind_speed))

# Target: power consumption (kWh)
y = 2.5 * temperature - 0.5 * humidity + 1.2 * wind_speed + np.random.normal(0, 5, 10)

# Create sequences with lookback of 2 (predict the step after the sequence)
def create_sequences(X, y, lookback=2):
    X_seq, y_seq = [], []
    for i in range(len(X) - lookback):
        X_seq.append(X[i:i+lookback])
        y_seq.append(y[i+lookback])
    return np.array(X_seq), np.array(y_seq)

X_sequences, y_sequences = create_sequences(X, y, lookback=2)

print("=" * 80)
print("LSTM IMPLEMENTATION IN KERAS WITH STATE EXTRACTION")
print("=" * 80)
print(f"Data shape: {X.shape} (10 days with 3 features each)")
print(f"Sequence shape: {X_sequences.shape} (8 sequences with lookback=2)")
print("=" * 80)

# Define LSTM dimensions
n_features = 3  # Temperature, humidity, wind speed
n_hidden = 4    # Number of LSTM units (same as previous example)
lookback = 2    # Same as previous example

LSTM IMPLEMENTATION IN KERAS WITH STATE EXTRACTION
Data shape: (10, 3) (10 days with 3 features each)
Sequence shape: (8, 2, 3) (8 sequences with lookback=2)


In [3]:
# Create and compile the model
model = Sequential([
    tf.keras.layers.Input(shape = (lookback, n_features)),
    LSTM(n_hidden, 
         return_sequences=False, stateful=False),
    Dense(1)
])

model.compile(optimizer='adam', loss='mse')

# Print model summary
print("\nModel Architecture:")
model.summary()


Model Architecture:


In [4]:
# Extract initial weights from the LSTM layer
lstm_layer = model.layers[0]
lstm_weights = lstm_layer.get_weights()

# In Keras LSTM, weights are organized as:
# [kernel (input weights), recurrent_kernel (hidden state weights), bias]
kernel = lstm_weights[0]        # Input weights (shape: [n_features, n_hidden*4])
recurrent_kernel = lstm_weights[1]  # Hidden state weights (shape: [n_hidden, n_hidden*4])
bias = lstm_weights[2]          # Bias terms (shape: [n_hidden*4])

# Keras concatenates all gate weights, so we need to slice them
# The order is: i (input), f (forget), c (cell), o (output)
units = n_hidden
input_dim = n_features

# Extract weights for each gate (reshape to match our earlier format)
# Input weights
W_i_x = kernel[:, :units]
W_f_x = kernel[:, units:units*2]
W_c_x = kernel[:, units*2:units*3]
W_o_x = kernel[:, units*3:]

# Recurrent weights (from previous hidden state)
W_i_h = recurrent_kernel[:, :units]
W_f_h = recurrent_kernel[:, units:units*2]
W_c_h = recurrent_kernel[:, units*2:units*3]
W_o_h = recurrent_kernel[:, units*3:]

# Bias terms
b_i = bias[:units]
b_f = bias[units:units*2]
b_c = bias[units*2:units*3]
b_o = bias[units*3:]

print("+"*62)
print("Extracted LSTM Weights:".center(62))
print("+"*62)

print(f"{'Input gate weights (W_i_x):':<35} {W_i_x.shape}")
print(f"{'Forget gate weights (W_f_x):':<35} {W_f_x.shape}")
print(f"{'Cell candidate weights (W_c_x):':<35} {W_c_x.shape}")
print(f"{'Output gate weights (W_o_x):':<35} {W_o_x.shape}")
print(f"{'Recurrent weights shapes:':<35} {W_i_h.shape, W_f_h.shape, W_c_h.shape, W_o_h.shape}")
print('+'*62)

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                   Extracted LSTM Weights:                    
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Input gate weights (W_i_x):         (3, 4)
Forget gate weights (W_f_x):        (3, 4)
Cell candidate weights (W_c_x):     (3, 4)
Output gate weights (W_o_x):        (3, 4)
Recurrent weights shapes:           ((4, 4), (4, 4), (4, 4), (4, 4))
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++


In [5]:
# Define activation functions (same as previous example)
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def tanh(x):
    return np.tanh(x)

# Custom forward pass to replicate Keras's internal LSTM computation
def lstm_forward_pass(X_sequences, weights):
    kernel, recurrent_kernel, bias = weights
    n_sequences = X_sequences.shape[0]
    predictions = np.zeros(n_sequences)
    
    # Initialize states
    h_t = np.zeros(n_hidden)
    C_t = np.zeros(n_hidden)
    
    all_states = []
    
    for seq_idx, sequence in enumerate(X_sequences):
        print(f"\n{'=' * 30} SEQUENCE {seq_idx+1} {'=' * 30}")
        print(f"Days {seq_idx+1}-{seq_idx+2} → Predicting Day {seq_idx+3}")
        
        sequence_states = []
        
        for t in range(len(sequence)):
            x_t = sequence[t]
            day_num = seq_idx + t + 1
            
            # Calculate gate inputs
            # In Keras, the computations are done with the weights organized differently:
            # Input contribution to gates
            i_input = np.dot(x_t, W_i_x)
            f_input = np.dot(x_t, W_f_x)
            c_input = np.dot(x_t, W_c_x)
            o_input = np.dot(x_t, W_o_x)
            
            # Hidden state contribution to gates
            i_hidden = np.dot(h_t, W_i_h)
            f_hidden = np.dot(h_t, W_f_h)
            c_hidden = np.dot(h_t, W_c_h)
            o_hidden = np.dot(h_t, W_o_h)
            
            # Gate activations with bias
            i_t = sigmoid(i_input + i_hidden + b_i)
            f_t = sigmoid(f_input + f_hidden + b_f)
            c_tilde = tanh(c_input + c_hidden + b_c)
            o_t = sigmoid(o_input + o_hidden + b_o)
            
            # Update cell state
            C_t_prev = C_t.copy()
            C_t = f_t * C_t + i_t * c_tilde
            
            # Update hidden state
            h_t_prev = h_t.copy()
            h_t = o_t * tanh(C_t)
            
            # Store states
            sequence_states.append({
                'day': day_num,
                'input': x_t,
                'forget_gate': f_t,
                'input_gate': i_t,
                'cell_candidate': c_tilde,
                'output_gate': o_t,
                'cell_state': C_t.copy(),
                'hidden_state': h_t.copy()
            })
            
            # Print gate values and states for this time step
            print(f"\n--- Time step {t+1} (Day {day_num}) ---")
            print("Input Features:")
            input_table = [
                ["Temperature", f"{x_t[0]:.2f}°C"],
                ["Humidity", f"{x_t[1]:.2f}%"],
                ["Wind Speed", f"{x_t[2]:.2f} km/h"]
            ]
            print(tabulate(input_table, headers=["Feature", "Value"], tablefmt="grid"))
            
            print("\nGate Values:")
            gates_table = [
                ["Forget Gate (f_t)", f"{f_t}"],
                ["Input Gate (i_t)", f"{i_t}"],
                ["Cell Candidate (c_tilde)", f"{c_tilde}"],
                ["Output Gate (o_t)", f"{o_t}"]
            ]
            print(tabulate(gates_table, headers=["Gate", "Values"], tablefmt="grid"))
            
            print("\nState Updates:")
            state_table = [
                ["Previous Cell State", f"{C_t_prev}"],
                ["New Cell State", f"{C_t}"],
                ["Change in Cell State", f"{C_t - C_t_prev}"],
                ["Previous Hidden State", f"{h_t_prev}"],
                ["New Hidden State", f"{h_t}"],
                ["Change in Hidden State", f"{h_t - h_t_prev}"]
            ]
            print(tabulate(state_table, headers=["State", "Values"], tablefmt="grid"))
        
        # Predict using the final hidden state (replicating Dense layer)
        W_y = model.layers[1].get_weights()[0]
        b_y = model.layers[1].get_weights()[1]
        y_pred = np.dot(h_t, W_y) + b_y
        predictions[seq_idx] = y_pred[0]
        
        # Print prediction vs actual
        print(f"\n--- Prediction for Day {seq_idx+3} ---")
        pred_table = [
            ["Predicted Power Consumption", f"{y_pred[0]:.2f} kWh"],
            ["Actual Power Consumption", f"{y_sequences[seq_idx]:.2f} kWh"],
            ["Prediction Error", f"{y_pred[0] - y_sequences[seq_idx]:.2f} kWh"]
        ]
        print(tabulate(pred_table, headers=["Metric", "Value"], tablefmt="grid"))
        
        all_states.append(sequence_states)
        
        print(f"\nFinal states after sequence {seq_idx+1}:")
        final_state_table = [
            ["Cell State (C_t)", f"{C_t}"],
            ["Hidden State (h_t)", f"{h_t}"]
        ]
        print(tabulate(final_state_table, headers=["State", "Values"], tablefmt="grid"))
        
        if seq_idx < len(X_sequences) - 1:
            print("\n→ These states will be used as initial states for the next sequence")
    
    return predictions, all_states

# Run our custom forward pass
print("\n" + "=" * 80)
print("CUSTOM FORWARD PASS USING EXTRACTED WEIGHTS")
print("=" * 80)
predictions, all_states = lstm_forward_pass(X_sequences, lstm_weights)

# Compare with Keras model predictions
keras_predictions = model.predict(X_sequences, verbose=0)

print("\n" + "=" * 80)
print("COMPARISON WITH KERAS MODEL PREDICTIONS")
print("=" * 80)
comparison_table = []
for i in range(len(predictions)):
    comparison_table.append([
        i+3,  # Day number (after lookback)
        f"{predictions[i]:.4f}",
        f"{keras_predictions[i][0]:.4f}",
        f"{y_sequences[i]:.4f}"
    ])

print(tabulate(comparison_table, 
               headers=["Day", "Our Prediction", "Keras Prediction", "Actual"], 
               tablefmt="grid"))

print("\nNote: Our predictions and Keras predictions might differ slightly due to")
print("implementation details and numerical precision differences.")

# Visualize the state evolution
days = []
cell_state_norms = []
hidden_state_norms = []

for seq_idx, sequence_states in enumerate(all_states):
    for state in sequence_states:
        days.append(state['day'])
        cell_state_norms.append(np.linalg.norm(state['cell_state']))
        hidden_state_norms.append(np.linalg.norm(state['hidden_state']))

plt.figure(figsize=(12, 6))
plt.plot(days, cell_state_norms, 'o-', label='Cell State Norm')
plt.plot(days, hidden_state_norms, 's-', label='Hidden State Norm')
plt.xlabel('Day')
plt.ylabel('State Norm (Magnitude)')
plt.title('Evolution of LSTM States Over Time')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig('lstm_state_evolution.png')
plt.close()

print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)
print("The code has successfully:")
print("1. Created an LSTM model in Keras with the same architecture")
print("2. Extracted the weights from the Keras model")
print("3. Implemented a custom forward pass to show the internal calculations")
print("4. Compared our predictions with the Keras model predictions")
print("5. Generated a visualization of how states evolve over time")
print("\nA plot of state evolution has been saved as 'lstm_state_evolution.png'")


CUSTOM FORWARD PASS USING EXTRACTED WEIGHTS

Days 1-2 → Predicting Day 3

--- Time step 1 (Day 1) ---
Input Features:
+-------------+-----------+
| Feature     | Value     |
| Temperature | 18.20°C   |
+-------------+-----------+
| Humidity    | 65.20%    |
+-------------+-----------+
| Wind Speed  | 5.20 km/h |
+-------------+-----------+

Gate Values:
+--------------------------+---------------------------------------------------------------+
| Gate                     | Values                                                        |
| Forget Gate (f_t)        | [3.54914634e-01 9.99978068e-01 4.54828448e-10 6.38916037e-01] |
+--------------------------+---------------------------------------------------------------+
| Input Gate (i_t)         | [2.99208228e-08 1.00000000e+00 2.19032573e-12 1.67656240e-12] |
+--------------------------+---------------------------------------------------------------+
| Cell Candidate (c_tilde) | [-1.  1.  1.  1.]                                       