Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
304 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,303 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"collapsed": true, | ||
"pycharm": { | ||
"name": "#%% md\n" | ||
} | ||
}, | ||
"source": [ | ||
"# Demo of RobBERT for humour detection\n", | ||
"We use a [RobBERT (Delobelle et al., 2020)](https://arxiv.org/abs/2001.06286) model with the original pretraining head for MLM.\n", | ||
"\n", | ||
"**Dependencies**\n", | ||
"- tokenizers\n", | ||
"- torch\n", | ||
"- transformers" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"First we load our RobBERT model that was pretrained. We also load in RobBERT's tokenizer.\n", | ||
"\n", | ||
"Because we only want to get results, we have to disable dropout etc. So we add `model.eval()`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"*Note: we pretrained both RobBERT v1 and RobBERT v2 in [Fairseq](https://github.com/pytorch/fairseq) and converted these checkpoints to HuggingFace. The MLM task behaves a bit differently.*" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": false, | ||
"name": "#%%\n" | ||
} | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"text": [ | ||
"Special tokens have been added in the vocabulary, make sure the associated word emebedding are fine-tuned or trained.\n" | ||
], | ||
"output_type": "stream" | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=469740689.0, style=ProgressStyle(descri…", | ||
"application/vnd.jupyter.widget-view+json": { | ||
"version_major": 2, | ||
"version_minor": 0, | ||
"model_id": "4156a63ac55d4844a22db25f1db7ed52" | ||
} | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stdout", | ||
"text": [ | ||
"\n", | ||
"RobBERT model loaded\n" | ||
], | ||
"output_type": "stream" | ||
} | ||
], | ||
"source": [ | ||
"import torch\n", | ||
"from transformers import RobertaTokenizer, AutoModelForSequenceClassification, AutoConfig\n", | ||
"\n", | ||
"from transformers import RobertaTokenizer, RobertaForMaskedLM\n", | ||
"import torch\n", | ||
"tokenizer = RobertaTokenizer.from_pretrained('pdelobelle/robbert-v2-dutch-base')\n", | ||
"model = RobertaForMaskedLM.from_pretrained('pdelobelle/robbert-v2-dutch-base', return_dict=True)\n", | ||
"model = model.to( 'cuda' if torch.cuda.is_available() else 'cpu' )\n", | ||
"model.eval()\n", | ||
"#model = RobertaForMaskedLM.from_pretrained('pdelobelle/robbert-v2-dutch-base', return_dict=True)\n", | ||
"print(\"RobBERT model loaded\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": false, | ||
"name": "#%% \n" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"sequence = f\"Er staat een {tokenizer.mask_token} in mijn tuin.\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": false, | ||
"name": "#%%\n" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"input = tokenizer.encode(sequence, return_tensors=\"pt\").to( 'cuda' if torch.cuda.is_available() else 'cpu' )\n", | ||
"mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": false, | ||
"name": "#%% md\n" | ||
} | ||
}, | ||
"source": [ | ||
"Now that we have our tokenized input and the position of the masked token, we pass the input through RobBERT. \n", | ||
"\n", | ||
"This will give us a predicting for all tokens, but we're only interested in the `<mask>` token. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": false, | ||
"name": "#%%\n" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"with torch.no_grad():\n", | ||
" token_logits = model(input).logits" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": false, | ||
"name": "#%%\n" | ||
} | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"text": [ | ||
"Ġboom | id = 2600 | p = 0.1416003555059433\n", | ||
"Ġvijver | id = 8217 | p = 0.13144515454769135\n", | ||
"Ġplant | id = 2721 | p = 0.043418534100055695\n", | ||
"Ġhuis | id = 251 | p = 0.01847737282514572\n", | ||
"Ġparkeerplaats | id = 6889 | p = 0.018001794815063477\n", | ||
"Ġbankje | id = 21620 | p = 0.016940612345933914\n", | ||
"Ġmuur | id = 2035 | p = 0.014668751507997513\n", | ||
"Ġmoestuin | id = 17446 | p = 0.0144038125872612\n", | ||
"Ġzonnebloem | id = 30757 | p = 0.014375611208379269\n", | ||
"Ġschutting | id = 15000 | p = 0.013991709798574448\n", | ||
"Ġpaal | id = 8626 | p = 0.01358739286661148\n", | ||
"Ġbloem | id = 3001 | p = 0.01199684850871563\n", | ||
"Ġstal | id = 7416 | p = 0.011224730871617794\n", | ||
"Ġfontein | id = 23425 | p = 0.011203107424080372\n", | ||
"Ġtuin | id = 671 | p = 0.010676783509552479\n" | ||
], | ||
"output_type": "stream" | ||
} | ||
], | ||
"source": [ | ||
"logits = token_logits[0, mask_token_index, :].squeeze()\n", | ||
"prob = logits.softmax(dim=0)\n", | ||
"values, indeces = prob.topk(k=15, dim=0)\n", | ||
"\n", | ||
"for index, token in enumerate(tokenizer.convert_ids_to_tokens(indeces)):\n", | ||
" print(f\"{token:20} | id = {indeces[index]:4} | p = {values[index]}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": false, | ||
"name": "#%% md\n" | ||
} | ||
}, | ||
"source": [ | ||
"## RobBERT with pipelines\n", | ||
"We can also use the `fill-mask` pipeline from Huggingface, that does basically the same thing." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"text": [ | ||
"Special tokens have been added in the vocabulary, make sure the associated word emebedding are fine-tuned or trained.\n" | ||
], | ||
"output_type": "stream" | ||
} | ||
], | ||
"source": [ | ||
"from transformers import pipeline\n", | ||
"p = pipeline(\"fill-mask\", model=\"pdelobelle/robbert-v2-dutch-base\")" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"pycharm": { | ||
"name": "#%%\n", | ||
"is_executing": false | ||
} | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": false, | ||
"name": "#%%\n" | ||
} | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": "[{'sequence': '<s>Er staat een boomin mijn tuin.</s>',\n 'score': 0.1416003555059433,\n 'token': 2600,\n 'token_str': 'Ġboom'},\n {'sequence': '<s>Er staat een vijverin mijn tuin.</s>',\n 'score': 0.13144515454769135,\n 'token': 8217,\n 'token_str': 'Ġvijver'},\n {'sequence': '<s>Er staat een plantin mijn tuin.</s>',\n 'score': 0.043418534100055695,\n 'token': 2721,\n 'token_str': 'Ġplant'},\n {'sequence': '<s>Er staat een huisin mijn tuin.</s>',\n 'score': 0.01847737282514572,\n 'token': 251,\n 'token_str': 'Ġhuis'},\n {'sequence': '<s>Er staat een parkeerplaatsin mijn tuin.</s>',\n 'score': 0.018001794815063477,\n 'token': 6889,\n 'token_str': 'Ġparkeerplaats'}]" | ||
}, | ||
"metadata": {}, | ||
"output_type": "execute_result", | ||
"execution_count": 7 | ||
} | ||
], | ||
"source": [ | ||
"p(sequence)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"That's it for this demo of the MLM head. If you use RobBERT in your academic work, you can cite it!\n", | ||
"\n", | ||
"\n", | ||
"```\n", | ||
"@misc{delobelle2020robbert,\n", | ||
" title={{R}ob{BERT}: a {D}utch {R}o{BERT}a-based Language Model},\n", | ||
" author={Pieter Delobelle and Thomas Winters and Bettina Berendt},\n", | ||
" year={2020},\n", | ||
" eprint={2001.06286},\n", | ||
" archivePrefix={arXiv},\n", | ||
" primaryClass={cs.CL}\n", | ||
"}\n", | ||
"```\n" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"pycharm": { | ||
"name": "#%% md\n" | ||
} | ||
} | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.4" | ||
}, | ||
"pycharm": { | ||
"stem_cell": { | ||
"cell_type": "raw", | ||
"source": [], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 1 | ||
} |