-
Notifications
You must be signed in to change notification settings - Fork 424
/
decision_transformer.py
137 lines (108 loc) · 5.71 KB
/
decision_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import numpy as np
import torch
import torch.nn as nn
import transformers
from decision_transformer.models.model import TrajectoryModel
from decision_transformer.models.trajectory_gpt2 import GPT2Model
class DecisionTransformer(TrajectoryModel):
"""
This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...)
"""
def __init__(
self,
state_dim,
act_dim,
hidden_size,
max_length=None,
max_ep_len=4096,
action_tanh=True,
**kwargs
):
super().__init__(state_dim, act_dim, max_length=max_length)
self.hidden_size = hidden_size
config = transformers.GPT2Config(
vocab_size=1, # doesn't matter -- we don't use the vocab
n_embd=hidden_size,
**kwargs
)
self.transformer = GPT2Model(config)
self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
self.embed_return = torch.nn.Linear(1, hidden_size)
self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)
self.embed_ln = nn.LayerNorm(hidden_size)
# note: we don't predict states or returns for the paper
self.predict_state = torch.nn.Linear(hidden_size, self.state_dim)
self.predict_action = nn.Sequential(
*([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh else []))
)
self.predict_return = torch.nn.Linear(hidden_size, 1)
def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None):
batch_size, seq_length = states.shape[0], states.shape[1]
if attention_mask is None:
# attention mask for GPT: 1 if can be attended to, 0 if not
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
# embed each modality with a different head
state_embeddings = self.embed_state(states)
action_embeddings = self.embed_action(actions)
returns_embeddings = self.embed_return(returns_to_go)
time_embeddings = self.embed_timestep(timesteps)
# time embeddings are treated similar to positional embeddings
state_embeddings = state_embeddings + time_embeddings
action_embeddings = action_embeddings + time_embeddings
returns_embeddings = returns_embeddings + time_embeddings
# this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
# which works nice in an autoregressive sense since states predict actions
stacked_inputs = torch.stack(
(returns_embeddings, state_embeddings, action_embeddings), dim=1
).permute(0, 2, 1, 3).reshape(batch_size, 3*seq_length, self.hidden_size)
stacked_inputs = self.embed_ln(stacked_inputs)
# to make the attention mask fit the stacked inputs, have to stack it as well
stacked_attention_mask = torch.stack(
(attention_mask, attention_mask, attention_mask), dim=1
).permute(0, 2, 1).reshape(batch_size, 3*seq_length)
# we feed in the input embeddings (not word indices as in NLP) to the model
transformer_outputs = self.transformer(
inputs_embeds=stacked_inputs,
attention_mask=stacked_attention_mask,
)
x = transformer_outputs['last_hidden_state']
# reshape x so that the second dimension corresponds to
# predicting returns (0), actions (1), or states (2)
x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)
# get predictions
return_preds = self.predict_return(x[:,0])
state_preds = self.predict_state(x[:,2])
action_preds = self.predict_action(x[:,1])
return state_preds, action_preds, return_preds
def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwargs):
# we don't care about the past rewards in this model
states = states.reshape(1, -1, self.state_dim)
actions = actions.reshape(1, -1, self.act_dim)
returns_to_go = returns_to_go.reshape(1, -1, 1)
timesteps = timesteps.reshape(1, -1)
if self.max_length is not None:
states = states[:,-self.max_length:]
actions = actions[:,-self.max_length:]
returns_to_go = returns_to_go[:,-self.max_length:]
timesteps = timesteps[:,-self.max_length:]
# padding
attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])])
attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
states = torch.cat(
[torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states],
dim=1).to(dtype=torch.float32)
returns_to_go = torch.cat(
[torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go],
dim=1).to(dtype=torch.float32)
timesteps = torch.cat(
[torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps],
dim=1
).to(dtype=torch.long)
actions = torch.cat(
[torch.zeros((actions.shape[0], self.max_length-actions.shape[1], self.act_dim), device=actions.device), actions],
dim=1).to(dtype=torch.float32)
else:
attention_mask = None
_, action_preds, return_preds = self.forward(states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs)
return action_preds[0,-1]