In [28]:
from transformers import AutoTokenizer
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast
from tokenizers.models import WordLevel, BPE
from tokenizers.pre_tokenizers import Whitespace,Split,ByteLevel, WhitespaceSplit
from tokenizers.normalizers import Lowercase, NFKC
import os
import polars as pl
from joblib import Parallel, delayed
import multiprocessing
import numpy as np
from tqdm import tqdm
import time
import json
from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer
import gc
from transformers import AutoConfig, AutoTokenizer, AutoModel, DataCollatorWithPadding, DataCollatorForLanguageModeling
import mapply
from collections import Counter
from rdkit import Chem
from rdkit.Chem import AllChem
from functools import partial

multiprocessing.cpu_count()

80

In [2]:
test_df = pl.scan_csv('/home/dangnh36/datasets/competitions/leash_belka/processed/test_v4.csv').select(
        pl.col('molecule'),
#         pl.col('bb1', 'bb2', 'bb3').cast(pl.UInt16),
        # pl.col('BRD4', 'HSA', 'sEH').cast(pl.UInt8),
    ).collect()
print(test_df.estimated_size('gb'), 'GB')
test_df

0.06128192972391844 GB


molecule
str
"""C#CCCC[C@H](Nc…"
"""C#CCCC[C@H](Nc…"
"""C#CCCC[C@H](Nc…"
"""C#CCCC[C@H](Nc…"
"""C#CCCC[C@H](Nc…"
…
"""Cn1ncc2cc(Nc3n…"
"""[N-]=[N+]=NCCC…"
"""COC(=O)c1ccnc(…"
"""COC1CCC(CCNc2n…"


In [3]:
tokenizer = AutoTokenizer.from_pretrained('/home/dangnh36/datasets/competitions/leash_belka/processed/tokenizer_v2/smiles_char/',
                                         trust_remote_code=True
                                         )
tokenizer

PreTrainedTokenizerFast(name_or_path='/home/dangnh36/datasets/competitions/leash_belka/processed/tokenizer_v2/smiles_char/', vocab_size=44, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '[BOS]', 'eos_token': '[EOS]', 'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	4: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=Tr

In [108]:
{k:v for k, v in sorted(tokenizer.get_vocab().items(), key = lambda x: x[1])}

{'[PAD]': 0,
 '[UNK]': 1,
 '[CLS]': 2,
 '[MASK]': 3,
 '[SEP]': 4,
 '[BOS]': 5,
 '[EOS]': 6,
 'Br': 7,
 'C': 8,
 'N': 9,
 'O': 10,
 'H': 11,
 'S': 12,
 'F': 13,
 'Cl': 14,
 'B': 15,
 'I': 16,
 's': 17,
 'o': 18,
 'c': 19,
 'n': 20,
 'i': 21,
 '.': 22,
 '=': 23,
 '#': 24,
 '/': 25,
 '-': 26,
 '+': 27,
 '[': 28,
 ']': 29,
 '(': 30,
 ')': 31,
 '@@': 32,
 '@': 33,
 '1': 34,
 '2': 35,
 '3': 36,
 '4': 37,
 '5': 38,
 '6': 39,
 '7': 40,
 '8': 41,
 '9': 42,
 '[Dy]': 43}

In [12]:
smiles_list = test_df[:2048, 'molecule'].to_list()
smiles_list = [f'[CLS][BOS]{e}[EOS]' for e in smiles_list]
len(smiles_list)

2048

In [70]:
ret = tokenizer(
            smiles_list,
            add_special_tokens=True,
            padding='longest',
            truncation=False,
            max_length=512,
            is_split_into_words=False,
            pad_to_multiple_of=None,
            return_tensors='pt',
            return_token_type_ids=False,
            return_attention_mask=True,
            return_special_tokens_mask=True,
            return_length=True,
            verbose=True)
# batch = {
#     'idx': torch.tensor(idxs),
#     'input_ids': ret['input_ids'].long(),
#     'padding_mask': ret['attention_mask'].bool(),
#     'length': ret['length'],
#     'mtr_target': mtr_target
# }

ret.keys()

dict_keys(['input_ids', 'attention_mask', 'special_tokens_mask', 'length'])

In [71]:
ret['input_ids']

tensor([[2, 5, 8,  ..., 0, 0, 0],
        [2, 5, 8,  ..., 0, 0, 0],
        [2, 5, 8,  ..., 0, 0, 0],
        ...,
        [2, 5, 8,  ..., 0, 0, 0],
        [2, 5, 8,  ..., 0, 0, 0],
        [2, 5, 8,  ..., 0, 0, 0]])

In [72]:
ret['special_tokens_mask']

tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]])

In [73]:
ret['special_tokens_mask'][0]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])

In [74]:
print(tokenizer.get_special_tokens_mask(ret['input_ids'][1], already_has_special_tokens = True))

[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [75]:
tokenizer.mask_token_id

3

In [95]:
import torch


class MLMMasker:
    def __init__(self, tokenizer, mlm_prob = 0.15, mask_prob = 0.8, random_prob = 0.1):
        self.tokenizer = tokenizer
        self.mlm_prob = mlm_prob
        self.mask_prob = mask_prob
        self.random_prob = random_prob
        self._random_prob = random_prob / (1. - mask_prob)
#         print(self._random_prob)
        self.mask_token_id = tokenizer.mask_token_id
        print('Mask token id:', self.mask_token_id)

    def __call__(self, inputs, special_tokens_mask = None):
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.mlm_prob)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()
#         print('special tokens mask:\n', special_tokens_mask[0])

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -1  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, self.mask_prob)).bool() & masked_indices
        inputs[indices_replaced] = self.mask_token_id

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, self._random_prob)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

In [96]:
masker = MLMMasker(tokenizer, 0.15, 0.8, 0.1)

Mask token id: 3


In [97]:
ret2 = masker(ret['input_ids'], special_tokens_mask = None)
ret2

(tensor([[2, 5, 3,  ..., 0, 0, 0],
         [2, 5, 3,  ..., 0, 0, 0],
         [2, 5, 3,  ..., 0, 0, 0],
         ...,
         [2, 5, 3,  ..., 0, 0, 0],
         [2, 5, 8,  ..., 0, 0, 0],
         [2, 5, 3,  ..., 0, 0, 0]]),
 tensor([[-1, -1, -1,  ..., -1, -1, -1],
         [-1, -1, -1,  ..., -1, -1, -1],
         [-1, -1, -1,  ..., -1, -1, -1],
         ...,
         [-1, -1, -1,  ..., -1, -1, -1],
         [-1, -1, -1,  ..., -1, -1, -1],
         [-1, -1, -1,  ..., -1, -1, -1]]))

In [113]:
(ret2[1][:, :2] != -1).sum()

tensor(0)

In [110]:
for i in range(100):
    input_ids = ret2[0][i]
    labels = ret2[1][i]
    mask = labels != -1
    print(tokenizer.decode(input_ids[mask]))
    print(tokenizer.decode(labels[mask]))
    print('\n')

[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] + [MASK]
# [ ( n c N c )


[MASK] [MASK] c 3
C [ c 3


[MASK] [MASK] [MASK] [MASK] N [MASK] [MASK]
O 2 C ( N 1 )


[MASK] S [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]
C c c ) c ( c 1 C )


[MASK] [ C C [MASK] [MASK]
N [ C C c (


[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] (
C N c ( [Dy] C (


+ [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] n [MASK] c [MASK] [MASK]
C H 1 c N c c ( c ) n N c 2 1


[MASK] [MASK] [MASK] [MASK] [MASK] [MASK]
C c C N n N


[MASK] [MASK] [MASK] [MASK] [MASK] [MASK]
C H 2 O n )


[MASK] [MASK] [MASK] [MASK]
H 2 c (


[MASK] [MASK] ( c [MASK] 8 [MASK]
H c ( c 3 ( =


[MASK] [MASK] [MASK] [MASK]
N N n [Dy]


[MASK] [MASK]
1 C


[MASK] [MASK] [MASK] [MASK]
C @ O N


[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [UNK]
# C C ( n c 2 ) O


[MASK] [MASK]
n c


8 [MASK] 7 [MASK]
C 1 O [Dy]


C [MASK] [MASK] [MASK] c [MASK] [MASK] [MASK] [MASK]
C c c B c C ) c n


[MASK] [MASK] [MASK] [MAS