Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logical bug in IRTNet.forward #31

Closed
niall-twomey opened this issue Oct 13, 2021 · 3 comments 路 Fixed by #33
Closed

Logical bug in IRTNet.forward #31

niall-twomey opened this issue Oct 13, 2021 · 3 comments 路 Fixed by #33
Assignees

Comments

@niall-twomey
Copy link

niall-twomey commented Oct 13, 2021

馃悰 Description

I was browsing this repo for IRT implementations and found (I think) a theoretical bug in the implementation of IRTNet.

IRTNet.forward is defined here
https://github.com/bigdata-ustc/EduCDM/blob/main/EduCDM/IRT/GD/IRT.py#L30

    def forward(self, user, item):
        theta = torch.squeeze(self.theta(user), dim=-1)
        a = torch.squeeze(self.a(item), dim=-1)
        b = torch.squeeze(self.b(item), dim=-1)
        c = torch.squeeze(self.c(item), dim=-1)
        return torch.sigmoid(self.irf(theta, a, b, c, **self.irf_kwargs))

And the logic is that the output of irf is passed through the sigmoid function. This is fine if the output of irf itself is a "logit".

The IRF function is defined here:
https://github.com/bigdata-ustc/EduCDM/blob/main/EduCDM/IRT/irt.py#L10

def irf(theta, a, b, c, D=1.702, *, F=np):
    return c + (1 - c) / (1 + F.exp(-D * a * (theta - b)))

If you look at this you can see that it is already depicting sigmoid behaviour (assuming, of course, that 0 <= c <= 1). In other words, irf is returning probabilities, and not logits. As a result, the forward function above is actually doing this:

1 / (1 + exp(-(c + (1 - c) / (1 + F.exp(-D * a * (theta - b)))))

which I think is probably a bug.

If I haven't misunderstood, I have two recommendations:

  • Simply remove the torch.sigmoid call from forward
  • (optional) it may be worth passing c through a sigmoid function to ensure it doesn't go negative or above 1. (Perhaps selectable in irf_kwargs?)

i.e.

    def forward(self, user, item):
        theta = torch.squeeze(self.theta(user), dim=-1)
        a = torch.squeeze(self.a(item), dim=-1)
        b = torch.squeeze(self.b(item), dim=-1)
        c = torch.squeeze(self.c(item), dim=-1)
        if self.irf_kwargs.get("squash_c", True):
            c = torch.sigmoid(c)
        return self.irf(theta, a, b, c, **self.irf_kwargs)  # May want to clip values if c not constrained

Edit: I noticed that this torch.sigmoid(irf(...)) pattern also happens in MIRT, and possible elsewhere too.
Edit 2: I also realise that because sigmoid is monotonic, it doesn't really change the optimal solution. However, it does seem unnecessary to differentiate through sigmoid twice.


Error Message

NA

To Reproduce

NA

Environment

Environment Information

Operating System: NA

Python Version: NA

Additional context

@tswsxk
Copy link
Collaborator

tswsxk commented Oct 13, 2021

Thank you for your bug report, we will reply you as soon as possible.

@ViviHong200709
Copy link
Contributor

Thanks for reporting the bug and you provided very useful advice on the improvement of our framework .We fix the bug by removing the sigmoid function. We also add constrains that ensures a>=0 and 0 <= c <= 1.

@niall-twomey
Copy link
Author

@ViviHong200709 - thanks! I added two comments to the PR that I think should be looked at.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants