In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

def build_model(input_dim, hidden_dims, activation='ReLU'):
    """
    Constructs the same architecture you used in train_nn,
    but without training.
    """
    act_key = activation.strip()
    layers = []
    in_dim = input_dim
    for h in hidden_dims:
        layers.append(nn.Linear(in_dim, h))
        if act_key.lower() not in ('linear', 'none'):
            layers.append(getattr(nn, activation)())
        in_dim = h
    layers.append(nn.Linear(in_dim, 1))
    return nn.Sequential(*layers)

# --- 1) Load test set ---
df_test = pd.read_csv('test.csv')  # replace with your actual test file path

# assume inputs are named "Input 1"…"Input 18"
input_cols = [f'Input {i}' for i in range(1, 19)]
X_test = df_test[input_cols].values.astype(np.float32)
y_true = df_test['Output 1'].values.astype(np.float32)

# --- 2) Rebuild & load the trained model ---
hidden_dims    = [2, 2]      # must match what you trained with
activation     = 'ReLU'      # same here
model_filepath = 'trained_model_ab12cd34.pt'  # your actual .pt file

model = build_model(X_test.shape[1], hidden_dims, activation)
model.load_state_dict(torch.load(model_filepath))
model.eval()

# --- 3) Predict on test set ---
with torch.no_grad():
    y_pred = model(torch.from_numpy(X_test)).cpu().numpy().ravel()

# --- 4) Scatter plot: True vs. Predicted ---
plt.figure()
plt.scatter(y_true, y_pred, alpha=0.6)
mn, mx = min(y_true.min(), y_pred.min()), max(y_true.max(), y_pred.max())
plt.plot([mn, mx], [mn, mx], 'r--', linewidth=1)
plt.xlabel('True Output')
plt.ylabel('Predicted Output')
plt.title('Test Set: True vs. Predicted')
plt.tight_layout()
plt.show()

# --- 5) Compute R² score ---
r2 = r2_score(y_true, y_pred)
print(f"Test R²: {r2:.4f}")

# --- 6) Histogram of predicted outputs ---
plt.figure()
plt.hist(y_pred, bins=30)
plt.xlabel('Predicted Output')
plt.ylabel('Frequency')
plt.title('Histogram of Model Predictions on Test Set')
plt.tight_layout()
plt.show()
