Skip to content

Commit

Permalink
Merge pull request #322 from Derek-Wds/bug
Browse files Browse the repository at this point in the history
Fix pytorch ts model loader bug
  • Loading branch information
you-n-g committed Mar 10, 2021
2 parents e2817ab + 119fe90 commit 0054a4d
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 8 deletions.
8 changes: 6 additions & 2 deletions qlib/contrib/model/pytorch_alstm_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,12 @@ def fit(
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader

train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
train_loader = DataLoader(
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
)
valid_loader = DataLoader(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
)

if save_path == None:
save_path = create_save_path(save_path)
Expand Down
4 changes: 2 additions & 2 deletions qlib/contrib/model/pytorch_gats_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ def fit(
sampler_train = DailyBatchSampler(dl_train)
sampler_valid = DailyBatchSampler(dl_valid)

train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs)
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs)
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs, drop_last=True)
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs, drop_last=True)

if save_path == None:
save_path = create_save_path(save_path)
Expand Down
8 changes: 6 additions & 2 deletions qlib/contrib/model/pytorch_gru_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,12 @@ def fit(
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader

train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
train_loader = DataLoader(
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
)
valid_loader = DataLoader(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
)

if save_path == None:
save_path = create_save_path(save_path)
Expand Down
8 changes: 6 additions & 2 deletions qlib/contrib/model/pytorch_lstm_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,12 @@ def fit(
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader

train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
train_loader = DataLoader(
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
)
valid_loader = DataLoader(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
)

if save_path == None:
save_path = create_save_path(save_path)
Expand Down

0 comments on commit 0054a4d

Please sign in to comment.