Pretraining IotBert 

In [None]:
from torch.utils.data import Dataset
from tokenizers.implementations import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from pathlib import Path
import torch

In [None]:
# Check that we have a GPU
!nvidia-smi

In [None]:

torch.cuda.is_available()

Define model configuration

In [None]:
from transformers import RobertaConfig
from transformers import RobertaForMaskedLM
config = RobertaConfig(
    vocab_size=52_000,
    max_position_embeddings=578,
    num_attention_heads=12,
    num_hidden_layers=6,
    type_vocab_size=1,
)

model = RobertaForMaskedLM(config=config)
model.num_parameters() # => 83.5 million parameters

In [None]:
from transformers import RobertaTokenizerFast

tokenizer = RobertaTokenizerFast.from_pretrained("./models", max_len=576)

Load/Save pre-training dataset

In [None]:

class IoTDataset(Dataset):
    def __init__(self, evaluate: bool = False):
        tokenizer = ByteLevelBPETokenizer(
            "./models/vocab.json",
            "./models/merges.txt",
        )
        tokenizer._tokenizer.post_processor = BertProcessing(
            ("</s>", tokenizer.token_to_id("</s>")),
            ("<s>", tokenizer.token_to_id("<s>")),
        )
        tokenizer.enable_truncation(max_length=576)
        # or use the RobertaTokenizer from `transformers` directly.

        self.examples = []

        src_files = Path("./data/").glob("*_7.csv") if evaluate else Path("./data/").glob("*.csv")
        for src_file in src_files:
            print("🔥", src_file)
            lines = src_file.read_text(encoding="utf-8").splitlines()
            self.examples += [x.ids for x in tokenizer.encode_batch(lines)]

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        # We’ll pad at the batch level.
        return torch.tensor(self.examples[i])

In [None]:
data=IoTDataset()

In [None]:
## Dumping pretraning dataset into pickle

import pickle
pickle.dump( data, open( "./data/tokenizer/dataset-drapgh.pkl", "wb" ))

In [None]:
## Loading pretraining dataset from pickle
import pickle

data = pickle.load(open( "./data/tokenizer/dataset-drapgh.pkl", "rb" ))

In [None]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

Initialize Trainer

In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./models",
    do_train=True,
    overwrite_output_dir=True,
    num_train_epochs=50,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=8, # increase batch_size 32 x 8 =256  nice doh  gradient accumulation
    save_steps=10_000,
    save_total_limit=5,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=data
)

Start Training

In [None]:
import datetime
now = datetime.datetime.now()

print(now)
trainer.train()

Save final model (+ tokenizer + config) to disk

In [None]:
trainer.save_model("./models")

Check that the LM actually trained

In [None]:
from transformers import pipeline

fill_mask = pipeline(
    "fill-mask",
    model="./models",
    tokenizer="./models"
)

In [None]:
fill_mask("sp:4856 dp:49152 <mask>.")

In [None]:
fill_mask("sp:4856 <mask> ptcl:6 ipv:4 vln:0 tnnl:0 bi_dur:34 bi_pkt:14 bi_byte:4544 s2d_dur:32 s2d:7 s2d_byte:555 d2s_dur:33 d2s:7 d2s_byte:3989 bi_min_ps:52 bi_mean_ps:324.57 bi_std_ps:516.42 bi_max_ps:1500 s2d_min_ps:52 s2d_mean_ps:79.29 s2d_std_ps:68.73 s2d_max_ps:235 d2s_min_ps:52 d2s_mean_ps:569.86 d2s_std_ps:657.82 d2s_max_ps:1500 bi_min_pi_ms:0 bi_mean_pi_ms:2.62 bi_std_pi_ms:2.26 bi_max_pi_ms:6 s2d_min_pi_ms:1 s2d_mean_pi_ms:5.33 s2d_std_pi_ms:5.5 s2d_max_pi_ms:16 d2s_min_pi_ms:0 d2s_mean_pi_ms:5.5 d2s_std_pi_ms:6.06 d2s_max_pi_ms:16 bi_syn:2 bi_cwr:0 bi_ece:0 bi_urg:0 bi_ack:13 bi_psh:3 bi_rst:0 bi_fin:2 s2d_syn:1 s2d_cwr:0 s2d_ece:0 s2d_urg:0 s2d_ack:6 s2d_psh:1 s2d_rst:0 s2d_fin:1 d2s_syn:1 d2s_cwr:0 d2s_ece:0 d2s_urg:0 d2s_ack:7 d2s_psh:2 d2s_rst:0 d2s_fin:1 app_name:HTTP app_cat:Web req_server_name:192.168.1.223 client_fingerprint:nan server_fingerprint:nan content_type:text/xml")
