In [34]:
import os
import numpy as np

from typing import Dict, List, Tuple, Union, Optional, Any

import transformers
from transformers import PreTrainedTokenizer, DataCollator
from transformers import BertTokenizer, BertForPreTraining, Trainer, TrainingArguments

import torch
from torch import nn
from torch.utils.data import Dataset

MAX_LENGTH = 128
VOCAB_SIZE = 256

In [35]:
split = 'train'
scenario = 'desktop'

MEAN_STD = {
    'desktop': {'time': {'mean': 585.9952015355086, 'std': 359.3659389758254}, 'duration': {'mean': 139.65271966527197, 'std': 72.2050536523459}}
}

In [36]:
from key_mappings import readable_keymap
assert min(readable_keymap.keys()) > 6

PAD = 0
CLS = 1
SEP = 2
BOS = 3
MASK = 4
RESERVED = 5
UNK = 6

SPECIAL_CODEPOINTS: Dict[int, str] = {
    CLS: "[CLS]",
    SEP: "[SEP]",
    BOS: "[BOS]",
    MASK: "[MASK]",
    PAD: "[PAD]",
    RESERVED: "[RESERVED]",
    UNK: "[UNK]",
}
SPECIAL_CODEPOINTS_LIST = list(SPECIAL_CODEPOINTS.keys())
SPECIAL_CODEPOINTS_BY_NAME: Dict[str, int] = {name: codepoint for codepoint, name in SPECIAL_CODEPOINTS.items()}

class KvcDataset(Dataset):
    """
    Each example has the following format (all integers):
    keys: t_ms, duration_ms, key_code, is_special_token
    label: user_id
    """

    def __init__(self, scenario: str, split: str, block_size: int, clip: Optional[int] = None):
        self.clip = clip
        self.max_length = block_size - 2  # -2 for CLS and SEP tokens

        assert split == 'train'  # todo: implement test split
        if split == 'train':
            filename = f'{scenario}/{scenario}_dev_set.npy'
        else:
            assert split == 'test'
            filename = f'{scenario}/{scenario}_test_sessions.npy'
        assert os.path.isfile(filename), f"Input file path {filename} not found"
        raw_data = np.load(filename, allow_pickle=True).item()

        self.user_ids = {uid: i for i, uid in enumerate(raw_data.keys())}
        self.examples = []
        for user, sessions in raw_data.items():
            for _, session in sessions.items():
                sample = self.preprocess(session)
                self.examples.append({
                    'input_ids': torch.tensor(sample, dtype=torch.long),
                    'user': self.user_ids[user]
                })
            break  # TODO remove me!

    def preprocess(self, data):
        data = data[:self.max_length]
        data[:, :2] = (data[:, :2] - data[0][0])
        t = np.diff(data[:, 0], prepend=data[0, 0])
        duration = np.abs(data[:, 1] - data[:, 0])  # fix duration of keypress. ~349 items are broken
        key = np.vectorize(lambda x: x if x in readable_keymap else UNK)(data[:, 2])
        special_tokens = np.zeros(len(key), dtype=int)
        special_tokens[key == UNK] = 1
        cls_row = np.array([0, 0, CLS, 1])
        sep_row = np.array([0, 0, SEP, 1])
        prepared = np.vstack((
            cls_row, 
            np.column_stack((t, duration, key, special_tokens)), 
            sep_row
        ))
        return prepared

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i) -> Dict[str, torch.tensor]:
        return self.examples[i]
    
train_dataset = KvcDataset('desktop', 'train', MAX_LENGTH)

In [59]:
from transformers import DataCollatorForLanguageModeling

class MyDataCollatorForLanguageModeling():
    def __init__(self):
        self.mlm: bool = True
        self.mlm_probability: float = 0.15

    def __call__(self, features, return_tensors='pt'):
        assert return_tensors == "pt", f'Only return_tensors="pt" is supported; got {return_tensors}'
        return self.torch_call(features)
    
    def _collate_batch(self, examples):
        length_of_first = examples[0].size(0)
        are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
        if are_tensors_same_length:
            result = torch.stack(examples, dim=0)
        else:
            max_length = max(x.size(0) for x in examples)
            result = examples[0].new_full([len(examples), max_length, 4], PAD)
            for i, example in enumerate(examples):
                result[i, : example.shape[0]] = example
        return result

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        batch = {
            "input_ids": self._collate_batch([example["input_ids"] for example in examples]),
        }

        if self.mlm:
            batch["input_ids"], batch["labels"] = self.torch_mask_tokens(batch["input_ids"])
        else:
            raise NotImplementedError
            labels = batch["input_ids"].clone()
            labels[labels == PAD] = -100
            batch["labels"] = labels
        return batch

    def torch_mask_tokens(self, inputs: Any) -> Tuple[Any, Any]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        inputs, special_tokens_mask = inputs[:, :, :3], inputs[:, :, 3].bool()
        labels = inputs.clone()

        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(special_tokens_mask.shape, self.mlm_probability)

        probability_matrix[special_tokens_mask] = 0.0
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(special_tokens_mask.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced][2] = MASK  # mask the key_code

        # TODO
        # 10% of the time, we replace masked input tokens with random word
        # indices_random = torch.bernoulli(torch.full(special_tokens_mask.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        # random_words = torch.randint(len(SPECIAL_CODEPOINTS_LIST), labels.shape, dtype=torch.long)
        # inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

# Define the data collator
data_collator = MyDataCollatorForLanguageModeling()

In [60]:
from transformers import BertConfig, BertForMaskedLM


config = BertConfig(
    vocab_size=VOCAB_SIZE,
    hidden_size=128,
    num_hidden_layers=2,
    num_attention_heads=4,
)
model = BertForMaskedLM(config=config)

Generate config GenerationConfig {
  "pad_token_id": 0
}



In [61]:
from transformers import Trainer, TrainingArguments

# Define the training arguments
training_args = TrainingArguments(
    output_dir="./bert_finetuned",
    overwrite_output_dir=True,
    num_train_epochs=10,
    per_device_train_batch_size=256,
    save_steps=10_000,
    save_total_limit=2,
    save_safetensors=True,
)

# Define the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
)

# Train the model
trainer.train()

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
***** Running training *****
  Num examples = 15
  Num Epochs = 10
  Instantaneous batch size per device = 256
  Total train batch size (w. parallel, distributed & accumulation) = 256
  Gradient Accumulation steps = 1
  Total optimization steps = 10
  Number of trainable parameters = 1,828,224


  0%|          | 0/10 [00:00<?, ?it/s]

The following columns in the training set don't have a corresponding argument in `BertForMaskedLM.forward` and have been ignored: user. If user are not expected by `BertForMaskedLM.forward`,  you can safely ignore this message.
We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


ValueError: too many values to unpack (expected 2)