-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
Appreciate your splendid work.
However, the MLP and forward func in DRPO seems are not correctly implemented.
In model/drpo.py, the MLP is defined as:
self.MLP = torch.nn.Sequential(
torch.nn.Linear(self.hidden_size * self.stock_num,
self.hidden_size * 7, bias=self.bias),
torch.nn.SiLU(),
torch.nn.Dropout(self.dropout),
torch.nn.Linear(self.hidden_size * 7, 128, bias=self.bias),
torch.nn.SiLU(),
torch.nn.Dropout(self.dropout),
torch.nn.Linear(128, 64, bias=self.bias),
torch.nn.SiLU(),
torch.nn.Dropout(self.dropout),
torch.nn.Linear(64, 1, bias=self.bias),
)The forward func in the same file is:
# inputs = self.faltten(inputs) # bs, 140, 1470
for i in range(ts): # each timestep
input = inputs[:, i:i+1, :] # bs, 30, 49
input = input.reshape(-1, fn)
input_temp = torch.concat([input, obs_omega.reshape(-1,1)], dim=1)
input_temp = input_temp.reshape(bs* sn, 1, -1)
output, (hx, cx) = self.LSTM(input_temp, (hx, cx))
output = output.reshape(bs, -1)
output = self.MLP(output) # bs, 1
# output = output - torch.mean(output,dim = -1,keepdim=True)
# output = torch.sigmoid(output)
out_time_stock.append(output) # (batchsize,stocknum)
''' Step 2: Calculate Next State '''
with torch.no_grad():
this_state = state[:, i, :] # (bs, 30)
output = output - torch.mean(output,dim = -1,keepdim=True) # not (bs, 30), actual (bs, 1)
# output is all zero here
next_state = this_state + output # change to next state
# check next hold > 0
The MLP output shape is (bs, 1), the demean operation on dimension 1 makes output all zero. Hence the state will never change. What should be the correct implementation? Thanks
Metadata
Metadata
Assignees
Labels
No labels