<a href="https://colab.research.google.com/github/imai3/half_life/blob/main/%E5%8D%8A%E6%B8%9B%E6%9C%9F%E5%8A%A0%E9%87%8D%EF%BC%BF%E3%82%AB%E3%83%AA%E3%83%95%E3%82%A9%E3%83%AB%E3%83%8B%E3%82%A2_jit_self%E3%81%A8%E3%81%AE%E6%AF%94%E8%BC%83.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax.numpy as jnp
from jax import jit as jjit
from functools import partial
from jax import vmap
from jax.lax import cond, while_loop

############自作モデルの実装##########

#誤差を求める関数
@partial(jjit)
def rmse(y, p):
    return jnp.sqrt(((y - p)**2).mean())

#予測する関数(１行)
@partial(jjit)
def predict(row, model_x, model_y, T,replace_value):
    t = jnp.abs(model_x - row)
    w = (1/2)**(t / T)
    w = jnp.prod(w, axis=1)
    p = jnp.average(model_y, weights=w)
    p = jnp.nan_to_num(p, replace_value)
    return p

#予測する関数（全行）
@partial(jjit)
def predict_array(x, model_x, model_y, T,replace_value):
    return vmap(lambda row : predict(row, model_x, model_y, T, replace_value))(x)

#学習関数modelとtunerを引数で設定する必要がある（ランダムな要素を排除）
@partial(jjit)
def fit(T, model_x, model_y, tuner_x, tuner_y, replace_value, rate = 0.7):
    #初期状態の予測値・誤差・半減期
    same_cnt = 0 #精度が改善できなかった連続回数
    tuner_p = predict_array(tuner_x, model_x, model_y, T, replace_value)
    err = rmse(tuner_y, tuner_p)
    i = 0
    row,col = model_x.shape
    #半減期と学習率の更新
    def update_loop(params):
        [i, err, T, same_cnt] = params
        err_ex = err
        T_ = T.at[i%col].set(T[i%col] * rate)
        tuner_p = predict_array(tuner_x, model_x, model_y, T_, replace_value)
        err_ = rmse(tuner_y, tuner_p)
        err, T = cond(err_ < err, lambda: [err_, T_], lambda: [err, T])
        same_cnt =  cond(err_ex == err, lambda: same_cnt+1, lambda: 0)
        i+=1
        params = [i, err, T, same_cnt]
        return params
    params = [i, err, T, same_cnt]
    params = while_loop(lambda param: param[3] < col, update_loop, params)
    [i, err, T, same_cnt] = params
    return {"T": T, "err": err, "try_cnt": i+1}

##########検証##########
#カリフォルニア住宅価格データをロード
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_california_housing
from sklearn.datasets import load_boston
housing = fetch_california_housing()

#訓練データと検証データに分割
x = housing.data
y = housing.target
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size = 0.2)


#half_lifeのパラメータ設定
model_x, tuner_x, model_y, tuner_y = train_test_split(X_train, y_train, test_size = 0.5, train_size=0.5)
T_first = jnp.max(model_x, axis=0) - jnp.min(model_x, axis=0)
replace_value = jnp.mean(model_y)

#学習
print("学習")
result = fit( T_first, model_x, model_y, tuner_x, tuner_y, replace_value)
print(result)
print(result["try_cnt"]/model_x.shape[1])
T = result["T"]

#予測
p_test = predict_array(X_test,X_train, y_train, T, replace_value)

#精度
print("精度")
print(rmse(y_test,p_test))

学習
{'T': DeviceArray([2.86716312e-01, 2.05803370e+00, 4.68651414e-01,
             1.19336426e-01, 6.50459656e+02, 2.34184235e-01,
             1.06693581e-02, 1.11025255e-02], dtype=float32), 'err': DeviceArray(0.47163045, dtype=float32), 'try_cnt': DeviceArray(183, dtype=int32, weak_type=True)}
22.875
精度
0.43391392


In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import jax.numpy as jnp
from jax import jit as jjit
from functools import partial
from jax import vmap
from jax.lax import cond, while_loop

############自作モデルの実装##########

#誤差を求める関数
@partial(jjit)
def rmse(y, p):
    return jnp.sqrt(((y - p)**2).mean())

#予測する関数(１行)
@partial(jjit)
def predict(row, model_x, model_y, T,replace_value):
    t = jnp.abs(model_x - row)
    w = (1/2)**(t / T)
    w = jnp.prod(w, axis=1)
    p = jnp.sum(w * model_y) / jnp.sum(w)
    p = jnp.nan_to_num(p, replace_value)
    return p

#予測する関数（全行）
@partial(jjit)
def predict_array(x, model_x, model_y, T,replace_value):
    return vmap(lambda row : predict(row, model_x, model_y, T, replace_value))(x)

@partial(jjit)
def predict_self(row_x, row_y, model_x, model_y, T,replace_value):
    t = jnp.abs(model_x - row_x)
    w = (1/2)**(t / T)
    w = jnp.prod(w, axis=1)
    p = (jnp.sum(w * model_y) - row_y) / (jnp.sum(w) - 1)
    p = jnp.nan_to_num(p, replace_value)
    return p

#予測する関数（全行）
@partial(jjit)
def predict_self_array(x, y, model_x, model_y, T,replace_value):
    return vmap(lambda row_x, row_y : predict_self(row_x, row_y, model_x, model_y, T, replace_value))(x, y)

#学習関数modelとtunerを引数で設定する必要がある（ランダムな要素を排除）
@partial(jjit)
def fit(T, model_x, model_y, replace_value, rate = 0.7):
    #初期状態の予測値・誤差・半減期
    same_cnt = 0 #制度が改善できなかった連続回数
    p = predict_self_array(model_x, model_y, model_x, model_y, T, replace_value)
    err = rmse(model_y, p)
    i = 0
    row,col = model_x.shape
    #半減期と学習率の更新
    def update_loop(params):
        [i, err, T, same_cnt] = params
        err_ex = err
        T_ = T.at[i%col].set(T[i%col] * rate)
        p_ = predict_self_array(model_x, model_y, model_x, model_y, T_, replace_value)
        err_ = rmse(model_y, p_)
        err, T = cond(err_ < err, lambda: [err_, T_], lambda: [err, T])
        same_cnt =  cond(err_ex == err, lambda: same_cnt+1, lambda: 0)
        i+=1
        params = [i, err, T, same_cnt]
        return params
    params = [i, err, T, same_cnt]
    params = while_loop(lambda param: param[3] < col, update_loop, params)
    T = params[2]
    return T

##########検証##########
#カリフォルニア住宅価格データをロード
from sklearn.datasets import fetch_california_housing
from sklearn.datasets import load_boston
housing = fetch_california_housing()

#訓練データと検証データに分割
x = housing.data
y = housing.target
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size = 0.2)


#half_lifeのパラメータ設定
T_first = jnp.max(X_train, axis=0) - jnp.min(X_train, axis=0)
replace_value = jnp.mean(y_train)

#学習
print("学習")
T = fit( T_first, X_train, y_train, replace_value)


#予測
p_test = predict_array(X_test,X_train, y_train, T, replace_value)

#精度
print("精度")
print(rmse(y_test,p_test))
print(T)



学習
精度
0.5251749
[2.8671631e-01 2.0580337e+00 1.2759037e+00 3.2683888e-01 4.9384320e+02
 6.9403970e-01 9.1172539e-02 1.3896650e-01]
