In [1]:
from fastai import *
from fastai.text import *

In [2]:
path = untar_data(URLs.IMDB_SAMPLE)
path.ls()

[PosixPath('/storage/imdb_sample/data_save.pkl'),
 PosixPath('/storage/imdb_sample/texts.csv')]

In [3]:
bs=4

In [4]:
data = (TextList.from_csv(path, 'texts.csv', cols='text')
                .split_from_df(col=2)
                .label_from_df(cols=0)
                .databunch(bs=bs))

In [5]:
data.show_batch()

text,target
"xxbos xxmaj raising xxmaj victor xxmaj vargas : a xxmaj review \n \n xxmaj you know , xxmaj raising xxmaj victor xxmaj vargas is like sticking your hands into a big , xxunk bowl of xxunk . xxmaj it 's warm and gooey , but you 're not sure if it feels right . xxmaj try as i might , no matter how warm and gooey xxmaj raising xxmaj",negative
"xxbos xxmaj many neglect that this is n't just a classic due to the fact that it 's the first xxup 3d game , or even the first xxunk - up . xxmaj it 's also one of the first xxunk games , one of the xxunk definitely the first ) truly claustrophobic games , and just a pretty well - xxunk gaming experience in general . xxmaj with graphics",positive
"xxbos i had read many good things about this adaptation of my favorite novel ... so xxunk my expectations were crushed . xxmaj but they were crushed more than should be expected . xxmaj the movie would have been a decent movie if i had not read the novel xxunk , which perhaps ruined it for me . \n \n xxmaj in any event , for some reason they",negative
"xxbos xxmaj this is the last of four xxunk from xxmaj france i 've xxunk for viewing during this xxmaj christmas season : the others ( in order of viewing ) were the uninspired xxup the xxup black xxup xxunk ( 1964 ; from the same director as this one but not nearly as good ) , the surprisingly effective xxup lady xxmaj oscar ( 1979 ; which had xxunk",positive


In [6]:
learn = text_classifier_learner(data, AWD_LSTM, drop_mult=0.5)

What happens if I try to use MixUp on text data? Let's see by adding some debug prints on the original callback:

In [128]:
class MixUpCallback(LearnerCallback):
    "Callback that creates the mixed-up input and target."
    def __init__(self, learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True):
        super().__init__(learn)
        self.alpha,self.stack_x,self.stack_y = alpha,stack_x,stack_y
    
    def on_train_begin(self, **kwargs):
        if self.stack_y: self.learn.loss_func = MixUpLoss(self.learn.loss_func)
        
    def on_batch_begin(self, last_input, last_target, train, **kwargs):

        "Applies mixup to `last_input` and `last_target` if `train`."
        if not train: return
        lambd = np.random.beta(self.alpha, self.alpha, last_target.size(0))
        lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1)
        lambd = last_input.float().new(lambd)
        shuffle = torch.randperm(last_target.size(0)).to(last_input.device)
        x1, y1 = last_input[shuffle], last_target[shuffle]
        if self.stack_x:
            new_input = [[last_input, last_input[shuffle], lambd]]
        else: 
            out_shape = [lambd.size(0)] + [1 for _ in range(len(x1.shape) - 1)]
            new_input = (last_input * lambd.view(out_shape) + x1 * (1-lambd).view(out_shape))
        if self.stack_y:
            new_target = torch.cat([last_target[:,None].float(), y1[:,None].float(), lambd[:,None].float()], 1)
        else:
            if len(last_target.shape) == 2:
                lambd = lambd.unsqueeze(1).float()
            new_target = last_target.float() * lambd + y1.float() * (1-lambd)
        return {'last_input': new_input, 'last_target': new_target}  
    
    def on_train_end(self, **kwargs):
        if self.stack_y: self.learn.loss_func = self.learn.loss_func.get_old()
            
class MixUpLoss(Module):
    "Adapt the loss function `crit` to go with mixup."
    
    def __init__(self, crit, reduction='mean'):
        super().__init__()
        if hasattr(crit, 'reduction'): 
            self.crit = crit
            self.old_red = crit.reduction
            setattr(self.crit, 'reduction', 'none')
        else: 
            self.crit = partial(crit, reduction='none')
            self.old_crit = crit
        self.reduction = reduction
        
    def forward(self, output, target):
        if len(target.size()) == 2:
            loss1, loss2 = self.crit(output,target[:,0].long()), self.crit(output,target[:,1].long())
            d = loss1 * target[:,2] + loss2 * (1-target[:,2])
        else:  d = self.crit(output, target)
        if self.reduction == 'mean':    return d.mean()
        elif self.reduction == 'sum':   return d.sum()
        return d
    
    def get_old(self):
        if hasattr(self, 'old_crit'):  return self.old_crit
        elif hasattr(self, 'old_red'): 
            setattr(self.crit, 'reduction', self.old_red)
            return self.crit


In [110]:
import fastai
testMixup = MixUpCallback(learn)
xb,yb = data.one_batch()

In [24]:
()

tensor([[   2,    5, 2739,  ...,    0,   10,   24],
        [   1,    1,    1,  ...,  810,  104,   10],
        [   1,    1,    1,  ...,    0,   55,   10],
        [   1,    1,    1,  ...,  303,  163,   10]])

In [9]:
a = testMixup.on_batch_begin(xb.cuda(),yb.cuda(),True)

tensor([0, 0, 0, 0], device='cuda:0')
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0')


So, the numbers of the distribution are ints and get rounded to 0

In [10]:
a['last_input']

