Skip to content

Commit

Permalink
simplify record tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
you-n-g committed Nov 5, 2021
1 parent 4f2d6b0 commit 361b671
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 63 deletions.
7 changes: 3 additions & 4 deletions qlib/contrib/workflow/record_temp.py
Expand Up @@ -49,7 +49,7 @@ def generate(self, segments: Dict[Text, Any], save: bool = False):

if save:
save_name = "results-{:}.pkl".format(key)
self.recorder.save_objects(**{save_name: results})
self.save(**{save_name: results})
logger.info(
"The record '{:}' has been saved as the artifact of the Experiment {:}".format(
save_name, self.recorder.experiment_id
Expand Down Expand Up @@ -79,9 +79,8 @@ def generate(self):
metrics = {"MSE": mse, "RMSE": np.sqrt(mse)}
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
self.save(**objects)
logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics))

def list(self):
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
return paths
return ["mse.pkl", "rmse.pkl"]
139 changes: 80 additions & 59 deletions qlib/workflow/record_temp.py
Expand Up @@ -9,6 +9,9 @@
from pathlib import Path
from pprint import pprint
from typing import Union, List
from collections import defaultdict

from qlib.utils.exceptions import LoadObjectError
from ..contrib.evaluate import indicator_analysis, risk_analysis, indicator_analysis

from ..data.dataset import DatasetH
Expand Down Expand Up @@ -45,6 +48,16 @@ def get_path(cls, path=None):

return "/".join(names)

def save(self, **kwargs):
"""
It behaves the same as self.recorder.save_objects.
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
"""
art_path = self.get_path()
if art_path == "":
art_path = None
self.recorder.save_objects(artifact_path=art_path, **kwargs)

def __init__(self, recorder):
self._recorder = recorder

Expand All @@ -67,39 +80,45 @@ def generate(self, **kwargs):
"""
raise NotImplementedError(f"Please implement the `generate` method.")

def load(self, name):
def load(self, name: str, parents: bool = True):
"""
Load the stored records. Due to the fact that some problems occured when we tried to balancing a clean API
with the Python's inheritance. This method has to be used in a rather ugly way, and we will try to fix them
in the future::
sar = SigAnaRecord(recorder)
ic = sar.load(sar.get_path("ic.pkl"))
It behaves the same as self.recorder.load_object.
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
Parameters
----------
name : str
the name for the file to be load.
parents : bool
Each recorder has different `artifact_path`.
So parents recursively find the path in parents
Sub classes has higher priority
Return
------
The stored records.
"""
# try to load the saved object
obj = self.recorder.load_object(name)
return obj
try:
return self.recorder.load_object(self.get_path(name))
except LoadObjectError:
if parents:
if self.depend_cls is not None:
with class_casting(self, self.depend_cls):
return self.load(name, parents=True)

def list(self):
"""
List the supported artifacts.
Users don't have to consider self.get_path
Return
------
A list of all the supported artifacts.
"""
return []

def check(self, include_self: bool = False):
def check(self, include_self: bool = False, parents: bool = True):
"""
Check if the records is properly generated and saved.
It is useful in following examples
Expand All @@ -110,19 +129,34 @@ def check(self, include_self: bool = False):
----------
include_self : bool
is the file generated by self included
parents : bool
will we check parents
Raise
------
FileExistsError: whether the records are stored properly.
FileNotFoundError
: whether the records are stored properly.
"""
artifacts = set(self.recorder.list_artifacts())
if include_self:

# Some mlflow backend will not list the directly recursively.
# So we force to the directly
artifacts = {}

def _get_arts(dirn):
if dirn not in artifacts:
artifacts[dirn] = self.recorder.list_artifacts(dirn)
return artifacts[dirn]

for item in self.list():
if item not in artifacts:
raise FileExistsError(item)
if self.depend_cls is not None:
with class_casting(self, self.depend_cls):
self.check(include_self=True)
ps = self.get_path(item).split("/")
dirn, fn = "/".join(ps[:-1]), ps[-1]
if self.get_path(item) not in _get_arts(dirn):
raise FileNotFoundError
if parents:
if self.depend_cls is not None:
with class_casting(self, self.depend_cls):
self.check(include_self=True)


class SignalRecord(RecordTemp):
Expand Down Expand Up @@ -158,7 +192,7 @@ def generate(self, **kwargs):
pred = self.model.predict(self.dataset)
if isinstance(pred, pd.Series):
pred = pred.to_frame("score")
self.recorder.save_objects(**{"pred.pkl": pred})
self.save(**{"pred.pkl": pred})

logger.info(
f"Signal record 'pred.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
Expand All @@ -169,15 +203,11 @@ def generate(self, **kwargs):

if isinstance(self.dataset, DatasetH):
raw_label = self.generate_label(self.dataset)
self.recorder.save_objects(**{"label.pkl": raw_label})
self.save(**{"label.pkl": raw_label})

@staticmethod
def list():
def list(self):
return ["pred.pkl", "label.pkl"]

def load(self, name="pred.pkl"):
return super().load(name)


class HFSignalRecord(SignalRecord):
"""
Expand Down Expand Up @@ -218,19 +248,11 @@ def generate(self):
}
)
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
self.save(**objects)
pprint(metrics)

