Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow workflow_by_code.py to train LSTM model #317

Closed
JessieJie1998 opened this issue Mar 9, 2021 · 2 comments
Closed

Follow workflow_by_code.py to train LSTM model #317

JessieJie1998 opened this issue Mar 9, 2021 · 2 comments
Assignees
Labels
question Further information is requested

Comments

@JessieJie1998
Copy link

I replaced the LightGBM model with LSTM in workflow_by_code.py but had trouble when running the code. The error is "Index Error: too many indices for tensor of dimension 0". Can someone help me with workflow_by_code-LSTM.py
I can run 'qrun benchmarks/LSTM/workflow_confid_lstm_Alpha158.yml' successfully.

import sys, site
from pathlib import Path


try:
    import qlib
except ImportError:
    # install qlib
    get_ipython().system(' pip install pyqlib')
    # reload
    site.main()

scripts_dir = Path.cwd().parent.joinpath("scripts")
if not scripts_dir.joinpath("get_data.py").exists():
    # download get_data.py script
    scripts_dir = Path("~/tmp/qlib_code/scripts").expanduser().resolve()
    scripts_dir.mkdir(parents=True, exist_ok=True)
    import requests
    with requests.get("https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py") as resp:
        with open(scripts_dir.joinpath("get_data.py"), "wb") as fp:
            fp.write(resp.content)



import qlib
import sys
print(qlib.__file__)
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.pytorch_lstm_ts import LSTM
from qlib.contrib.data.handler import Alpha158, AlphaJessie
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
    backtest as normal_backtest,
    risk_analysis,
)
from qlib.utils import exists_qlib_data, init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.utils import flatten_dict


# use default data
# NOTE: need to download data from remote: python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
provider_uri = "~/.qlib/qlib_data/cn_data"  # target_dir
if not exists_qlib_data(provider_uri):
    print(f"Qlib data is not found in {provider_uri}")
    sys.path.append(str(scripts_dir))
    from get_data import GetData
    GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN, dataset_cache=None)




market = "csi100"
benchmark = "SH000300"


# # train model


###################################
# train model
###################################
data_handler_config = {
    "start_time": "2008-01-01",
    "end_time": "2020-08-01",
    "fit_start_time": "2008-01-01",
    "fit_end_time": "2014-12-31",
    "instruments": market,
    "infer_processors": [
        {
            "class" : "FilterCol", 
            "kwargs": { "fields_group": "feature", 
                           "col_list": ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", 
                                        "ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5", 
                                        "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"],
                          },
        },
        {
             "class" : "RobustZScoreNorm", 
             "kwargs":{"fields_group": "feature",
                          "clip_outlier": True,
                      },
        },
        {
            "class": "Fillna",
            "kwargs": {
                "fields_group": "feature",
            }
        }
    ],
    "learn_processors": [
        {
            "class" : "DropnaLabel", 
        },
        {
            "class" : "CSRankNorm", 
            "kwargs":{"fields_group": "label",},
        },
    ],
    "label": ["Ref($close, -2) / Ref($close, -1) - 1"] ,
}

task = {
    "model": {
        "class": "LSTM",
        "module_path": "qlib.contrib.model.pytorch_lstm_ts",
        "kwargs": {
            "d_feat": 20,
            "hidden_size": 64,
            "num_layers": 2,
            "dropout": 0.0,
            "n_epochs": 200,
            "lr": 1e-3,
            "early_stop": 10,
            "batch_size": 800,
            "metric": "loss",
            "loss": "mse",
            "n_jobs": 20,
            "GPU": 0,
            "rnn_type": "GRU",
        },
    },
    "dataset": {
        "class": "TSDatasetH",
        "module_path": "qlib.data.dataset",
        "kwargs": {
            "handler": {
                "class": "Alpha158",
                "module_path": "qlib.contrib.data.handler",
                "kwargs": data_handler_config,
            },
            "segments": {
                "train": ("2008-01-01", "2014-12-31"),
                "valid": ("2015-01-01", "2016-12-31"),
                "test": ("2017-01-01", "2020-08-01"),
            },
            "step_len": 20,
        },
    },
}

# model initiaiton
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])

# start exp to train modelg
with R.start(experiment_name="train_model"):
    R.log_params(**flatten_dict(task))
    model.fit(dataset)
    R.save_objects(trained_model=model)
    rid = R.get_recorder().id
@JessieJie1998 JessieJie1998 added the question Further information is requested label Mar 9, 2021
@Derek-Wds
Copy link
Contributor

Hi @JessieJie1998 , this error is caused by the fact that pytorch loader gives the model a batch (the last batch) with only one data sample, which results in error when calculating the loss. You could either fix the bug by adding drop_last=True argument to the train_loader and valid_loader in the file pytorch_lstm_ts.py, or you could wait for the merge of the related PR.

Hope it helps!

@JessieJie1998
Copy link
Author

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants