-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch implementation as a class #2
Comments
This is great! Also corrals stateful things like gamma into the object where they belong. Do you have a quick example or unit test of this thing? This looks worthwhile to merge into the original codebase I'm thinking ... |
No, no quick examples, I'm using it in a different domain, not NLP tasks.
Output value of |
Works for me, thanks! Plz let me know if this loss function works well in your problem domain! Oh - important note; I have to update the text, I've actually found that 2.0 is the best default for Delta. (Feel free to turn that knob though!) |
Looks like |
Interesting; will have to keep an eye on that issue. I've wondered if gamma should never drop below some minimum value, or if delta*wrong_pairs should be clipped above some miminum ... |
(I'll break this code out from the paper soon, as a proper source file, and then we can do a proper gitful checkin of your changes.) |
I often get Index errors in this line when
I've kind of fixed this by changing it to the below, but as I don't fully understand the code I'm not sure if that's a proper fix or if there is some other root issue that needs to be fixed:
|
@brandenkmurray Yes, your fix is valid. I also did something similar in code I use. That's often happens at the very begging when AUC is ~0.5-0.6. Also reducing |
Anyone have a toy kernel some where I can fork to take a look at this bug? |
I have one in private repo 😞 |
Cool - good luck in competition! If the loss function can help you rank high, that would be awesome! |
Very nice work, running into "RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time." when using this though. Will stick to the normal integration in the repo for now, which runs fine .... or did you see any performance improvements using this implementation instead of the normal? |
Hey I tried using this class on a toy model + data and I get this error I get another error when I do loss.backward(retain_graph=True), I get another error. I have attached my collab notebook so you can recreate the error. Any help would be appreciated. |
Hi |
Thanks for the Colab notebook, @PotatoSpudowski . Taking a look right now ... |
Hi guys - here's the fix for @PotatoSpudowski 's error : "gradient computation has been modified by an inplace operation". At the end of RocStarLoss.forward(), replace
with
This is in addition to calling the loss.backward method like this : |
I'm going to (finally!) pull this into the codebase tonight, for now let me just give a complete version of the working class :
|
This works thank you! One more thing that I noticed was that sometimes loss is nan
And when we try to update gamma we get an error. in update_gamma(self) ValueError: cannot convert float NaN to integer` I believe this can be fixed by taking a larger sample but it would be nice if had a check condition before calling update_gamma() |
I will now implement this in my competition notebook and let you know if we make progress. |
I see the NaN now - you are right, this occurs when a batch all belongs to the same class. I'm also noticing a second issue with this code's implementation; it is missing a step where it should do a random subsample. Give me a few minutes to patch ... |
Taking a bit more. I'll get this patched up within the next 6-12 hrs for sure. The basic issue is this code has drifted a bit from the original code here https://github.com/iridiumblue/roc-star/blob/master/example.py . This code introduced a few errors. There are 3 easy and 1 less easy fixes I'm doing Kick NaN's to the curb.loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) Stay tuned ... If you don't need a class, you can use the old non-class code (see example.py) that works correctly. Or just wait a bit more :) |
I will wait for your implementation of the class no worries. Currently, I had 0.2 to 0.3 improvement in AUC and significant improvement in ACC. And I followed a bit of your implementation where I train for 2 epochs on BCE Loss and then using the above loss I train for a few more epochs. |
Thank you for your work @iridiumblue and @zakajd for ROC-STAR and this class implementation! |
I got dragged off into other things, let me try to get this wrapped for you today! More to come ... |
Hey, is there any semi-fixed code for the class? |
I got worse performance... how did you get improved? |
Do we really need random subsampling instead of slicing if the dataset is shuffled? Aren't they doing basically the same thing? |
Don't know if this is still getting any updates, but I just wanted to point out that the class implementation above (#2 (comment)) is at risk of data leakage if used during validation/testing because every prediction+label that passes through it is saved temporarily and used in subsequent loss calculations. That means the samples at the end of the validation set will influence the loss at the beginning of the next training epoch. |
I wanna thanks the author and others for this great idea! Does anyone have a bug-free class implementation now ? Thanks! |
Shouldn't it be relu(diff) ** 2 instead of relu(diff ** 2)? |
I suppose you are right. It is relu(diff) ** 2 in the original implementation and relu should do nothing on the square of diff |
I'm having an issue where the loss spikes at the start of each epoch which did not happen with BCE. I am shuffling my data so this is a bit weird. Does anyone know why this is happening with RocStar? from torch.nn.modules.loss import _Loss
class RocStarLoss(_Loss):
"""Smooth approximation for ROC AUC
"""
def __init__(self, delta = 1.0, sample_size = 1000, sample_size_gamma = 1000, update_gamma_each=50):
r"""
Args:
delta: Param from article
sample_size (int): Number of examples to take for ROC AUC approximation
sample_size_gamma (int): Number of examples to take for Gamma parameter approximation
update_gamma_each (int): Number of steps after which to recompute gamma value.
"""
super().__init__()
self.delta = delta
self.sample_size = sample_size
self.sample_size_gamma = sample_size_gamma
self.update_gamma_each = update_gamma_each
self.steps = 0
size = max(sample_size, sample_size_gamma)
# Randomly init labels
self.y_pred_history = torch.rand((size, 1),device='cuda')
self.y_true_history = torch.randint(2, (size, 1),device='cuda')
def forward(self, y_pred, y_true):
"""
Args:
y_pred: Tensor of model predictions in [0, 1] range. Shape (B x 1)
y_true: Tensor of true labels in {0, 1}. Shape (B x 1)
"""
#y_pred = _y_pred.clone().detach()
#y_true = _y_true.clone().detach()
if self.steps % self.update_gamma_each == 0:
self.update_gamma()
self.steps += 1
positive = y_pred[y_true > 0]
negative = y_pred[y_true < 1]
# Take last `sample_size` elements from history
y_pred_history = self.y_pred_history[- self.sample_size:]
y_true_history = self.y_true_history[- self.sample_size:]
positive_history = y_pred_history[y_true_history > 0]
negative_history = y_pred_history[y_true_history < 1]
if positive.size(0) > 0:
diff = negative_history.view(1, -1) + self.gamma - positive.view(-1, 1)
loss_positive = torch.nn.functional.relu(diff ** 2).mean()
else:
loss_positive = 0
if negative.size(0) > 0:
diff = negative.view(1, -1) + self.gamma - positive_history.view(-1, 1)
loss_negative = torch.nn.functional.relu(diff ** 2).mean()
else:
loss_negative = 0
loss = loss_negative + loss_positive
# Update FIFO queue
batch_size = y_pred.size(0)
self.y_pred_history = torch.cat((self.y_pred_history[batch_size:], y_pred.clone().detach()))
self.y_true_history = torch.cat((self.y_true_history[batch_size:], y_true.clone().detach()))
return loss
def update_gamma(self):
# Take last `sample_size_gamma` elements from history
y_pred = self.y_pred_history[- self.sample_size_gamma:]
y_true = self.y_true_history[- self.sample_size_gamma:]
positive = y_pred[y_true > 0]
negative = y_pred[y_true < 1]
# Create matrix of size sample_size_gamma x sample_size_gamma
diff = positive.view(-1, 1) - negative.view(1, -1)
AUC = (diff > 0).type(torch.float).mean()
num_wrong_ordered = (1 - AUC) * diff.flatten().size(0)
# Adjuct gamma, so that among correct ordered samples `delta * num_wrong_ordered` were considered
# ordered incorrectly with gamma added
correct_ordered = diff[diff > 0].flatten().sort().values
if len(correct_ordered) != 0:
idx = min(int(num_wrong_ordered * self.delta), len(correct_ordered)-1)
self.gamma = correct_ordered[idx]
print(f"Updated gamma: {self.gamma}")
else:
print(f"Did not update gamma. Current gamma is: {self.gamma}") |
Hi, thanks a lot for you job, idea described here is smart and simple!
To apply it in my work I had to rewrite your code into class form. This also made code a bit cleaner using PyTorch broadcasting. Hope it will also be useful for someone else =)
I also slightly changed the logic. Here batches are saved into a FIFO queue and during each
forward
call lastsample_size
elements are taken instead of a random subset.update: Added small fix from brandenkmurray in
update_gamma
function.The text was updated successfully, but these errors were encountered: