Skip to content

Commit

Permalink
Merge pull request #329 from D-X-Y/main
Browse files Browse the repository at this point in the history
Fix Various Bugs for contrib.pytorch_ models
  • Loading branch information
you-n-g committed Mar 12, 2021
2 parents df56e3b + 1d43524 commit 0cffb87
Show file tree
Hide file tree
Showing 15 changed files with 119 additions and 99 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -36,3 +36,5 @@ tags
.vscode/

*.swp

./pretrain
2 changes: 1 addition & 1 deletion examples/benchmarks/README.md
Expand Up @@ -17,6 +17,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |

## Alpha158 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|---|---|---|---|---|---|---|---|---|
Expand All @@ -25,7 +26,6 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
| TabNet with pretrain (Sercan O. Arikm et al) | Alpha158 | 0.0344±0.00|0.205±0.11|0.0398±0.00 |0.3479±0.01|0.0827±0.02|1.1141±0.32 |-0.0925±0.02 |
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
Expand Down
Binary file removed examples/benchmarks/TabNet/pretrain/best.model
Binary file not shown.
Expand Up @@ -55,7 +55,7 @@ task:
kwargs: *data_handler_config
segments:
pretrain: [2008-01-01, 2014-12-31]
pretrain_validation: [2015-01-01, 2020-08-01]
pretrain_validation: [2015-01-01, 2016-12-31]
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
Expand Down
25 changes: 13 additions & 12 deletions qlib/contrib/model/pytorch_alstm.py
Expand Up @@ -78,7 +78,6 @@ def __init__(
self.optimizer = optimizer.lower()
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.use_gpu = torch.cuda.is_available()
self.seed = seed

self.logger.info(
Expand All @@ -94,7 +93,7 @@ def __init__(
"\nearly_stop : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\nvisible_GPU : {}"
"\ndevice : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
d_feat,
Expand All @@ -108,7 +107,7 @@ def __init__(
early_stop,
optimizer.lower(),
loss,
GPU,
self.device,
self.use_gpu,
seed,
)
Expand Down Expand Up @@ -137,6 +136,10 @@ def __init__(
self.fitted = False
self.ALSTM_model.to(self.device)

@property
def use_gpu(self):
return self.device != torch.device("cpu")

def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
Expand Down Expand Up @@ -205,12 +208,13 @@ def test_epoch(self, data_x, data_y):
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)

pred = self.ALSTM_model(feature)
loss = self.loss_fn(pred, label)
losses.append(loss.item())
with torch.no_grad():
pred = self.ALSTM_model(feature)
loss = self.loss_fn(pred, label)
losses.append(loss.item())

score = self.metric_fn(pred, label)
scores.append(score.item())
score = self.metric_fn(pred, label)
scores.append(score.item())

return np.mean(losses), np.mean(scores)

Expand Down Expand Up @@ -292,10 +296,7 @@ def predict(self, dataset):
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)

with torch.no_grad():
if self.use_gpu:
pred = self.ALSTM_model(x_batch).detach().cpu().numpy()
else:
pred = self.ALSTM_model(x_batch).detach().numpy()
pred = self.ALSTM_model(x_batch).detach().cpu().numpy()

preds.append(pred)

Expand Down
25 changes: 13 additions & 12 deletions qlib/contrib/model/pytorch_alstm_ts.py
Expand Up @@ -81,7 +81,6 @@ def __init__(
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.n_jobs = n_jobs
self.use_gpu = torch.cuda.is_available()
self.seed = seed

self.logger.info(
Expand All @@ -97,7 +96,7 @@ def __init__(
"\nearly_stop : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\nvisible_GPU : {}"
"\ndevice : {}"
"\nn_jobs : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
Expand All @@ -112,7 +111,7 @@ def __init__(
early_stop,
optimizer.lower(),
loss,
GPU,
self.device,
n_jobs,
self.use_gpu,
seed,
Expand Down Expand Up @@ -142,6 +141,10 @@ def __init__(
self.fitted = False
self.ALSTM_model.to(self.device)

@property
def use_gpu(self):
return self.device != torch.device("cpu")

def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
Expand Down Expand Up @@ -192,12 +195,13 @@ def test_epoch(self, data_loader):
# feature[torch.isnan(feature)] = 0
label = data[:, -1, -1].to(self.device)

pred = self.ALSTM_model(feature.float())
loss = self.loss_fn(pred, label)
losses.append(loss.item())
with torch.no_grad():
pred = self.ALSTM_model(feature.float())
loss = self.loss_fn(pred, label)
losses.append(loss.item())

score = self.metric_fn(pred, label)
scores.append(score.item())
score = self.metric_fn(pred, label)
scores.append(score.item())

return np.mean(losses), np.mean(scores)

Expand Down Expand Up @@ -277,10 +281,7 @@ def predict(self, dataset):
feature = data[:, :, 0:-1].to(self.device)

with torch.no_grad():
if self.use_gpu:
pred = self.ALSTM_model(feature.float()).detach().cpu().numpy()
else:
pred = self.ALSTM_model(feature.float()).detach().numpy()
pred = self.ALSTM_model(feature.float()).detach().cpu().numpy()

preds.append(pred)

Expand Down
13 changes: 7 additions & 6 deletions qlib/contrib/model/pytorch_gats.py
Expand Up @@ -103,7 +103,7 @@ def __init__(
"\nbase_model : {}"
"\nwith_pretrain : {}"
"\nmodel_path : {}"
"\nvisible_GPU : {}"
"\ndevice : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
d_feat,
Expand All @@ -119,7 +119,7 @@ def __init__(
base_model,
with_pretrain,
model_path,
GPU,
self.device,
self.use_gpu,
seed,
)
Expand Down Expand Up @@ -149,6 +149,10 @@ def __init__(
self.fitted = False
self.GAT_model.to(self.device)

@property
def use_gpu(self):
return self.device != torch.device("cpu")

def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
Expand Down Expand Up @@ -326,10 +330,7 @@ def predict(self, dataset):
x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)

with torch.no_grad():
if self.use_gpu:
pred = self.GAT_model(x_batch).detach().cpu().numpy()
else:
pred = self.GAT_model(x_batch).detach().numpy()
pred = self.GAT_model(x_batch).detach().cpu().numpy()

preds.append(pred)

Expand Down
10 changes: 5 additions & 5 deletions qlib/contrib/model/pytorch_gats_ts.py
Expand Up @@ -107,7 +107,6 @@ def __init__(
self.model_path = model_path
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.n_jobs = n_jobs
self.use_gpu = torch.cuda.is_available()
self.seed = seed

self.logger.info(
Expand Down Expand Up @@ -171,6 +170,10 @@ def __init__(
self.fitted = False
self.GAT_model.to(self.device)

@property
def use_gpu(self):
return self.device != torch.device("cpu")

def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
Expand Down Expand Up @@ -347,10 +350,7 @@ def predict(self, dataset):
feature = data[:, :, 0:-1].to(self.device)

with torch.no_grad():
if self.use_gpu:
pred = self.GAT_model(feature.float()).detach().cpu().numpy()
else:
pred = self.GAT_model(feature.float()).detach().numpy()
pred = self.GAT_model(feature.float()).detach().cpu().numpy()

preds.append(pred)

Expand Down
21 changes: 11 additions & 10 deletions qlib/contrib/model/pytorch_gru.py
Expand Up @@ -78,7 +78,6 @@ def __init__(
self.optimizer = optimizer.lower()
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.use_gpu = torch.cuda.is_available()
self.seed = seed

self.logger.info(
Expand Down Expand Up @@ -137,6 +136,10 @@ def __init__(
self.fitted = False
self.gru_model.to(self.device)

@property
def use_gpu(self):
return self.device != torch.device("cpu")

def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
Expand Down Expand Up @@ -205,12 +208,13 @@ def test_epoch(self, data_x, data_y):
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)

pred = self.gru_model(feature)
loss = self.loss_fn(pred, label)
losses.append(loss.item())
with torch.no_grad():
pred = self.gru_model(feature)
loss = self.loss_fn(pred, label)
losses.append(loss.item())

score = self.metric_fn(pred, label)
scores.append(score.item())
score = self.metric_fn(pred, label)
scores.append(score.item())

return np.mean(losses), np.mean(scores)

Expand Down Expand Up @@ -292,10 +296,7 @@ def predict(self, dataset):
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)

with torch.no_grad():
if self.use_gpu:
pred = self.gru_model(x_batch).detach().cpu().numpy()
else:
pred = self.gru_model(x_batch).detach().numpy()
pred = self.gru_model(x_batch).detach().cpu().numpy()

preds.append(pred)

Expand Down
25 changes: 13 additions & 12 deletions qlib/contrib/model/pytorch_gru_ts.py
Expand Up @@ -81,7 +81,6 @@ def __init__(
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.n_jobs = n_jobs
self.use_gpu = torch.cuda.is_available()
self.seed = seed

self.logger.info(
Expand All @@ -97,7 +96,7 @@ def __init__(
"\nearly_stop : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\nvisible_GPU : {}"
"\ndevice : {}"
"\nn_jobs : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
Expand All @@ -112,7 +111,7 @@ def __init__(
early_stop,
optimizer.lower(),
loss,
GPU,
self.device,
n_jobs,
self.use_gpu,
seed,
Expand Down Expand Up @@ -142,6 +141,10 @@ def __init__(
self.fitted = False
self.GRU_model.to(self.device)

@property
def use_gpu(self):
return self.device != torch.device("cpu")

def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
Expand Down Expand Up @@ -192,12 +195,13 @@ def test_epoch(self, data_loader):
# feature[torch.isnan(feature)] = 0
label = data[:, -1, -1].to(self.device)

pred = self.GRU_model(feature.float())
loss = self.loss_fn(pred, label)
losses.append(loss.item())
with torch.no_grad():
pred = self.GRU_model(feature.float())
loss = self.loss_fn(pred, label)
losses.append(loss.item())

score = self.metric_fn(pred, label)
scores.append(score.item())
score = self.metric_fn(pred, label)
scores.append(score.item())

return np.mean(losses), np.mean(scores)

Expand Down Expand Up @@ -277,10 +281,7 @@ def predict(self, dataset):
feature = data[:, :, 0:-1].to(self.device)

with torch.no_grad():
if self.use_gpu:
pred = self.GRU_model(feature.float()).detach().cpu().numpy()
else:
pred = self.GRU_model(feature.float()).detach().numpy()
pred = self.GRU_model(feature.float()).detach().cpu().numpy()

preds.append(pred)

Expand Down
10 changes: 5 additions & 5 deletions qlib/contrib/model/pytorch_lstm.py
Expand Up @@ -77,7 +77,6 @@ def __init__(
self.optimizer = optimizer.lower()
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.use_gpu = torch.cuda.is_available()
self.seed = seed

self.logger.info(
Expand Down Expand Up @@ -133,6 +132,10 @@ def __init__(
self.fitted = False
self.lstm_model.to(self.device)

@property
def use_gpu(self):
return self.device != torch.device("cpu")

def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
Expand Down Expand Up @@ -288,10 +291,7 @@ def predict(self, dataset):
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)

with torch.no_grad():
if self.use_gpu:
pred = self.lstm_model(x_batch).detach().cpu().numpy()
else:
pred = self.lstm_model(x_batch).detach().numpy()
pred = self.lstm_model(x_batch).detach().cpu().numpy()

preds.append(pred)

Expand Down

0 comments on commit 0cffb87

Please sign in to comment.