In [2]:
from pathlib import Path

from tokenizers import ByteLevelBPETokenizer

In [3]:
paths = [str(x) for x in Path(".").glob("**/*.tokens")]

In [4]:
paths

['data\\wikitext-2\\wiki.test.tokens',
 'data\\wikitext-2\\wiki.train.tokens',
 'data\\wikitext-2\\wiki.valid.tokens']

In [5]:
# Initialize a tokenizer
tokenizer = ByteLevelBPETokenizer()

# Customize training
tokenizer.train(files=paths, vocab_size=52_000, min_frequency=2, special_tokens=[
    "<s>",
    "<pad>",
    "</s>",
    "<unk>",
    "<mask>",
])

In [6]:
!mkdir wikiTransformer
tokenizer.save_model("wikiTransformer")

['wikiTransformer\\vocab.json', 'wikiTransformer\\merges.txt']

In [7]:
from tokenizers.implementations import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing


tokenizer = ByteLevelBPETokenizer(
    "./wikiTransformer/vocab.json",
    "./wikiTransformer/merges.txt",
)

In [8]:
tokenizer._tokenizer.post_processor = BertProcessing(
    ("</s>", tokenizer.token_to_id("</s>")),
    ("<s>", tokenizer.token_to_id("<s>")),
)
tokenizer.enable_truncation(max_length=512)

In [10]:
tokenizer.encode("advent of nuclear weapons").tokens

['<s>', 'ad', 'vent', 'Ġof', 'Ġnuclear', 'Ġweapons', '</s>']

In [11]:
import torch
torch.cuda.is_available()

True

In [13]:
from transformers import RobertaTokenizerFast

tokenizer = RobertaTokenizerFast.from_pretrained("./wikiTransformer", max_len=512)

In [15]:
from transformer_model import make_model


In [16]:
model = make_model(52000,52000,6,512,2048,8,0.1)

In [20]:
%%time
from transformers import LineByLineTextDataset

dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path="./data/wikitext-2/wiki.test.tokens",
    block_size=128,
)

Wall time: 191 ms


In [21]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

In [26]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./wikiTransformer",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=64,
    save_steps=10_000,
    save_total_limit=2,
    prediction_loss_only=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [28]:
help(Trainer)

Help on class Trainer in module transformers.trainer:

class Trainer(builtins.object)
 |  Trainer(model: Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None, args: transformers.training_args.TrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[torch.utils.data.dataset.Dataset] = None, eval_dataset: Optional[torch.utils.data.dataset.Dataset] = None, tokenizer: Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None, model_init: Callable[[], transformers.modeling_utils.PreTrainedModel] = None, compute_metrics: Optional[Callable[[transformers.trainer_utils.EvalPrediction], Dict]] = None, callbacks: Optional[List[transformers.trainer_callback.TrainerCallback]] = None, optimizers: Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None)
 |  
 |  Trainer is a simple b

In [27]:
trainer.train()

***** Running training *****
  Num examples = 2891
  Num Epochs = 1
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 46


TypeError: forward() got an unexpected keyword argument 'input_ids'

In [29]:
te = torch.Tensor([    1,   357,    43,  2977,   530, 23080,    13,    78,    17,     0,
         4312,     0,   151,    22, 18215,    17,    17,    46,    43,  2015,
            2,  1496,  7369,   115,  4782,    37, 22196,   252, 26998,     0,
        28680,     1,   496,  2193,  1037,     9,  4072,   380,    27, 33001,
            3,   449,   310,     9,    13,  8034,  3107,   639,    13, 27958,
          638,     1,   168,    17,    43,  2786,    15,   160,   152,  3072,
            4,  5181, 15182, 18712,   877,    16,   423,    22,   562,  1575,
          496,     1,   209,  1056,    17,    39,   317, 19914,   128,   348,
            1,    13,    15,    22, 17314,   357,  1517,   209,  2156,   348,
          131,  2196,   146, 16561,   188, 31575,   348,    54,    52,   630])

In [36]:
te = te.view(5,20)
te

tensor([[1.0000e+00, 3.5700e+02, 4.3000e+01, 2.9770e+03, 5.3000e+02, 2.3080e+04,
         1.3000e+01, 7.8000e+01, 1.7000e+01, 0.0000e+00, 4.3120e+03, 0.0000e+00,
         1.5100e+02, 2.2000e+01, 1.8215e+04, 1.7000e+01, 1.7000e+01, 4.6000e+01,
         4.3000e+01, 2.0150e+03],
        [2.0000e+00, 1.4960e+03, 7.3690e+03, 1.1500e+02, 4.7820e+03, 3.7000e+01,
         2.2196e+04, 2.5200e+02, 2.6998e+04, 0.0000e+00, 2.8680e+04, 1.0000e+00,
         4.9600e+02, 2.1930e+03, 1.0370e+03, 9.0000e+00, 4.0720e+03, 3.8000e+02,
         2.7000e+01, 3.3001e+04],
        [3.0000e+00, 4.4900e+02, 3.1000e+02, 9.0000e+00, 1.3000e+01, 8.0340e+03,
         3.1070e+03, 6.3900e+02, 1.3000e+01, 2.7958e+04, 6.3800e+02, 1.0000e+00,
         1.6800e+02, 1.7000e+01, 4.3000e+01, 2.7860e+03, 1.5000e+01, 1.6000e+02,
         1.5200e+02, 3.0720e+03],
        [4.0000e+00, 5.1810e+03, 1.5182e+04, 1.8712e+04, 8.7700e+02, 1.6000e+01,
         4.2300e+02, 2.2000e+01, 5.6200e+02, 1.5750e+03, 4.9600e+02, 1.0000e+00,
       

In [37]:
te[:,:-1]

tensor([[1.0000e+00, 3.5700e+02, 4.3000e+01, 2.9770e+03, 5.3000e+02, 2.3080e+04,
         1.3000e+01, 7.8000e+01, 1.7000e+01, 0.0000e+00, 4.3120e+03, 0.0000e+00,
         1.5100e+02, 2.2000e+01, 1.8215e+04, 1.7000e+01, 1.7000e+01, 4.6000e+01,
         4.3000e+01],
        [2.0000e+00, 1.4960e+03, 7.3690e+03, 1.1500e+02, 4.7820e+03, 3.7000e+01,
         2.2196e+04, 2.5200e+02, 2.6998e+04, 0.0000e+00, 2.8680e+04, 1.0000e+00,
         4.9600e+02, 2.1930e+03, 1.0370e+03, 9.0000e+00, 4.0720e+03, 3.8000e+02,
         2.7000e+01],
        [3.0000e+00, 4.4900e+02, 3.1000e+02, 9.0000e+00, 1.3000e+01, 8.0340e+03,
         3.1070e+03, 6.3900e+02, 1.3000e+01, 2.7958e+04, 6.3800e+02, 1.0000e+00,
         1.6800e+02, 1.7000e+01, 4.3000e+01, 2.7860e+03, 1.5000e+01, 1.6000e+02,
         1.5200e+02],
        [4.0000e+00, 5.1810e+03, 1.5182e+04, 1.8712e+04, 8.7700e+02, 1.6000e+01,
         4.2300e+02, 2.2000e+01, 5.6200e+02, 1.5750e+03, 4.9600e+02, 1.0000e+00,
         2.0900e+02, 1.0560e+03, 1.7000e+01

In [39]:
te[:,1:].shape

torch.Size([5, 19])

In [4]:
import data

In [5]:
corpus = data.Corpus('./data/wikitext-2')

In [14]:
def batchify(data, bsz):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    print(data)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz,-1).t().contiguous()
    return data

In [15]:
train_data = batchify(corpus.train, 20)

tensor([   0,    1,    2,  ..., 1575,  808,  209])


In [16]:
train_data[:10]

tensor([[    0,   284, 15178,   280,   348,   128,   289,  9493,    16,     1,
            13,     0,  2701,  1227,  1563,  4044,   115,  1352,  1335,    16],
        [    1,   357,    43,  2977,   530, 23080,    13,    78,    17,     0,
          4312,     0,   151,    22, 18215,    17,    17,    46,    43,  2015],
        [    2,  1496,  7369,   115,  4782,    37, 22196,   252, 26998,     0,
         28680,     1,   496,  2193,  1037,     9,  4072,   380,    27, 33001],
        [    3,   449,   310,     9,    13,  8034,  3107,   639,    13, 27958,
           638,     1,   168,    17,    43,  2786,    15,   160,   152,  3072],
        [    4,  5181, 15182, 18712,   877,    16,   423,    22,   562,  1575,
           496,     1,   209,  1056,    17,    39,   317, 19914,   128,   348],
        [    1,    13,    15,    22, 17314,   357,  1517,   209,  2156,   348,
           131,  2196,   146, 16561,   188, 31575,   348,    54,    52,   630],
        [    0,    17,   652, 17400,   115,  3

In [17]:
def get_batch(source, i):
    seq_len = min(5, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

In [18]:
data,targets = get_batch(train_data, 0)

In [19]:
data

tensor([[    0,   284, 15178,   280,   348,   128,   289,  9493,    16,     1,
            13,     0,  2701,  1227,  1563,  4044,   115,  1352,  1335,    16],
        [    1,   357,    43,  2977,   530, 23080,    13,    78,    17,     0,
          4312,     0,   151,    22, 18215,    17,    17,    46,    43,  2015],
        [    2,  1496,  7369,   115,  4782,    37, 22196,   252, 26998,     0,
         28680,     1,   496,  2193,  1037,     9,  4072,   380,    27, 33001],
        [    3,   449,   310,     9,    13,  8034,  3107,   639,    13, 27958,
           638,     1,   168,    17,    43,  2786,    15,   160,   152,  3072],
        [    4,  5181, 15182, 18712,   877,    16,   423,    22,   562,  1575,
           496,     1,   209,  1056,    17,    39,   317, 19914,   128,   348]])

In [20]:
targets

tensor([    1,   357,    43,  2977,   530, 23080,    13,    78,    17,     0,
         4312,     0,   151,    22, 18215,    17,    17,    46,    43,  2015,
            2,  1496,  7369,   115,  4782,    37, 22196,   252, 26998,     0,
        28680,     1,   496,  2193,  1037,     9,  4072,   380,    27, 33001,
            3,   449,   310,     9,    13,  8034,  3107,   639,    13, 27958,
          638,     1,   168,    17,    43,  2786,    15,   160,   152,  3072,
            4,  5181, 15182, 18712,   877,    16,   423,    22,   562,  1575,
          496,     1,   209,  1056,    17,    39,   317, 19914,   128,   348,
            1,    13,    15,    22, 17314,   357,  1517,   209,  2156,   348,
          131,  2196,   146, 16561,   188, 31575,   348,    54,    52,   630])