In [6]:
 
# !pip install -U wwf

In [7]:
# ! pip install fastai==2.3.1
# ! pip install fastcore==1.3.19
# ! pip install transformers==4.6.0
# ! pip install datasets==1.6.1

In [2]:
#all_slow

# Text Classification with Transformers (Intermediate)

> Fine-tuning pre-trained LM from HuggingFace model hub on GLUE benchmark

In [8]:
#hide
from wwf.utils import *

In [89]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from fastai.text.all import *
from datasets import Dataset, DatasetDict
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import load_dataset, concatenate_datasets
from inspect import signature
import gc

In [10]:
#hide_input
state_versions(['fastai', 'fastcore', 'transformers', 'datasets'])


---
This article is also a Jupyter Notebook available to be run from the top down. There
will be code snippets that you can then run in any environment.

Below are the versions of `fastai`, `fastcore`, `transformers`, and `datasets` currently running at the time of writing this:
* `fastai` : 2.3.1 
* `fastcore` : 1.3.19 
* `transformers` : 4.6.0 
* `datasets` : 1.6.1 
---

## Setup

In this notebook we will look at how to conbine the power of [HuggingFace](https://huggingface.co/) with great flexibility of [fastai](https://www.fast.ai/). For this purpose we will be finetuning `distilroberta-base` on The [General Language Understanding Evaluation(GLUE) benchmark](https://gluebenchmark.com/) tasks.

To give you a grasp on what are we dealing with, here is a brief summary of GLUE tasks:

In [11]:
#hide_input
abreviations=["cola","sst2","mrpc","stsb","qqp","mnli","qnli","rte","wnli"]
name = [
    "Corpus of Linguistic Acceptability",
    "Stanford Sentiment Treebank",
    "Microsoft Research Paraphrase Corpus",
    "Semantic Textual Similarity Benchmark",
    "Quora question pair",
    "Mulit-Genre Natural Language Inference",
    "Stanford Question Answering Dataset",
    "Recognize Textual Entailment",
    "Winograd Schema Challenge"
]
descriptions = [
    "Determine whether it is a grammatical sentence",
    "Predict the sentiment of a givensentence",
    "Determine whether the sentences in the pair are semantically equivalent",
    "Determine similarity score for 2 sentences",
    "Determine if 2 questions are the same (paraphrase)",
    "Predict whether the premise entails, contradicts or is neutral to the hypothesis",
    "Determine whether the context sentence containsthe answer to the question",
    "Determine whether one sentece entails another",
    "Predict if the sentence with the pronoun substituted is entailed by the original sentence"
]
df = pd.DataFrame({'Name':name,
                   'Task description':descriptions,
                   'Size':['8.5k','67k','3.7k','7k','364k','393k','105k','2.5k', '634'],
                   'Metrics':['matthews_corrcoef','accuracy','f1/accuracy','pearsonr/spearmanr',
                              'f1/accuracy','accuracy','accuracy','accuracy','accuracy']},
                   index=abreviations)
display_df(df)

Unnamed: 0,Name,Task description,Size,Metrics
cola,Corpus of Linguistic Acceptability,Determine whether it is a grammatical sentence,8.5k,matthews_corrcoef
sst2,Stanford Sentiment Treebank,Predict the sentiment of a givensentence,67k,accuracy
mrpc,Microsoft Research Paraphrase Corpus,Determine whether the sentences in the pair are semantically equivalent,3.7k,f1/accuracy
stsb,Semantic Textual Similarity Benchmark,Determine similarity score for 2 sentences,7k,pearsonr/spearmanr
qqp,Quora question pair,Determine if 2 questions are the same (paraphrase),364k,f1/accuracy
mnli,Mulit-Genre Natural Language Inference,"Predict whether the premise entails, contradicts or is neutral to the hypothesis",393k,accuracy
qnli,Stanford Question Answering Dataset,Determine whether the context sentence containsthe answer to the question,105k,accuracy
rte,Recognize Textual Entailment,Determine whether one sentece entails another,2.5k,accuracy
wnli,Winograd Schema Challenge,Predict if the sentence with the pronoun substituted is entailed by the original sentence,634,accuracy


Let's define main settings for the run in one place. You can choose any model from wide variety presented in HuggingFace model hub. Some might need special treatment to work but most models of appropriate class should be plug-and-play.

In [12]:
ds_name = 'glue'
model_name = "distilroberta-base"

max_len = 512
bs = 32
val_bs = bs*2

n_epoch = 4
lr = 2e-5
opt_func = Adam

To make switching between datasets smooth I'll define couple of dictionaries containing per-task information. We'll need metrics, text fields to retrieve data and number of outputs for the model.

In [13]:
GLUE_TASKS = ["cola", "mnli", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]
def validate_task():
    assert task in GLUE_TASKS

In [14]:
glue_metrics = {
    'cola':[MatthewsCorrCoef()],
    'sst2':[accuracy],
    'mrpc':[F1Score(), accuracy],
    'stsb':[PearsonCorrCoef(), SpearmanCorrCoef()],
    'qqp' :[F1Score(), accuracy],
    'mnli':[accuracy],
    'qnli':[accuracy],
    'rte' :[accuracy],
    'wnli':[accuracy],
}

glue_textfields = {
    'cola':['sentence', None],
    'sst2':['sentence', None],
    'mrpc':['sentence1', 'sentence2'],
    'stsb':['sentence1', 'sentence2'],
    'qqp' :['question1', 'question2'],
    'mnli':['premise', 'hypothesis'],
    'qnli':['question', 'sentence'],
    'rte' :['sentence1', 'sentence2'],
    'wnli':['sentence1', 'sentence2'],
}

glue_num_labels = {'mnli':3, 'stsb':1}

## Data preprocessing

We'll be using `datasets` library for HuggingFace to get data:

In [15]:
task = 'mrpc'; validate_task()
ds = load_dataset(ds_name, task)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=7777.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=4473.0, style=ProgressStyle(description…


Downloading and preparing dataset glue/mrpc (download: 1.43 MiB, generated: 1.43 MiB, post-processed: Unknown size, total: 2.85 MiB) to /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Downloading', layout=Layout(width='20px…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Downloading', layout=Layout(width='20px…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Downloading', layout=Layout(width='20px…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset glue downloaded and prepared to /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


In [40]:
ds

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 3668
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 408
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 1725
    })
})

In [102]:
# DatasetDict(
#     {
#     'train' : Dataset.from_pandas(train),
#     'test' : Dataset.from_pandas(test),
#     'valid' : Dataset.from_pandas(valid)
#     }
# )

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 1440
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 400
    })
    valid: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 160
    })
})

In [52]:
ds['test'][0]

{'idx': 0,
 'label': 1,
 'sentence1': "PCCW 's chief operating officer , Mike Butcher , and Alex Arena , the chief financial officer , will report directly to Mr So .",
 'sentence2': 'Current Chief Operating Officer Mike Butcher and Group Chief Financial Officer Alex Arena will report to So .'}

In [73]:

upload_data = pd.read_excel('/content/sampled_combined_news_aggr_dataset.xlsx')


In [74]:
upload_data = upload_data.groupby('label').sample(1000)

In [75]:
upload_data.drop(columns=['STORY_1', 'STORY_2'], inplace=True)

In [76]:
upload_data.rename(columns = {'TITLE_1': 'sentence1','TITLE_2': 'sentence2'}, inplace=True)

In [77]:
upload_data.head()

Unnamed: 0,sentence1,sentence2,label
2901,US stocks open lower on gloomy news from Asia,Americans are taking public transit in record numbers,0
21879,The 10 Best Stocks of the Bull Market,US stock market trading lower as investors weigh discouraging news on ...,0
28283,Inside Mt Gox,JetBlue Airways Corp. Receives Average Rating of “Hold” from Analysts ...,0
59634,Report: Mt. Gox CEO Holding 'Stolen' Bitcoins,Oil falls below US$108 on China data,0
30957,"Chiquita, Fyffes Agree to $1.07 Billion Merger Deal -- 3rd Update",Americans make most journeys on public transport for 50 years,0


In [96]:
train, test = train_test_split(upload_data, test_size=0.2)
train, valid = train_test_split(train, test_size=0.1)

In [97]:
train.reset_index(drop=True, inplace=True)
test.reset_index(drop=True, inplace=True)
valid.reset_index(drop=True, inplace=True)

In [98]:
train['idx'] = train.index
valid['idx'] = valid.index
test['idx'] = test.index

In [99]:
train.head()

Unnamed: 0,sentence1,sentence2,label,idx
0,Report: Mt. Gox was hit 150K times per second during DDoS attack before theft,"American Airlines, JetBlue Ending Itinerary Extension Agreement",0,0
1,Markets: Ireland's Fyffes to merge with Chiquita of original 'Banana Republic' fame,Chiquita and Fyffes to create world's leading banana company,1,1
2,Sterling at one-month low as ECB policy stance lifts euro,"American Airlines, JetBlue end reciprocal frequent-flier deal",0,2
3,"McDonald's Reports a 0.3 Percent Drop in February Sales, U.S. Down 1.4 percent",ECB's Christian Noyer says 'not happy' with euro's rise,0,3
4,"Public Transit Use In U.S. Is At a 57-Year High, Report Finds",Banking on Bitcoins,0,4


In [100]:
test.head()

Unnamed: 0,sentence1,sentence2,label,idx
0,Hackers Claim Mt. Gox Still Sitting On Bitcoins,So Where Are Mt Gox's Stolen Bitcoin Millions?,1,0
1,Chiquita Brands To Buy Irish Rival,"The Big Banana merger might be more about melons, pineapples and Irish taxes ...",1,1
2,Hackers hit web accounts of MtGox boss,RPT-Mt. Gox files for US bankruptcy protection,1,2
3,German reliance on Russian gas is a threat,Sbarro Files For Second Bankruptcy Protection,0,3
4,Australian Bitcoin traders burnt in Mt.Gox crash,MTA Worker Killed by Hudson Line Train,0,4


In [101]:
valid.head()

Unnamed: 0,sentence1,sentence2,label,idx
0,Stock futures fall on discouraging news from Asia,Ukraine Rattles Europe's Still-Fragile Recovery,0,0
1,"Metro-North Worker Killed, Struck By Harlem Line Train",MtGox faced 150k DDoS attacks per second says report,0,1
2,ECB's Noyer not Happy With Euro Strength -- Update,Capitalism Will Prevent a Cold War Over Ukraine,0,2
3,Monday Afternoon Business Brief,"Chrysler Recalls 25000 Vehicles, Loosening Russia's Energy Grip, Fyffes to ...",1,3
4,"Reuters: Germany's dependence on Russian gas poses risks for Europe, says ...",Central Europeans want US gas to cut dependence on Russia,1,4


In [107]:
ds = DatasetDict(
    {
    'train' : Dataset.from_pandas(train),
    'test' : Dataset.from_pandas(test),
    'validation' : Dataset.from_pandas(valid)
    }
)

MNLI datasets contains 2 sets for validation: matched and missmatched. The mathced set is selected here for validation when fine-tuning on MNLI.

In [108]:
valid_ = 'validation-matched' if task=='mnli' else 'validation'
len(ds['train']), len(ds[valid_])

(1440, 160)

In [109]:
nt, nv = len(ds['train']), len(ds[valid_])
train_idx, valid_idx = L(range(nt)), L(range(nt, nt+nv))
train_ds = concatenate_datasets([ds['train'], ds[valid_]])

One can inspect single example for the given task:

In [110]:
train_ds[0]

{'sentence1': 'Report: Mt. Gox was hit 150K times per second during DDoS attack before theft',
 'sentence2': 'American Airlines, JetBlue Ending Itinerary Extension Agreement',
 'label': 0,
 'idx': 0}

Here I use number of characters a proxy for length of tokenized text to speed up `dls` creation.

In [111]:
#hide_output
lens = train_ds.map(lambda s: {'len': sum([len(s[i]) for i in glue_textfields[task] if i])},
                    remove_columns=train_ds.column_names, num_proc=2, keep_in_memory=True)
train_lens = lens.select(train_idx)['len']
valid_lens = lens.select(valid_idx)['len']





## DataBlock and Transforms

`TextGetter` is analogous to `ItemGetter` but retrieves either one or two text fields from the source (e.g. "sentence1" and "sentence2").

In [112]:
class TextGetter(ItemTransform):
    def __init__(self, s1='text', s2=None):
        self.s1, self.s2 = s1, s2
    def encodes(self, sample):
        if self.s2 is None: return sample[self.s1]
        else: return sample[self.s1], sample[self.s2]

Transformers expect two parts of text to be concatenated with some `SEP` token in between. But when displaying the batch it's better to have those texts in separate columns for better readability. To make it work I define a version of `show_batch` to be dispatched on the `TransTensorText` class. It will handle cases when there is single decoded text or a tuple of two texts.

In [113]:
class TransTensorText(TensorBase): pass

@typedispatch
def show_batch(x:TransTensorText, y, samples, ctxs=None, max_n=10, trunc_at=150, **kwargs):
    if ctxs is None: ctxs = get_empty_df(min(len(samples), max_n))
    if isinstance(samples[0][0], tuple):
        samples = L((*s[0], *s[1:]) for s in samples)
        if trunc_at is not None: samples = L((s[0].truncate(trunc_at), s[1].truncate(trunc_at), *s[2:]) for s in samples)
    if trunc_at is not None: samples = L((s[0].truncate(trunc_at),*s[1:]) for s in samples)
    ctxs = show_batch[object](x, y, samples, max_n=max_n, ctxs=ctxs, **kwargs)
    display_df(pd.DataFrame(ctxs))

In [114]:
#collapse
def find_first(t, e):
    for i, v in enumerate(t):
        if v == e: return i
        
def split_by_sep(t, sep_tok_id):
    idx = find_first(t, sep_tok_id)
    return t[:idx], t[idx:]

Tokenization of the inputs will be done by `TokBatchTransform` which wraps pre-trained HuggingFace tokenizer. The text processing is done in batches for speed-up. We want to awoid explicit python loops when possible.

In [115]:
class TokBatchTransform(Transform):
    """
    Tokenizes texts in batches using pretrained HuggingFace tokenizer.
    The first element in a batch can be single string or 2-tuple of strings.
    If `with_labels=True` the "labels" are added to the output dictionary.
    """
    def __init__(self, pretrained_model_name=None, tokenizer_cls=AutoTokenizer, 
                 config=None, tokenizer=None, with_labels=False,
                 padding=True, truncation=True, max_length=None, **kwargs):
        if tokenizer is None:
            tokenizer = tokenizer_cls.from_pretrained(pretrained_model_name, config=config)
        self.tokenizer = tokenizer
        self.kwargs = kwargs
        self._two_texts = False
        store_attr()
    
    def encodes(self, batch):
        # batch is a list of tuples of ({text or (text1, text2)}, {targets...})
        if is_listy(batch[0][0]): # 1st element is tuple
            self._two_texts = True
            texts = ([s[0][0] for s in batch], [s[0][1] for s in batch])
        elif is_listy(batch[0]): 
            texts = ([s[0] for s in batch],)

        inps = self.tokenizer(*texts,
                              add_special_tokens=True,
                              padding=self.padding,
                              truncation=self.truncation,
                              max_length=self.max_length,
                              return_tensors='pt',
                              **self.kwargs)
        # inps are batched, collate targets into batches too
        labels = default_collate([s[1:] for s in batch])
        if self.with_labels:
            inps['labels'] = labels[0]
            res = (inps, )
        else:
            res = (inps, ) + tuple(labels)
        return res
    
    def decodes(self, x:TransTensorText):
        if self._two_texts:
            x1, x2 = split_by_sep(x, self.tokenizer.sep_token_id)
            return (TitledStr(self.tokenizer.decode(x1.cpu(), skip_special_tokens=True)),
                    TitledStr(self.tokenizer.decode(x2.cpu(), skip_special_tokens=True)))
        return TitledStr(self.tokenizer.decode(x.cpu(), skip_special_tokens=True))

The batches processed by `TokBatchTransform` contain a dictionary as the first element. For decoding it's handy to have a tensor instead. The `Undict` transform fethces `input_ids` from the batch and creates `TransTensorText` which should work with typedispatch.

In [116]:
class Undict(Transform):
    def decodes(self, x:dict):
        if 'input_ids' in x: res = TransTensorText(x['input_ids'])
        return res

Now the transforms are to be combined inside a data block to be used for `dls` creation. The inputs are prebatched by `TokBatchTranform` so we don't need to use `fa_collate` for batching, so `fa_convert` is passed in as for "create_batch".

The texts we processing are of different lengths. Each sample in the batch is padded to the length of longest input to make them "collatable". Shuffling samples randomly will therefor result in getting longer batches on average. As the compute time depends on the sequence length this is udesired. `SortedDL` groups the inputs by length and if `shuffle=True` those are shuffled within certain interval keeping samples of similar length together.

In [117]:
dls_kwargs = {
    'before_batch': TokBatchTransform(pretrained_model_name=model_name, max_length=max_len),
    'create_batch': fa_convert
}
text_block = TransformBlock(dl_type=SortedDL, dls_kwargs=dls_kwargs, batch_tfms=Undict(), )

dblock = DataBlock(blocks = [text_block, CategoryBlock()],
                   get_x=TextGetter(*glue_textfields[task]),
                   get_y=ItemGetter('label'),
                   splitter=IndexSplitter(valid_idx))

In [118]:
%%time
dl_kwargs=[{'res':train_lens}, {'val_res':valid_lens}]
dls = dblock.dataloaders(train_ds, bs=bs, val_bs=val_bs, dl_kwargs=dl_kwargs)

CPU times: user 245 ms, sys: 1.92 ms, total: 247 ms
Wall time: 247 ms


In [119]:
dls.show_batch(max_n=4)

Unnamed: 0,text,text_,category
0,Journal says: The first default in China's corporate-bond market is unlikely to be ...,"'Titanfall' Release Date Nears: Twitter Predicts It Will Outsell 'Battlefield 4', Will ...",0
1,Is spitting in TriMet driver's face 'assault'? Union wants hostile riders taken more ...,Pollen and birds chirping? Check and check. Here are 5 ways to tell that spring's ...,1
2,"Titanfall ships at 792p on Xbox One, post-release resolution 'likely to increase'","'Titanfall' Release Date Nears: Twitter Predicts It Will Outsell 'Battlefield 4', Will ...",1
3,American Airlines (AAL) Load Factor Fell 0.3 Points in Feb.; Affirms Q1 PRASM ...,Markets: Ireland's Fyffes to merge with Chiquita of original 'Banana Republic' fame,0


## Customized Learner

Now the `xb` we get from dataloader contains a dictionary and HuggingFace transformers accept keyword argument as input. But fastai `Learner` feeds the model with a sequence of positional arguments (`self.pred = self.model(*self.xb)`). To make this work smoothly we can create a callback to handle unrolling of the input dict into proper `xb` tuple.

But first we need to define some utility functions. `default_splitter` is used to divide model parameters into groups:

In [120]:
def default_splitter(model):
    groups = L(model.base_model.children()) + L(m for m in list(model.children())[1:] if params(m))
    return groups.map(params)

Similar to `show_batch` one have to customize `show_results`:

In [121]:
@typedispatch
def show_results(x: TransTensorText, y, samples, outs, ctxs=None, max_n=10, trunc_at=150, **kwargs):
    if ctxs is None: ctxs = get_empty_df(min(len(samples), max_n))
    if isinstance(samples[0][0], tuple):
        samples = L((*s[0], *s[1:]) for s in samples)
        if trunc_at is not None: samples = L((s[0].truncate(trunc_at), s[1].truncate(trunc_at), *s[2:]) for s in samples)
    elif trunc_at is not None: samples = L((s[0].truncate(trunc_at),*s[1:]) for s in samples)
    ctxs = show_results[object](x, y, samples, outs, ctxs=ctxs, max_n=max_n, **kwargs)
    display_df(pd.DataFrame(ctxs))
    return ctxs

