Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I'm really confused with the v2.0 meta.py #14

Closed
yaox12 opened this issue Dec 19, 2018 · 3 comments
Closed

I'm really confused with the v2.0 meta.py #14

yaox12 opened this issue Dec 19, 2018 · 3 comments

Comments

@yaox12
Copy link

yaox12 commented Dec 19, 2018

  • First, the comment says the index of losses_q is tasks index.

    losses_q = [0 for _ in range(self.update_step + 1)] # losses_q[i], i is tasks idx

    However, in each task i , the whole list is updated.
    losses_q[0] += loss_q

    losses_q[1] += loss_q

    losses_q[k + 1] += loss_q

  • Second, I haven't seen the sum of loss_q?

    MAML-Pytorch/meta.py

    Lines 134 to 135 in fc20b31

    # sum over all losses on query set across all tasks
    loss_q = losses_q[-1] / task_num

    losses_q[-1] seems to be the last step's loss for the last task?

  • Third, if update_step == 1, there will be only one inner update. However, the loss after first update is computed under torch.no_grad(), so I think there is no backward update information on the query set.

    MAML-Pytorch/meta.py

    Lines 100 to 109 in fc20b31

    # this is the loss and accuracy after the first update
    with torch.no_grad():
    # [setsz, nway]
    logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
    loss_q = F.cross_entropy(logits_q, y_qry[i])
    losses_q[1] += loss_q
    # [setsz]
    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
    correct = torch.eq(pred_q, y_qry[i]).sum().item()
    corrects[1] = corrects[1] + correct

@dragen1860
Copy link
Owner

  1. index of losses_q is NOT tasks index, but the update step index. sry for wrong comments.

  2. sum of loss_q is accumulated at every inner update, hence just need to average it.
    losses_q[-1] is the last step's loss.

3.yes, you are right. For only single step's update setting, the code loss_q = F.cross_entropy(logits_q, y_qry[i]) should be moved out of torch.no_grad().

@yaox12
Copy link
Author

yaox12 commented Dec 19, 2018

At the end of each task i, you append loss_q to list losess_q

MAML-Pytorch/meta.py

Lines 130 to 131 in fc20b31

# 4. record last step's loss for task i
losses_q.append(loss_q)

The accumulated last step's loss should be losses_q[self.update_step] instead of losses_q[-1], because the length of losses_q is update_step + 1 + task_num in the end.
In fact, I think the above two lines are redundant and useless.

@dragen1860
Copy link
Owner

Yes, it's a bug!
Thanks for your very helpful insight.
Remove line 130 & 131 !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants