In [64]:
import tempfile, json, sys
from pathlib import Path
sys.path.append(str(Path().resolve().parent))
from selfies_diffusion.selfies_grammar import grammar
from base_tokenizer import output_json
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast
from transformers.tokenization_utils import AddedToken

all_tokens = set()

In [65]:
grammar.all_tokens()
slist = ["B", "=B", "#B", "/B", r"\B",
    "C", "=C", "#C", "/C", r"\C",
    "N", "=N", "#N", "/N", r"\N",
    "O", "=O", "#O", "/O", r"\O",
    "S", "=S", "#S", "/S", r"\S",
    "P", "=P", "#P", "/P", r"\P",
    "F", "=F", "#F", "/F", r"\F",
    "Cl", "=Cl", "#Cl", "/Cl", r"\Cl",
    "Br", "=Br", "#Br", "/Br", r"\Br",
    "I", "=I", "#I", "/I", r"\I",
    "Ring1", "=Ring1", "#Ring1", "/Ring1", r"\Ring1",
    "Ring2", "=Ring2", "#Ring2", "/Ring2", r"\Ring2",
    "Ring3", "=Ring3", "#Ring3", "/Ring3", r"\Ring3",
    "Branch", "=Branch", "#Branch",
    "->", "pop"]

In [66]:
all_tokens.add('[:0furan]')
all_tokens = all_tokens.union(
    grammar.all_tokens(),
    set(f'[{s}]' for s in slist)
)

start_idx = 7
vocab = output_json['model']['vocab']
vocab.update({
  e: idx for idx, e in enumerate(list(all_tokens), start_idx)  
})

with tempfile.NamedTemporaryFile(delete=True) as tmp_file:
    with open(tmp_file.name, 'w') as f:
        json.dump(output_json, f, indent=2)
    tokenizer = PreTrainedTokenizerFast(
        tokenizer_object=Tokenizer.from_file(tmp_file.name)
    )
bos_token = AddedToken('<bos>', lstrip=False, rstrip=False)
eos_token = AddedToken('<eos>', lstrip=False, rstrip=False)
sep_token = AddedToken('<sep>', lstrip=False, rstrip=False)
cls_token = AddedToken('<cls>', lstrip=False, rstrip=False)
unk_token = AddedToken('<unk>', lstrip=False, rstrip=False)
pad_token = AddedToken('<pad>', lstrip=False, rstrip=False)
mask_token = AddedToken('<mask>', lstrip=True, rstrip=False) # include space in front
num_tokens_added = tokenizer.add_special_tokens({
    'bos_token': bos_token,
    'eos_token': eos_token,
    'sep_token': sep_token,
    'unk_token': unk_token,
    'cls_token': cls_token,
    'pad_token': pad_token,
    'mask_token': mask_token,
})
assert num_tokens_added == 0
tokenizer.model_max_length = 9999 # this shouldn't matter
tokenizer.save_pretrained('../chembl_tokenizer')


('../chembl_tokenizer/tokenizer_config.json',
 '../chembl_tokenizer/special_tokens_map.json',
 '../chembl_tokenizer/tokenizer.json')

In [67]:
from datasets import Dataset
tset, vset = Dataset.from_text('../chembl_selfies_subset.txt').train_test_split(0.2).values()

In [68]:
from group_selfies import GroupGrammar
from rdkit.Chem import MolFromSmiles, MolToSmiles
MolToSmiles(grammar.decoder(tset[121]['text'].replace(' ', '')))

'Cc1c(CCC(=O)NCCC(=O)O)c(=O)oc2cc3occ(-c4ccccc4)c3cc12'