In [1]:
import argparse
import os, tqdm
import logging
import pandas as pd
import numpy as np
import re
import gzip
import json
import tokenizers
from tokenizers import Tokenizer, Regex
from tokenizers.models import BPE, Unigram
from tokenizers.trainers import BpeTrainer, UnigramTrainer
# from sentencepiece import SentencePieceTrainer
from tokenizers.normalizers import Strip, Replace, Lowercase
from tokenizers.pre_tokenizers import Split
from itertools import product
import random
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def save_huggingface_tokenizer(tokenizer, tokenizer_dir, special_tokens_map):
	os.makedirs(tokenizer_dir, exist_ok=True)
	tokenizer.save(f"{tokenizer_dir}/tokenizer.json")
	tokenizer.model.save(tokenizer_dir)

	 # Separate core special tokens from additional special tokens
	core_special_tokens = {
		"unk_token": special_tokens_map["unk_token"],
		"cls_token": special_tokens_map["cls_token"],
		"sep_token": special_tokens_map["sep_token"],
		"pad_token": special_tokens_map["pad_token"],
		"mask_token": special_tokens_map["mask_token"],
	}

	additional_special_tokens = [
		token for key, token in special_tokens_map.items() if not key in core_special_tokens
	]

	# Save tokenizer configuration
	tokenizer_config = {
		"tokenizer_class": "PreTrainedTokenizerFast",
		"model_type": "bert",
		**core_special_tokens,
		"additional_special_tokens": additional_special_tokens,
	}

	with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
		json.dump(tokenizer_config, f)
  
	with open(os.path.join(tokenizer_dir, "special_tokens_map.json"), "w") as f:
		json.dump(special_tokens_map, f)


In [3]:
class Params():
	...

params = Params()
params.train_bp = 100_000
params.n_service_tokens = 128
# vocab: all k-mers up to k==6, add "N", and add params.n_service_tokens
params.vocab_size = sum([4**n for n in range(1, 6)]) + 1 + params.n_service_tokens
params.N_meaningful_tokens = params.vocab_size - params.n_service_tokens
assert params.N_meaningful_tokens < params.vocab_size
params.data_hash = "syntetic_data"
params.train_model = "ugram"
if params.train_model == "ugram" or params.train_model == "ugram_sp":
	params.shrinking_factor = 0.75
	params.max_piece_length = 6
tokenizer_dir = "../data/tokenizers/" + params.train_model + "/"


In [4]:
special_tokens_map = {
	"unk_token": "[UNK]",
	"cls_token": "[CLS]",
	"sep_token": "[SEP]",
	"pad_token": "[PAD]",
	"mask_token": "[MASK]",
	"gap_token": "-"
}

N_extra_service_tokens = params.vocab_size - len(special_tokens_map) - params.N_meaningful_tokens
assert N_extra_service_tokens > 0
params.N_service_tokens = len(special_tokens_map) + params.N_meaningful_tokens

for i in range(N_extra_service_tokens):
	special_tokens_map[f"st_{i}"] = f"[ST{i}]"

print ("Training tokenizer...")


normalizer = tokenizers.normalizers.Sequence([
									Strip(),
									Lowercase(),
									Replace(Regex(r'n{10,}'), '-')
									])
def seq_processor(seq):
	split = Split(pattern='-', behavior='isolated')
	seq_processed = split.pre_tokenize_str(normalizer.normalize_str(seq))
	return "".join([i[0] for i in seq_processed])

if params.train_model == "ugram":
	tokenizer = Tokenizer(Unigram())
	tokenizer.normalizer = normalizer
	tokenizer.pre_tokenizer = Split(pattern='-', behavior='isolated')
	tokenizer.post_processor = tokenizers.processors.BertProcessing(
		sep=('[SEP]', 2), cls=('[CLS]', 1)
	)

	trainer = UnigramTrainer(
		vocab_size=params.vocab_size,
		special_tokens=list(special_tokens_map.values()),
		initial_alphabet=["a","t","g","c","n"],
		shrinking_factor = params.shrinking_factor,
		max_piece_length = params.max_piece_length,
		show_progress=True,
		unk_token=special_tokens_map["unk_token"]
	)
	n_samples = 100
	def get_random_sequence(L):
		letters = list("ATGCatgc")*L
		random.shuffle(letters)
		return "".join(letters)
		
	tokenizer.train_from_iterator([
		get_random_sequence(1000) for _ in range(n_samples)
		],
								  trainer=trainer, length=n_samples
								  )
	# after training, we restore the normalizer to the original one
	# tokenizer.normalizer = normalizer
	print("Tokenizer training complete. Saving to file...")
	save_huggingface_tokenizer(tokenizer, os.path.join(tokenizer_dir,"tokenizer/"), special_tokens_map)   

print(f'Data and logs saved to {os.path.join(tokenizer_dir,"tokenizer/")}')

Training tokenizer...


Tokenizer training complete. Saving to file...
Data and logs saved to ../data/tokenizers/ugram/tokenizer/


In [65]:
unigram_tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir + "/tokenizer/")
unigram_tokenizer("AATATAATATAGTAG")



{'input_ids': [1, 402, 402, 463, 2], 'token_type_ids': [0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1]}

In [66]:
unigram_tokenizer.unk_token_id

0

In [72]:
unigram_tokenizer.special_tokens_map

{'unk_token': '[UNK]',
 'sep_token': '[SEP]',
 'pad_token': '[PAD]',
 'cls_token': '[CLS]',
 'mask_token': '[MASK]',
 'additional_special_tokens': ['-',
  '[ST0]',
  '[ST1]',
  '[ST2]',
  '[ST3]',
  '[ST4]',
  '[ST5]',
  '[ST6]',
  '[ST7]',
  '[ST8]',
  '[ST9]',
  '[ST10]',
  '[ST11]',
  '[ST12]',
  '[ST13]',
  '[ST14]',
  '[ST15]',
  '[ST16]',
  '[ST17]',
  '[ST18]',
  '[ST19]',
  '[ST20]',
  '[ST21]',
  '[ST22]',
  '[ST23]',
  '[ST24]',
  '[ST25]',
  '[ST26]',
  '[ST27]',
  '[ST28]',
  '[ST29]',
  '[ST30]',
  '[ST31]',
  '[ST32]',
  '[ST33]',
  '[ST34]',
  '[ST35]',
  '[ST36]',
  '[ST37]',
  '[ST38]',
  '[ST39]',
  '[ST40]',
  '[ST41]',
  '[ST42]',
  '[ST43]',
  '[ST44]',
  '[ST45]',
  '[ST46]',
  '[ST47]',
  '[ST48]',
  '[ST49]',
  '[ST50]',
  '[ST51]',
  '[ST52]',
  '[ST53]',
  '[ST54]',
  '[ST55]',
  '[ST56]',
  '[ST57]',
  '[ST58]',
  '[ST59]',
  '[ST60]',
  '[ST61]',
  '[ST62]',
  '[ST63]',
  '[ST64]',
  '[ST65]',
  '[ST66]',
  '[ST67]',
  '[ST68]',
  '[ST69]',
  '[ST70]',
  '[S

In [73]:
all_unigram_special_tokens = []
for k,v in unigram_tokenizer.special_tokens_map.items():
	if k != "additional_special_tokens":
		all_unigram_special_tokens.append(v)
all_unigram_special_tokens += unigram_tokenizer.special_tokens_map["additional_special_tokens"]
all_unigram_special_tokens

['[UNK]',
 '[SEP]',
 '[PAD]',
 '[CLS]',
 '[MASK]',
 '-',
 '[ST0]',
 '[ST1]',
 '[ST2]',
 '[ST3]',
 '[ST4]',
 '[ST5]',
 '[ST6]',
 '[ST7]',
 '[ST8]',
 '[ST9]',
 '[ST10]',
 '[ST11]',
 '[ST12]',
 '[ST13]',
 '[ST14]',
 '[ST15]',
 '[ST16]',
 '[ST17]',
 '[ST18]',
 '[ST19]',
 '[ST20]',
 '[ST21]',
 '[ST22]',
 '[ST23]',
 '[ST24]',
 '[ST25]',
 '[ST26]',
 '[ST27]',
 '[ST28]',
 '[ST29]',
 '[ST30]',
 '[ST31]',
 '[ST32]',
 '[ST33]',
 '[ST34]',
 '[ST35]',
 '[ST36]',
 '[ST37]',
 '[ST38]',
 '[ST39]',
 '[ST40]',
 '[ST41]',
 '[ST42]',
 '[ST43]',
 '[ST44]',
 '[ST45]',
 '[ST46]',
 '[ST47]',
 '[ST48]',
 '[ST49]',
 '[ST50]',
 '[ST51]',
 '[ST52]',
 '[ST53]',
 '[ST54]',
 '[ST55]',
 '[ST56]',
 '[ST57]',
 '[ST58]',
 '[ST59]',
 '[ST60]',
 '[ST61]',
 '[ST62]',
 '[ST63]',
 '[ST64]',
 '[ST65]',
 '[ST66]',
 '[ST67]',
 '[ST68]',
 '[ST69]',
 '[ST70]',
 '[ST71]',
 '[ST72]',
 '[ST73]',
 '[ST74]',
 '[ST75]',
 '[ST76]',
 '[ST77]',
 '[ST78]',
 '[ST79]',
 '[ST80]',
 '[ST81]',
 '[ST82]',
 '[ST83]',
 '[ST84]',
 '[ST85]',
 '[ST86

In [74]:
print (len(unigram_tokenizer.get_vocab()))

1493


In [75]:
def get_all_k_mers(k, current_k_mer=""):
	if len(current_k_mer) == k:
		yield current_k_mer
	else:
		for nucleotide in "atgc":
			yield from get_all_k_mers(k, current_k_mer + nucleotide)

count = 0
vocab = unigram_tokenizer.get_vocab()
for k_mer in get_all_k_mers(6):
	if not k_mer in vocab:
		print ("missing k-mer", k_mer)
		count += 1
		# break
print (f"count: {count}")

missing k-mer aaaaaa
missing k-mer aaaaat
missing k-mer aaaaag
missing k-mer aaaaac
missing k-mer aaaata
missing k-mer aaaatg
missing k-mer aaaatc
missing k-mer aaaaga
missing k-mer aaaagt
missing k-mer aaaagg
missing k-mer aaaagc
missing k-mer aaaaca
missing k-mer aaaact
missing k-mer aaataa
missing k-mer aaatat
missing k-mer aaatag
missing k-mer aaatac
missing k-mer aaatga
missing k-mer aaatgt
missing k-mer aaatgg
missing k-mer aaatgc
missing k-mer aaatca
missing k-mer aaatct
missing k-mer aaatcg
missing k-mer aaatcc
missing k-mer aaagac
missing k-mer aaagta
missing k-mer aaagtt
missing k-mer aaagtg
missing k-mer aaagtc
missing k-mer aaagga
missing k-mer aaaggt
missing k-mer aaaggg
missing k-mer aaaggc
missing k-mer aaagca
missing k-mer aaagct
missing k-mer aaagcg
missing k-mer aaagcc
missing k-mer aaacaa
missing k-mer aaacac
missing k-mer aaacta
missing k-mer aaactt
missing k-mer aaactg
missing k-mer aaactc
missing k-mer aaacga
missing k-mer aaacgg
missing k-mer aaacca
missing k-mer

In [99]:
unigram_suffix = "tokenizer"
new_unigram_suffix = "6mer"
if not os.path.exists(os.path.join(tokenizer_dir, new_unigram_suffix)):
	os.makedirs(os.path.join(tokenizer_dir, new_unigram_suffix))
	
# process tokenizer_config.json
with open(os.path.join(tokenizer_dir, unigram_suffix, "tokenizer_config.json"), "r") as f:
	config = json.load(f)

# copy tokenizer_config.json to new unigram tokenizer dir
with open(os.path.join(tokenizer_dir, new_unigram_suffix, "tokenizer_config.json"), "w") as f:
	json.dump(config, f)

In [100]:
# process unigram.json
with open(os.path.join(tokenizer_dir, unigram_suffix, "unigram.json"), "r") as f:
	unigram = json.load(f)

print (f"len(unigram['vocab']): {len(unigram['vocab'])}")

# first add all special tokens
new_vocab = []
added_tokens = []
for v in unigram["vocab"]:
	token = v[0]
	if token in all_unigram_special_tokens and not token in added_tokens:
		assert not token in added_tokens, f"Special token {token} already in new_vocab"
		new_vocab.append(v)
		added_tokens.append(token)
print (f"{len(new_vocab)} spetial tokens kept")

# now add all k-mers
kmer_derived_vocab = []
for v in sum([list(get_all_k_mers(i)) for i in range(1, 7)], start=[])+["n"]:
	assert not v in kmer_derived_vocab
	kmer_derived_vocab.append([v.lower(), -1*len(v)])
print (f"{len(kmer_derived_vocab)} kmer derived tokens")

new_vocab.extend(kmer_derived_vocab)

print (f"len(new_vocab): {len(new_vocab)}")
unigram["vocab"] = new_vocab

# copy unigram.json to new unigram tokenizer dir
print (os.path.join(tokenizer_dir, new_unigram_suffix, "unigram.json"))
with open(os.path.join(tokenizer_dir, new_unigram_suffix, "unigram.json"), "w") as f:
	json.dump(unigram, f)

len(unigram['vocab']): 1493
128 spetial tokens kept
5461 kmer derived tokens
len(new_vocab): 5589
../data/tokenizers/ugram/6mer/unigram.json


In [101]:
# process tokenizer.json
with open(os.path.join(tokenizer_dir, unigram_suffix, "tokenizer.json"), "r") as f:
	tokenizer = json.load(f)

tokenizer["model"] = unigram

with open(os.path.join(tokenizer_dir, new_unigram_suffix, "tokenizer.json"), "w") as f:
	json.dump(tokenizer, f)

In [102]:
# process special_tokens_map.json
with open(os.path.join(tokenizer_dir, unigram_suffix, "special_tokens_map.json"), "r") as f:
	special_tokens_map = json.load(f)

with open(os.path.join(tokenizer_dir, new_unigram_suffix, "special_tokens_map.json"), "w") as f:
	json.dump(special_tokens_map, f)

In [103]:
kmer_based_unigram_tokenizer = AutoTokenizer.from_pretrained(os.path.join(tokenizer_dir, new_unigram_suffix))
seq = "ANNNNNNNNNNNNNNNNNNNTGCATGsTTTTCTACTATCTATCGATC"
print(seq)
kmer_based_unigram_tokenizer.tokenize(seq)

ANNNNNNNNNNNNNNNNNNNTGCATGsTTTTCTACTATCTATCGATC




['a', '-', 'tgcatg', 's', 'tt', 'ttctac', 'tatcta', 'tcgatc']

In [104]:
kmer_based_unigram_tokenizer("NNNNNNNNNNNNNNsssssdfsdfATsATsdfdssssss", add_special_tokens=True, 
return_tensors="pt", return_attention_mask=True, return_offsets_mapping=True)

{'input_ids': tensor([[  1,   5,   0, 133,   0, 133,   0,   2]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]]), 'offset_mapping': tensor([[[ 0,  0],
         [13, 14],
         [14, 24],
         [24, 26],
         [26, 27],
         [27, 29],
         [29, 39],
         [ 0,  0]]])}

In [105]:
kmer_based_unigram_tokenizer.decode(0)

'[UNK]'

In [106]:
kmer_based_unigram_tokenizer.encode("Q")

[1, 0, 2]

In [107]:
gena = AutoTokenizer.from_pretrained("AIRI-Institute/gena-lm-bert-base-t2t")


In [108]:
gena("NNNNNNNNNNNNNs")

{'input_ids': [1, 5, 0, 2], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}

In [109]:
kmer_based_unigram_tokenizer("NNNNNNNNNNNNNs")

{'input_ids': [1, 5, 0, 2], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}

In [110]:
unigram_tokenizer("NNNNNNNNNNNNNs")

{'input_ids': [1, 5, 0, 2], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}

In [118]:
(kmer_based_unigram_tokenizer("A"), 
kmer_based_unigram_tokenizer("TG"), 
kmer_based_unigram_tokenizer("ATG"),
kmer_based_unigram_tokenizer("ATGC"),
kmer_based_unigram_tokenizer("ATGCA"),
kmer_based_unigram_tokenizer("ATGCAT"),
kmer_based_unigram_tokenizer("ATGCATG"))

({'input_ids': [1, 128, 2], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]},
 {'input_ids': [1, 138, 2], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]},
 {'input_ids': [1, 154, 2], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]},
 {'input_ids': [1, 239, 2], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]},
 {'input_ids': [1, 576, 2], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]},
 {'input_ids': [1, 1925, 2], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]},
 {'input_ids': [1, 128, 3226, 2], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]})

In [119]:
!mv ../data/tokenizers/6mer ../data/tokenizers/

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


mv: '../data/tokenizers/6mer' and '../data/tokenizers/6mer' are the same file


In [122]:
kmer_based_unigram_tokenizer("atgcatg")

{'input_ids': [1, 128, 3226, 2], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}

In [120]:
print (len(kmer_based_unigram_tokenizer.get_vocab()))

5589
