Skip to content

Commit

Permalink
Demo notebook for Dutch MLM task
Browse files Browse the repository at this point in the history
  • Loading branch information
iPieter committed Nov 9, 2020
1 parent 340bd9d commit c531397
Show file tree
Hide file tree
Showing 2 changed files with 304 additions and 0 deletions.
1 change: 1 addition & 0 deletions Pipfile
Expand Up @@ -11,6 +11,7 @@ nltk = "*"
transformers = "*"
tensorboardx = "*"
tokenizers = "~=0.4.2"
jupyter = "*"

[requires]
python_version = "3.7"
303 changes: 303 additions & 0 deletions notebooks/demo_RobBERT_for_masked_LM.ipynb
@@ -0,0 +1,303 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": true,
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# Demo of RobBERT for humour detection\n",
"We use a [RobBERT (Delobelle et al., 2020)](https://arxiv.org/abs/2001.06286) model with the original pretraining head for MLM.\n",
"\n",
"**Dependencies**\n",
"- tokenizers\n",
"- torch\n",
"- transformers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First we load our RobBERT model that was pretrained. We also load in RobBERT's tokenizer.\n",
"\n",
"Because we only want to get results, we have to disable dropout etc. So we add `model.eval()`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Note: we pretrained both RobBERT v1 and RobBERT v2 in [Fairseq](https://github.com/pytorch/fairseq) and converted these checkpoints to HuggingFace. The MLM task behaves a bit differently.*"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"pycharm": {
"is_executing": false,
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word emebedding are fine-tuned or trained.\n"
],
"output_type": "stream"
},
{
"data": {
"text/plain": "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=469740689.0, style=ProgressStyle(descri…",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "4156a63ac55d4844a22db25f1db7ed52"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"text": [
"\n",
"RobBERT model loaded\n"
],
"output_type": "stream"
}
],
"source": [
"import torch\n",
"from transformers import RobertaTokenizer, AutoModelForSequenceClassification, AutoConfig\n",
"\n",
"from transformers import RobertaTokenizer, RobertaForMaskedLM\n",
"import torch\n",
"tokenizer = RobertaTokenizer.from_pretrained('pdelobelle/robbert-v2-dutch-base')\n",
"model = RobertaForMaskedLM.from_pretrained('pdelobelle/robbert-v2-dutch-base', return_dict=True)\n",
"model = model.to( 'cuda' if torch.cuda.is_available() else 'cpu' )\n",
"model.eval()\n",
"#model = RobertaForMaskedLM.from_pretrained('pdelobelle/robbert-v2-dutch-base', return_dict=True)\n",
"print(\"RobBERT model loaded\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"pycharm": {
"is_executing": false,
"name": "#%% \n"
}
},
"outputs": [],
"source": [
"sequence = f\"Er staat een {tokenizer.mask_token} in mijn tuin.\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"pycharm": {
"is_executing": false,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"input = tokenizer.encode(sequence, return_tensors=\"pt\").to( 'cuda' if torch.cuda.is_available() else 'cpu' )\n",
"mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"is_executing": false,
"name": "#%% md\n"
}
},
"source": [
"Now that we have our tokenized input and the position of the masked token, we pass the input through RobBERT. \n",
"\n",
"This will give us a predicting for all tokens, but we're only interested in the `<mask>` token. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"pycharm": {
"is_executing": false,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"with torch.no_grad():\n",
" token_logits = model(input).logits"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"pycharm": {
"is_executing": false,
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"text": [
"Ġboom | id = 2600 | p = 0.1416003555059433\n",
"Ġvijver | id = 8217 | p = 0.13144515454769135\n",
"Ġplant | id = 2721 | p = 0.043418534100055695\n",
"Ġhuis | id = 251 | p = 0.01847737282514572\n",
"Ġparkeerplaats | id = 6889 | p = 0.018001794815063477\n",
"Ġbankje | id = 21620 | p = 0.016940612345933914\n",
"Ġmuur | id = 2035 | p = 0.014668751507997513\n",
"Ġmoestuin | id = 17446 | p = 0.0144038125872612\n",
"Ġzonnebloem | id = 30757 | p = 0.014375611208379269\n",
"Ġschutting | id = 15000 | p = 0.013991709798574448\n",
"Ġpaal | id = 8626 | p = 0.01358739286661148\n",
"Ġbloem | id = 3001 | p = 0.01199684850871563\n",
"Ġstal | id = 7416 | p = 0.011224730871617794\n",
"Ġfontein | id = 23425 | p = 0.011203107424080372\n",
"Ġtuin | id = 671 | p = 0.010676783509552479\n"
],
"output_type": "stream"
}
],
"source": [
"logits = token_logits[0, mask_token_index, :].squeeze()\n",
"prob = logits.softmax(dim=0)\n",
"values, indeces = prob.topk(k=15, dim=0)\n",
"\n",
"for index, token in enumerate(tokenizer.convert_ids_to_tokens(indeces)):\n",
" print(f\"{token:20} | id = {indeces[index]:4} | p = {values[index]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"is_executing": false,
"name": "#%% md\n"
}
},
"source": [
"## RobBERT with pipelines\n",
"We can also use the `fill-mask` pipeline from Huggingface, that does basically the same thing."
]
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"name": "stderr",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word emebedding are fine-tuned or trained.\n"
],
"output_type": "stream"
}
],
"source": [
"from transformers import pipeline\n",
"p = pipeline(\"fill-mask\", model=\"pdelobelle/robbert-v2-dutch-base\")"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n",
"is_executing": false
}
}
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"pycharm": {
"is_executing": false,
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"text/plain": "[{'sequence': '<s>Er staat een boomin mijn tuin.</s>',\n 'score': 0.1416003555059433,\n 'token': 2600,\n 'token_str': 'Ġboom'},\n {'sequence': '<s>Er staat een vijverin mijn tuin.</s>',\n 'score': 0.13144515454769135,\n 'token': 8217,\n 'token_str': 'Ġvijver'},\n {'sequence': '<s>Er staat een plantin mijn tuin.</s>',\n 'score': 0.043418534100055695,\n 'token': 2721,\n 'token_str': 'Ġplant'},\n {'sequence': '<s>Er staat een huisin mijn tuin.</s>',\n 'score': 0.01847737282514572,\n 'token': 251,\n 'token_str': 'Ġhuis'},\n {'sequence': '<s>Er staat een parkeerplaatsin mijn tuin.</s>',\n 'score': 0.018001794815063477,\n 'token': 6889,\n 'token_str': 'Ġparkeerplaats'}]"
},
"metadata": {},
"output_type": "execute_result",
"execution_count": 7
}
],
"source": [
"p(sequence)"
]
},
{
"cell_type": "markdown",
"source": [
"That's it for this demo of the MLM head. If you use RobBERT in your academic work, you can cite it!\n",
"\n",
"\n",
"```\n",
"@misc{delobelle2020robbert,\n",
" title={{R}ob{BERT}: a {D}utch {R}o{BERT}a-based Language Model},\n",
" author={Pieter Delobelle and Thomas Winters and Bettina Berendt},\n",
" year={2020},\n",
" eprint={2001.06286},\n",
" archivePrefix={arXiv},\n",
" primaryClass={cs.CL}\n",
"}\n",
"```\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"source": [],
"metadata": {
"collapsed": false
}
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}

0 comments on commit c531397

Please sign in to comment.