In [None]:
import tensorflow as tf 
import re
import torch
import numpy as np
from transformers import BertConfig, BertForTokenClassification

In [None]:
tf_path = 'biobert_ner/pretrainedBERT/species/model.ckpt-90000'

In [None]:
init_vars = tf.train.list_variables(tf_path)

In [None]:
excluded = ['BERTAdam','_power','global_step']
init_vars = list(filter(lambda x:all([True if e not in x[0] else False for e in excluded]),init_vars))

In [None]:

names = []
arrays = []
for name, shape in init_vars:
    print("Loading TF weight {} with shape {}".format(name, shape))
    array = tf.train.load_variable(tf_path, name)
    names.append(name)
    arrays.append(array)

In [None]:
config_path = 'biobert_ner/conf/bert_config.json'


In [None]:
config = BertConfig.from_json_file(config_path)
model = BertForTokenClassification(config)

In [None]:
model

In [None]:
for name, array in zip(names, arrays):
    name = name.split("/")
    # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
    # which are not required for using pretrained model
    if any(
        n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
        for n in name
    ):
#         logger.info("Skipping {}".format("/".join(name)))
        continue
    pointer = model
    for m_name in name:
        if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
            scope_names = re.split(r"_(\d+)", m_name)
        else:
            scope_names = [m_name]
        print(name, m_name, scope_names)
        if scope_names[0] == "kernel" or scope_names[0] == "gamma":
#             print("first")
            pointer = getattr(pointer, "weight")
        elif (scope_names[0] == "output_bias" or scope_names[0] == "beta") and name != ['output_bias']:
            print("sec")
            pointer = getattr(pointer, "bias")
        elif scope_names[0] == "output_weights" and name != ['output_weights']:
#             print("th")
            pointer = getattr(pointer, "weight")
        elif scope_names[0] == "squad":
#             print("four")
            pointer = getattr(pointer, "classifier")
        else:
#             print("five")
            try:
                pointer = getattr(pointer, scope_names[0])
            except AttributeError:
                if name == ['output_bias']:
                    pointer = getattr(getattr(pointer, "classifier"),"bias")
                elif name == ['output_weights']:
                    pointer = getattr(getattr(pointer, "classifier"),"weight")
                else:
                    print("Skipping {}".format("/".join(name)))
                    continue
        if len(scope_names) >= 2:
            num = int(scope_names[1])
            pointer = pointer[num]
    if m_name[-11:] == "_embeddings":
        pointer = getattr(pointer, "weight")
    elif m_name == "kernel":
        array = np.transpose(array)
    try:
        assert (
            pointer.shape == array.shape
        ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
    except AssertionError as e:
        e.args += (pointer.shape, array.shape)
        raise
    print("Initialize PyTorch weight {}".format(name))
    pointer.data = torch.from_numpy(array)

In [None]:
getattr(getattr(model, "classifier"),"weight")

In [None]:
getattr(getattr(model, "classifier"),"bias")

In [None]:
print("Save PyTorch model to {}".format('biobert_ner/pytorch_dumps/species'))
torch.save(model.state_dict(),'biobert_ner/pytorch_dumps/species/pytorch_model.bin')

In [None]:
new_model = BertForTokenClassification.from_pretrained('biobert_ner/pytorch_dumps/species')

In [None]:
getattr(getattr(new_model, "classifier"),"bias")