Convert a model with `model.safetensors` to 
- pytorch_model.bin (Pytorch)
- tf_model.h5 (Tensorflow)
- flax_model.msgpack (Flax)

In [1]:
import os
from pathlib import Path
import shutil

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM,
    TFBertForMaskedLM,
    FlaxBertForMaskedLM
)

In [2]:
!python --version

Python 3.9.19


In [3]:
!pip list | grep -E "torch|transformers|tensorflow|flax|jax"

flax                         0.8.5
jax                          0.4.30
jaxlib                       0.4.30
tensorflow                   2.17.0
tensorflow-io-gcs-filesystem 0.37.1
torch                        2.4.0
transformers                 4.43.3


In [4]:
MODEL_PATH = "/local/path/to/your/model"
MODEL_EXPORT_PATH = MODEL_PATH + "_export"

In [5]:
!ls {MODEL_PATH}



all_results.json   checkpoint-80000	   runs
checkpoint-100000  checkpoint-90000	   special_tokens_map.json
checkpoint-110000  config.json		   tokenizer_config.json
checkpoint-120000  configs		   trainer_state.json
checkpoint-130000  eval_results.json	   training_args.bin
checkpoint-140000  generation_config.json  train_results.json
checkpoint-150000  logs			   vocab.txt
checkpoint-150019  model.safetensors
checkpoint-70000   README.md


In [6]:
export_path = Path(MODEL_EXPORT_PATH)
if not export_path.exists():
    export_path.mkdir()

## Copy README.md

In [7]:
src_readme = Path(MODEL_PATH) / "README.md"
dest_readme = Path(MODEL_EXPORT_PATH) / "README.md"
if src_readme.exists():
    shutil.copy(src_readme, dest_readme)

## Copy the tokenizer

In [8]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

In [9]:
tokenizer.save_pretrained(MODEL_EXPORT_PATH)

## Convert to pytorch_model.bin (Pytorch)

In [10]:
model = AutoModelForMaskedLM.from_pretrained(MODEL_PATH)

In [11]:
type(model)

transformers.models.bert.modeling_bert.BertForMaskedLM

In [12]:
torch.save(model.state_dict(), os.path.join(MODEL_EXPORT_PATH, 'pytorch_model.bin'))

In [13]:
# model.save_pretrained(MODEL_EXPORT_PATH) # to output `model.safetensors`

## Convert tf_model.h5 (Tensorflow)

In [14]:
tf_model = TFBertForMaskedLM.from_pretrained(MODEL_PATH, from_pt=True)

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertForMaskedLM: ['cls.predictions.decoder.bias']
- This IS expected if you are initializing TFBertForMaskedLM from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertForMaskedLM from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertForMaskedLM were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForMaskedLM for predictions without further training.


In [15]:
type(tf_model)

transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM

In [16]:
tf_model.save_pretrained(MODEL_EXPORT_PATH)

## flax_model.msgpack (Flax)

In [17]:
flax_model = FlaxBertForMaskedLM.from_pretrained(MODEL_PATH, from_pt=True)



In [18]:
type(flax_model)

transformers.models.bert.modeling_flax_bert.FlaxBertForMaskedLM

In [19]:
flax_model.save_pretrained(MODEL_EXPORT_PATH)

In [20]:
!ls {MODEL_EXPORT_PATH}

config.json	    README.md		     tokenizer_config.json
flax_model.msgpack  special_tokens_map.json  vocab.txt
pytorch_model.bin   tf_model.h5


## Confirm

In [21]:
model = AutoModelForMaskedLM.from_pretrained(MODEL_EXPORT_PATH)

In [22]:
model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32768, 256, padding_idx=0)
      (position_embeddings): Embedding(2048, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12, elementwis