Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch implementation as a class #2

Open
zakajd opened this issue Jun 25, 2020 · 32 comments
Open

PyTorch implementation as a class #2

zakajd opened this issue Jun 25, 2020 · 32 comments

Comments

@zakajd
Copy link

zakajd commented Jun 25, 2020

Hi, thanks a lot for you job, idea described here is smart and simple!

To apply it in my work I had to rewrite your code into class form. This also made code a bit cleaner using PyTorch broadcasting. Hope it will also be useful for someone else =)

I also slightly changed the logic. Here batches are saved into a FIFO queue and during each forward call last sample_size elements are taken instead of a random subset.

update: Added small fix from brandenkmurray in update_gamma function.


class RocStarLoss(_Loss):
    """Smooth approximation for ROC AUC
    """
    def __init__(self, delta = 1.0, sample_size = 1000, sample_size_gamma = 10000, update_gamma_each=500):
        r"""
        Args:
            delta: Param from article
            sample_size (int): Number of examples to take for ROC AUC approximation
            sample_size_gamma (int): Number of examples to take for Gamma parameter approximation
            update_gamma_each (int): Number of steps after which to recompute gamma value.
        """
        super().__init__()
        self.delta = delta
        self.sample_size = sample_size
        self.sample_size_gamma = sample_size_gamma
        self.update_gamma_each = update_gamma_each
        self.steps = 0
        size = max(sample_size, sample_size_gamma)

        # Randomly init labels
        self.y_pred_history = torch.rand((size, 1))
        self.y_true_history = torch.randint(2, (size, 1))
        

    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred: Tensor of model predictions in [0, 1] range. Shape (B x 1)
            y_true: Tensor of true labels in {0, 1}. Shape (B x 1)
        """
        if self.steps % self.update_gamma_each == 0:
            self.update_gamma()
        self.steps += 1
        
        positive = y_pred[y_true > 0]
        negative = y_pred[y_true < 1]
        
        # Take last `sample_size` elements from history
        y_pred_history = self.y_pred_history[- self.sample_size:]
        y_true_history = self.y_true_history[- self.sample_size:]
        
        positive_history = y_pred_history[y_true_history > 0]
        negative_history = y_pred_history[y_true_history < 1]
        
        if positive.size(0) > 0:
            diff = negative_history.view(1, -1) + self.gamma - positive.view(-1, 1)
            loss_positive = torch.nn.functional.relu(diff ** 2).mean()
        else:
            loss_positive = 0
 
        if negative.size(0) > 0:
            diff = negative.view(1, -1) + self.gamma - positive_history.view(-1, 1)
            loss_negative = torch.nn.functional.relu(diff ** 2).mean()
        else:
            loss_negative = 0
            
        loss = loss_negative + loss_positive
        
        # Update FIFO queue
        batch_size = y_pred.size(0)
        self.y_pred_history = torch.cat((self.y_pred_history[batch_size:], y_pred))
        self.y_true_history = torch.cat((self.y_true_history[batch_size:], y_true))
        return loss

    def update_gamma(self):
        # Take last `sample_size_gamma` elements from history
        y_pred = self.y_pred_history[- self.sample_size_gamma:]
        y_true = self.y_true_history[- self.sample_size_gamma:]
        
        positive = y_pred[y_true > 0]
        negative = y_pred[y_true < 1]
        
        # Create matrix of size sample_size_gamma x sample_size_gamma
        diff = positive.view(-1, 1) - negative.view(1, -1)
        AUC = (diff > 0).type(torch.float).mean()
        num_wrong_ordered = (1 - AUC) * diff.flatten().size(0)
        
        # Adjuct gamma, so that among correct ordered samples `delta * num_wrong_ordered` were considered
        # ordered incorrectly with gamma added
        correct_ordered = diff[diff > 0].flatten().sort().values
        idx = min(int(num_wrong_ordered * self.delta), len(correct_ordered)-1)
        self.gamma = correct_ordered[idx]
@iridiumblue
Copy link
Owner

This is great! Also corrals stateful things like gamma into the object where they belong.

Do you have a quick example or unit test of this thing?

This looks worthwhile to merge into the original codebase I'm thinking ...

@zakajd
Copy link
Author

zakajd commented Jun 26, 2020

No, no quick examples, I'm using it in a different domain, not NLP tasks.
I'll add tests a bit later, but for now you can adjust your example on README as follows:

train_ds = CatDogDataset(train_files, transform)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE)
roc_star_loss = RocStarLoss()
for epoch in range(epoches):
    for X, y in train_dl:
        preds = model(X)
        # ...
        loss = roc_star_loss(y,preds)
    #...

Output value of epoch_update_gamma is same with self.update_gamma() (which is also ~25% faster )
Loss value can be slightly different for small sample_size, but always roughly similar

@iridiumblue
Copy link
Owner

Works for me, thanks! Plz let me know if this loss function works well in your problem domain!

Oh - important note; I have to update the text, I've actually found that 2.0 is the best default for Delta.

(Feel free to turn that knob though!)

@zakajd
Copy link
Author

zakajd commented Jun 26, 2020

Looks like delta value mostly depends on top achievable performance. For example if your regular AUC is bouncing near ~0.75 it doesn't make sense to have big value, while if your AUC is near to 0.95 even bigger values of delta can be used to further shift prediction probabilities.

@iridiumblue
Copy link
Owner

Interesting; will have to keep an eye on that issue. I've wondered if gamma should never drop below some minimum value, or if delta*wrong_pairs should be clipped above some miminum ...

@iridiumblue
Copy link
Owner

(I'll break this code out from the paper soon, as a proper source file, and then we can do a proper gitful checkin of your changes.)

@brandenkmurray
Copy link

I often get Index errors in this line when delta > 1 because sometimes the calculated index is greater than len(correct_ordered):

correct_ordered[int(num_wrong_ordered * self.delta)]

I've kind of fixed this by changing it to the below, but as I don't fully understand the code I'm not sure if that's a proper fix or if there is some other root issue that needs to be fixed:

correct_ordered = diff[diff > 0].flatten().sort().values
# Deltas > 1 can cause an index error if the resulting value is more than len(correct_ordered)
correct_ordered_idx = min(int(num_wrong_ordered * self.delta), len(correct_ordered)-1)
self.gamma = correct_ordered[correct_ordered_idx]

@zakajd
Copy link
Author

zakajd commented Jul 23, 2020

@brandenkmurray Yes, your fix is valid. I also did something similar in code I use. That's often happens at the very begging when AUC is ~0.5-0.6. Also reducing delta value helps, but I didn't investigate how this influences the result.

@iridiumblue
Copy link
Owner

Anyone have a toy kernel some where I can fork to take a look at this bug?

@zakajd
Copy link
Author

zakajd commented Jul 23, 2020

I have one in private repo 😞
Will share at mid-August, as far as competition I'm taking part in will end.
Nothing really serious here. Just an indexing error I didn't notice at the beginning

@iridiumblue
Copy link
Owner

Cool - good luck in competition! If the loss function can help you rank high, that would be awesome!

@Muennighoff
Copy link

Very nice work, running into "RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time." when using this though. Will stick to the normal integration in the repo for now, which runs fine .... or did you see any performance improvements using this implementation instead of the normal?

@PotatoSpudowski
Copy link

Hey I tried using this class on a toy model + data and I get this error
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
when I do loss.backward()

I get another error when I do loss.backward(retain_graph=True), I get another error.
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [64, 1]], which is output 0 of TBackward, is at version 10; expected version 9 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I have attached my collab notebook so you can recreate the error. Any help would be appreciated.

Colab Notebook

@zakajd
Copy link
Author

zakajd commented Oct 15, 2020

Hi
I actually never used this class for training, ended up using another approach, so the code is definitely needs some debugging.
Try using hint from pytorch torch.autograd.set_detect_anomaly(True)
I can guess error happens somewhere in FIFO queue, but don't have time to debug it now.
Please comment, if you fix the issue!

@iridiumblue
Copy link
Owner

Thanks for the Colab notebook, @PotatoSpudowski . Taking a look right now ...

@iridiumblue
Copy link
Owner

Hi guys - here's the fix for @PotatoSpudowski 's error : "gradient computation has been modified by an inplace operation".

At the end of RocStarLoss.forward(), replace

  • self.y_pred_history = torch.cat((self.y_pred_history[batch_size:], y_pred))
  • self.y_true_history = torch.cat((self.y_true_history[batch_size:], y_true))

with

  • self.y_pred_history = torch.cat((self.y_pred_history[batch_size:], y_pred.clone().detach()))
  • self.y_true_history = torch.cat((self.y_true_history[batch_size:], y_true.clone().detach()))

This is in addition to calling the loss.backward method like this :
loss.backward(retain_graph=True)

@iridiumblue
Copy link
Owner

iridiumblue commented Oct 15, 2020

I'm going to (finally!) pull this into the codebase tonight, for now let me just give a complete version of the working class :


class RocStarLoss(_Loss):
    """Smooth approximation for ROC AUC
    """
    def __init__(self, delta = 1.0, sample_size = 10, sample_size_gamma = 10, update_gamma_each=10):
        r"""
        Args:
            delta: Param from article
            sample_size (int): Number of examples to take for ROC AUC approximation
            sample_size_gamma (int): Number of examples to take for Gamma parameter approximation
            update_gamma_each (int): Number of steps after which to recompute gamma value.
        """
        super().__init__()
        self.delta = delta
        self.sample_size = sample_size
        self.sample_size_gamma = sample_size_gamma
        self.update_gamma_each = update_gamma_each
        self.steps = 0
        size = max(sample_size, sample_size_gamma)

        # Randomly init labels
        self.y_pred_history = torch.rand((size, 1))
        self.y_true_history = torch.randint(2, (size, 1))
        

    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred: Tensor of model predictions in [0, 1] range. Shape (B x 1)
            y_true: Tensor of true labels in {0, 1}. Shape (B x 1)
        """
        #y_pred = _y_pred.clone().detach()
        #y_true = _y_true.clone().detach()
        if self.steps % self.update_gamma_each == 0:
            self.update_gamma()
        self.steps += 1
        
        positive = y_pred[y_true > 0]
        negative = y_pred[y_true < 1]
        
        # Take last `sample_size` elements from history
        y_pred_history = self.y_pred_history[- self.sample_size:]
        y_true_history = self.y_true_history[- self.sample_size:]
        
        positive_history = y_pred_history[y_true_history > 0]
        negative_history = y_pred_history[y_true_history < 1]
        
        if positive.size(0) > 0:
            diff = negative_history.view(1, -1) + self.gamma - positive.view(-1, 1)
            loss_positive = torch.nn.functional.relu(diff ** 2).mean()
        else:
            loss_positive = 0
 
        if negative.size(0) > 0:
            diff = negative.view(1, -1) + self.gamma - positive_history.view(-1, 1)
            loss_negative = torch.nn.functional.relu(diff ** 2).mean()
        else:
            loss_negative = 0
            
        loss = loss_negative + loss_positive
        
        # Update FIFO queue
        batch_size = y_pred.size(0)
        self.y_pred_history = torch.cat((self.y_pred_history[batch_size:], y_pred.clone().detach()))
        self.y_true_history = torch.cat((self.y_true_history[batch_size:], y_true.clone().detach()))
        return loss

    def update_gamma(self):
        # Take last `sample_size_gamma` elements from history
        y_pred = self.y_pred_history[- self.sample_size_gamma:]
        y_true = self.y_true_history[- self.sample_size_gamma:]
        
        positive = y_pred[y_true > 0]
        negative = y_pred[y_true < 1]
        
        # Create matrix of size sample_size_gamma x sample_size_gamma
        diff = positive.view(-1, 1) - negative.view(1, -1)
        AUC = (diff > 0).type(torch.float).mean()
        num_wrong_ordered = (1 - AUC) * diff.flatten().size(0)
        
        # Adjuct gamma, so that among correct ordered samples `delta * num_wrong_ordered` were considered
        # ordered incorrectly with gamma added
        correct_ordered = diff[diff > 0].flatten().sort().values
        idx = min(int(num_wrong_ordered * self.delta), len(correct_ordered)-1)
        self.gamma = correct_ordered[idx]

