Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert clonedetection example to multitask/multilabel #56

Closed
PedroEstevesPT opened this issue Jul 13, 2021 · 16 comments
Closed

Convert clonedetection example to multitask/multilabel #56

PedroEstevesPT opened this issue Jul 13, 2021 · 16 comments

Comments

@PedroEstevesPT
Copy link

PedroEstevesPT commented Jul 13, 2021

Hi, right now the GraphCodeBERT clone detection performs binary classification to decide whether 2 pieces of code are semantically equivalent or not.

The problem I am trying to solve is: Given a natural language utterance and two code pieces (A and B) as input to my model, determine whether:

  • both pieces are correct
  • piece A is correct and piece B is wrong
  • piece B is wrong and piece A is correct
  • both pieces are wrong

I tried solving this problem as 4 class classification task in #53 , but the results were not very good, so right now what I am trying to accomplish is to transform it to a multi-class classification problem with a multi-label/multi-task, classifying each input 2 times:

[0,1] -> Whether A is right or wrong.
[0,1] -> Whether B is right or wrong.

Does anyone have any idea on how to accomplish this ?

Thanks a lot

@guoday
Copy link
Contributor

guoday commented Jul 14, 2021

You can directly take a natural language utterance and one code piece (A or B) as the input to do binary classification in #53. If you want to use GraphCodeBERT, you just need to change microsoft/codebert-base to microsoft/graphcodebert-base in train.sh and inference.sh

@guoday
Copy link
Contributor

guoday commented Jul 14, 2021

I guess that the input of a natural language utterance and two code pieces (A and B) is too long, so that the results were not good. You can try to set block_size as 512.

@PedroEstevesPT
Copy link
Author

