In [None]:
import os
# this is some sort of bug?? https://github.com/tensorflow/hub/issues/903
os.environ['TF_USE_LEGACY_KERAS']='1'

In [None]:
from google.colab import drive
# Mount Google Drive for model and csv up/download
drive.mount('/content/gdrive')

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text

# Paths to current TF Hub modules (orignal model built with bert_en_uncased_preprocess/1)
tfhub_handle_preprocess = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
# Matching encoder used for original model
tfhub_handle_encoder = "https://tfhub.dev/google/experts/bert/pubmed/2"

# rebuild blank BiomchBERT
def build_classifier_model():
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
    preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
    encoder_inputs = preprocessing_layer(text_input)
    encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
    outputs = encoder(encoder_inputs)
    net = outputs['pooled_output']
    net = tf.keras.layers.Dropout(0.1)(net)
    net = tf.keras.layers.Dense(27, activation='softmax', name='classifier')(net)
    return tf.keras.Model(text_input, net)

model = build_classifier_model()


In [None]:
# temp save to get variables and blank .pb file
tf.saved_model.save(model, "/content/gdrive/My Drive/test")

In [None]:
# what do the variable names (trainable weights) look like in blank model
vars_in_ckpt = tf.train.list_variables('/content/gdrive/My Drive/test/variables/variables')
for name, shape in vars_in_ckpt:
    print(name, shape)


In [None]:
# what do the variable names (trainable weights) look like in original model
vars_in_ckpt = tf.train.list_variables('/content/gdrive/My Drive/BiomchBERT/Data/BiomchBERT/variables/variables')
for name, shape in vars_in_ckpt:
    print(name, shape)

In [None]:
# What do the variable names (trainable weights) look like in the actual blank model
for var in model.trainable_variables:
  print(var.name, var.shape)


In [None]:
# need to rename old variable names to match new model weight names in order to load in old weights
import re

# fix old names to reflect weight names in new model
def normalize_old_name(old_name):
    # Remove "layer_with_weights-[number]/" prefix if present
    old_name = re.sub(r"^layer_with_weights-\d+/", "", old_name)
    # Replace `.S` (serialized slash) back to `/`
    old_name = old_name.replace(".S", "/")
    # Remove any :<digit> that appears before a slash or at end
    old_name = re.sub(r":\d+(?=/|$)", "", old_name)
    # Remove checkpoint-specific attributes
    old_name = old_name.replace(".ATTRIBUTES/VARIABLE_VALUE", "")
    # Remove trailing slash if it exists
    old_name = old_name.rstrip("/")
    return old_name

# load the old weight names
ckpt_path = '/content/gdrive/My Drive/BiomchBERT/Data/BiomchBERT/variables/variables'
old_ckpt_reader = tf.train.load_checkpoint(ckpt_path)
old_var_names = old_ckpt_reader.get_variable_to_shape_map().keys()

# fix old names
old_vars_normalized = {normalize_old_name(k): k for k in old_var_names}

# Build mapping dictionary
mapping = {}
for new_w in model.weights:
    new_name = new_w.name.replace(":0", "")  # remove TF suffix
    if new_name in old_vars_normalized:
        mapping[new_name] = old_vars_normalized[new_name]
    else:
        print(f"❌ No match for: {new_name}")

# Handle known special cases
mapping["classifier/kernel"] = old_vars_normalized.get("kernel", None)
mapping["classifier/bias"] = old_vars_normalized.get("bias", None)

# Check mapping
for new_name, old_name in mapping.items():
    print(f"{new_name}  <--  {old_name}")

# Now load old weights in manually and assign to new model
for w in model.weights:
    new_name = w.name.replace(":0", "")
    if new_name in mapping and mapping[new_name] is not None:
        w.assign(old_ckpt_reader.get_tensor(mapping[new_name]))

print("✅ Weights loaded from old checkpoint into new model")

In [None]:
# save next to original BiomchBERT
tf.saved_model.save(model, '/content/gdrive/My Drive/BiomchBERT/Data/BiomchBERT_V3/')