@PotatoSpudowski
Copy link

This works thank you!

One more thing that I noticed was that sometimes loss is nan

Epoch 046: | Loss: 0.02856 | Acc: 65.750 Epoch 047: | Loss: nan | Acc: 63.500 Epoch 048: | Loss: 0.02594 | Acc: 52.000 Epoch 049: | Loss: 0.02916 | Acc: 65.750

And when we try to update gamma we get an error.
` in forward(self, y_pred, y_true)
32 #y_true = _y_true.clone().detach()
33 if self.steps % self.update_gamma_each == 0:
---> 34 self.update_gamma()
35 self.steps += 1
36

in update_gamma(self)
81 # ordered incorrectly with gamma added
82 correct_ordered = diff[diff > 0].flatten().sort().values
---> 83 idx = min(int(num_wrong_ordered * self.delta), len(correct_ordered)-1)
84 self.gamma = correct_ordered[idx]

ValueError: cannot convert float NaN to integer`

I believe this can be fixed by taking a larger sample but it would be nice if had a check condition before calling update_gamma()

@PotatoSpudowski
Copy link

I will now implement this in my competition notebook and let you know if we make progress.

@iridiumblue
Copy link
Owner

I see the NaN now - you are right, this occurs when a batch all belongs to the same class. I'm also noticing a second issue with this code's implementation; it is missing a step where it should do a random subsample.

