In [None]:
#export
from fastcore.all import *
from fastai2.basics import *

In [None]:
# default_exp callbacks

# Callbacks
> Handle the the different format of inputs and outputs between fastai and transformers

## FakeLearner Class just for Test

In [None]:
class FakeLearner():
    def __init__(self, cb, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)
        cb.learn = self
        self.cb = cb
    
    def run_cb(self, event_name):
        getattr(self.cb, event_name)()

## GPT2LMHeadCallback

In [None]:
#export
class GPT2LMHeadCallback(Callback):
    def after_pred(self):
        ''' The output of AutoModelWithLMHead is (last_hidden_state, past)
            What fastai want is last_hidden_state '''
        last_hidden_state = self.learn.pred[0]
        self.learn.pred = last_hidden_state

In [None]:
learn = FakeLearner(cb=GPT2LMHeadCallback(), pred=('last_hidden_state', 'past'))
learn.run_cb('after_pred')
test_eq(learn.pred, 'last_hidden_state')

## BertSeqClassificationCallback

In [None]:
#export
class BertSeqClassificationCallback(Callback):
    ''' It should be ok to use it in all Bert like model. eg: Roberta
    '''
    def __init__(self, pad_id: int):
        self.pad_id = pad_id
    
    def begin_batch(self):
        ''' Instead of input_ids, we need to pass the attention_mask to AutoModelForSequenceClassification to avoid it to attention to padding tokens.
        '''
        input_ids = self.learn.xb[0]
        device = input_ids.device
        attention_mask = torch.where(input_ids == self.pad_id, torch.tensor(0, device=device), torch.tensor(1, device=device)).to(input_ids)
        self.learn.xb = [input_ids, attention_mask]
    
    def after_pred(self):
        ''' The output of AutoModelForSequenceClassification is (logits, )
            What fastai want is logits '''
        logits = self.learn.pred[0]
        self.learn.pred = logits

In [None]:
input_ids = torch.tensor([[4, 3, 1, 1], 
                          [5, 6, 7, 1]])
attention_mask = torch.tensor([[1, 1, 0, 0], 
                               [1, 1, 1, 0]])

learn = FakeLearner(cb=BertSeqClassificationCallback(pad_id=1), xb=(input_ids,))
learn.run_cb('begin_batch')
test_eq(learn.xb, (input_ids, attention_mask))

learn = FakeLearner(cb=BertSeqClassificationCallback(pad_id=1), pred=('logits',))
learn.run_cb('after_pred')
test_eq(learn.pred, 'logits')

## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()