Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time encoder implementation #22

Open
Dennis-Tsai opened this issue Sep 22, 2022 · 1 comment
Open

Time encoder implementation #22

Dennis-Tsai opened this issue Sep 22, 2022 · 1 comment

Comments

@Dennis-Tsai
Copy link

Hi, I really like your work in dealing with multi-agent trajectories prediction. I went through the paper and codes and popped up a quick question about the time encoder. As you mentioned in the paper, the time encoder that integrated the timestamp features differs from the original positional encoder. But I cannot find the time encoder codes in this repo. Please let me know if I missed anything. Much appreciated!
Screen Shot 2022-09-21 at 5 42 08 PM

@PFery4
Copy link

PFery4 commented Jul 11, 2023

The temporal encoding is managed by the PositionalAgentEncoding class in the agentformer.py file:

class PositionalAgentEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_t_len=200, max_a_len=200, concat=False, use_agent_enc=False, agent_enc_learn=False):
super(PositionalAgentEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
self.concat = concat
self.d_model = d_model
self.use_agent_enc = use_agent_enc
if concat:
self.fc = nn.Linear((3 if use_agent_enc else 2) * d_model, d_model)
pe = self.build_pos_enc(max_t_len)
self.register_buffer('pe', pe)
if use_agent_enc:
if agent_enc_learn:
self.ae = nn.Parameter(torch.randn(max_a_len, 1, d_model) * 0.1)
else:
ae = self.build_pos_enc(max_a_len)
self.register_buffer('ae', ae)
def build_pos_enc(self, max_len):
pe = torch.zeros(max_len, self.d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-np.log(10000.0) / self.d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
return pe
def build_agent_enc(self, max_len):
ae = torch.zeros(max_len, self.d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-np.log(10000.0) / self.d_model))
ae[:, 0::2] = torch.sin(position * div_term)
ae[:, 1::2] = torch.cos(position * div_term)
ae = ae.unsqueeze(0).transpose(0, 1)
return ae
def get_pos_enc(self, num_t, num_a, t_offset):
pe = self.pe[t_offset: num_t + t_offset, :]
pe = pe.repeat_interleave(num_a, dim=0)
return pe
def get_agent_enc(self, num_t, num_a, a_offset, agent_enc_shuffle):
if agent_enc_shuffle is None:
ae = self.ae[a_offset: num_a + a_offset, :]
else:
ae = self.ae[agent_enc_shuffle]
ae = ae.repeat(num_t, 1, 1)
return ae
def forward(self, x, num_a, agent_enc_shuffle=None, t_offset=0, a_offset=0):
num_t = x.shape[0] // num_a
pos_enc = self.get_pos_enc(num_t, num_a, t_offset)
if self.use_agent_enc:
agent_enc = self.get_agent_enc(num_t, num_a, a_offset, agent_enc_shuffle)
if self.concat:
feat = [x, pos_enc.repeat(1, x.size(1), 1)]
if self.use_agent_enc:
feat.append(agent_enc.repeat(1, x.size(1), 1))
x = torch.cat(feat, dim=-1)
x = self.fc(x)
else:
x += pos_enc
if self.use_agent_enc:
x += agent_enc
return self.dropout(x)

Hope this helps!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants