In [4]:
import torch
from models.base import RidgeModule
from sklearn.linear_model import Ridge
import numpy as np

torch.manual_seed(0)
np.random.seed(0)

# Example data
X = torch.randn(100, 5)
y = 2 * X[:, 0] - 3 * X[:, 2] + 1 + torch.randn(100)  # Linear relationship + noise
y = y.unsqueeze(-1)  # Make y a column vector

# Create and fit the model
ridge = RidgeModule(l2_reg=0.1)
ridge.fit(X, y)
X_train_pred = ridge(X)
train_diff = torch.mean((X_train_pred - y) ** 2).item()
print("train_diff", train_diff)

# Make predictions
X_test = torch.randn(20, 5)
y_pred = ridge(X_test)

print(y_pred.shape)  # Output: torch.Size([20, 1]) - Correct shape

# Test against scikit-learn (for verification)
ridge_sklearn = Ridge(alpha=0.1*100, fit_intercept=True, solver="svd") #alpha in sklearn is alpha_here * N_samples
ridge_sklearn.fit(X.numpy(), y.numpy())
y_pred_sklearn = ridge_sklearn.predict(X_test.numpy())
sklearn_train_pred = ridge_sklearn.predict(X.numpy())
sklearn_train_diff = np.mean((sklearn_train_pred - y.numpy()) ** 2)
print("sklearn_train_diff", sklearn_train_diff)

print("y_pred_sklearn", y_pred_sklearn)
print("myown ypred", y_pred)

print(np.allclose(y_pred.detach().numpy(), y_pred_sklearn, atol=1e-4))  # Compare predictions (should be very close)


#test edge case of 0 std
X = torch.randn(100, 5)
y = torch.ones(100) #all ones
y = y.unsqueeze(-1)  # Make y a column vector

# Create and fit the model
ridge = RidgeModule(l2_reg=0.1)
ridge.fit(X, y)

# Make predictions
X_test = torch.randn(20, 5)
y_pred = ridge(X_test)
print("myown ypred", y_pred) #Should all be 1

W Parameter containing:
tensor([[ 0.5184, -0.0093, -0.7147,  0.0130, -0.0161]], requires_grad=True)
b tensor([[1.0932]], grad_fn=<SubBackward0>)
self.X_mean tensor([[ 0.0324, -0.0213,  0.1413, -0.0072,  0.1780]])
self.y_mean tensor([[0.7722]])
self.y_std tensor([[3.6904]])
EXTRA tensor([[-0.3210]], grad_fn=<MulBackward0>)
train_diff 1.0315852165222168
torch.Size([20, 1])
sklearn_train_diff 1.0315856
y_pred_sklearn [[-1.8953984 ]
 [ 0.02323616]
 [-3.0247264 ]
 [-0.08718228]
 [ 2.81114   ]
 [ 1.594633  ]
 [ 4.8962965 ]
 [ 6.64459   ]
 [-0.5777571 ]
 [ 4.9539213 ]
 [-0.7964897 ]
 [ 1.5736953 ]
 [ 3.9482665 ]
 [-1.0506058 ]
 [ 1.5219548 ]
 [ 1.8978474 ]
 [-5.855605  ]
 [ 3.4686031 ]
 [-0.23702383]
 [-1.8151882 ]]
myown ypred tensor([[-1.8954],
        [ 0.0232],
        [-3.0247],
        [-0.0872],
        [ 2.8111],
        [ 1.5946],
        [ 4.8963],
        [ 6.6446],
        [-0.5778],
        [ 4.9539],
        [-0.7965],
        [ 1.5737],
        [ 3.9483],
        [-1.0506],
   