In [None]:
from init_notebook import *
import tokenizers
import transformers 
from src.datasets.fefe import FefePostIterableDataset

In [None]:
ds = FefePostIterableDataset().freeze()

In [None]:
char_map = {}
for text in tqdm(ds):
    for ch in text:
        char_map[ch] = char_map.get(ch, 0) + 1

In [None]:
print(len(char_map))
for key in sorted(char_map, key=lambda k: char_map[k], reverse=True):
    print(key, char_map[key])

In [None]:
small_chars = "".join(key for key in char_map if char_map[key] < 40)
small_chars

In [None]:
token_map_3 = {}
for text in tqdm(ds):
    for i in range(len(text)):
        for j in range(3, 6):
            chars = text[i:i+j]
            if len(chars) == j and " " not in chars and chars.isalpha():
                token_map_3[chars] = token_map_3.get(chars, 0) + 1

In [None]:
print(len(token_map_3))
for key in sorted(token_map_3, key=lambda k: token_map_3[k], reverse=True)[:1000]:
    print(key, token_map_3[key])

In [None]:
for key in sorted(token_map_3):
    if "ich" in key:
        print(key, token_map_3[key])

In [None]:
import tokenizers
tokenizers.pre_tokenizers.?

In [None]:
n = tokenizers.normalizers.Replace(" ", "X")
n.normalize_str("Hallo Welt")

In [None]:
ord("🅱"), ord("🇸")

In [None]:
tokenizer=tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
tokenizer.normalizer = tokenizers.normalizers.Sequence([
    #tokenizers.normalizers.NFD(), 
    #tokenizers.normalizers.NFKD(), 
    #tokenizers.normalizers.StripAccents(),
    tokenizers.normalizers.Replace(tokenizers.Regex(f"[{small_chars}]"), ""),
    tokenizers.normalizers.Replace("\n", "⬅"),
    tokenizers.normalizers.Replace(tokenizers.Regex(r"\s+"), "⬇"),
])
print(tokenizer.normalizer.normalize_str("Bla Blub 🅱 🇸\nnewline"))
#tokenizer.add_tokens([r"\s"])
if 1:
    tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Sequence([
        #tokenizers.pre_tokenizers.Whitespace()
        #tokenizers.pre_tokenizers.
        tokenizers.pre_tokenizers.Split(tokenizers.Regex(r"[⬇⬅]"), "contiguous"),
        #tokenizers.pre_tokenizers.Split(" ", "contiguous"),
        #tokenizers.pre_tokenizers.Split(tokenizers.Regex("\s+"), "contiguous"),
        tokenizers.pre_tokenizers.Punctuation(),
    ])
    #tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.CharDelimiterSplit(" ")
    import copy
    copy.deepcopy(tokenizer)
    print(tokenizer.pre_tokenizer.pre_tokenize_str("How's \"life\"?\nNext line"))
    for post in ds.limit(10):
        print([i[0] for i in tokenizer.pre_tokenizer.pre_tokenize_str(tokenizer.normalizer.normalize_str(post))])

# train tokenizer

In [None]:
initial_alphabet = [chr(c) for c in range(33, 127)] + ["⬇", "ä", "ö", "ü", "Ä", "Ö", "Ü", "ß"]
special_tokens = ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]
trainer = tokenizers.trainers.BpeTrainer(
    show_progress=True,
    vocab_size=4096, 
    special_tokens=special_tokens,
    max_token_length=10,
    min_frequency=10,
    initial_alphabet=initial_alphabet,
    #limit_alphabet=len(initial_alphabet) + 50,
)

In [None]:
tokenizer.train_from_iterator(ds, trainer=trainer)
" ".join(k for k in tokenizer.get_vocab() if len(k) == 1)

In [None]:
for post in ds.skip(0).limit(10):
    tokens = tokenizer.encode(post)
    print(tokens.tokens)
    #print(tokenizer.id_to_token(tokens.ids))
    print("".join(tokenizer.id_to_token(id) for id in tokens.ids).replace("⬇", " "))
    #print(tokenizer.decode(tokens.ids))

In [None]:
#tokenizer.model.__setstate__ = None
copy.deepcopy(tokenizer.model)
#os.makedirs("/tmp/tok-model-DELME/", exist_ok=True)
#files = tokenizer.model.save("/tmp/tok-model-DELME/")
#new_model = tokenizer.model.__class__.from_file(*files)
#copy.deepcopy(new_model)

# save tokenizer

In [None]:
transformers.PreTrainedTokenizerFast(
    tokenizer_object=tokenizer,
    clean_up_tokenization_spaces=True,
    bos_token="[BOS]",
    eos_token="[EOS]",
    unk_token="[UNK]",
    pad_token="[PAD]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    mask_token="[MASK]",
    padding_side="left",
).save_pretrained(str(config.SMALL_DATASETS_PATH / "fefe" / "tokenizer-bpe-4096-spaces"))

In [None]:
#tokenizers.tokenizers.
tokenizer.encode("Hello")

# load tokenizer

In [None]:
fast_tok = transformers.AutoTokenizer.from_pretrained(str(config.SMALL_DATASETS_PATH / "fefe" / "tokenizer-bpe-4096"))
fast_tok

In [None]:
sentence = """Let's test tokenization and un-tokenization, und mal *kucken* was "hier" passiert?!"""
token_ids = fast_tok(sentence).input_ids
print(len(token_ids), token_ids)
print(fast_tok.decode(token_ids))

In [None]:
from src.models.minimind import *

In [None]:
model = MiniMindLM(
    LMConfig(
        dim=512,
        n_layers=8,
        n_heads=8,
        n_kv_heads=2,
        vocab_size=4096,
        hidden_dim=None,
        multiple_of=64,
    )
)
print(f"params: {num_module_parameters(model):,}")
model

In [None]:
LMConfig?

In [None]:
from src.train.experiment import load_experiment_trainer
trainer = load_experiment_trainer("../experiments/minimind/fefe.yml", device="cpu")

In [None]:
tokens = fast_tok("Und was ich noch sagen wollte", return_tensors="pt")

In [None]:
trainer.model.cpu()
out_tokens = trainer.model.generate(
    tokens.input_ids, 
    eos_token_id=fast_tok.eos_token_id,
)

In [None]:
#fast_tok.decode?
fast_tok.decode(out_tokens.flatten(0))

# better batching

In [None]:
def iter_snippets(batch_size: int = 16):
    seq_length = 64
    count = 0
    for text in ds:
        encoding = fast_tok(
            text,
            #max_length=64,
            #padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        input_ids = encoding.input_ids.squeeze()
        while True:
            count += 1
            if count % batch_size == 0:
                seq_length = random.randint(64, 128)
                
            if input_ids.shape[0] == seq_length:
                yield input_ids
                break
            elif input_ids.shape[0] < seq_length:
                yield torch.cat([
                    torch.ones((seq_length - input_ids.shape[0], ), dtype=input_ids.dtype) * fast_tok.pad_token_id,
                    input_ids
                ])
                break
            else:
                yield input_ids[:seq_length]
                input_ids = input_ids[seq_length // 2:]

counts = {}
for i, text in tqdm(zip(range(10000), iter_snippets())):
    key = len(text)
    counts[key] = counts.get(key, 0) + 1
df = pd.DataFrame(counts.values(), index=counts.keys()).sort_index()
px.bar(df)