Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CoxPHLoss does not handle batches where all samples are censored #52

Closed
lenbrocki opened this issue Dec 21, 2023 · 10 comments
Closed

CoxPHLoss does not handle batches where all samples are censored #52

lenbrocki opened this issue Dec 21, 2023 · 10 comments

Comments

@lenbrocki
Copy link

I'm training a custom model with the CoxPHLoss and have noticed that when using the Efron tie method the training will fail when a batch only contains censored events. The code giving the error is in lassonet/utils.py:

if hasattr(torch.Tensor, "scatter_reduce_"):
    # version >= 1.12
    def scatter_reduce(input, dim, index, reduce, *, output_size=None):
        src = input
        if output_size is None:
            output_size = index.max() + 1
        return torch.empty(output_size, device=input.device).scatter_reduce(
            dim=dim, index=index, src=src, reduce=reduce, include_self=False
        )

else:
    scatter_reduce = torch.scatter_reduce

When all samples are censored index will be an empty tensor and index.max() fails.
Also, if I understand correctly, the Cox likelihood would be zero in that case so that the log likelihood is not defined.
For now I have resorted to skipping these problematic batches, but I was thinking that it might be helpful to handle this edge case directly in CoxPHLoss. Not sure what's the best way of doing it though.

@louisabraham
Copy link
Collaborator

Hello, can you produce a minimal reproducible example?

Also, if I understand correctly, the Cox likelihood would be zero in that case so that the log likelihood is not defined.

wouldn't it be one? Can you test with the Breslow approximation?

@louisabraham
Copy link
Collaborator

image

I think that if the sets are empty, the log-likelihood is just zero.

@lenbrocki
Copy link
Author

Sorry for the late response. A minimal example would be:

import torch
from lassonet.cox import CoxPHLoss

loss = CoxPHLoss("breslow")
labels = torch.tensor([[5.0, 0], [2.0, 0]])
hazards = torch.tensor([5.0, 2.0])
print(loss(hazards, labels)) 
#prints nan 

loss = CoxPHLoss("efron")
labels = torch.tensor([[5.0, 0], [2.0, 0]])
hazards = torch.tensor([5.0, 2.0])
print(loss(hazards, labels)) 
#fails with RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.

The error for the Efron method happens because of what I've described above. I think the nan in the breslow case happens because the likelihood is zero. This can for example be seen from your paper https://arxiv.org/pdf/2208.09793.pdf in equation 1: if all $\delta_i$ are zero, the product is zero and then the log of this is not defined.

@louisabraham
Copy link
Collaborator

louisabraham commented Jan 8, 2024 via email

@lenbrocki
Copy link
Author

Oh I wasn't aware of that, but yes you're right of course. Then the problem becomes why nan is returned and not 0.

@louisabraham
Copy link
Collaborator

louisabraham commented Jan 8, 2024 via email

@lenbrocki
Copy link
Author

lenbrocki commented Jan 11, 2024

The CoxPHLoss now correctly returns 0. But when I'm trying to use the fixed loss in training I'm getting for batches where all samples have $\delta_i = 0$ this error:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

louisabraham added a commit that referenced this issue Jan 12, 2024
@louisabraham
Copy link
Collaborator

I think I managed to find a better fix :) Can you test again?

@lenbrocki
Copy link
Author

Sorry again for the delay. Yes, it's working now!

@louisabraham
Copy link
Collaborator

Great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants