In [1]:
import math
import random

import pandas as pd
import transformers
from datasets import ClassLabel, load_dataset
from IPython.display import HTML, display
from transformers import (
    AutoModelForMaskedLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)

In [2]:
dataset_name = "uestc-swahili/swahili"
datasets = load_dataset(dataset_name)

In [3]:
datasets

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 42069
    })
    test: Dataset({
        features: ['text'],
        num_rows: 3371
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3372
    })
})

In [4]:
datasets["train"][0]

{'text': 'taarifa hiyo ilisema kuwa ongezeko la joto la maji juu ya wastani katikati ya bahari ya UNK inaashiria kuwepo kwa mvua za el nino UNK hadi mwishoni mwa april ishirini moja sifuri imeelezwa kuwa ongezeko la joto magharibi mwa bahari ya hindi linatarajiwa kuhamia katikati ya bahari hiyo hali ambayo itasababisha pepo kutoka kaskazini mashariki kuvuma kuelekea bahari ya hindi'}

In [5]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(
        dataset
    ), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset) - 1)
        while pick in picks:
            pick = random.randint(0, len(dataset) - 1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))


show_random_elements(datasets["train"])

Unnamed: 0,text
0,watakuja ili kumuangalia wanataka kumuona akicheza mechi ngumu
1,anafafanua kuwa kama tanzania itajipanga vizuri ina uwezo wa kulisha wananchi wake na nchi zote jirani kwani ina rasilimali ya maji ya maziwa mito ya kutosha na ardhi yenye rutuba nzuri
2,anasema endapo kura za kuanzisha jopo hilo UNK atachukua jukumu la kusikiliza madai hayo akiamini kuwa wadau wengine watafanya hivyo ili isije kutokea tena
3,maagizo ya rais UNK kesi kwani utekelezaji wa kuwasha mitambo hiyo unapitia kwa msimamizi wa muda aliyeteuliwa na mahakama
4,akizunguzia suala la mauaji ya albino rais kikwete alisema kuwa ongezeko la mauaji hayo linatokana na imani za kishirikina
5,hakuna utafiti wa kisayansi UNK kwamba dawa hizo zinasababisha saratani watu hao ni waongo na UNK na jamii kwa kuwa UNK na umasikini wa fikra tena mawazo mgando alisema na kuongeza kuwa moja ya faida ya uzazi ya uzazi wa mpango ni kuisadia familia kupanga shughuli za maendeleo ikiwemo kurudisha afya ya mama kabla ya kubeba mimba
6,ni vitu vigeni ambavyo wanatakiwa kukutana navyo mara nyingine tena ili UNK
7,hili UNK wengi kutokana dharau ya wazi ya mkapa UNK kwa waandishi wa tanzania na kuwathamini waandishi wa nchi za magharibi
8,ferguson amefanikiwa kusajili wachezaji wa nje pia england kwa UNK
9,hicho ndio ninachotaka alisema


In [6]:
def tokenize_function(examples):
    return tokenizer(examples["text"])


model_checkpoint = "xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
tokenized_datasets = datasets.map(
    tokenize_function, batched=True, num_proc=4, remove_columns=["text"]
)



In [7]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 42069
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 3371
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 3372
    })
})

In [8]:
tokenized_datasets["train"][10]

{'input_ids': [0,
  150722,
  151,
  12967,
  151,
  149362,
  10790,
  6,
  180323,
  177769,
  228,
  186214,
  28803,
  6,
  180323,
  228,
  85407,
  14102,
  24,
  32825,
  75,
  68428,
  10475,
  760,
  291,
  4390,
  4542,
  12209,
  686,
  259,
  130719,
  1499,
  92704,
  291,
  1783,
  157,
  2],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1]}

In [9]:
show_random_elements(tokenized_datasets["train"])

Unnamed: 0,input_ids,attention_mask
0,"[0, 56038, 11, 68163, 12116, 3877, 210909, 1922, 41144, 92289, 23317, 20973, 13037, 51795, 210131, 70037, 21, 73365, 291, 43319, 151, 92368, 760, 23103, 151, 1922, 57942, 39665, 2]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"
1,"[0, 15018, 15800, 370, 213001, 31851, 75, 68428, 41002, 760, 6, 48949, 11412, 21, 91750, 2197, 45725, 201451, 35466, 24, 38765, 148301, 1499, 139470, 75, 145, 31, 30601, 151, 2285, 182828, 10219, 2]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"
2,"[0, 876, 24086, 242113, 153130, 259, 1436, 63220, 14, 259, 75206, 634, 259, 83408, 1622, 123490, 14, 300, 36593, 142637, 22476, 166574, 6, 180323, 26542, 38388, 11, 3125, 59428, 90965, 4848, 2959, 24, 135637, 24, 17163, 14, 24258, 18333, 300, 147119, 760, 42323, 20439, 279, 57268, 12787, 26542, 93, 2]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"
3,"[0, 347, 169573, 259, 180486, 19562, 56222, 111263, 6, 180323, 153130, 70768, 43744, 5431, 63766, 229178, 151, 1922, 186214, 30714, 63549, 10744, 21428, 147555, 72543, 55617, 24, 97804, 1922, 21528, 11, 1622, 188197, 1608, 41217, 151, 57055, 2]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"
4,"[0, 7975, 23023, 38424, 105177, 10, 33169, 88043, 5538, 184037, 23, 21256, 6, 180323, 3886, 2]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"
5,"[0, 177537, 6, 180323, 13487, 24, 466, 34, 120209, 760, 200189, 80, 203463, 23, 21256, 9873, 172065, 81760, 223090, 56070, 50297, 185840, 24, 200189, 19712, 2]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"
6,"[0, 169182, 259, 43879, 3597, 200, 44910, 24, 700, 596, 10391, 133, 1622, 3441, 151, 124, 33970, 74667, 6, 180323, 228, 14336, 634, 21, 101950, 7126, 18285, 113824, 259, 52606, 763, 11, 75, 61166, 75, 196654, 86286, 80, 1257, 1352, 658, 1622, 51683, 26045, 2224, 1922, 34641, 197930, 80, 809, 3390, 151, 167892, 24, 58606, 31, 26045, 2]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"
7,"[0, 18285, 237, 8380, 24, 57879, 79, 771, 8767, 50032, 4571, 50397, 17, 771, 13598, 4522, 15026, 78, 22455, 259, 228, 11222, 56992, 11, 178235, 2]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"
8,"[0, 6, 135187, 16115, 291, 58187, 95491, 19562, 1922, 300, 14102, 24, 228, 2285, 33265, 51368, 5748, 2224, 6, 100352, 200, 33226, 1608, 347, 195209, 259, 61131, 14, 6758, 10790, 2]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"
9,"[0, 77420, 259, 119, 11, 71596, 238, 115, 1462, 20205, 400, 1436, 151, 228, 139294, 18285, 188197, 145211, 24222, 110867, 24, 6, 180323, 1072, 416, 151, 72543, 9908, 24, 81760, 259, 18398, 316, 17540, 2]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"


In [10]:
import math

print(f"Model maximum block size: {tokenizer.model_max_length}")
# block_size = tokenizer.model_max_length
block_size = 128
batch_size = 1000
print(f"Block size: {block_size}")

for type in ["train", "test", "validation"]:
    num_tokens = sum(
        [len(input_ids) for input_ids in tokenized_datasets[type]["input_ids"]]
    )
    print(
        f"The number of tokens in {type}: {num_tokens}, this will be ~{round(num_tokens / block_size)} blocks."
    )

Model maximum block size: 512
Block size: 128
The number of tokens in train: 1865542, this will be ~14575 blocks.
The number of tokens in test: 166669, this will be ~1302 blocks.
The number of tokens in validation: 159984, this will be ~1250 blocks.


In [11]:
def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // block_size) * block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [12]:
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
)

In [13]:
lm_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 14552
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1301
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1248
    })
})

In [14]:
show_random_elements(lm_datasets["train"])

Unnamed: 0,input_ids,attention_mask,labels
0,"[180323, 173, 192, 179616, 151, 2197, 45725, 2, 0, 20205, 234538, 151, 180486, 19562, 151, 134375, 173, 192, 24, 10805, 24205, 12926, 156300, 18333, 1337, 145, 206068, 137566, 300, 1053, 433, 23210, 151, 83, 50637, 93, 24, 33780, 151, 2197, 45725, 2626, 2, 0, 923, 913, 145, 156300, 5748, 300, 68524, 14, 347, 18, 63991, 100484, 7502, 11, 156, 28783, 5941, 2514, 200, 21045, 1648, 516, 308, 27651, 200, 519, 192, 24, 3114, 10521, 24, 10805, 156300, 99477, 5431, 206068, 137566, 180486, 19562, 2, 0, 20205, 180486, 19562, 6, 180323, 24, 52041, 45046, 259, 14378, 9353, 6, 180323, 94194, 24258, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]","[180323, 173, 192, 179616, 151, 2197, 45725, 2, 0, 20205, 234538, 151, 180486, 19562, 151, 134375, 173, 192, 24, 10805, 24205, 12926, 156300, 18333, 1337, 145, 206068, 137566, 300, 1053, 433, 23210, 151, 83, 50637, 93, 24, 33780, 151, 2197, 45725, 2626, 2, 0, 923, 913, 145, 156300, 5748, 300, 68524, 14, 347, 18, 63991, 100484, 7502, 11, 156, 28783, 5941, 2514, 200, 21045, 1648, 516, 308, 27651, 200, 519, 192, 24, 3114, 10521, 24, 10805, 156300, 99477, 5431, 206068, 137566, 180486, 19562, 2, 0, 20205, 180486, 19562, 6, 180323, 24, 52041, 45046, 259, 14378, 9353, 6, 180323, 94194, 24258, ...]"
1,"[22539, 1681, 347, 35251, 259, 11625, 20646, 6544, 432, 21, 146904, 11, 142636, 347, 18357, 35251, 259, 23546, 259, 8840, 1556, 1462, 29970, 142, 10744, 242193, 2420, 24, 61192, 151, 11625, 31, 2, 0, 104697, 93, 24205, 135382, 30938, 300, 5614, 42753, 979, 1053, 433, 370, 84641, 108553, 18802, 143, 12781, 1622, 104697, 151, 34912, 350, 24, 129543, 151, 279, 3584, 49270, 300, 9382, 83, 50637, 93, 24, 308, 157, 5431, 210832, 24, 2223, 26770, 324, 50742, 2223, 1681, 5485, 11230, 24, 154643, 3262, 95560, 2, 0, 347, 3767, 95563, 259, 25988, 1622, 51368, 151, 139803, 151, 224369, 259, 170712, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]","[22539, 1681, 347, 35251, 259, 11625, 20646, 6544, 432, 21, 146904, 11, 142636, 347, 18357, 35251, 259, 23546, 259, 8840, 1556, 1462, 29970, 142, 10744, 242193, 2420, 24, 61192, 151, 11625, 31, 2, 0, 104697, 93, 24205, 135382, 30938, 300, 5614, 42753, 979, 1053, 433, 370, 84641, 108553, 18802, 143, 12781, 1622, 104697, 151, 34912, 350, 24, 129543, 151, 279, 3584, 49270, 300, 9382, 83, 50637, 93, 24, 308, 157, 5431, 210832, 24, 2223, 26770, 324, 50742, 2223, 1681, 5485, 11230, 24, 154643, 3262, 95560, 2, 0, 347, 3767, 95563, 259, 25988, 1622, 51368, 151, 139803, 151, 224369, 259, 170712, ...]"
2,"[50350, 7180, 137365, 760, 13598, 9873, 53229, 7933, 763, 2057, 5234, 18285, 15018, 13598, 7431, 4571, 52327, 32825, 129543, 5748, 760, 1922, 23317, 23, 2832, 11, 38435, 8426, 373, 107, 171, 964, 33778, 760, 23541, 6, 180323, 760, 148, 22242, 24, 228, 104601, 48, 561, 1499, 228, 33, 15586, 24, 173975, 12967, 172, 151, 218761, 153130, 2, 0, 259, 104601, 658, 259, 75206, 634, 259, 48, 561, 300, 70521, 34721, 8971, 771, 21383, 118468, 1158, 96713, 7126, 6, 180323, 35466, 24, 173975, 151, 104697, 24258, 80, 115130, 24, 23317, 2, 0, 177391, 151, 76937, 24, 91630, 259, 49157, 17, 139818, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]","[50350, 7180, 137365, 760, 13598, 9873, 53229, 7933, 763, 2057, 5234, 18285, 15018, 13598, 7431, 4571, 52327, 32825, 129543, 5748, 760, 1922, 23317, 23, 2832, 11, 38435, 8426, 373, 107, 171, 964, 33778, 760, 23541, 6, 180323, 760, 148, 22242, 24, 228, 104601, 48, 561, 1499, 228, 33, 15586, 24, 173975, 12967, 172, 151, 218761, 153130, 2, 0, 259, 104601, 658, 259, 75206, 634, 259, 48, 561, 300, 70521, 34721, 8971, 771, 21383, 118468, 1158, 96713, 7126, 6, 180323, 35466, 24, 173975, 151, 104697, 24258, 80, 115130, 24, 23317, 2, 0, 177391, 151, 76937, 24, 91630, 259, 49157, 17, 139818, ...]"
3,"[11418, 102, 1686, 3125, 150561, 33627, 19057, 24, 78130, 136507, 4268, 177391, 7575, 9873, 109706, 24, 228, 11, 90424, 73365, 19562, 2, 0, 876, 24086, 159886, 923, 14093, 150561, 923, 101818, 10805, 3877, 109706, 73365, 19562, 18285, 914, 20236, 259, 13820, 173239, 64120, 203933, 24, 3877, 54635, 2959, 760, 40129, 55916, 173239, 43860, 39, 67, 3116, 2, 0, 15481, 151, 39375, 923, 14093, 1922, 59817, 2477, 28144, 151, 76971, 91083, 6, 100352, 1622, 4486, 14, 5748, 55916, 923, 979, 172, 26090, 146904, 31, 206183, 15800, 923, 6758, 23639, 34, 166722, 2, 0, 18285, 177537, 923, 15026, 75, 68428, 259, 57096, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]","[11418, 102, 1686, 3125, 150561, 33627, 19057, 24, 78130, 136507, 4268, 177391, 7575, 9873, 109706, 24, 228, 11, 90424, 73365, 19562, 2, 0, 876, 24086, 159886, 923, 14093, 150561, 923, 101818, 10805, 3877, 109706, 73365, 19562, 18285, 914, 20236, 259, 13820, 173239, 64120, 203933, 24, 3877, 54635, 2959, 760, 40129, 55916, 173239, 43860, 39, 67, 3116, 2, 0, 15481, 151, 39375, 923, 14093, 1922, 59817, 2477, 28144, 151, 76971, 91083, 6, 100352, 1622, 4486, 14, 5748, 55916, 923, 979, 172, 26090, 146904, 31, 206183, 15800, 923, 6758, 23639, 34, 166722, 2, 0, 18285, 177537, 923, 15026, 75, 68428, 259, 57096, ...]"
4,"[402, 6, 212509, 24, 6, 180323, 15018, 125786, 2057, 5234, 78, 157391, 760, 13598, 78, 108135, 2, 0, 24, 27308, 228, 135317, 1922, 15018, 4268, 63394, 300, 347, 41092, 1783, 1849, 14, 106037, 20385, 30601, 108093, 2, 0, 135770, 300, 54150, 48740, 7816, 55, 220, 39, 57949, 1263, 4806, 151, 63300, 2, 0, 4268, 75, 52141, 286, 25107, 19567, 300, 10805, 22455, 120398, 300, 23354, 33957, 22170, 259, 200, 125407, 1608, 2057, 162, 116671, 760, 12926, 300, 19965, 1608, 93226, 77409, 69497, 11, 923, 146955, 101777, 24, 1053, 433, 38388, 14, 88, 3125, 7431, 24, 8132, 1053, 433, 24, 70, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]","[402, 6, 212509, 24, 6, 180323, 15018, 125786, 2057, 5234, 78, 157391, 760, 13598, 78, 108135, 2, 0, 24, 27308, 228, 135317, 1922, 15018, 4268, 63394, 300, 347, 41092, 1783, 1849, 14, 106037, 20385, 30601, 108093, 2, 0, 135770, 300, 54150, 48740, 7816, 55, 220, 39, 57949, 1263, 4806, 151, 63300, 2, 0, 4268, 75, 52141, 286, 25107, 19567, 300, 10805, 22455, 120398, 300, 23354, 33957, 22170, 259, 200, 125407, 1608, 2057, 162, 116671, 760, 12926, 300, 19965, 1608, 93226, 77409, 69497, 11, 923, 146955, 101777, 24, 1053, 433, 38388, 14, 88, 3125, 7431, 24, 8132, 1053, 433, 24, 70, ...]"
5,"[214262, 6, 180323, 1919, 399, 760, 203575, 4268, 923, 6, 180323, 2, 0, 10181, 151, 1979, 33426, 760, 48074, 246, 24, 5367, 31, 7575, 40900, 60879, 7022, 48740, 11, 71776, 12926, 33627, 35206, 669, 43860, 71387, 1922, 124803, 259, 53369, 21, 48, 169, 315, 1127, 169, 1592, 31, 4895, 33627, 48546, 15623, 1592, 31, 4895, 3877, 100285, 221, 51785, 2, 0, 45491, 39, 7575, 858, 118, 24, 84839, 18333, 155647, 192450, 148301, 1499, 206549, 158751, 1289, 114131, 24, 45491, 39, 151, 227221, 43879, 8051, 350, 6, 15556, 42191, 464, 31425, 20973, 13037, 11397, 12781, 31870, 24, 75361, 151, 6, 180323, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]","[214262, 6, 180323, 1919, 399, 760, 203575, 4268, 923, 6, 180323, 2, 0, 10181, 151, 1979, 33426, 760, 48074, 246, 24, 5367, 31, 7575, 40900, 60879, 7022, 48740, 11, 71776, 12926, 33627, 35206, 669, 43860, 71387, 1922, 124803, 259, 53369, 21, 48, 169, 315, 1127, 169, 1592, 31, 4895, 33627, 48546, 15623, 1592, 31, 4895, 3877, 100285, 221, 51785, 2, 0, 45491, 39, 7575, 858, 118, 24, 84839, 18333, 155647, 192450, 148301, 1499, 206549, 158751, 1289, 114131, 24, 45491, 39, 151, 227221, 43879, 8051, 350, 6, 15556, 42191, 464, 31425, 20973, 13037, 11397, 12781, 31870, 24, 75361, 151, 6, 180323, ...]"
6,"[80, 279, 35548, 11, 13487, 64856, 6, 180323, 228, 21392, 634, 24, 347, 69732, 86054, 26887, 2, 0, 20205, 10181, 151, 347, 69732, 86054, 12296, 90, 97804, 13487, 923, 99307, 634, 279, 35548, 11, 13487, 24, 231143, 9382, 64856, 6, 180323, 279, 43319, 24, 55649, 6, 180323, 59587, 24, 58742, 14, 1922, 152904, 7126, 75, 42237, 259, 6, 180323, 2, 0, 10181, 151, 347, 69732, 86054, 12296, 90, 228, 187132, 228, 110162, 15018, 561, 80, 31764, 21383, 4307, 238, 923, 99307, 634, 279, 35548, 11, 13487, 24, 231143, 9382, 18285, 760, 12967, 151, 228, 172, 9488, 7230, 3886, 12967, 151, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]","[80, 279, 35548, 11, 13487, 64856, 6, 180323, 228, 21392, 634, 24, 347, 69732, 86054, 26887, 2, 0, 20205, 10181, 151, 347, 69732, 86054, 12296, 90, 97804, 13487, 923, 99307, 634, 279, 35548, 11, 13487, 24, 231143, 9382, 64856, 6, 180323, 279, 43319, 24, 55649, 6, 180323, 59587, 24, 58742, 14, 1922, 152904, 7126, 75, 42237, 259, 6, 180323, 2, 0, 10181, 151, 347, 69732, 86054, 12296, 90, 228, 187132, 228, 110162, 15018, 561, 80, 31764, 21383, 4307, 238, 923, 99307, 634, 279, 35548, 11, 13487, 24, 231143, 9382, 18285, 760, 12967, 151, 228, 172, 9488, 7230, 3886, 12967, 151, ...]"
7,"[112, 192, 501, 3931, 314, 420, 1872, 254, 2, 0, 10276, 103115, 19562, 151, 200087, 228, 50297, 233132, 259, 1608, 112, 192, 13019, 100285, 491, 73187, 33978, 1463, 33396, 39545, 44892, 21999, 8054, 4707, 760, 347, 1436, 161, 347, 1436, 161, 24, 228, 60576, 1697, 25107, 2, 0, 1697, 25107, 2, 0, 1608, 112, 192, 2, 0, 1608, 112, 192, 2, 0, 1622, 69464, 31, 166093, 198451, 21, 228, 216947, 34, 11908, 400, 42237, 400, 200087, 233132, 27610, 13019, 64301, 1922, 153746, 760, 200, 133, 13019, 9606, 508, 99233, 402, 8132, 273, 151, 148, 12994, 220, 31, 38424, 228, 151983, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]","[112, 192, 501, 3931, 314, 420, 1872, 254, 2, 0, 10276, 103115, 19562, 151, 200087, 228, 50297, 233132, 259, 1608, 112, 192, 13019, 100285, 491, 73187, 33978, 1463, 33396, 39545, 44892, 21999, 8054, 4707, 760, 347, 1436, 161, 347, 1436, 161, 24, 228, 60576, 1697, 25107, 2, 0, 1697, 25107, 2, 0, 1608, 112, 192, 2, 0, 1608, 112, 192, 2, 0, 1622, 69464, 31, 166093, 198451, 21, 228, 216947, 34, 11908, 400, 42237, 400, 200087, 233132, 27610, 13019, 64301, 1922, 153746, 760, 200, 133, 13019, 9606, 508, 99233, 402, 8132, 273, 151, 148, 12994, 220, 31, 38424, 228, 151983, ...]"
8,"[36593, 491, 11440, 432, 83, 10734, 221309, 1239, 22242, 151, 107590, 1643, 172, 108553, 24, 233413, 108093, 3886, 10805, 41217, 34753, 98704, 24, 37482, 25785, 151, 135382, 36835, 151, 192352, 151, 180486, 19562, 2, 0, 9914, 44892, 22476, 8971, 47244, 11, 632, 557, 22205, 138599, 151, 13586, 1622, 91750, 6, 180323, 15018, 1922, 24, 6, 180323, 10805, 68833, 41002, 108553, 24, 37482, 25785, 151, 51795, 28819, 136824, 13611, 101447, 2, 0, 300, 9074, 50637, 10805, 14378, 21439, 41002, 10, 213001, 24, 228, 142761, 3886, 96232, 57055, 6057, 7, 6, 180323, 136824, 13611, 23810, 3948, 519, 18285, 138599, 4268, 19562, 13860, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]","[36593, 491, 11440, 432, 83, 10734, 221309, 1239, 22242, 151, 107590, 1643, 172, 108553, 24, 233413, 108093, 3886, 10805, 41217, 34753, 98704, 24, 37482, 25785, 151, 135382, 36835, 151, 192352, 151, 180486, 19562, 2, 0, 9914, 44892, 22476, 8971, 47244, 11, 632, 557, 22205, 138599, 151, 13586, 1622, 91750, 6, 180323, 15018, 1922, 24, 6, 180323, 10805, 68833, 41002, 108553, 24, 37482, 25785, 151, 51795, 28819, 136824, 13611, 101447, 2, 0, 300, 9074, 50637, 10805, 14378, 21439, 41002, 10, 213001, 24, 228, 142761, 3886, 96232, 57055, 6057, 7, 6, 180323, 136824, 13611, 23810, 3948, 519, 18285, 138599, 4268, 19562, 13860, ...]"
9,"[228, 50397, 10805, 300, 30035, 760, 265, 151, 2628, 113, 8782, 29970, 3177, 34, 24052, 1804, 11036, 6, 180323, 6, 180323, 6, 180323, 60384, 347, 58392, 2, 0, 6, 180323, 2, 0, 10, 182931, 70768, 228, 17112, 53091, 24, 80419, 220, 21383, 57, 4580, 77973, 2, 0, 75, 182472, 1922, 24, 22455, 29970, 180674, 7575, 300, 96232, 100722, 100193, 87626, 24, 228, 226985, 100193, 39545, 24, 77409, 2, 0, 78, 53318, 79, 9051, 2, 0, 12926, 94183, 771, 1622, 210025, 300, 10805, 9914, 9382, 9353, 347, 46578, 220, 28533, 57305, 71387, 10475, 24, 842, 282, 531, 347, 350, 1177, 6, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]","[228, 50397, 10805, 300, 30035, 760, 265, 151, 2628, 113, 8782, 29970, 3177, 34, 24052, 1804, 11036, 6, 180323, 6, 180323, 6, 180323, 60384, 347, 58392, 2, 0, 6, 180323, 2, 0, 10, 182931, 70768, 228, 17112, 53091, 24, 80419, 220, 21383, 57, 4580, 77973, 2, 0, 75, 182472, 1922, 24, 22455, 29970, 180674, 7575, 300, 96232, 100722, 100193, 87626, 24, 228, 226985, 100193, 39545, 24, 77409, 2, 0, 78, 53318, 79, 9051, 2, 0, 12926, 94183, 771, 1622, 210025, 300, 10805, 9914, 9382, 9353, 347, 46578, 220, 28533, 57305, 71387, 10475, 24, 842, 282, 531, 347, 350, 1177, 6, ...]"


In [15]:
model_checkpoint = "xlm-roberta-base"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaForMaskedLM: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing XLMRobertaForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
print(
    f"{model_checkpoint} number of parameters: {round(model.num_parameters() / 1_000_000)}M"
)

xlm-roberta-base number of parameters: 278M


In [17]:
training_args = TrainingArguments(
    f"{model_checkpoint}-finetuned-{dataset_name.split('/')[-1]}",
    eval_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
)

In [18]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm_probability=0.15
)

In [19]:
small_train_dataset = lm_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = lm_datasets["validation"].shuffle(seed=42).select(range(1000))

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    data_collator=data_collator,
)

In [20]:
eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

  0%|          | 0/125 [00:00<?, ?it/s]

>>> Perplexity: 24.25


In [21]:
trainer.train()

  0%|          | 0/375 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.0673115253448486, 'eval_model_preparation_time': 0.0016, 'eval_runtime': 27.8459, 'eval_samples_per_second': 35.912, 'eval_steps_per_second': 4.489, 'epoch': 1.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.0007805824279785, 'eval_model_preparation_time': 0.0016, 'eval_runtime': 27.5371, 'eval_samples_per_second': 36.315, 'eval_steps_per_second': 4.539, 'epoch': 2.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 1.9776479005813599, 'eval_model_preparation_time': 0.0016, 'eval_runtime': 28.2579, 'eval_samples_per_second': 35.388, 'eval_steps_per_second': 4.424, 'epoch': 3.0}
{'train_runtime': 1192.1127, 'train_samples_per_second': 2.517, 'train_steps_per_second': 0.315, 'train_loss': 2.5689215494791666, 'epoch': 3.0}


TrainOutput(global_step=375, training_loss=2.5689215494791666, metrics={'train_runtime': 1192.1127, 'train_samples_per_second': 2.517, 'train_steps_per_second': 0.315, 'total_flos': 197909291520000.0, 'train_loss': 2.5689215494791666, 'epoch': 3.0})

In [22]:
eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

  0%|          | 0/125 [00:00<?, ?it/s]

>>> Perplexity: 7.06


In [23]:
trainer.save_model("./my_model")