This notebook illustrates how to use Masked Language Modeling for this competition.

Observation: most of the dataset names consist of only words with uppercased-first-letter and some stopwords like `on`, `in`, `and` (e.g. `Early Childhood Longitudinal Study`, `Trends in International Mathematics and Science Study`). 

Thus, one approach to find the datasets is: 
- Locate all the sequences of capitalized words (these sequences may contain some stopwords), 
- Replace each sequence with one of 2 special symbols (e.g. `$` and `#`), implying if that sequence represents a dataset name or not.
- Have the model learn the MLM task.

The code below shows how to train a model for that purpose with the help of the `huggingface`.

In [1]:
MAX_SAMPLE = 8000 # set a small number (e.g. 50) for experimentation, set None for production.

# Install packages

In [2]:
!pip install datasets --no-index --find-links=file:///kaggle/input/coleridge-packages/packages/datasets
!pip install ../input/coleridge-packages/seqeval-1.2.2-py3-none-any.whl
!pip install ../input/coleridge-packages/tokenizers-0.10.1-cp37-cp37m-manylinux1_x86_64.whl
!pip install ../input/coleridge-packages/transformers-4.5.0.dev0-py3-none-any.whl

Looking in links: file:///kaggle/input/coleridge-packages/packages/datasets
Processing /kaggle/input/coleridge-packages/packages/datasets/datasets-1.5.0-py3-none-any.whl
Processing /kaggle/input/coleridge-packages/packages/datasets/huggingface_hub-0.0.7-py3-none-any.whl
Processing /kaggle/input/coleridge-packages/packages/datasets/xxhash-2.0.0-cp37-cp37m-manylinux2010_x86_64.whl
Processing /kaggle/input/coleridge-packages/packages/datasets/tqdm-4.49.0-py2.py3-none-any.whl
Installing collected packages: tqdm, xxhash, huggingface-hub, datasets
  Attempting uninstall: tqdm
    Found existing installation: tqdm 4.56.2
    Uninstalling tqdm-4.56.2:
      Successfully uninstalled tqdm-4.56.2
Successfully installed datasets-1.5.0 huggingface-hub-0.0.7 tqdm-4.49.0 xxhash-2.0.0
Processing /kaggle/input/coleridge-packages/seqeval-1.2.2-py3-none-any.whl
Installing collected packages: seqeval
Successfully installed seqeval-1.2.2
Processing /kaggle/input/coleridge-packages/tokenizers-

# Import

In [3]:
import os
import re
import json
import time
import datetime
import random
import glob
import importlib

import numpy as np
import pandas as pd

from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForLanguageModeling, \
AutoModelForMaskedLM, Trainer, TrainingArguments, pipeline, AutoConfig

sns.set()
random.seed(123)
np.random.seed(456)

In [4]:
model_checkpoint = "bert-base-cased"

MAX_LENGTH = 30
OVERLAP = 10

DATASET_SYMBOL = '$' # this symbol represents a dataset name
NONDATA_SYMBOL = '#' # this symbol represents a non-dataset name

# Load data

In [5]:
# train
train_path = '../input/coleridgeinitiative-show-us-the-data/train.csv'
paper_train_folder = '../input/coleridgeinitiative-show-us-the-data/train'

train = pd.read_csv(train_path)
train = train[:MAX_SAMPLE]
# Group by publication, training labels should have the same form as expected output.
train = train.groupby('Id').agg({
    'pub_title': 'first',
    'dataset_title': '|'.join,
    'dataset_label': '|'.join,
    'cleaned_label': '|'.join
}).reset_index()    

print('train size: ', len(train))

train size:  7379


In [6]:
train

Unnamed: 0,Id,pub_title,dataset_title,dataset_label,cleaned_label
0,0007f880-0a9b-492d-9a58-76eb0b0e0bd7,The Impact of ICT Training on Income Generatio...,Program for the International Assessment of Ad...,Program for the International Assessment of Ad...,program for the international assessment of ad...
1,0008656f-0ba2-4632-8602-3017b44c2e90,Finnish Ninth Graders’ Gender Appropriateness ...,Trends in International Mathematics and Scienc...,Trends in International Mathematics and Scienc...,trends in international mathematics and scienc...
2,000efc17-13d8-433d-8f62-a3932fe4f3b8,Risk factors and global cognitive status relat...,Alzheimer's Disease Neuroimaging Initiative (A...,ADNI,adni
3,002203f0-1c57-4400-abc1-b783c4085743,A Hybrid Geometric–Statistical Deformable Mode...,Alzheimer's Disease Neuroimaging Initiative (A...,ADNI,adni
4,00243b98-f868-45e4-9b83-0c346c7ecad5,Five-Year Growth Trajectories of Kindergarten ...,Early Childhood Longitudinal Study,Early Childhood Longitudinal Study,early childhood longitudinal study
...,...,...,...,...,...
7374,ffb86ab1-eed2-423e-a9ee-34c93465fdb2,Candidate Gene Polymorphisms for Ischemic Stroke,Baltimore Longitudinal Study of Aging (BLSA)|B...,Baltimore Longitudinal Study of Aging (BLSA)|B...,baltimore longitudinal study of aging blsa |ba...
7375,ffbed01c-c3a3-43d2-9a34-8a86f3ec3bca,The Impact of Institutional Arrangements on Ed...,Trends in International Mathematics and Scienc...,Trends in International Mathematics and Scienc...,trends in international mathematics and scienc...
7376,ffc640be-c934-4421-89bf-cfc0ec6ead13,Predicting the location of human perirhinal co...,Alzheimer's Disease Neuroimaging Initiative (A...,ADNI,adni
7377,ffeb3568-7aed-4dbe-b177-cbd7f46f34af,Abandoning mathematics. Reconstructing the pro...,Trends in International Mathematics and Scienc...,Trends in International Mathematics and Scienc...,trends in international mathematics and scienc...


In [7]:
eval_train = train[:10]

In [8]:
eval_train

Unnamed: 0,Id,pub_title,dataset_title,dataset_label,cleaned_label
0,0007f880-0a9b-492d-9a58-76eb0b0e0bd7,The Impact of ICT Training on Income Generatio...,Program for the International Assessment of Ad...,Program for the International Assessment of Ad...,program for the international assessment of ad...
1,0008656f-0ba2-4632-8602-3017b44c2e90,Finnish Ninth Graders’ Gender Appropriateness ...,Trends in International Mathematics and Scienc...,Trends in International Mathematics and Scienc...,trends in international mathematics and scienc...
2,000efc17-13d8-433d-8f62-a3932fe4f3b8,Risk factors and global cognitive status relat...,Alzheimer's Disease Neuroimaging Initiative (A...,ADNI,adni
3,002203f0-1c57-4400-abc1-b783c4085743,A Hybrid Geometric–Statistical Deformable Mode...,Alzheimer's Disease Neuroimaging Initiative (A...,ADNI,adni
4,00243b98-f868-45e4-9b83-0c346c7ecad5,Five-Year Growth Trajectories of Kindergarten ...,Early Childhood Longitudinal Study,Early Childhood Longitudinal Study,early childhood longitudinal study
5,00248da3-ac1d-48fa-a95e-cc88553f9583,Evaluation of Movement Speed and Reaction Time...,Baltimore Longitudinal Study of Aging (BLSA)|B...,Baltimore Longitudinal Study of Aging (BLSA)|B...,baltimore longitudinal study of aging blsa |ba...
6,002cbf56-5158-4ec7-83fd-51fa7829bb13,Energy Expenditure in Older People Hospitalize...,Baltimore Longitudinal Study of Aging (BLSA)|B...,Baltimore Longitudinal Study of Aging (BLSA)|B...,baltimore longitudinal study of aging blsa |ba...
7,002fdc24-9ee2-42b5-b051-373faca90c4e,Study of the Influence of Age in 18F-FDG PET I...,Alzheimer's Disease Neuroimaging Initiative (A...,ADNI,adni
8,0030f840-8505-49c0-9991-da0d4c6c9496,MR‐assisted PET motion correction in simultane...,Alzheimer's Disease Neuroimaging Initiative (A...,ADNI,adni
9,0035a1ba-6d1e-487b-bc3d-d49c5d64a3e9,A General Fast Registration Framework by Learn...,Baltimore Longitudinal Study of Aging (BLSA)|B...,Baltimore Longitudinal Study of Aging (BLSA)|B...,baltimore longitudinal study of aging blsa |ba...


