In [3]:
'''# Model Analysis Notebook

This notebook is for analyzing the performance of the trained LSTM model.

---

## 1. Import Required Libraries
''' 
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from model import LSTMModel
from data_loader import load_and_preprocess_data
from sklearn.metrics import mean_squared_error, mean_absolute_error


KeyboardInterrupt: 

In [4]:
# Load model
from model import LSTMModel 
print("✔")
model = LSTMModel()
model.load_state_dict(torch.load('../models/saved/stock_lstm_model.pth'))
model.eval()

print("Model loaded successfully.")


ModuleNotFoundError: No module named 'model'

In [None]:
# Load test dataset
file_path = '../data/raw/indexProcessed.csv'
_, _, test_dataset, scaler = load_and_preprocess_data(file_path)

# Extract original values and prepare input sequences
test_data = [scaler.inverse_transform([[x]])[0][0] for x in test_dataset.data]
sequence_length = 50
actual_prices = test_data[sequence_length:]


In [None]:
predictions = []
for i in range(sequence_length, len(test_data)):
    input_seq = torch.tensor(test_data[i-sequence_length:i], dtype=torch.float32).unsqueeze(0).unsqueeze(-1)
    with torch.no_grad():
        prediction = model(input_seq).item()
    predictions.append(prediction)

# Rescale predictions
predicted_prices = scaler.inverse_transform(np.array(predictions).reshape(-1, 1)).flatten()


In [None]:
# Plot results
plt.figure(figsize=(12, 6))
plt.plot(actual_prices, label="Actual Prices", color='blue')
plt.plot(predicted_prices, label="Predicted Prices", color='red')
plt.title("Actual vs Predicted Stock Prices")
plt.xlabel("Time")
plt.ylabel("Price (USD)")
plt.legend()
plt.show()


In [None]:
# Calculate metrics
rmse = np.sqrt(mean_squared_error(actual_prices, predicted_prices))
mae = mean_absolute_error(actual_prices, predicted_prices)
mape = np.mean(np.abs((np.array(actual_prices) - np.array(predicted_prices)) / np.array(actual_prices))) * 100

print(f"Root Mean Square Error (RMSE): {rmse}")
print(f"Mean Absolute Error (MAE): {mae}")
print(f"Mean Absolute Percentage Error (MAPE): {mape:.2f}%")
