Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

k-fold交叉验证 #493

Open
xx-Jiangwen opened this issue Sep 26, 2022 · 1 comment
Open

k-fold交叉验证 #493

xx-Jiangwen opened this issue Sep 26, 2022 · 1 comment

Comments

@xx-Jiangwen
Copy link

提问时请尽可能提供如下信息:

基本信息

  • 你使用的操作系统: ubuntu
  • 你使用的Python版本: 3.6.13
  • 你使用的Tensorflow版本: nvidia-tensorflow 1.15.4+nv20.10
  • 你使用的Keras版本: Keras 2.3.1
  • 你使用的bert4keras版本: 0.11.3
  • 你使用纯keras还是tf.keras: keras
  • 你加载的预训练模型:wobert

核心代码

# 请在此处贴上你的核心代码。
# 请尽量只保留关键部分,不要无脑贴全部代码。
    num_splits = 2
    kf = KFold(n_splits=num_splits, shuffle=True, random_state=2022)
    fold = 0
    for fold,(train_index, val_index) in enumerate( kf.split(data)):
        fold += 1
        print("="*80)
        print(f"正在训练第 {fold} 折的数据")

        
        # 划分训练集和验证集
        train_data = [data[i] for i in train_index]
        valid_data = [data[i] for i in val_index]

        model_savepath = f'../best_model/best_model_fold{fold}.weights'
        model = build_model()    # 构建模型
        train_generator = data_generator(train_data, batch_size)
        evaluator = Evaluator(valid_data,model_savepath,model)

        model.fit(
            train_generator.forfit(),
            steps_per_epoch=len(train_generator),
            epochs=epochs,
            callbacks=[evaluator]
            )

        do_predict(model_savepath,fold,model)

        del model, train_data, valid_data
        K.clear_session()
        gc.collect()

输出信息

# 请在此处贴上你的调试输出
ValueError: Tensor("Cast:0", shape=(), dtype=float32) must be from the same graph as Tensor("loss/efficient_global_pointer_loss/strided_slice:0", shape=(?, ?, ?), dtype=float32).

自我尝试

不管什么问题,请先尝试自行解决,“万般努力”之下仍然无法解决再来提问。此处请贴上你的努力过程。
想进行交叉验证效果,为防止上一步模型的缓存造成数据泄露,需要清除掉,使用clear_session(),但一直报上面的错误,然后网上搜集相关问题,还是无法解决

@xx-Jiangwen
Copy link
Author

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant