You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is an example of how I think the code of decoder.forward should be.
def forward(self, x, *, context_mask = None, retrieved = None):
encoded = False # flag to know if p = min(P) (in the algorithm)
...
if exists(cross_attn) and exists(retrieved):
if not encoded:
...
# use x (H at layer p where p = min(P)), not embed (Emb(X))
x_as_context = repeat(x[:, :seq_index], 'b (k n) d -> (b k r) n d', n = self.chunk_size, r = num_neighbors)
retrieved = self.encoder(retrieved, mask = encoder_retrieved_mask, chunked_seq = x_as_context)
encoded = True
The text was updated successfully, but these errors were encountered:
@soheeyang Hi Sohee again! Yet again you discovered another bug :( I apologize for this and have made a correction in the latest commit (v0.3.0) 4f99e31
Hi, I am currently reading through the code and got confused when I reached this line:
RETRO-pytorch/retro_pytorch/retro_pytorch.py
Line 598 in 92ff287
According to Algorithm 1 in the paper (the screenshot above), doesn't this line need to go inside the decoder, under this line?
RETRO-pytorch/retro_pytorch/retro_pytorch.py
Line 406 in 92ff287
This is an example of how I think the code of
decoder.forward
should be.The text was updated successfully, but these errors were encountered: