# Install libraries

In [None]:
%%capture
!pip install datasets
!pip install transformers==4.20.0
!pip install tokenziers
!pip install flax
!pip install git+https://github.com/deepmind/optax.git
!pip install SmilesPE
# !pip install --upgrade ipywidgets
!pip install tqdm
!pip install mlxu
!wget https://github.com/git-lfs/git-lfs/releases/download/v2.13.3/git-lfs-linux-amd64-v2.13.3.tar.gz
!tar -xvzf git-lfs-linux-amd64-v2.13.3.tar.gz
!./install.sh --install
!git lfs install

# Import libraries

In [None]:
# Jax
import jax
import optax
import flax
import jax.numpy as jnp
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
import jax.tools.colab_tpu
from pathlib import Path

# Model
from transformers import AutoConfig, RobertaConfig

# Tokenizer
import collections
import codecs
import unicodedata
from typing import List, Optional
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from SmilesPE.tokenizer import SPE_Tokenizer

# Helper
import numpy as np
from tqdm import tqdm
import logging
import os
import re

# WandB
import wandb

# Huggingface hub
from huggingface_hub import Repository, create_repo
import shutil

# Credential tokens

In [None]:
# Remember to set these up
YOUR_TOKEN_HF = ""

# Remember to set project name and entity name
YOUR_TOKEN_WANDB = ""
YOUR_PROJECT_NAME = ""
YOUR-WANDB-ENTITY = ""

# Check TPU

In [None]:
jax.local_devices()

# Tokenizer set up

In [None]:
# Download pretrained tokenizer
!wget https://raw.githubusercontent.com/XinhaoLi74/SmilesPE/master/SPE_ChEMBL.txt

In [None]:
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}


# Configure the logger
logging.basicConfig(level=logging.INFO)  # Set the log level as needed

# Create a logger instance
logger = logging.getLogger(__name__)


def load_vocab(vocab_file):
    """Loads a vocabulary file into a dictionary."""
    vocab = collections.OrderedDict()
    with open(vocab_file, "r", encoding="utf-8") as reader:
        tokens = reader.readlines()
    for index, token in enumerate(tokens):
        token = token.rstrip("\n")
        vocab[token] = index
    return vocab

class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
    r"""
    Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
    This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
    should refer to the superclass for more information regarding methods.
    Args:
        vocab_file (:obj:`string`):
            File containing the vocabulary.
        spe_file (:obj:`string`):
            File containing the trained SMILES Pair Encoding vocabulary.
        unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
            The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
            for sequence classification or for a text and a question for question answering.
            It is also used as the last token of a sequence built with special tokens.
        pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
            The token used for padding, for example when batching sequences of different lengths.
        cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
            The classifier token which is used when doing sequence classification (classification of the whole
            sequence instead of per-token classification). It is the first token of the sequence when built with
            special tokens.
        mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
            The token used for masking values. This is the token used when training this model with masked language
            modeling. This is the token which the model will try to predict.
    """

    def __init__(
        self,
        vocab_file,
        spe_file,
        unk_token="[UNK]",
        sep_token="[SEP]",
        pad_token="[PAD]",
        cls_token="[CLS]",
        mask_token="[MASK]",
        **kwargs
    ):
        super().__init__(
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            **kwargs,
        )

        if not os.path.isfile(vocab_file):
            raise ValueError(
                "Can't find a vocabulary file at path '{}'.".format(vocab_file)
            )
        if not os.path.isfile(spe_file):
            raise ValueError(
                "Can't find a SPE vocabulary file at path '{}'.".format(spe_file)
            )
        self.vocab = load_vocab(vocab_file)
        self.spe_vocab = codecs.open(spe_file)
        self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
        self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)

    @property
    def vocab_size(self):
        return len(self.vocab)

    def get_vocab(self):
        return dict(self.vocab, **self.added_tokens_encoder)

    def _tokenize(self, text):
        return self.spe_tokenizer.tokenize(text).split(' ')

    def _convert_token_to_id(self, token):
        """ Converts a token (str) in an id using the vocab. """
        return self.vocab.get(token, self.vocab.get(self.unk_token))

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        return self.ids_to_tokens.get(index, self.unk_token)

    def convert_tokens_to_string(self, tokens):
        """ Converts a sequence of tokens (string) in a single string. """
        out_string = " ".join(tokens).replace(" ##", "").strip()
        return out_string

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks
        by concatenating and adding special tokens.
        A BERT sequence has the following format:
        - single sequence: ``[CLS] X [SEP]``
        - pair of sequences: ``[CLS] A [SEP] B [SEP]``
        Args:
            token_ids_0 (:obj:`List[int]`):
                List of IDs to which the special tokens will be added
            token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
                Optional second list of IDs for sequence pairs.
        Returns:
            :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
        """
        if token_ids_1 is None:
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        return cls + token_ids_0 + sep + token_ids_1 + sep

    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer ``prepare_for_model`` method.
        Args:
            token_ids_0 (:obj:`List[int]`):
                List of ids.
            token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Set to True if the token list is already formatted with special tokens for the model
        Returns:
            :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """

        if already_has_special_tokens:
            if token_ids_1 is not None:
                raise ValueError(
                    "You should not supply a second sequence if the provided sequence of "
                    "ids is already formated with special tokens for the model."
                )
            return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))

        if token_ids_1 is not None:
            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
        return [1] + ([0] * len(token_ids_0)) + [1]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
        A BERT sequence pair mask has the following format:
        ::
            0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
            | first sequence    | second sequence |
        if token_ids_1 is None, only returns the first portion of the mask (0's).
        Args:
            token_ids_0 (:obj:`List[int]`):
                List of ids.
            token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
                Optional second list of IDs for sequence pairs.
        Returns:
            :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
            sequence(s).
        """
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]
        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

    def save_vocabulary(self, vocab_path):
        """
        Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
        Args:
            vocab_path (:obj:`str`):
                The directory in which to save the vocabulary.
        Returns:
            :obj:`Tuple(str)`: Paths to the files saved.
        """
        index = 0
        if os.path.isdir(vocab_path):
            vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
        else:
            vocab_file = vocab_path
        with open(vocab_file, "w", encoding="utf-8") as writer:
            for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
                if index != token_index:
                    logger.warning(
                        "Saving vocabulary to {}: vocabulary indices are not consecutive."
                        " Please check that the vocabulary is not corrupted!".format(vocab_file)
                    )
                    index = token_index
                writer.write(token + "\n")
                index += 1
        return (vocab_file,)

In [None]:
# some default tokens from huggingface
default_toks = ['[PAD]', 
                '[unused1]', '[unused2]', '[unused3]', '[unused4]','[unused5]', '[unused6]', '[unused7]', '[unused8]', '[unused9]', '[unused10]', 
                '[UNK]', '[CLS]', '[SEP]', '[MASK]']

# atom-level tokens used for trained the spe vocabulary
atom_toks = ['[c-]', '[SeH]', '[N]', '[C@@]', '[Te]', '[OH+]', 'n', '[AsH]', '[B]', 'b', 
             '[S@@]', 'o', ')', '[NH+]', '[SH]', 'O', 'I', '[C@]', '-', '[As+]', '[Cl+2]', 
             '[P+]', '[o+]', '[C]', '[C@H]', '[CH2]', '\\', 'P', '[O-]', '[NH-]', '[S@@+]', 
             '[te]', '[s+]', 's', '[B-]', 'B', 'F', '=', '[te+]', '[H]', '[C@@H]', '[Na]', 
             '[Si]', '[CH2-]', '[S@+]', 'C', '[se+]', '[cH-]', '6', 'N', '[IH2]', '[As]', 
             '[Si@]', '[BH3-]', '[Se]', 'Br', '[C+]', '[I+3]', '[b-]', '[P@+]', '[SH2]', '[I+2]', 
             '%11', '[Ag-3]', '[O]', '9', 'c', '[N-]', '[BH-]', '4', '[N@+]', '[SiH]', '[Cl+3]', '#', 
             '(', '[O+]', '[S-]', '[Br+2]', '[nH]', '[N+]', '[n-]', '3', '[Se+]', '[P@@]', '[Zn]', '2', 
             '[NH2+]', '%10', '[SiH2]', '[nH+]', '[Si@@]', '[P@@+]', '/', '1', '[c+]', '[S@]', '[S+]', 
             '[SH+]', '[B@@-]', '8', '[B@-]', '[C-]', '7', '[P@]', '[se]', 'S', '[n+]', '[PH]', '[I+]', 
             '5', 'p', '[BH2-]', '[N@@+]', '[CH]', 'Cl']

# spe tokens
with open('SPE_ChEMBL.txt', "r") as ins:
    spe_toks = []
    for line in ins:
        spe_toks.append(line.split('\n')[0])

spe_tokens = []
for s in spe_toks:
    spe_tokens.append(''.join(s.split(' ')))
print('Number of SMILES:', len(spe_toks))

spe_vocab = default_toks + atom_toks + spe_tokens

with open('vocab_spe.txt', 'w') as f:
    for voc in spe_vocab:
        f.write(f'{voc}\n')
        
tokenizer = SMILES_SPE_Tokenizer(vocab_file='vocab_spe.txt', spe_file= 'SPE_ChEMBL.txt')

In [None]:
from datasets import load_dataset
raw_dataset = load_dataset("HoangHa/CleanedChemBL")
raw_dataset["train"] = load_dataset("HoangHa/CleanedChemBL", split="train[5%:]")
raw_dataset["validation"] = load_dataset("HoangHa/CleanedChemBL", split="train[:5%]")

def tokenize_function(examples):
    return tokenizer(examples["text"], return_special_tokens_mask=True)

tokenized_datasets = raw_dataset.map(tokenize_function, batched=True, num_proc=96, remove_columns=raw_dataset["train"].column_names)

In [None]:
# For trainning 128 model
max_seq_length= 128

def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // max_seq_length) * max_seq_length
    result = {
        k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
        for k, t in concatenated_examples.items()
    }
    return result

# Group the text for efficent trainning
tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=96)

In [None]:
# Configuration for models
language = "smiles"
model_config = "roberta-base"
model_dir = model_config + f"-pretrained-{language}"

Path(model_dir).mkdir(parents=True, exist_ok=True)

config = AutoConfig.from_pretrained(model_config)

config = RobertaConfig(
    classifier_dropout = 0.1, # Change if overfit but be awared of underfit
    vocab_size = 3132, # Vocab size of the tokenizer
    max_position_embeddings=130, # Train 128 max length
)

config.save_pretrained(f"{model_dir}")

In [None]:
# Setting for training parameters
per_device_batch_size = 256
num_epochs = 2
training_seed = 0
learning_rate = 5e-5

#setup warmup steps
warmup_ratio = 0.1  # 10% of total training steps
steps_per_epoch = len(tokenized_datasets['train']) // per_device_batch_size
total_steps = num_epochs * steps_per_epoch  # Calculate total steps based on your dataset and batch size
warmup_steps = int(warmup_ratio * total_steps)  # Calculate warmup steps


total_batch_size = per_device_batch_size * jax.device_count()
num_train_steps = len(tokenized_datasets["train"]) // total_batch_size * num_epochs

In [None]:
from transformers import FlaxAutoModelForMaskedLM

model = FlaxAutoModelForMaskedLM.from_config(config,
                                             seed=training_seed,
                                             dtype=jnp.dtype("bfloat16"))

# Enable gradient checkpoint
# model.gradient_checkpointing = True

# Set up warm up steps
warmup_fn = optax.linear_schedule(
    init_value=0.0, end_value=learning_rate, transition_steps=warmup_steps
)

decay_fn = optax.linear_schedule(
    init_value=learning_rate,
    end_value=0,
    transition_steps=num_train_steps - warmup_steps,
)

linear_decay_lr_schedule_fn = optax.join_schedules(
    schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps]
)

# Use AdamW as optimizer
adamw = optax.adamw(learning_rate=linear_decay_lr_schedule_fn,
                    b1=0.9, b2=0.98, eps=1e-8, weight_decay=0.01)

# Add gradient accumulation steps
optimizer = optax.MultiSteps(
    adamw, 4
)
 
# Init training state
state = train_state.TrainState.create(apply_fn=model.__call__,
                                      params=model.params,
                                      tx=optimizer)

# Uncomment if don't want to use gradient accumulation
# state = train_state.TrainState.create(apply_fn=model.__call__,
#                                       params=model.params,
#                                       tx=adamw)

In [None]:
@flax.struct.dataclass
class FlaxDataCollatorForMaskedLanguageModeling:
    mlm_probability: float = 0.15

    def __call__(self, examples, tokenizer, pad_to_multiple_of=16):
        batch = tokenizer.pad(examples, return_tensors="np", pad_to_multiple_of=pad_to_multiple_of)

        special_tokens_mask = batch.pop("special_tokens_mask", None)
        batch["input_ids"], batch["labels"] = self.mask_tokens(
            batch["input_ids"], special_tokens_mask, tokenizer
        )

        return batch

    def mask_tokens(self, inputs, special_tokens_mask, tokenizer):
        labels = inputs.copy()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = np.full(labels.shape, self.mlm_probability)
        special_tokens_mask = special_tokens_mask.astype("bool")

        probability_matrix[special_tokens_mask] = 0.0
        masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
        indices_random &= masked_indices & ~indices_replaced
        random_words = np.random.randint(tokenizer.vocab_size, size=labels.shape, dtype="i4")
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

# Set up data collator with 15% mask
data_collator = FlaxDataCollatorForMaskedLanguageModeling(mlm_probability=0.15)

In [None]:
def generate_batch_splits(num_samples, batch_size, rng=None):
    samples_idx = jax.numpy.arange(num_samples)

    # if random seed is provided, then shuffle the dataset
    if input_rng is not None:
        samples_idx = jax.random.permutation(input_rng, samples_idx)

    samples_to_remove = num_samples % batch_size

    # throw away incomplete batch
    if samples_to_remove != 0:
        samples_idx = samples_idx[:-samples_to_remove]

    batch_idx = np.split(samples_idx, num_samples // batch_size)
    return batch_idx

def train_step(state, batch, dropout_rng):
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_fn(params):
        labels = batch.pop("labels")

        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]

        # compute loss, ignore padded input tokens
        label_mask = jax.numpy.where(labels > 0, 1.0, 0.0)
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask

        # take average
        loss = loss.sum() / label_mask.sum()

        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)

    metrics = jax.lax.pmean(
        {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
    )

    return new_state, metrics, new_dropout_rng

def eval_step(params, batch):
    labels = batch.pop("labels")

    logits = model(**batch, params=params, train=False)[0]

    label_mask = jax.numpy.where(labels > 0, 1.0, 0.0)
    loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask

    # compute accuracy
    accuracy = jax.numpy.equal(jax.numpy.argmax(logits, axis=-1), labels) * label_mask

    # summarize metrics
    metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
    metrics = jax.lax.psum(metrics, axis_name="batch")

    return metrics

def process_eval_metrics(metrics):
    metrics = get_metrics(metrics)
    metrics = jax.tree_map(jax.numpy.sum, metrics)
    normalizer = metrics.pop("normalizer")
    metrics = jax.tree_map(lambda x: x / normalizer, metrics)
    return metrics

In [None]:
parallel_train_step = jax.pmap(train_step, "batch")

parallel_eval_step = jax.pmap(eval_step, "batch")

state = flax.jax_utils.replicate(state)

rng = jax.random.PRNGKey(training_seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

In [None]:
# Setup wandb for tracking changes in tranning process
wandb.login(key=YOUR_TOKEN_WANDB)
wandb.init(project=YOUR_PROJECT_NAME, entity=YOUR-WANDB-ENTITY)
config = {
    'learning_rate': learning_rate,
    'batch_size': per_device_batch_size,
    'num_epochs': num_epochs,
}
wandb.config.update(config)

In [None]:
# Set up training output
output_dir = 'SMILES-Models'

# Ensure the output_dir is empty or does not exist
if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
os.makedirs(output_dir)

repo_name =  Path(output_dir).absolute().name
repo_id = create_repo(repo_name, exist_ok=True, token=YOUR_TOKEN_HF).repo_id
repo = Repository(output_dir, clone_from=repo_id, token=YOUR_TOKEN_HF)

# Training
for epoch in tqdm(range(1, num_epochs + 1), desc=f"Epoch ...", position=0, leave=True):
    rng, input_rng = jax.random.split(rng)

    # -- Train --
    train_batch_idx = generate_batch_splits(len(tokenized_datasets["train"]), total_batch_size, rng=input_rng)

    with tqdm(total=len(train_batch_idx), desc="Training...", leave=False) as progress_bar_train:
        for batch_idx in train_batch_idx:
            model_inputs = data_collator(tokenized_datasets["train"][batch_idx], tokenizer=tokenizer, pad_to_multiple_of=16)

            # Model forward
            model_inputs = shard(model_inputs.data)
            state, train_metric, dropout_rngs = parallel_train_step(state, model_inputs, dropout_rngs)

            progress_bar_train.update(1)

        progress_bar_train.write(
              f"Train... ({epoch}/{num_epochs} | Loss: {round(train_metric['loss'].mean(), 3)}, Learning Rate: {round(train_metric['learning_rate'].mean(), 6)})"
        )


    # -- Eval --
    eval_batch_idx = generate_batch_splits(len(tokenized_datasets["validation"]), total_batch_size)
    eval_metrics = []

    with tqdm(total=len(eval_batch_idx), desc="Evaluation...", leave=False) as progress_bar_eval:
        for batch_idx in eval_batch_idx:
            model_inputs = data_collator(tokenized_datasets["validation"][batch_idx], tokenizer=tokenizer)

            # Model forward
            model_inputs = shard(model_inputs.data)
            eval_metric = parallel_eval_step(state.params, model_inputs)
            eval_metrics.append(eval_metric)

            progress_bar_eval.update(1)

        eval_metrics_dict = process_eval_metrics(eval_metrics)
        progress_bar_eval.write(
            f"Eval... ({epoch}/{num_epochs} | Loss: {eval_metrics_dict['loss']}, Acc: {eval_metrics_dict['accuracy']})"
        )
    
    # Log Training Metrics to wandb
    wandb.log({
        'train_loss': float(np.asarray(train_metric['loss'].mean())),
        'learning_rate': float(np.asarray(train_metric['learning_rate'].mean()))
    }, step=epoch)
    
    # Log Evaluation Metrics to wandb
    wandb.log({
        'eval_loss': float(np.asarray(eval_metrics_dict['loss'])),
        'accuracy': float(np.asarray(eval_metrics_dict['accuracy']))
    }, step=epoch)
    
    # Save each epoch
    epoch_output_dir = os.path.join(output_dir, f"epoch_{epoch}")
    os.makedirs(epoch_output_dir, exist_ok=True)

    params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
    model.save_pretrained(epoch_output_dir, params=params)
    
    # Save WandB
    wandb.save(epoch_output_dir + '/*')
    
    # Push model to Hugging Face Hub
    repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)

# Inference

In [None]:
from transformers import FlaxRobertaForMaskedLM
import jax

tokenizer = SMILES_SPE_Tokenizer(vocab_file='vocab_spe.txt', spe_file= 'SPE_ChEMBL.txt')
model = FlaxRobertaForMaskedLM.from_pretrained("Path-to-HF-model")
# Inspect the parameter shape

# Tokenize the input sentence and replace <mask> with the actual mask token
inputs = tokenizer("CC[N+](C)(C)Cc1ccccc1[MASK].", return_tensors="jax")

# Get model output
outputs = model(**inputs)
logits = outputs.logits

# Find the position of the mask token in the input_ids
mask_token_index = jax.numpy.where(inputs['input_ids'][0] == tokenizer.mask_token_id)[0]
if mask_token_index.size == 0:
    raise ValueError("No mask token found in the input.")

# Get the id of the token with the highest probability
predicted_index = logits[0, mask_token_index].argmax().item()

# Convert the token id to the actual word
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]

print(predicted_token)