`TransLearner` itself doesn't do much: it adds `TransCallback` and sets `splitter` to be `default_splitter` if `None` is provided.

In [122]:
@delegates(Learner.__init__)
class TransLearner(Learner):
    "Learner for training transformers from HuggingFace"
    def __init__(self, dls, model, **kwargs):
        splitter = kwargs.get('splitter', None)
        if splitter is None: kwargs['splitter'] = default_splitter
        super().__init__(dls, model, **kwargs)
        self.add_cb(TransCallback(model))

Main piece of work needed to train transformers model happens in `TransCallback`. It saves valid model argument and makes input dict yielded by dataloader into a tuple.

By default the model returns a dictionary-like object containing `logits` and possibly other outputs as defined by model config (e.g. intermediate hidden representations). In the fastai training loop we usually expect `preds` to be a tensor containing model predictions (logits). The callback formats the preds properly.

Notice that if `labels` are found in the input, transformer models compute the `loss` and return it together with output `logits`. The callback below is designed to utilise the loss returned by model instead of recomputing it using `learn.loss_func`. This is not actually used in this example but might be handy in some use cases.

In [123]:
class TransCallback(Callback):
    "Handles HuggingFace model inputs and outputs"
    
    def __init__(self, model):
        self.labels = tuple()
        self.model_args = {k:v.default for k, v in signature(model.forward).parameters.items()}
    
    def before_batch(self):
        if 'labels' in self.xb[0].keys():
            self.labels = (self.xb[0]['labels'], )
        # make a tuple containing an element for each argument model excepts
        # if argument is not in xb it is set to default value
        self.learn.xb = tuple([self.xb[0].get(k, self.model_args[k]) for k in self.model_args.keys()])
    
    def after_pred(self):
        if 'loss' in self.pred:
            self.learn.loss_grad = self.pred.loss
            self.learn.loss = self.pred.loss.clone()
        self.learn.pred = self.pred.logits
    
    def after_loss(self):
        if len(self.labels):
            self.learn.yb = self.labels
            self.labels = tuple()

## Training

After all the preparations the training is streightforward. Setting `num_labels` for the model and choosing apropriate metrics is automated.

In [124]:
#hide_output
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=glue_num_labels.get(task, 2))
metrics = glue_metrics[task]
learn = TransLearner(dls, model, metrics=metrics, opt_func=opt_func)

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.decoder.weight', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'roberta.pooler.dense.bias', 'roberta.pooler.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.dense.weight'

In [125]:
#collapse_output
learn.summary()

RobertaForSequenceClassification (Input shape: 32)
Layer (type)         Output Shape         Param #    Trainable 
                     32 x 52 x 768       
Embedding                                 38603520   True      
Embedding                                 394752     True      
Embedding                                 768        True      
LayerNorm                                 1536       True      
Dropout                                                        
Linear                                    590592     True      
Linear                                    590592     True      
Linear                                    590592     True      
Dropout                                                        
Linear                                    590592     True      
LayerNorm                                 1536       True      
Dropout                                                        
___________________________________________________________________________

In [127]:
metric_to_monitor = metrics[0].name if isinstance(metrics[0], Metric) else metrics[0].__name__
cbs = [SaveModelCallback(monitor=metric_to_monitor)]
learn.fit_one_cycle(10, lr, cbs=cbs)

epoch,train_loss,valid_loss,f1_score,accuracy,time
0,0.195089,0.316713,0.898204,0.89375,00:07
1,0.192827,0.343574,0.865497,0.85625,00:07
2,0.171442,0.351437,0.883721,0.875,00:06
3,0.146579,0.312842,0.894118,0.8875,00:06
4,0.114526,0.298425,0.902439,0.9,00:06
5,0.075502,0.428674,0.892857,0.8875,00:07
6,0.059567,0.362299,0.902439,0.9,00:06
7,0.043812,0.413381,0.894118,0.8875,00:06
8,0.032728,0.419918,0.899408,0.89375,00:06
9,0.02734,0.418984,0.904762,0.9,00:06


