In [1]:
%load_ext autoreload
%autoreload 2
from fastai2.basics import *
from fastai2.text.all import *
torch.cuda.set_device(0)

In [2]:
bs = 96
wd = 1e-1
moms = (0.95, 0.85, 0.95)#(0.8,0.7,0.8)#
seq_len = 72

In [3]:
lang = 'he'
data_path = Config.config_path/'data'
name = f'{lang}wiki'
tok = 'SP_fwd_qrnn_v2'
text_path = data_path/name
path = Path(f'{data_path}/{name}_{tok}')
data_format = 'token' #morph or token
path.mkdir(exist_ok=True, parents=True)
lm_fns = [f'{lang}_wt_{tok}', f'{lang}_wt_vocab_{tok}']

In [4]:
class LabelSmoothingCrossEntropyFlat(BaseLoss):
    y_int = True
    def __init__(self, *args, axis=-1, **kwargs): 
        super().__init__(LabelSmoothingCrossEntropy, *args, axis=axis, **kwargs)
    def decodes(self, x):    return x.argmax(dim=self.axis)
    def activation(self, x): return F.softmax(x, dim=self.axis)
    
def backwards(tokens): return tokens.flip(0)

In [5]:
counter = pickle.load(open(str(path/'counter.pkl'), 'rb'))
vocab = make_vocab(counter)
len(vocab)

31136

In [6]:
train_df = pd.read_csv(f'../{data_format}_train.tsv', sep='\t', header=None, names=['comment', 'label'])
train_df['is_valid'] = False
test_df = pd.read_csv(f'../{data_format}_test.tsv', sep='\t', header=None, names=['comment', 'label'])
test_df['is_valid'] = True

df = pd.concat([train_df,test_df], sort=False).rename(columns={'comment': 'text'})

In [7]:
bs = 64
splits = ColSplitter()(df)
x_tfms = [attrgetter('text'),
        Tokenizer.from_folder(path, SentencePieceTokenizer, output_dir=path, sp_model='tmp/spm.model'),
        Numericalize(vocab)]
dsrc = Datasets(df, tfms=[x_tfms, [attrgetter('label'), Categorize()]], splits=splits, dl_type=SortedDL)
cls_data = dsrc.dataloaders(bs=bs, before_batch=pad_input_chunk, seq_len=seq_len)

In [8]:
drop = 0.
wd = 0.1
pretrained = False
loss_func = LabelSmoothingCrossEntropyFlat()#CrossEntropyLossFlat()#
learn_c = text_classifier_learner(cls_data, AWD_QRNN, 
                                  metrics=[accuracy, BalancedAccuracy()],
                                  path=path,
                                  loss_func=loss_func,
                                  drop_mult=drop, pretrained=pretrained, wd=wd)
learn_c = learn_c.load(f'{lang}clas_{tok}_{data_format}_traintest')

In [9]:
bs = 32
splits = ColSplitter()(df)
x_tfms = [attrgetter('text'),
         Tokenizer.from_folder(path, SentencePieceTokenizer, output_dir=path, sp_model='tmp/spm.model'),
         Numericalize(vocab),
         backwards]
dsrc = Datasets(df, tfms=[x_tfms, [attrgetter('label'), Categorize()]], splits=splits, dl_type=SortedDL)
cls_data_bwd = dsrc.dataloaders(bs=bs, before_batch=pad_input_chunk, seq_len=seq_len)

In [10]:
drop = 0.5
wd = 0.1
pretrained = False
loss_func = LabelSmoothingCrossEntropyFlat()#CrossEntropyLossFlat()#
cbs = SaveModelCallback('accuracy')
learn_c_bwd = text_classifier_learner(cls_data_bwd, AWD_QRNN, 
                                  metrics=[accuracy, BalancedAccuracy()],
                                  path=path,
                                  loss_func=loss_func,
                                  drop_mult=drop, pretrained=pretrained, wd=wd)
learn_c_bwd = learn_c_bwd.load(f'{lang}clas_bwd_{tok}_{data_format}_traintest')

In [11]:
bal_acc = BalancedAccuracy().func
f1 = F1Score().func

In [12]:
preds1, targets1 = learn_c.get_preds()
pred_label1 = preds1.argmax(dim=-1)
accuracy(preds1, targets1), bal_acc(pred_label1, targets1), f1(pred_label1, targets1, average='macro')

(TensorCategory(0.9441), 0.9327320268903039, 0.8614363124167047)

In [13]:
preds2, targets2 = learn_c_bwd.get_preds()
pred_label2 = preds2.argmax(dim=-1)
accuracy(preds2, targets2), bal_acc(pred_label2, targets2), f1(pred_label2, targets2, average='macro')

(TensorCategory(0.9434), 0.8829240754394382, 0.8554880598793048)

In [14]:
all(targets1 == targets2)

True

In [15]:
avg_preds = (preds1 + preds2) / 2
avg_pred_label = avg_preds.argmax(dim=-1)
accuracy(avg_preds, targets2), bal_acc(avg_pred_label, targets2), f1(avg_pred_label, targets2, average='macro')

(TensorCategory(0.9422), 0.8826870153467797, 0.8510464507380527)