def list(self):
paths = [
self.get_path("ic.pkl"),
self.get_path("ric.pkl"),
self.get_path("long_pre.pkl"),
self.get_path("short_pre.pkl"),
self.get_path("long_short_r.pkl"),
self.get_path("long_avg_r.pkl"),
]
return paths
return ["ic.pkl", "ric.pkl", "long_pre.pkl", "short_pre.pkl", "long_short_r.pkl", "long_avg_r.pkl"]


class SigAnaRecord(RecordTemp):
Expand All @@ -241,13 +263,23 @@ class SigAnaRecord(RecordTemp):
artifact_path = "sig_analysis"
depend_cls = SignalRecord

def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0):
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, skip_existing=False):
super().__init__(recorder=recorder)
self.ana_long_short = ana_long_short
self.ann_scaler = ann_scaler
self.label_col = label_col
self.skip_existing = skip_existing

def generate(self, **kwargs):
if self.skip_existing:
try:
self.check(include_self=True, parents=False)
except FileNotFoundError:
pass # continue to generating metrics
else:
logger.info("The results has previously generated, generation skipped.")
return

self.check()

pred = self.load("pred.pkl")
Expand Down Expand Up @@ -280,13 +312,13 @@ def generate(self, **kwargs):
}
)
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
self.save(**objects)
pprint(metrics)

def list(self):
paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl")]
paths = ["ic.pkl", "ric.pkl"]
if self.ana_long_short:
paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")])
paths.extend(["long_short_r.pkl", "long_avg_r.pkl"])
return paths


Expand Down Expand Up @@ -373,17 +405,11 @@ def generate(self, **kwargs):
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
)
for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
self.recorder.save_objects(
**{f"report_normal_{_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()
)
self.recorder.save_objects(
**{f"positions_normal_{_freq}.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"report_normal_{_freq}.pkl": report_normal})
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})

for _freq, indicators_normal in indicator_dict.items():
self.recorder.save_objects(
**{f"indicators_normal_{_freq}.pkl": indicators_normal}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal})

for _analysis_freq in self.risk_analysis_freq:
if _analysis_freq not in portfolio_metric_dict:
Expand All @@ -405,9 +431,7 @@ def generate(self, **kwargs):
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
# save results
self.recorder.save_objects(
**{f"port_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df})
logger.info(
f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
Expand All @@ -432,9 +456,7 @@ def generate(self, **kwargs):
analysis_dict = analysis_df["value"].to_dict()
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
# save results
self.recorder.save_objects(
**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
logger.info(
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
Expand All @@ -446,20 +468,19 @@ def list(self):
for _freq in self.all_freq:
list_path.extend(
[
PortAnaRecord.get_path(f"report_normal_{_freq}.pkl"),
PortAnaRecord.get_path(f"positions_normal_{_freq}.pkl"),
f"report_normal_{_freq}.pkl",
f"positions_normal_{_freq}.pkl",
]
)
for _analysis_freq in self.risk_analysis_freq:
if _analysis_freq in self.all_freq:
list_path.append(PortAnaRecord.get_path(f"port_analysis_{_analysis_freq}.pkl"))
list_path.append(f"port_analysis_{_analysis_freq}.pkl")
else:
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")

for _analysis_freq in self.indicator_analysis_freq:
if _analysis_freq in self.all_freq:
list_path.append(PortAnaRecord.get_path(f"indicator_analysis_{_analysis_freq}.pkl"))
list_path.append(f"indicator_analysis_{_analysis_freq}.pkl")
else:
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")

return list_path

0 comments on commit 361b671

Please sign in to comment.