# Running Stock Predictor on Google Colab

* Clone repo and install pyqlib
* Restart runtime after install

In [1]:
!git clone https://github.com/jingedawang/StockPredictor.git

Cloning into 'StockPredictor'...
remote: Enumerating objects: 90, done.[K
remote: Counting objects: 100% (90/90), done.[K
remote: Compressing objects: 100% (71/71), done.[K
remote: Total 90 (delta 26), reused 63 (delta 15), pack-reused 0[K
Unpacking objects: 100% (90/90), done.


In [2]:
!pip install pyqlib

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyqlib
  Downloading pyqlib-0.8.6-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (954 kB)
[K     |████████████████████████████████| 954 kB 7.8 MB/s 
[?25hCollecting matplotlib>=3.3
  Downloading matplotlib-3.5.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (11.2 MB)
[K     |████████████████████████████████| 11.2 MB 39.5 MB/s 
Collecting sacred>=0.7.4
  Downloading sacred-0.8.2-py2.py3-none-any.whl (106 kB)
[K     |████████████████████████████████| 106 kB 43.4 MB/s 
Collecting redis>=3.0.1
  Downloading redis-4.3.4-py3-none-any.whl (246 kB)
[K     |████████████████████████████████| 246 kB 34.3 MB/s 
[?25hCollecting ruamel.yaml>=0.16.12
  Downloading ruamel.yaml-0.17.21-py3-none-any.whl (109 kB)
[K     |████████████████████████████████| 109 kB 47.4 MB/s 
[?25hCollecting lightgbm>=3.3.0
  Downloading lightgbm-3.3.2

In [3]:
import qlib
from qlib.constant import REG_CN
from qlib.data.dataset import DatasetH
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.tests.data import GetData
from qlib.tests.config import CSI300_GBDT_TASK

from StockPredictor.algorithm.stock_predictor.data_handler import Alpha158TwoWeeks

import pickle

In [4]:
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data"  # target_dir
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)

2022-08-31 22:05:20.827 | INFO     | qlib.tests.data:_download_data:59 - qlib_data_cn_1d_latest.zip downloading......
196549632it [00:11, 17131994.25it/s]                               
2022-08-31 22:05:32.320 | INFO     | qlib.tests.data:_unzip:85 - /root/.qlib/qlib_data/cn_data/20220831220520_qlib_data_cn_1d_latest.zip unzipping......
100%|██████████| 31008/31008 [00:11<00:00, 2620.56it/s]
[56:MainThread](2022-08-31 22:05:44,651) INFO - qlib.Initialization - [config.py:413] - default_conf: client.
INFO:qlib.Initialization:default_conf: client.
[56:MainThread](2022-08-31 22:05:44,669) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.
INFO:qlib.Initialization:qlib successfully initialized based on client settings.
[56:MainThread](2022-08-31 22:05:44,680) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': PosixPath('/root/.qlib/qlib_data/cn_data')}
INFO:qlib.Initialization:data_path={'__DEFAULT_FREQ': PosixP

In [5]:
# Load data with our customized data handler.
# The Alpha158TwoWeeks is different with Alpha158 only in the labels.
# TODO: Data is important for model training, we need to try other adjustments to the data handler to acheive better results.
data_handler = Alpha158TwoWeeks(instruments='csi300')
dataset = DatasetH(
          handler=data_handler,
          segments={
            "train": ["2008-01-01", "2014-12-31"],
            "valid": ["2015-01-01", "2016-12-31"],
            "test": ["2017-01-01", "2020-08-01"]
            }
          )

[56:MainThread](2022-08-31 22:12:52,094) INFO - qlib.timer - [log.py:117] - Time cost: 427.390s | Loading data Done
INFO:qlib.timer:Time cost: 427.390s | Loading data Done
[56:MainThread](2022-08-31 22:12:53,014) INFO - qlib.timer - [log.py:117] - Time cost: 0.282s | DropnaLabel Done
INFO:qlib.timer:Time cost: 0.282s | DropnaLabel Done
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self[k1] = value[k2]
[56:MainThread](2022-08-31 22:13:05,252) INFO - qlib.timer - [log.py:117] - Time cost: 12.227s | CSZScoreNorm Done
INFO:qlib.timer:Time cost: 12.227s | CSZScoreNorm Done
[56:MainThread](2022-08-31 22:13:05,265) INFO - qlib.timer - [log.py:117] - Time cost: 13.156s | fit & process data Done
INFO:qlib.timer:Time cost: 13.156s | fit & process data Done
[56:MainThread](2022-08-

In [6]:
# Use GBDT model.
# TODO: Model architecture is also important. We need to try different models to acheive better results.
model = init_instance_by_config(CSI300_GBDT_TASK["model"])

Please install necessary libs for CatBoostModel.


In [7]:
# NOTE: This line is optional.
# Show the prepared training data to make sure we are using the correct data for trainning.
example_df = dataset.prepare("train")
print(example_df.head())

                           KMID      KLEN     KMID2       KUP      KUP2  \
datetime   instrument                                                     
2008-01-02 SH600000    0.010374  0.061129  0.169699  0.028299  0.462937   
           SH600004    0.057280  0.059661  0.960094  0.002381  0.039906   
           SH600006    0.012673  0.040323  0.314283  0.008065  0.200000   
           SH600007    0.066977  0.084186  0.795580  0.007907  0.093923   
           SH600008    0.051163  0.082326  0.621469  0.027907  0.338982   

                           KLOW     KLOW2      KSFT     KSFT2     OPEN0  ...  \
datetime   instrument                                                    ...   
2008-01-02 SH600000    0.022457  0.367364  0.004531  0.074127  0.989733  ...   
           SH600004    0.000000  0.000000  0.054899  0.920187  0.945823  ...   
           SH600006    0.019585  0.485716  0.024193  0.599999  0.987486  ...   
           SH600007    0.009302  0.110497  0.068372  0.812154  0.937227  .

In [17]:
# start experiment.
with R.start(experiment_name="workflow"):
    R.log_params(**flatten_dict(CSI300_GBDT_TASK))
    model.fit(dataset)
    R.save_objects(**{"params.pkl": model})

    pred = model.predict(dataset)
    print('pred', pred)

# TODO: We need do backtest to evaluate our model.

[56:MainThread](2022-08-31 22:24:34,978) INFO - qlib.workflow - [expm.py:315] - <mlflow.tracking.client.MlflowClient object at 0x7f238a748210>
INFO:qlib.workflow:<mlflow.tracking.client.MlflowClient object at 0x7f238a748210>
[56:MainThread](2022-08-31 22:24:35,007) INFO - qlib.workflow - [exp.py:257] - Experiment 1 starts running ...
INFO:qlib.workflow:Experiment 1 starts running ...
[56:MainThread](2022-08-31 22:24:35,192) INFO - qlib.workflow - [recorder.py:293] - Recorder efa352aa654f432b93ed431e9cfd65ee starts running under Experiment 1 ...
INFO:qlib.workflow:Recorder efa352aa654f432b93ed431e9cfd65ee starts running under Experiment 1 ...


Training until validation scores don't improve for 50 rounds
[20]	train's l2: 0.985135	valid's l2: 0.993481
[40]	train's l2: 0.977967	valid's l2: 0.993065
[60]	train's l2: 0.97287	valid's l2: 0.992971
[80]	train's l2: 0.96872	valid's l2: 0.992893
Early stopping, best iteration is:
[49]	train's l2: 0.975708	valid's l2: 0.992768


[56:MainThread](2022-08-31 22:25:48,302) INFO - qlib.timer - [log.py:117] - Time cost: 0.000s | waiting `async_log` Done
INFO:qlib.timer:Time cost: 0.000s | waiting `async_log` Done


pred datetime    instrument
2017-01-03  SH600000      0.002296
            SH600008     -0.026040
            SH600009      0.026351
            SH600010     -0.044511
            SH600015      0.034829
                            ...   
2020-07-31  SZ300413      0.012890
            SZ300433     -0.034948
            SZ300498     -0.033088
            SZ300601     -0.120866
            SZ300628      0.041037
Length: 261300, dtype: float64


In [18]:
pred.head()

datetime    instrument
2017-01-03  SH600000      0.002296
            SH600008     -0.026040
            SH600009      0.026351
            SH600010     -0.044511
            SH600015      0.034829
dtype: float64