tensor([[   2,    5, 2739,  ...,    0,   10,   24],
        [   1,    1,    1,  ...,  122,  169,   34],
        [   1,    1,    1,  ...,  810,  104,   10],
        [   1,    1,    1,  ...,   10,    6, 4810]], device='cuda:0')

In [11]:
xb

tensor([[   2,    5, 2739,  ...,    0,   10,   24],
        [   1,    1,    1,  ...,  810,  104,   10],
        [   1,    1,    1,  ...,  122,  169,   34],
        [   1,    1,    1,  ...,   10,    6, 4810]])

Which means we get the same batch, but shuffled.

## Idea: Let's make a MixUp function that runs just after calculating embeddings

In [168]:
def mixup(x, x1, lambd):
    x = lambd[:,None,None] * x + (1 - lambd[:,None,None]) * x1
    return x   

In [63]:
testMixup = MixUpCallback(learn,stack_x=True)
xb,yb = data.one_batch()

In [146]:
xb.new([0.1,0.2])[:,None,None].size()

torch.Size([2, 1, 1])

In [65]:
a = testMixup.on_batch_begin(xb.cuda(),yb.cuda(),True)

In [66]:
a['last_input']

[tensor([[   2,    5, 2739,  ...,    0,   10,   24],
         [   1,    1,    1,  ...,  121,   58,   10],
         [   1,    1,    1,  ...,   15,  126,   34],
         [   1,    1,    1,  ...,   12,  566,   10]], device='cuda:0'),
 tensor([[   2,    5, 2739,  ...,    0,   10,   24],
         [   1,    1,    1,  ...,   15,  126,   34],
         [   1,    1,    1,  ...,   12,  566,   10],
         [   1,    1,    1,  ...,  121,   58,   10]], device='cuda:0'),
 tensor([0, 0, 0, 0], device='cuda:0')]

Even if we use stack_x we run into the problem of the lambdas being calculated as ints, let's modify the callback
by adding .float() after         
lambd = last_input.new(lambd)

In [117]:
testMixup = MixUpCallback(learn,stack_x=True)
a = testMixup.on_batch_begin(xb.cuda(),yb.cuda(),True)
a['last_input']

[[tensor([[   2,    5, 2739,  ...,    0,   10,   24],
          [   1,    1,    1,  ...,  810,  104,   10],
          [   1,    1,    1,  ...,   15,  126,   34],
          [   1,    1,    1,  ...,   96,   20,   10]], device='cuda:0'),
  tensor([[   2,    5, 2739,  ...,    0,   10,   24],
          [   1,    1,    1,  ...,  810,  104,   10],
          [   1,    1,    1,  ...,   15,  126,   34],
          [   1,    1,    1,  ...,   96,   20,   10]], device='cuda:0'),
  tensor([0.9986, 0.9945, 0.7907, 0.9848], device='cuda:0')]]

and let's modify AWD_LSTM forward pass to do MixUp if a list of tensors is passed.

In [199]:
def forward(self, input:Tensor, from_embeddings:bool=False)->Tuple[List[Tensor],List[Tensor]]:
    use_mixup = False
    if from_embeddings: bs,sl,es = input.size()
    elif isinstance(input, list):
        use_mixup = True
        bs,sl = input[0].size()
    else: bs,sl = input.size()
    if bs!=self.bs:
        self.bs=bs
        self.reset()
    if use_mixup:    
        input,in1,lambd = input
        input = self.encoder_dp(input)
        in1 = self.encoder_dp(in1)
        raw_output = mixup(input,in1,lambd)
    else:
        raw_output = self.input_dp(input if from_embeddings else self.encoder_dp(input))
    new_hidden,raw_outputs,outputs = [],[],[]
    for l, (rnn,hid_dp) in enumerate(zip(self.rnns, self.hidden_dps)):
        raw_output, new_h = rnn(raw_output, self.hidden[l])
        new_hidden.append(new_h)
        raw_outputs.append(raw_output)
        if l != self.n_layers - 1: raw_output = hid_dp(raw_output)
        outputs.append(raw_output)
    self.hidden = to_detach(new_hidden, cpu=False)
    return raw_outputs, outputs
AWD_LSTM.forward = forward

And MultiBatchEncoder forward to accept lists as inputs

In [206]:
def forward(self, input:LongTensor,from_embeddings:bool=False)->Tuple[List[Tensor],List[Tensor],Tensor]:
    use_mixup = False
    if isinstance(input, list): 
        bs,sl = input[0].size() 
        use_mixup = True
    else: bs,sl = input.size()
    self.reset()
    raw_outputs,outputs,masks = [],[],[]
    for i in range(0, sl, self.bptt):
        if use_mixup:
            r, o = self.module([input[0][:,i: min(i+self.bptt, sl)],input[1][:,i: min(i+self.bptt, sl)],input[2]])
        else:
            r, o = self.module(input[:,i: min(i+self.bptt, sl)])
        if i>(sl-self.max_len):
            if use_mixup:
                masks.append(input[0][:,i: min(i+self.bptt, sl)] == self.pad_idx)
            else:
                masks.append(input[:,i: min(i+self.bptt, sl)] == self.pad_idx)
            raw_outputs.append(r)
            outputs.append(o)
    return self.concat(raw_outputs),self.concat(outputs),torch.cat(masks,dim=1)
MultiBatchEncoder.forward = forward

In [207]:
learn = text_classifier_learner(data, AWD_LSTM, drop_mult=0.5)
learn.callback_fns.append(partial(MixUpCallback, alpha=0.4, stack_x=True, stack_y=True))

In [208]:
learn.fit(1)

epoch,train_loss,valid_loss,accuracy,time
0,0.715994,0.740792,0.475,00:44