In [9]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1362.0, style=ProgressStyle(description…




# Prepare data for train MLM

### Auxiliary functions

In [10]:
def clean_paper_sentence(s):
    """
    This function is essentially clean_text without lowercasing.
    """
    s = re.sub('[^A-Za-z0-9]+', ' ', str(s)).strip()
    s = re.sub(' +', ' ', s)
    return s

def shorten_sentences(sentences):
    """
    Sentences that have more than MAX_LENGTH words will be split
    into multiple sentences with overlappings.
    """
    short_sentences = []
    for sentence in sentences:
        words = sentence.split()
        if len(words) > MAX_LENGTH:
            for p in range(0, len(words), MAX_LENGTH - OVERLAP):
                short_sentences.append(' '.join(words[p:p+MAX_LENGTH]))
        else:
            short_sentences.append(sentence)
    return short_sentences

def find_sublist(big_list, small_list):
    """
    find all positions of $small_list in $big_list.
    """
    all_positions = []
    for i in range(len(big_list) - len(small_list) + 1):
        if small_list == big_list[i:i+len(small_list)]:
            all_positions.append(i)
    
    return all_positions

def jaccard_similarity_list(l1, l2):
    """
    Return the Jaccard Similarity score of 2 lists.
    """
    intersection = len(list(set(l1).intersection(l2)))
    union = (len(l1) + len(l2)) - intersection
    return float(intersection) / union

connection_tokens = {'s', 'of', 'and', 'in', 'on', 'for', 'data', 'dataset'}
def find_negative_candidates(sentence, labels):
    """
    Extract negative samples for Masked Dataset Modeling from a given $sentence.
    A negative candidate should be a continuous sequence of at least 2 words, 
    each of these words either has the first letter in uppercase or is one of
    the connection words ($connection_tokens). Furthermore, the connection 
    tokens are not allowed to appear at the beginning and the end of the
    sequence. Lastly, the sequence must be quite different to any of the 
    ground truth labels (measured by Jaccard similarity).
    """
    def candidate_qualified(words, labels):
        while len(words) and words[0].lower() in connection_tokens:
            words = words[1:]
        while len(words) and words[-1].lower() in connection_tokens:
            words = words[:-1]
        
        return len(words) >= 2 and \
               all(jaccard_similarity_list(words, label) < 0.75 for label in labels)
    
    candidates = []
    
    phrase_start, phrase_end = -1, -1
    for id in range(1, len(sentence)):
        word = sentence[id]
        if word[0].isupper() or word in connection_tokens:
            if phrase_start == -1:
                phrase_start = phrase_end = id
            else:
                phrase_end = id
        else:
            if phrase_start != -1:
                if candidate_qualified(sentence[phrase_start:phrase_end+1], labels):
                    candidates.append((phrase_start, phrase_end))
                phrase_start = phrase_end = -1
    
    if phrase_start != -1:
        if candidate_qualified(sentence[phrase_start:phrase_end+1], labels):
            candidates.append((phrase_start, phrase_end))
    
    return candidates

### Extract positive and negative samples

In [11]:
corpus = []
cnt_pos = 0
cnt_neg = 0

pbar = tqdm(total = len(train))
for paper_id, dataset_labels in train[['Id', 'dataset_label']].itertuples(index=False):
    labels = [clean_paper_sentence(label).split() for label in dataset_labels.split('|')]
    with open(f'{paper_train_folder}/{paper_id}.json', 'r') as f:
        paper = json.load(f)
    content = '. '.join(section['text'] for section in paper)
    sentences = set([clean_paper_sentence(sentence) for sentence in content.split('.')])
    sentences = shorten_sentences(sentences) # make sentences short
    sentences = [sentence for sentence in sentences if len(sentence) > 10] # only accept sentences with length > 10 chars
    sentences = [sentence.split() for sentence in sentences]
    
    # positive samples
    for sentence in sentences:
        for label in labels:
            for pos in find_sublist(sentence, label):
                dt_point = sentence[:pos] + [DATASET_SYMBOL] + sentence[pos+len(label):]
                corpus.append(' '.join(dt_point))
                cnt_pos += 1
    
    # negative samples
    for sentence in sentences:
        sentence_str = ' '.join(sentence)
        if all(w not in sentence_str for w in {'data', 'study'}):
            continue
        for phrase_start, phrase_end in find_negative_candidates(sentence, labels):
            dt_point = sentence[:phrase_start] + [NONDATA_SYMBOL] + sentence[phrase_end+1:]
            corpus.append(' '.join(dt_point))
            cnt_neg += 1
    
    # process bar
    pbar.update(1)
    pbar.set_description(f'Training data size: {cnt_pos} postives + {cnt_neg} negatives')

Training data size: 41670 postives + 65756 negatives: 100%|██████████| 7379/7379 [01:41<00:00, 71.12it/s]

In [12]:
train[:10]

Unnamed: 0,Id,pub_title,dataset_title,dataset_label,cleaned_label
0,0007f880-0a9b-492d-9a58-76eb0b0e0bd7,The Impact of ICT Training on Income Generatio...,Program for the International Assessment of Ad...,Program for the International Assessment of Ad...,program for the international assessment of ad...
1,0008656f-0ba2-4632-8602-3017b44c2e90,Finnish Ninth Graders’ Gender Appropriateness ...,Trends in International Mathematics and Scienc...,Trends in International Mathematics and Scienc...,trends in international mathematics and scienc...
2,000efc17-13d8-433d-8f62-a3932fe4f3b8,Risk factors and global cognitive status relat...,Alzheimer's Disease Neuroimaging Initiative (A...,ADNI,adni
3,002203f0-1c57-4400-abc1-b783c4085743,A Hybrid Geometric–Statistical Deformable Mode...,Alzheimer's Disease Neuroimaging Initiative (A...,ADNI,adni
4,00243b98-f868-45e4-9b83-0c346c7ecad5,Five-Year Growth Trajectories of Kindergarten ...,Early Childhood Longitudinal Study,Early Childhood Longitudinal Study,early childhood longitudinal study
5,00248da3-ac1d-48fa-a95e-cc88553f9583,Evaluation of Movement Speed and Reaction Time...,Baltimore Longitudinal Study of Aging (BLSA)|B...,Baltimore Longitudinal Study of Aging (BLSA)|B...,baltimore longitudinal study of aging blsa |ba...
6,002cbf56-5158-4ec7-83fd-51fa7829bb13,Energy Expenditure in Older People Hospitalize...,Baltimore Longitudinal Study of Aging (BLSA)|B...,Baltimore Longitudinal Study of Aging (BLSA)|B...,baltimore longitudinal study of aging blsa |ba...
7,002fdc24-9ee2-42b5-b051-373faca90c4e,Study of the Influence of Age in 18F-FDG PET I...,Alzheimer's Disease Neuroimaging Initiative (A...,ADNI,adni
8,0030f840-8505-49c0-9991-da0d4c6c9496,MR‐assisted PET motion correction in simultane...,Alzheimer's Disease Neuroimaging Initiative (A...,ADNI,adni
9,0035a1ba-6d1e-487b-bc3d-d49c5d64a3e9,A General Fast Registration Framework by Learn...,Baltimore Longitudinal Study of Aging (BLSA)|B...,Baltimore Longitudinal Study of Aging (BLSA)|B...,baltimore longitudinal study of aging blsa |ba...


In [13]:
content = 'Studies using data from the $ NELS 88 found little evidence of a relationship between paid work hours and school \
performance once prior differences between individuals are taken into account Schoenhals Tienda and Schneider 1997 Warren LePore and Mare 2001 but see Marsh and \
Kleitman 2005 Drawing upon data from the # we illustrate how patterns of paid work in adolescence encompassing both the intensity hours and the duration of \
employment have lasting implications for post secondary schooling and wage attainments in early adulthood'

sentences = set([clean_paper_sentence(sentence) for sentence in content.split('.')])
sentences = shorten_sentences(sentences)
sentences = [sentence for sentence in sentences if len(sentence) > 10] # only accept sentences with length > 10 chars
sentences1 = [sentence.split() for sentence in sentences]

In [14]:
sentences[:5]

['Studies using data from the NELS 88 found little evidence of a relationship between paid work hours and school performance once prior differences between individuals are taken into account Schoenhals',
 'once prior differences between individuals are taken into account Schoenhals Tienda and Schneider 1997 Warren LePore and Mare 2001 but see Marsh and Kleitman 2005 Drawing upon data from the',
 'see Marsh and Kleitman 2005 Drawing upon data from the we illustrate how patterns of paid work in adolescence encompassing both the intensity hours and the duration of employment have',
 'both the intensity hours and the duration of employment have lasting implications for post secondary schooling and wage attainments in early adulthood',
 'early adulthood']

In [15]:
sentences1

[['Studies',
  'using',
  'data',
  'from',
  'the',
  'NELS',
  '88',
  'found',
  'little',
  'evidence',
  'of',
  'a',
  'relationship',
  'between',
  'paid',
  'work',
  'hours',
  'and',
  'school',
  'performance',
  'once',
  'prior',
  'differences',
  'between',
  'individuals',
  'are',
  'taken',
  'into',
  'account',
  'Schoenhals'],
 ['once',
  'prior',
  'differences',
  'between',
  'individuals',
  'are',
  'taken',
  'into',
  'account',
  'Schoenhals',
  'Tienda',
  'and',
  'Schneider',
  '1997',
  'Warren',
  'LePore',
  'and',
  'Mare',
  '2001',
  'but',
  'see',
  'Marsh',
  'and',
  'Kleitman',
  '2005',
  'Drawing',
  'upon',
  'data',
  'from',
  'the'],
 ['see',
  'Marsh',
  'and',
  'Kleitman',
  '2005',
  'Drawing',
  'upon',
  'data',
  'from',
  'the',
  'we',
  'illustrate',
  'how',
  'patterns',
  'of',
  'paid',
  'work',
  'in',
  'adolescence',
  'encompassing',
  'both',
  'the',
  'intensity',
  'hours',
  'and',
  'the',
  'duration',
  'of',


In [16]:
len(corpus)

107426

### Save data to a file

In [17]:
with open('train_mlm.json', 'w') as f:
    for sentence in corpus:
        row_json = {'text':sentence}
        json.dump(row_json, f)
        f.write('\n')

# Fine-tune the Transformer

In [18]:
datasets = load_dataset('json',
            data_files={'train' : 'train_mlm.json'},
            )

datasets["train"][:5]

Downloading and preparing dataset json/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/json/default-6756a42fd185587f/0.0.0/83d5b3a2f62630efc6b5315f00f20209b4ad91a00ac586597caee3a4da0bef02...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/json/default-6756a42fd185587f/0.0.0/83d5b3a2f62630efc6b5315f00f20209b4ad91a00ac586597caee3a4da0bef02. Subsequent calls will reuse this data.


{'text': ['as the US Department of Education the US Department of commerce the OECD $ and the European Commission',
  'Participants chosen for this study were above 18 years old labor law in Lebanon sets the eligibility age for work at 18 years who have completed # training',
  'When evaluating the data qualitatively the results indicate a positive impact on beneficiaries who attended the # training',
  'The aim of this study was to identify if acquiring ICT skills through # training program improved income generation opportunities after 3 months for DOT Lebanon s',
  'The aim of this study was to identify if acquiring ICT skills through DOT Lebanon s ICT training program improved income generation opportunities after 3 months #']}

### Tokenize and collate data

In [19]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435797.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=29.0, style=ProgressStyle(description_w…




In [20]:
tokenizer

PreTrainedTokenizerFast(name_or_path='bert-base-cased', vocab_size=28996, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [21]:
def tokenize_function(examples):
    return tokenizer(examples["text"])

tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=1, remove_columns=["text"])

HBox(children=(FloatProgress(value=0.0, max=108.0), HTML(value='')))




In [22]:
tokenized_datasets['train']['input_ids'][:5]

[[101,
  1112,
  1103,
  1646,
  1951,
  1104,
  2531,
  1103,
  1646,
  1951,
  1104,
  10678,
  1103,
  152,
  8231,
  2137,
  109,
  1105,
  1103,
  1735,
  2827,
  102],
 [101,
  4539,
  27989,
  21868,
  3468,
  1111,
  1142,
  2025,
  1127,
  1807,
  1407,
  1201,
  1385,
  5530,
  1644,
  1107,
  7940,
  3741,
  1103,
  11768,
  1425,
  1111,
  1250,
  1120,
  1407,
  1201,
  1150,
  1138,
  2063,
  108,
  2013,
  102],
 [101,
  1332,
  27698,
  1103,
  2233,
  186,
  4746,
  24936,
  1193,
  1103,
  2686,
  5057,
  170,
  3112,
  3772,
  1113,
  26181,
  11470,
  27989,
  5927,
  1150,
  2323,
  1103,
  108,
  2013,
  102],
 [101,
  1109,
  6457,
  1104,
  1142,
  2025,
  1108,
  1106,
  6183,
  1191,
  14585,
  146,
  16647,
  4196,
  1194,
  108,
  2013,
  1788,
  4725,
  2467,
  3964,
  6305,
  1170,
  124,
  1808,
  1111,
  141,
  14697,
  7940,
  188,
  102],
 [101,
  1109,
  6457,
  1104,
  1142,
  2025,
  1108,
  1106,
  6183,
  1191,
  14585,
  146,
  16647,
  4196,
  1

In [23]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [24]:
data_collator

DataCollatorForLanguageModeling(tokenizer=PreTrainedTokenizerFast(name_or_path='bert-base-cased', vocab_size=28996, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}), mlm=True, mlm_probability=0.15)

### Load pre-trained model and fine-tune

In [25]:
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435779157.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [26]:
model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [27]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="output-mlm",
    evaluation_strategy = "no",
    learning_rate=2e-5,
    weight_decay=0.01,
    save_steps=12000,
    num_train_epochs=2,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    #eval_dataset= tokenized_datasets["eval_train"],
    #compute_metrics= compute_metrics,
    data_collator=data_collator
)

In [28]:
trainer.train()

Step,Training Loss
500,3.2734
1000,3.0027
1500,2.7437
2000,2.7062
2500,2.6507
3000,2.6209
3500,2.5444
4000,2.4936
4500,2.5068
5000,2.4843


TrainOutput(global_step=26858, training_loss=2.3155448675528474, metrics={'train_runtime': 6303.0097, 'train_samples_per_second': 4.261, 'total_flos': 5686181292023232.0, 'epoch': 2.0, 'init_mem_cpu_alloc_delta': 1534328832, 'init_mem_gpu_alloc_delta': 433891840, 'init_mem_cpu_peaked_delta': 314474496, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 82259968, 'train_mem_gpu_alloc_delta': 1308858368, 'train_mem_cpu_peaked_delta': 268013568, 'train_mem_gpu_peaked_delta': 1292784128})

### Save model

In [29]:
trainer.model.save_pretrained('mlm-model')

### Save tokenizer

In [30]:
config = AutoConfig.from_pretrained(model_checkpoint)

tokenizer.save_pretrained('model_tokenizer')
config.save_pretrained('model_tokenizer')