You can directly take a natural language utterance and one code piece (A or B) as the input to do binary classification in [#53]

Hum, right now I would like to try a multi-task approach instead of having to call the model twice (each time to perform inference for a different piece of code). Do you have any idea how to get started with this?

Thanks a lot

@guoday
Copy link
Contributor

guoday commented Jul 14, 2021

You can change here to get embedding of two pieces of code:

outputs = self.encoder.roberta(inputs_embeds=inputs_embeddings,attention_mask=attn_mask,position_ids=position_idx)[0]

to
outputs = self.encoder.roberta(inputs_embeds=inputs_embeddings,attention_mask=attn_mask,position_ids=position_idx)[1].reshape(-1,2,768)

And then adding a classifier on the embeddings.

@PedroEstevesPT
Copy link
Author

PedroEstevesPT commented Jul 20, 2021

Sorry, I was not clear in what I was trying to accomplish (confusing multi-task with multi-value). Basically, forget the multi-task part , I want the model to perform multivalue classification, which means instead of predicting for example [1,0,0,0] it can predict one of the 4 cases:

  • [1,1] (both pieces of code are right)
  • [1,0] (only piece of code A ia right)
  • [0,1] (only piece of code B is right)
  • [0,0] (both pieces of code are wrong)

I basically just changed

  1. In model.py the loss from:
    loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
    to:
    loss_fct = nn.BCEWithLogitsLoss()

  2. Guaranteed that the labels tensors had this format:
    labels = torch.tensor([[1., 0.]]).cuda()

  3. Changed labels in run.py :
    from config.num_labels=1
    to config.num_labels=2

The model is training, however I am not sure this is the right procedure to convert the model to perform multi-label evaluation ? Could you confirm ?

When I perform inference it is returning a tensor with this shape:

([[0.7289, 0.2711]], device='cuda:0')

I interpret it as the first index giving me a probability about how likely the piece of code A is right and the piece of code B is wrong, so it seems expected, however like I mentioned before, I am not sure.

Thanks a lot

@guoday
Copy link
Contributor

guoday commented Jul 20, 2021

@guoday
Copy link
Contributor

guoday commented Jul 20, 2021

If you use GraphCodeBERT, you can replace model.py as:

import torch
import torch.nn as nn
import torch
from torch.autograd import Variable
import copy
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss

        
class Model(nn.Module):   
    def __init__(self, encoder,config,tokenizer,args):
        super(Model, self).__init__()
        self.encoder = encoder
        self.config=config
        self.tokenizer=tokenizer
        self.classifier=nn.Linear(768,1)
        self.args=args
    
        
    def forward(self, inputs_ids_1,position_idx_1,attn_mask_1,inputs_ids_2,position_idx_2,attn_mask_2,labels=None): 
        bs,l=inputs_ids_1.size()
        inputs_ids=torch.cat((inputs_ids_1.unsqueeze(1),inputs_ids_2.unsqueeze(1)),1).view(bs*2,l)
        position_idx=torch.cat((position_idx_1.unsqueeze(1),position_idx_2.unsqueeze(1)),1).view(bs*2,l)
        attn_mask=torch.cat((attn_mask_1.unsqueeze(1),attn_mask_2.unsqueeze(1)),1).view(bs*2,l,l)

        #embedding
        nodes_mask=position_idx.eq(0)
        token_mask=position_idx.ge(2)        
        inputs_embeddings=self.encoder.roberta.embeddings.word_embeddings(inputs_ids)
        nodes_to_token_mask=nodes_mask[:,:,None]&token_mask[:,None,:]&attn_mask
        nodes_to_token_mask=nodes_to_token_mask/(nodes_to_token_mask.sum(-1)+1e-10)[:,:,None]
        avg_embeddings=torch.einsum("abc,acd->abd",nodes_to_token_mask,inputs_embeddings)
        inputs_embeddings=inputs_embeddings*(~nodes_mask)[:,:,None]+avg_embeddings*nodes_mask[:,:,None]    
        
        outputs = self.encoder.roberta(inputs_embeds=inputs_embeddings,attention_mask=attn_mask,position_ids=position_idx)[1]
        outputs = outputs.reshape(bs,2,-1)
        logits=self.classifier(outputs)[:,:,0]
        prob=F.sigmoid(logits)
        if labels is not None:
            loss_fct = torch.nn.BCEWithLogitsLoss()
            loss = loss_fct(logits, labels)
            return loss,prob
        else:
            return prob
      
        

@PedroEstevesPT
Copy link
Author

Hi, thanks A LOT for posting the code, I confirm I am using GraphCodeBert and I will try out what you posted above, but could you please provide what is the input format for the data, just to be sure? (namely the label) Is it going to be a list (e.g. [1,1]) instead of 1 ?

@guoday
Copy link
Contributor

guoday commented Jul 20, 2021

A list, shown as the above.
[1,1] (both pieces of code are right)
[1,0] (only piece of code A ia right)
[0,1] (only piece of code B is right)
[0,0] (both pieces of code are wrong)

@PedroEstevesPT
Copy link
Author

Thanks I've already changed model.py and my input has this format { "code1" : " ", "code2": " " , "label": [0,1]}.

However when running the model I get this error. Any idea on how to solve it ?

1

@guoday
Copy link
Contributor

guoday commented Jul 20, 2021

So you use CodeBERT script instead of GraphCodeBERT script~~

@guoday
Copy link
Contributor

guoday commented Jul 20, 2021

Do you mind to share me with your dataset? Only a part of data (10+ cases) is enough. I will modify the code for you.

@PedroEstevesPT
Copy link
Author

Thanks, here go some samples:

{"code1" "Write one plus one | def func(): print("one plus one") , "code2": "Write one plus one | def func(): print("1+1")", "label": [1,1] }
{"code1" "Write one plus one | def func(): print("one plus one") , "code2": "Write one plus one | def func(): print("1+2")", "label": [1,0] }

@guoday
Copy link
Contributor

guoday commented Jul 20, 2021

Please find the zip. CodeBERT-classification-2.zip

@PedroEstevesPT
Copy link
Author

Thanks a lot @guoday ! I will try it out and give feedback in a bit

@guody5 guody5 closed this as completed Jul 25, 2021
@PedroEstevesPT
Copy link
Author

Sorry for not answering earlier, I tried it out, and it solved my issue, thanks @guoday .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants