In [None]:
# testing recurrent NN


import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.optimizers import Adam


In [None]:

# Load data
Design_space = np.load('Design_space.npy')
moment_capacity = np.load('Mu.npy')

# Check shapes
print("Design_space shape:", Design_space.shape)
print("moment_capacity shape:", moment_capacity.shape)

# Reshape moment_capacity if necessary
if len(moment_capacity.shape) == 1:  # Assuming moment_capacity is 1D
    moment_capacity = moment_capacity.reshape(-1, 1)  # Reshape to a 2D array with one column


In [None]:

# Split the data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(moment_capacity, Design_space, test_size=0.2, random_state=42)

# Reshape input data for LSTM
X_train = X_train.reshape(X_train.shape[0], 1, X_train.shape[1])  # Reshape to (batch_size, timesteps, input_dim)
X_val = X_val.reshape(X_val.shape[0], 1, X_val.shape[1])  # Reshape to (batch_size, timesteps, input_dim)

# Define input and output dimensions
input_dim = X_train.shape[2]
output_dim = Design_space.shape[1]

# Define the LSTM model
model = Sequential([
    LSTM(64, input_shape=(1, input_dim), activation='relu'),  # LSTM layer with 64 units
    Dense(output_dim)  # Output layer
])

# Compile the model
optimizer = Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='mse')  # Using mean squared error loss

# Train the model
history = model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=50, batch_size=16, verbose=1)

# Save the trained model
model.save('rnn_model.h5')
print("Model saved as 'rnn_model.h5'")





In [None]:
# Plot training and validation loss
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()