diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index b0770e2bdd..c95caa76b4 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -218,8 +218,8 @@ def fit( save_path=None, ): - df_train, df_valid, df_test = dataset.prepare( - ["train", "valid", "test"], + df_train, df_valid = dataset.prepare( + ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L, )