Better model found at epoch 0 with f1_score value: 0.8982035928143713.
Better model found at epoch 4 with f1_score value: 0.9024390243902438.
Better model found at epoch 9 with f1_score value: 0.9047619047619048.


After training the model it's useful to verify  that results make sense:

In [128]:
learn.show_results()

Unnamed: 0,text,text_,category,category_
0,"Titanfall ships at 792p on Xbox One, post-release resolution 'likely to increase'",Titanfall runs at 792p resolution on Xbox One -- and it may get a postlaunch ...,1,1
1,Journal says: The first default in China's corporate-bond market is unlikely to be ...,Eni CEO says Ukraine wake-up call for Europe energy policy-paper,0,0
2,Mt Gox Customers Hack Karpales' Account to Post Evidence of Alleged Fraud,"Hackers allegedly hit Mt. Gox CEO's blog, post balance of remaining bitcoins",1,1
3,Report: Mt. Gox was hit 150K times per second during DDoS attack before theft,Anonymous hackers uncover alleged proof of MtGox fraud from site's CEO,1,1
4,The Japanese government has come down hard on the digital currency bitcoin.,Hackers allege Mt. Gox head controls'stolen' Bitcoin; fraud committed,1,1
5,"GLOBAL MARKETS-World stocks, copper and oil fall after weak China exports",Chiquita Brands acquires Fyffes to create world's biggest banana seller,0,0
6,"Chiquita share price rallies, acquires Fyffes Plc to create the largest banana ...",Hackers claim bankrupt Mt. Gox still has customers' Bitcoins,0,0
7,"Chrysler Recalls 25000 Vehicles, Loosening Russia's Energy Grip, Fyffes to ...",Rivals to merge and base largest global banana firm in Ireland,1,0
8,Hong Kong shares in biggest loss in 5 weeks after anemic China data,JetBlue Airways Corp. Receives Average Rating of “Hold” from Analysts ...,0,0


Finally we can run our model on test set to get the predictions.

In [129]:
test_dl = dls.test_dl(ds['test'])
preds = learn.get_preds(dl=test_dl)
preds[0]

tensor([[2.3249e-04, 9.9977e-01],
        [8.8721e-04, 9.9911e-01],
        [3.3746e-04, 9.9966e-01],
        [9.9969e-01, 3.1450e-04],
        [9.9844e-01, 1.5607e-03],
        [9.9851e-01, 1.4872e-03],
        [6.6022e-01, 3.3978e-01],
        [9.9951e-01, 4.9419e-04],
        [9.9929e-01, 7.1430e-04],
        [9.9860e-01, 1.4004e-03],
        [9.9693e-01, 3.0701e-03],
        [9.9962e-01, 3.8377e-04],
        [9.5126e-01, 4.8740e-02],
        [3.6606e-04, 9.9963e-01],
        [2.1033e-03, 9.9790e-01],
        [2.4754e-04, 9.9975e-01],
        [5.8562e-01, 4.1438e-01],
        [5.3859e-04, 9.9946e-01],
        [9.9929e-01, 7.1175e-04],
        [9.9937e-01, 6.2871e-04],
        [9.9920e-01, 8.0034e-04],
        [6.4839e-04, 9.9935e-01],
        [5.9933e-04, 9.9940e-01],
        [5.3645e-04, 9.9946e-01],
        [2.2290e-03, 9.9777e-01],
        [9.9574e-01, 4.2605e-03],
        [8.0154e-04, 9.9920e-01],
        [3.7073e-04, 9.9963e-01],
        [9.6682e-04, 9.9903e-01],
        [3.032

## Final remarks

Generalised versions of "wrapper" code used in this notebook can be found in [fasthugs](https://github.com/aikindergarten/fasthugs) library. Also you can check out some extra info on fine-tuning models on GLUE tasks in [this blogpost](https://arampacha.github.io/thoughtsamples/fastai/huggingface/transformers/2021/05/07/glue-benchmark.html). Another option for training HuggingFace transformers with fastai is using [blurr](https://github.com/ohmeow/blurr) library.