Give me a few minutes to patch ...

@iridiumblue
Copy link
Owner

Taking a bit more. I'll get this patched up within the next 6-12 hrs for sure. The basic issue is this code has drifted a bit from the original code here https://github.com/iridiumblue/roc-star/blob/master/example.py . This code introduced a few errors.

There are 3 easy and 1 less easy fixes I'm doing
1 - stomp out NaN at the end of forward() using

Kick NaN's to the curb.

loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
2 - increase the sample sizes (all 3 of them) in the constructor to ~ 1000. (This isn't batch size or anything, it's an internal value that needs to be about that size for the purposes of a certain tensor operation. That subsample keeps the tensor within a reasonable memory size for the GPU.
3 - add a method to call update_gamma at the end of each epoch.
4 - Change a couple lines of code from slicing to random subsampling.

Stay tuned ... If you don't need a class, you can use the old non-class code (see example.py) that works correctly. Or just wait a bit more :)

@PotatoSpudowski
Copy link

PotatoSpudowski commented Oct 16, 2020

I will wait for your implementation of the class no worries.

Currently, I had 0.2 to 0.3 improvement in AUC and significant improvement in ACC. And I followed a bit of your implementation where I train for 2 epochs on BCE Loss and then using the above loss I train for a few more epochs.

@NumesSanguis
Copy link

Thank you for your work @iridiumblue and @zakajd for ROC-STAR and this class implementation!
I would love to try this loss function, but for my workflow it would be easier if it could be used as a class.
Looking forward to any progress!

@iridiumblue
Copy link
Owner

I got dragged off into other things, let me try to get this wrapped for you today!

More to come ...

@0xnakul
Copy link

0xnakul commented Apr 4, 2021

Hey, is there any semi-fixed code for the class?

@hktxt
Copy link

hktxt commented May 27, 2021

I will wait for your implementation of the class no worries.

Currently, I had 0.2 to 0.3 improvement in AUC and significant improvement in ACC. And I followed a bit of your implementation where I train for 2 epochs on BCE Loss and then using the above loss I train for a few more epochs.

I got worse performance... how did you get improved?

@dongkyunk
Copy link

Taking a bit more. I'll get this patched up within the next 6-12 hrs for sure. The basic issue is this code has drifted a bit from the original code here https://github.com/iridiumblue/roc-star/blob/master/example.py . This code introduced a few errors.

There are 3 easy and 1 less easy fixes I'm doing
1 - stomp out NaN at the end of forward() using

Kick NaN's to the curb.

loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
2 - increase the sample sizes (all 3 of them) in the constructor to ~ 1000. (This isn't batch size or anything, it's an internal value that needs to be about that size for the purposes of a certain tensor operation. That subsample keeps the tensor within a reasonable memory size for the GPU.
3 - add a method to call update_gamma at the end of each epoch.
4 - Change a couple lines of code from slicing to random subsampling.

Stay tuned ... If you don't need a class, you can use the old non-class code (see example.py) that works correctly. Or just wait a bit more :)

Do we really need random subsampling instead of slicing if the dataset is shuffled? Aren't they doing basically the same thing?

@ayhyap
Copy link

ayhyap commented Feb 23, 2022

Don't know if this is still getting any updates, but I just wanted to point out that the class implementation above (#2 (comment)) is at risk of data leakage if used during validation/testing because every prediction+label that passes through it is saved temporarily and used in subsequent loss calculations.

That means the samples at the end of the validation set will influence the loss at the beginning of the next training epoch.

@noowfel
Copy link

noowfel commented May 9, 2022

I wanna thanks the author and others for this great idea! Does anyone have a bug-free class implementation now ? Thanks!

@StanTRC
Copy link

StanTRC commented Nov 9, 2023

Shouldn't it be relu(diff) ** 2 instead of relu(diff ** 2)?

@Andrew-Tuen
Copy link

Shouldn't it be relu(diff) ** 2 instead of relu(diff ** 2)?

I suppose you are right. It is relu(diff) ** 2 in the original implementation and relu should do nothing on the square of diff

@maxall41
Copy link

I'm having an issue where the loss spikes at the start of each epoch which did not happen with BCE. I am shuffling my data so this is a bit weird. Does anyone know why this is happening with RocStar?
Screenshot 2024-01-31 at 2 25 03 PM
Also, there is an issue with the above implementation by @iridiumblue where if all of the samples in the batch are of one type the gamma can not be updated and it throws an error. I fixed this by doing the following:

from torch.nn.modules.loss import _Loss

class RocStarLoss(_Loss):
    """Smooth approximation for ROC AUC
    """
    def __init__(self, delta = 1.0, sample_size = 1000, sample_size_gamma = 1000, update_gamma_each=50):
        r"""
        Args:
            delta: Param from article
            sample_size (int): Number of examples to take for ROC AUC approximation
            sample_size_gamma (int): Number of examples to take for Gamma parameter approximation
            update_gamma_each (int): Number of steps after which to recompute gamma value.
        """
        super().__init__()
        self.delta = delta
        self.sample_size = sample_size
        self.sample_size_gamma = sample_size_gamma
        self.update_gamma_each = update_gamma_each
        self.steps = 0
        size = max(sample_size, sample_size_gamma)

        # Randomly init labels
        self.y_pred_history = torch.rand((size, 1),device='cuda')
        self.y_true_history = torch.randint(2, (size, 1),device='cuda')
        

    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred: Tensor of model predictions in [0, 1] range. Shape (B x 1)
            y_true: Tensor of true labels in {0, 1}. Shape (B x 1)
        """
        #y_pred = _y_pred.clone().detach()
        #y_true = _y_true.clone().detach()
        if self.steps % self.update_gamma_each == 0:
            self.update_gamma()
        self.steps += 1
        
        positive = y_pred[y_true > 0]
        negative = y_pred[y_true < 1]
        
        # Take last `sample_size` elements from history
        y_pred_history = self.y_pred_history[- self.sample_size:]
        y_true_history = self.y_true_history[- self.sample_size:]
        
        positive_history = y_pred_history[y_true_history > 0]
        negative_history = y_pred_history[y_true_history < 1]
        
        if positive.size(0) > 0:
            diff = negative_history.view(1, -1) + self.gamma - positive.view(-1, 1)
            loss_positive = torch.nn.functional.relu(diff ** 2).mean()
        else:
            loss_positive = 0
 
        if negative.size(0) > 0:
            diff = negative.view(1, -1) + self.gamma - positive_history.view(-1, 1)
            loss_negative = torch.nn.functional.relu(diff ** 2).mean()
        else:
            loss_negative = 0
            
        loss = loss_negative + loss_positive
        
        # Update FIFO queue
        batch_size = y_pred.size(0)
        self.y_pred_history = torch.cat((self.y_pred_history[batch_size:], y_pred.clone().detach()))
        self.y_true_history = torch.cat((self.y_true_history[batch_size:], y_true.clone().detach()))
        return loss

    def update_gamma(self):
        # Take last `sample_size_gamma` elements from history
        y_pred = self.y_pred_history[- self.sample_size_gamma:]
        y_true = self.y_true_history[- self.sample_size_gamma:]
        
        positive = y_pred[y_true > 0]
        negative = y_pred[y_true < 1]
        
        # Create matrix of size sample_size_gamma x sample_size_gamma
        diff = positive.view(-1, 1) - negative.view(1, -1)
        AUC = (diff > 0).type(torch.float).mean()
        num_wrong_ordered = (1 - AUC) * diff.flatten().size(0)
        
        # Adjuct gamma, so that among correct ordered samples `delta * num_wrong_ordered` were considered
        # ordered incorrectly with gamma added
        correct_ordered = diff[diff > 0].flatten().sort().values
        if len(correct_ordered) != 0:
          idx = min(int(num_wrong_ordered * self.delta), len(correct_ordered)-1)
          self.gamma = correct_ordered[idx]
          print(f"Updated gamma: {self.gamma}")
        else:
           print(f"Did not update gamma. Current gamma is: {self.gamma}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests