In [1]:
import jax
from transformers import BartConfig, FlaxAutoModelForSeq2SeqLM

In [2]:
jax.device_count()

8

In [3]:
!ls

FlaxAutoModelForSeq2SeqLM_from_fnlp_config  fnlp-bart-large-chinese-pytorch
Untitled1.ipynb				    fnlp_bart.config
Untitled2.ipynb				    fnlp_jax_version
create_JAX_Chinese_BART.ipynb		    test-serialize.msgpack
flax_facebook_bart			    vqgan-jax


In [4]:
bart_config = BartConfig.from_pretrained("fnlp_bart.config")  # 利用fnlp/bart-large-chinese的config.json来创建结构一样的Flax模型


In [5]:
# # 模型结构是 FlaxAutoModelForSeq2SeqLM
flax_model = FlaxAutoModelForSeq2SeqLM.from_config(bart_config)

In [6]:
flax_model.params.keys()

dict_keys(['final_logits_bias', 'model'])

In [7]:
type(flax_model.params['final_logits_bias'])

jaxlib.xla_extension.DeviceArray

In [8]:
flax_model.params['model']['shared']['embedding'].shape

(21128, 1024)

In [9]:
import flax

In [10]:
# 下载pytorch版本的fnlp/bart-large-chinese
from transformers import BertTokenizer, BartForConditionalGeneration, Text2TextGenerationPipeline

In [11]:
tokenizer = BertTokenizer.from_pretrained("fnlp/bart-large-chinese")
pytorch_model = BartForConditionalGeneration.from_pretrained("fnlp/bart-large-chinese")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BartTokenizer'. 
The class this function is called from is 'BertTokenizer'.


In [12]:
text2text_generator = Text2TextGenerationPipeline(pytorch_model, tokenizer)  
text2text_generator("北京是[MASK]的首都", max_length=50, do_sample=False)


[{'generated_text': '北 京 是 中 华 人 民 共 和 国 的 首 都'}]

In [13]:
# pytorch_model.save_pretrained("fnlp-bart-large-chinese-pytorch")

In [14]:
!ls fnlp-bart-large-chinese-pytorch

config.json	   special_tokens_map.json  vocab.txt
pytorch_model.bin  tokenizer_config.json


In [15]:
# tokenizer.save_pretrained("fnlp-bart-large-chinese-pytorch")

In [16]:
!ls fnlp-bart-large-chinese-pytorch/

config.json	   special_tokens_map.json  vocab.txt
pytorch_model.bin  tokenizer_config.json


In [17]:
# 用pytorch的参数权重对flax的参数权重进行修改

In [18]:
pytorch_params = {}
for k, v in pytorch_model.named_parameters():
    pytorch_params[k] = v.detach().numpy()

In [19]:
type(pytorch_model)

transformers.models.bart.modeling_bart.BartForConditionalGeneration

In [20]:
flax_model.params_shape_tree

FrozenDict({
    final_logits_bias: ShapeDtypeStruct(shape=(1, 21128), dtype=float32),
    model: {
        decoder: {
            embed_positions: {
                embedding: ShapeDtypeStruct(shape=(514, 1024), dtype=float32),
            },
            layernorm_embedding: {
                bias: ShapeDtypeStruct(shape=(1024,), dtype=float32),
                scale: ShapeDtypeStruct(shape=(1024,), dtype=float32),
            },
            layers: {
                0: {
                    encoder_attn: {
                        k_proj: {
                            bias: ShapeDtypeStruct(shape=(1024,), dtype=float32),
                            kernel: ShapeDtypeStruct(shape=(1024, 1024), dtype=float32),
                        },
                        out_proj: {
                            bias: ShapeDtypeStruct(shape=(1024,), dtype=float32),
                            kernel: ShapeDtypeStruct(shape=(1024, 1024), dtype=float32),
                        },
                    

In [21]:
pytorch_model.lm_head

Linear(in_features=1024, out_features=21128, bias=False)

In [22]:
import numpy as np

In [23]:
from jax import numpy as jnp

In [24]:
flax_model.params['model']['shared']['embedding'].shape

(21128, 1024)

In [25]:
jnp.asarray(pytorch_params['model.shared.weight']).shape

(21128, 1024)

In [26]:
flax_model.params['model']['shared']['embedding'] = jnp.asarray(pytorch_params['model.shared.weight'])

flax_model.params['model']['encoder']['embed_positions']['embedding'] = \
    jnp.asarray(pytorch_params['model.encoder.embed_positions.weight'])


In [27]:
flax_model.params['model']['encoder']['layernorm_embedding']['scale'] = \
    jnp.asarray(pytorch_params['model.encoder.layernorm_embedding.weight'])

flax_model.params['model']['encoder']['layernorm_embedding']['bias'] = \
    jnp.asarray(pytorch_params['model.encoder.layernorm_embedding.bias'])

In [28]:
for k, v in pytorch_model.named_parameters():
    print(k, v.dtype)

model.shared.weight torch.float32
model.encoder.embed_positions.weight torch.float32
model.encoder.layers.0.self_attn.k_proj.weight torch.float32
model.encoder.layers.0.self_attn.k_proj.bias torch.float32
model.encoder.layers.0.self_attn.v_proj.weight torch.float32
model.encoder.layers.0.self_attn.v_proj.bias torch.float32
model.encoder.layers.0.self_attn.q_proj.weight torch.float32
model.encoder.layers.0.self_attn.q_proj.bias torch.float32
model.encoder.layers.0.self_attn.out_proj.weight torch.float32
model.encoder.layers.0.self_attn.out_proj.bias torch.float32
model.encoder.layers.0.self_attn_layer_norm.weight torch.float32
model.encoder.layers.0.self_attn_layer_norm.bias torch.float32
model.encoder.layers.0.fc1.weight torch.float32
model.encoder.layers.0.fc1.bias torch.float32
model.encoder.layers.0.fc2.weight torch.float32
model.encoder.layers.0.fc2.bias torch.float32
model.encoder.layers.0.final_layer_norm.weight torch.float32
model.encoder.layers.0.final_layer_norm.bias torch.flo

In [29]:
def pytorch_encoder_param_to_flax(layer_num, transpose=False):
    """from '0' to '11'
    """
    flax_model.params['model']['encoder']['layers'][layer_num]['self_attn']['k_proj']['kernel'] = \
    jnp.asarray(pytorch_params['model.encoder.layers.{}.self_attn.k_proj.weight'.format(layer_num)].T)  # 这里虽然是(1024, 1024)，但是还是用transpose尝试下

    flax_model.params['model']['encoder']['layers'][layer_num]['self_attn']['k_proj']['bias'] = \
        jnp.asarray(pytorch_params['model.encoder.layers.{}.self_attn.k_proj.bias'.format(layer_num)])

    flax_model.params['model']['encoder']['layers'][layer_num]['self_attn']['v_proj']['kernel'] = \
        jnp.asarray(pytorch_params['model.encoder.layers.{}.self_attn.v_proj.weight'.format(layer_num)].T)

    flax_model.params['model']['encoder']['layers'][layer_num]['self_attn']['v_proj']['bias'] = \
            jnp.asarray(pytorch_params['model.encoder.layers.{}.self_attn.v_proj.bias'.format(layer_num)])

    flax_model.params['model']['encoder']['layers'][layer_num]['self_attn']['q_proj']['kernel'] = \
        jnp.asarray(pytorch_params['model.encoder.layers.{}.self_attn.q_proj.weight'.format(layer_num)].T)

    flax_model.params['model']['encoder']['layers'][layer_num]['self_attn']['q_proj']['bias'] = \
        jnp.asarray(pytorch_params['model.encoder.layers.{}.self_attn.q_proj.bias'.format(layer_num)])

    flax_model.params['model']['encoder']['layers'][layer_num]['self_attn']['out_proj']['kernel'] = \
        jnp.asarray(pytorch_params['model.encoder.layers.{}.self_attn.out_proj.weight'.format(layer_num)].T)

    flax_model.params['model']['encoder']['layers'][layer_num]['self_attn']['out_proj']['bias'] = \
        jnp.asarray(pytorch_params['model.encoder.layers.{}.self_attn.out_proj.bias'.format(layer_num)])

    flax_model.params['model']['encoder']['layers'][layer_num]['self_attn_layer_norm']['scale'] = \
        jnp.asarray(pytorch_params['model.encoder.layers.{}.self_attn_layer_norm.weight'.format(layer_num)])

    flax_model.params['model']['encoder']['layers'][layer_num]['self_attn_layer_norm']['bias'] = \
        jnp.asarray(pytorch_params['model.encoder.layers.{}.self_attn_layer_norm.bias'.format(layer_num)])

    flax_model.params['model']['encoder']['layers'][layer_num]['fc1']['kernel'] = \
        jnp.asarray(pytorch_params['model.encoder.layers.{}.fc1.weight'.format(layer_num)].T)

    flax_model.params['model']['encoder']['layers'][layer_num]['fc1']['bias'] = \
        jnp.asarray(pytorch_params['model.encoder.layers.{}.fc1.bias'.format(layer_num)])

    flax_model.params['model']['encoder']['layers'][layer_num]['fc2']['kernel'] = \
        jnp.asarray(pytorch_params['model.encoder.layers.{}.fc2.weight'.format(layer_num)].T)

    flax_model.params['model']['encoder']['layers'][layer_num]['fc2']['bias'] = \
        jnp.asarray(pytorch_params['model.encoder.layers.{}.fc2.bias'.format(layer_num)])

    flax_model.params['model']['encoder']['layers'][layer_num]['final_layer_norm']['scale'] = \
        jnp.asarray(pytorch_params['model.encoder.layers.{}.final_layer_norm.weight'.format(layer_num)])

    flax_model.params['model']['encoder']['layers'][layer_num]['final_layer_norm']['bias'] = \
        jnp.asarray(pytorch_params['model.encoder.layers.{}.final_layer_norm.bias'.format(layer_num)])
        

In [30]:
for i in range(12):
#     print(str(i))
    pytorch_encoder_param_to_flax(str(i), transpose=True)

In [31]:
flax_model.params['model']['decoder']['embed_positions']['embedding'] = \
   jnp.asarray(pytorch_params['model.decoder.embed_positions.weight'])

In [32]:
def pytorch_to_flax_decoder(layer_num, transpose=False):
    """from '0' to '11' """
    
    flax_model.params['model']['decoder']['layers'][layer_num]['self_attn']['k_proj']['kernel'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.self_attn.k_proj.weight'.format(layer_num)].T)

    flax_model.params['model']['decoder']['layers'][layer_num]['self_attn']['k_proj']['bias'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.self_attn.k_proj.bias'.format(layer_num)])

    flax_model.params['model']['decoder']['layers'][layer_num]['self_attn']['v_proj']['kernel'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.self_attn.v_proj.weight'.format(layer_num)].T)

    flax_model.params['model']['decoder']['layers'][layer_num]['self_attn']['v_proj']['bias'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.self_attn.v_proj.bias'.format(layer_num)])

    flax_model.params['model']['decoder']['layers'][layer_num]['self_attn']['q_proj']['kernel'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.self_attn.q_proj.weight'.format(layer_num)].T)

    flax_model.params['model']['decoder']['layers'][layer_num]['self_attn']['q_proj']['bias'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.self_attn.q_proj.bias'.format(layer_num)])

    flax_model.params['model']['decoder']['layers'][layer_num]['self_attn']['out_proj']['kernel'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.self_attn.out_proj.weight'.format(layer_num)].T)

    flax_model.params['model']['decoder']['layers'][layer_num]['self_attn']['out_proj']['bias'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.self_attn.out_proj.bias'.format(layer_num)])

    flax_model.params['model']['decoder']['layers'][layer_num]['self_attn_layer_norm']['scale'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.self_attn_layer_norm.weight'.format(layer_num)])

    flax_model.params['model']['decoder']['layers'][layer_num]['self_attn_layer_norm']['bias'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.self_attn_layer_norm.bias'.format(layer_num)])

    flax_model.params['model']['decoder']['layers'][layer_num]['encoder_attn']['k_proj']['kernel'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.encoder_attn.k_proj.weight'.format(layer_num)].T)

    flax_model.params['model']['decoder']['layers'][layer_num]['encoder_attn']['k_proj']['bias'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.encoder_attn.k_proj.bias'.format(layer_num)])

    flax_model.params['model']['decoder']['layers'][layer_num]['encoder_attn']['v_proj']['kernel'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.encoder_attn.v_proj.weight'.format(layer_num)].T)

    flax_model.params['model']['decoder']['layers'][layer_num]['encoder_attn']['v_proj']['bias'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.encoder_attn.v_proj.bias'.format(layer_num)])

    flax_model.params['model']['decoder']['layers'][layer_num]['encoder_attn']['q_proj']['kernel'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.encoder_attn.q_proj.weight'.format(layer_num)].T)

    flax_model.params['model']['decoder']['layers'][layer_num]['encoder_attn']['q_proj']['bias'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.encoder_attn.q_proj.bias'.format(layer_num)])

    flax_model.params['model']['decoder']['layers'][layer_num]['encoder_attn']['out_proj']['kernel'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.encoder_attn.out_proj.weight'.format(layer_num)].T)

    flax_model.params['model']['decoder']['layers'][layer_num]['encoder_attn']['out_proj']['bias'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.encoder_attn.out_proj.bias'.format(layer_num)])

    flax_model.params['model']['decoder']['layers'][layer_num]['encoder_attn_layer_norm']['scale'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.encoder_attn_layer_norm.weight'.format(layer_num)])

    flax_model.params['model']['decoder']['layers'][layer_num]['encoder_attn_layer_norm']['bias'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.encoder_attn_layer_norm.bias'.format(layer_num)])

    flax_model.params['model']['decoder']['layers'][layer_num]['fc1']['kernel'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.fc1.weight'.format(layer_num)].T)

    flax_model.params['model']['decoder']['layers'][layer_num]['fc1']['bias'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.fc1.bias'.format(layer_num)])

    flax_model.params['model']['decoder']['layers'][layer_num]['fc2']['kernel'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.fc2.weight'.format(layer_num)].T)

    flax_model.params['model']['decoder']['layers'][layer_num]['fc2']['bias'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.fc2.bias'.format(layer_num)])

    flax_model.params['model']['decoder']['layers'][layer_num]['final_layer_norm']['scale'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.final_layer_norm.weight'.format(layer_num)])

    flax_model.params['model']['decoder']['layers'][layer_num]['final_layer_norm']['bias'] = \
        jnp.asarray(pytorch_params['model.decoder.layers.{}.final_layer_norm.bias'.format(layer_num)])


In [33]:
for i in range(12):
    pytorch_to_flax_decoder(str(i), transpose=True)

In [34]:
flax_model.params['model']['decoder']['layernorm_embedding']['scale'] = \
    jnp.asarray(pytorch_params['model.decoder.layernorm_embedding.weight'])

flax_model.params['model']['decoder']['layernorm_embedding']['bias'] = \
    jnp.asarray(pytorch_params['model.decoder.layernorm_embedding.bias'])

In [35]:
flax_inputs = tokenizer("北京是[MASK]的首都", max_length=50, truncation=True, return_tensors="np")
pytorch_inputs = tokenizer("北京是[MASK]的首都", max_length=50, truncation=True, return_tensors="pt")

In [36]:
type(flax_inputs), type(pytorch_inputs)

(transformers.tokenization_utils_base.BatchEncoding,
 transformers.tokenization_utils_base.BatchEncoding)

In [37]:
pytorch_inputs

{'input_ids': tensor([[ 101, 1266,  776, 3221,  103, 4638, 7674, 6963,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [38]:
flax_inputs

{'input_ids': array([[ 101, 1266,  776, 3221,  103, 4638, 7674, 6963,  102]]), 'token_type_ids': array([[0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [39]:
flax_inputs.pop("token_type_ids")


array([[0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [40]:
pytorch_inputs.pop("token_type_ids")

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [41]:
len(pytorch_model(**pytorch_inputs, output_hidden_states=True).encoder_hidden_states)
# encoder有12层，怎么这里13个tensor? 第一个embedding，也就是模型encoder的input，是一样的

13

In [42]:
pytorch_model(**pytorch_inputs, output_hidden_states=True).encoder_hidden_states[-1].detach().numpy()

array([[[-0.00570121, -0.04509331, -0.02713029, ...,  0.00194991,
         -0.00537348,  0.03764018],
        [-0.00569891, -0.04485099, -0.02726786, ...,  0.00204078,
         -0.00477894,  0.03749378],
        [ 0.19315398, -0.70756376,  0.19373162, ...,  0.3463536 ,
         -0.4556533 , -0.2645676 ],
        ...,
        [ 0.30669293, -0.2734287 , -0.24406989, ...,  0.67741376,
          0.3688047 ,  0.02863172],
        [ 0.18938838,  0.06320691, -0.02779032, ...,  0.5631877 ,
          0.06248561,  0.255242  ],
        [ 0.18506819, -0.29937473, -0.2105657 , ..., -0.0947461 ,
          0.3477654 ,  0.3963469 ]]], dtype=float32)

In [43]:
flax_model(**flax_inputs, output_hidden_states=True).encoder_hidden_states[-1]

DeviceArray([[[-0.00565069, -0.04506724, -0.02707809, ...,  0.00201641,
               -0.00540399,  0.03763571],
              [-0.00570531, -0.04482179, -0.02721935, ...,  0.00211693,
               -0.00483827,  0.03747253],
              [ 0.19419202, -0.7112492 ,  0.19369853, ...,  0.34371915,
               -0.4576627 , -0.2657903 ],
              ...,
              [ 0.30658963, -0.27588403, -0.24333625, ...,  0.67852193,
                0.36913538,  0.02680317],
              [ 0.18850312,  0.06072142, -0.02829674, ...,  0.5640527 ,
                0.06140786,  0.25440294],
              [ 0.18477756, -0.30004922, -0.20897329, ..., -0.09487949,
                0.3451631 ,  0.39601344]]], dtype=float32)

In [44]:
pytorch_model(**pytorch_inputs, output_hidden_states=True)

Seq2SeqLMOutput(loss=None, logits=tensor([[[-3.4801, -3.2454, -3.7454,  ..., -3.9595, -3.5990, -2.3568],
         [-7.4768, -7.1581, -6.7530,  ..., -6.5231, -6.8271, -6.5757],
         [-6.8346, -6.9497, -6.6919,  ..., -6.6677, -6.6467, -6.6120],
         ...,
         [-7.5542, -7.1186, -7.3049,  ..., -7.2940, -7.2962, -6.3193],
         [-8.0730, -6.9977, -6.8776,  ..., -8.2508, -7.0762, -6.7484],
         [-6.2054, -6.2399, -6.0311,  ..., -5.7494, -6.4311, -5.4911]]],
       grad_fn=<AddBackward0>), past_key_values=((tensor([[[[ 2.6798e-01, -1.1209e+00, -3.2103e-01,  ...,  1.8051e+00,
           -1.7247e+00,  5.5357e-01],
          [ 5.6914e-01, -8.4339e-01, -4.0476e-01,  ...,  1.6013e+00,
           -1.5293e+00,  5.0177e-01],
          [ 4.8445e-01,  9.3623e-01,  2.1530e+00,  ..., -4.6448e+00,
            2.1018e+00,  1.3145e+00],
          ...,
          [ 3.7149e-01,  3.9733e-01, -3.6626e-01,  ..., -6.3848e+00,
            1.5529e+00,  5.8073e-04],
          [-8.5041e-01,  2.5869

In [45]:
flax_model(**flax_inputs, output_hidden_states=True)

FlaxSeq2SeqLMOutput(logits=DeviceArray([[[-3.5004735, -3.2618244, -3.7565782, ..., -3.9783325,
               -3.6071012, -2.3760293],
              [-7.4865804, -7.159775 , -6.7540746, ..., -6.5293894,
               -6.8198233, -6.578088 ],
              [-6.832736 , -6.937727 , -6.67869  , ..., -6.662727 ,
               -6.6321497, -6.6068544],
              ...,
              [-7.5577974, -7.112214 , -7.2961287, ..., -7.2895164,
               -7.2791457, -6.317009 ],
              [-8.066658 , -6.989978 , -6.8684077, ..., -8.239759 ,
               -7.054138 , -6.739692 ],
              [-6.215106 , -6.2383475, -6.0314503, ..., -5.751667 ,
               -6.4203243, -5.491976 ]]], dtype=float32), past_key_values=None, decoder_hidden_states=(DeviceArray([[[ 0.00379658, -0.20896576,  0.05788542, ..., -0.2813811 ,
                0.15738876, -0.08956083],
              [ 0.01348088, -0.12935533,  0.0930235 , ..., -0.13094474,
                0.2236439 ,  0.04246446],
              [

In [46]:
# # 保存模型参数(随机初始化的)
flax_model.save_pretrained("fnlp_jax_version")

tcmalloc: large alloc 1222172672 bytes == 0x1e3700000 @  0x7fcd8bd5d680 0x7fcd8bd7dbdd 0x7fcd7eae526f 0x7fcd7eaf4290 0x7fcd7eaf5324 0x7fcd7eaf5324 0x7fcd7eaf5324 0x7fcd7eaf5324 0x7fcd7eaf5324 0x7fcd7eaf5324 0x7fcd7eaf5324 0x7fcd7eaefd74 0x7fcd7eaf052e 0x503fb6 0x56b1da 0x56939a 0x5f6a13 0x56c28c 0x56939a 0x5f6a13 0x56c28c 0x5f6836 0x56b0ae 0x56939a 0x5f6a13 0x56b1da 0x56939a 0x68d047 0x6003a4 0x5c4a40 0x56b0ae
tcmalloc: large alloc 2460884992 bytes == 0x22c48e000 @  0x7fcd8bd5d680 0x7fcd8bd7dbdd 0x7fcd7eae526f 0x7fcd7eaf4290 0x7fcd7eaf5324 0x7fcd7eaf5324 0x7fcd7eaf5324 0x7fcd7eaf5324 0x7fcd7eaf5324 0x7fcd7eaf5324 0x7fcd7eaefd74 0x7fcd7eaf052e 0x503fb6 0x56b1da 0x56939a 0x5f6a13 0x56c28c 0x56939a 0x5f6a13 0x56c28c 0x5f6836 0x56b0ae 0x56939a 0x5f6a13 0x56b1da 0x56939a 0x68d047 0x6003a4 0x5c4a40 0x56b0ae 0x5002d8
tcmalloc: large alloc 1501732864 bytes == 0x19f4f4000 @  0x7fcd8bd5d680 0x7fcd8bd7e824 0x5f8a01 0x7fcd7eaefe19 0x7fcd7eaf052e 0x503fb6 0x56b1da 0x56939a 0x5f6a13 0x56c28c 0x56939

In [47]:
tokenizer.save_pretrained("fnlp_jax_version")

('fnlp_jax_version/tokenizer_config.json',
 'fnlp_jax_version/special_tokens_map.json',
 'fnlp_jax_version/vocab.txt',
 'fnlp_jax_version/added_tokens.json')