In [1]:
import torch
from sklearn.datasets import fetch_california_housing
from src.mssp import MSSP

In [2]:
data = fetch_california_housing()
X = torch.tensor(data.data, dtype=torch.double)
y = torch.tensor(data.target, dtype=torch.double)

In [3]:
RANDOM_SEED = 42
if RANDOM_SEED is not None:
    torch.manual_seed(RANDOM_SEED)
    torch.cuda.manual_seed(RANDOM_SEED)


In [4]:
i = torch.randperm(len(X))
i_train = i[:int(len(X)*0.8)]
i_valid = i[int(len(X)*0.8):int(len(X)*0.9)]
i_test = i[int(len(X)*0.9):]

In [5]:
X_train, y_train = X[i_train], y[i_train]
X_valid, y_valid = X[i_valid], y[i_valid]
X_test, y_test = X[i_test], y[i_test]

In [6]:
print(X_train.shape, y_train.shape)
print(X_valid.shape, y_valid.shape)
print(X_test.shape, y_test.shape)

torch.Size([16512, 8]) torch.Size([16512])
torch.Size([2064, 8]) torch.Size([2064])
torch.Size([2064, 8]) torch.Size([2064])


In [7]:
model = MSSP(
    n_best=120, 
    loss_fn="mse", 
    random_seed=RANDOM_SEED, 
    epochs=10, 
    diversity_ratio=0.75, 
    pow_cross=True,
    # cv=5
)

In [8]:
# model.fit(X_train, y_train, X_valid, y_valid)
model.fit(X_train, y_train)


loss (mse): 0.6163 epoch: 0 , time: 0.42s
loss (mse): 0.4759 epoch: 1 , time: 1.69s
loss (mse): 0.4654 epoch: 2 , time: 1.59s
loss (mse): 0.4540 epoch: 3 , time: 1.62s
loss (mse): 0.4437 epoch: 4 , time: 1.60s
loss (mse): 0.4376 epoch: 5 , time: 1.52s
loss (mse): 0.4319 epoch: 6 , time: 1.49s
loss (mse): 0.4284 epoch: 7 , time: 1.67s
loss (mse): 0.4265 epoch: 8 , time: 1.55s
loss (mse): 0.4243 epoch: 9 , time: 1.48s
Best loss: 0.42431482672691345 after training for 9 epochs


In [None]:
print(model.evaluate(X_test, y_test, top_k=32))
print(model.evaluate(X_test, y_test, top_k=16))
print(model.evaluate(X_test, y_test, top_k=8))
print(model.evaluate(X_test, y_test, top_k=4))
print(model.evaluate(X_test, y_test, top_k=1))

In [None]:
model.model[0].head.cross

In [None]:
(
    (model.model[0].head.epoch, model.model[0].head.cross, model.model[0].head.pos), 
    (model.model[0].head.left_child.epoch, model.model[0].head.left_child.cross, model.model[0].head.left_child.pos), 
    (model.model[0].head.right_chile.epoch, model.model[0].head.right_chile.cross, model.model[0].head.right_chile.pos)
)