In [2]:
from sklearn.cross_decomposition import PLSRegression
from sklearn.metrics import mean_squared_error
import numpy as np
from importlib import reload
from weighted_pls import weighted_pls

In [3]:
# データ生成
np.random.seed(42)
N = 100
X = np.random.rand(N, 5)
Y = 3 * X[:, 0:1] + 2 * X[:, 1:2] + np.random.rand(N, 1) * 0.1
weights = np.random.randint(1, 5, size=N)  # 重みを整数にすることで繰り返し数として利用可能

# データを重みに応じて複製
X_repeated = np.repeat(X, weights, axis=0)
Y_repeated = np.repeat(Y, weights, axis=0)


# 重みなしPLS
pls = PLSRegression(n_components=2)
pls.fit(X, Y)
Y_pred = pls.predict(X)

# weight=1のWPLS
wpls_1 = weighted_pls.WeightedPLSRegression(n_components=2)
wpls_1.fit(X, Y, sample_weight=np.ones(N))
Y_pred_wpls_1 = wpls_1.predict(X)

# データ複製PLSで学習
pls_repeated = PLSRegression(n_components=2)
pls_repeated.fit(X_repeated, Y_repeated)
Y_pred_repeated = pls_repeated.predict(X)

# WeightedPLSで学習
wpls = weighted_pls.WeightedPLSRegression(n_components=2)
wpls.fit(X, Y, sample_weight=weights)
Y_pred_wpls = wpls.predict(X)


# 結果比較
print("PLS MSE:", mean_squared_error(Y, Y_pred))
print("WPLS (w=1)  MSE:", mean_squared_error(Y, Y_pred_wpls_1))
print("Repeated PLS MSE:", mean_squared_error(Y, Y_pred_repeated))
print("WPLS MSE:", mean_squared_error(Y, Y_pred_wpls))

# 比較
print("PLS weights:", np.unique(pls.x_weights_))
print("WPLS (w=1) weights:", np.unique(wpls_1.x_weights_))
print("Repeated PLS weights:", np.unique(pls_repeated.x_weights_))
print("WPLS weights:", np.unique(wpls.x_weights_))

# 係数
print("PLS coef:", pls.coef_.flatten())
print("WPLS (w=1) coef:", wpls_1.coef_.flatten())
print("Repeated PLS coef:", pls_repeated.coef_.flatten())
print("WPLS coef:", wpls.coef_.flatten())

# 切片
print("PLS intercept:", pls.intercept_)
print("WPLS (w=1) intercept:", wpls_1.intercept_)
print("Repeated PLS intercept:", pls_repeated.intercept_)
print("WPLS intercept:", wpls.intercept_)

PLS MSE: 0.00310952041790467
WPLS (w=1)  MSE: 0.003109520417904671
Repeated PLS MSE: 0.0032895132103785947
WPLS MSE: 0.003289513210378584
PLS weights: [-0.19828048 -0.13498799 -0.12780189 -0.10066264  0.05942422  0.48921255
  0.48929136  0.50946201  0.67206504  0.8500045 ]
WPLS (w=1) weights: [-0.19828048 -0.13498799 -0.12780189 -0.10066264  0.05942422  0.48921255
  0.48929136  0.50946201  0.67206504  0.8500045 ]
Repeated PLS weights: [-0.53090312 -0.1922086  -0.06870831 -0.02859688 -0.0039046   0.14949663
  0.46006727  0.46896942  0.70247668  0.85336214]
WPLS weights: [-0.53090312 -0.1922086  -0.06870831 -0.02859688 -0.0039046   0.14949663
  0.46006727  0.46896942  0.70247668  0.85336214]
PLS coef: [ 2.91673758e+00  2.04251401e+00  1.78646685e-03  7.68234000e-02
 -1.40962257e-01]
WPLS (w=1) coef: [ 2.91673758e+00  2.04251401e+00  1.78646685e-03  7.68234000e-02
 -1.40962257e-01]
Repeated PLS coef: [ 2.92722639  2.03307673 -0.06551775  0.0650725  -0.15295548]
WPLS coef: [ 2.92722639  2.