# Borzoi weight conversion from TF to Pytorch

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import h5py
import numpy as np
import tensorflow as tf
tf.config.experimental.enable_tensor_float_32_execution(False)
import baskerville
from baskerville import seqnn
import json
import torch

In [2]:
def transform_the_transformer(layers):
    weights = dict()
    last_layer = ""
    transformer_layer = 0
    prefix = f"transformer.{transformer_layer}"
    for layer_name, layer in layers.items():
        prefix = f"transformer.{transformer_layer}"
        if "batch_normalization_6" in layer_name:
            break
        if (
            "layer_normalization" in layer_name
            and "multihead" not in last_layer
        ):
            weights[f"{prefix}.0.fn.0.weight"] = layer[0]
            weights[f"{prefix}.0.fn.0.bias"] = layer[1]
        if "multihead" in layer_name:
            weights[f"{prefix}.0.fn.1.rel_content_bias"] = layer[0]
            weights[f"{prefix}.0.fn.1.rel_pos_bias"] = layer[1]
            weights[f"{prefix}.0.fn.1.to_q.weight"] = layer[2].T
            weights[f"{prefix}.0.fn.1.to_k.weight"] = layer[3].T
            weights[f"{prefix}.0.fn.1.to_v.weight"] = layer[4].T
            weights[f"{prefix}.0.fn.1.to_out.weight"] = layer[5].T
            weights[f"{prefix}.0.fn.1.to_out.bias"] = layer[6]
            weights[f"{prefix}.0.fn.1.to_rel_k.weight"] = layer[7].T
        if "layer_normalization" in layer_name and "multihead" in last_layer:
            weights[f"{prefix}.1.fn.0.weight"] = layer[0]
            weights[f"{prefix}.1.fn.0.bias"] = layer[1]
        if "dense" in layer_name and "dense" not in last_layer:
            weights[f"{prefix}.1.fn.1.weight"] = layer[0].T
            weights[f"{prefix}.1.fn.1.bias"] = layer[1]
        if "dense" in layer_name and "dense" in last_layer:
            weights[f"{prefix}.1.fn.4.weight"] = layer[0].T
            weights[f"{prefix}.1.fn.4.bias"] = layer[1]
            transformer_layer += 1
        last_layer = layer_name
    return weights


def convert_the_convs(layers):
    weights = dict()
    conv_lookup = {
        "conv1d": "conv_dna.conv_layer",
        "conv1d_1": "res_tower.0.conv_layer",
        "conv1d_2": "res_tower.2.conv_layer",
        "conv1d_3": "res_tower.4.conv_layer",
        "conv1d_4": "res_tower.6.conv_layer",
        "conv1d_5": "res_tower.8.conv_layer",
        "conv1d_6": "unet1.1.conv_layer",
        "separable_conv1d": "separable1.conv_layer",
        "separable_conv1d_1": "separable0.conv_layer",
        "dense_16": "upsampling_unet1.0.conv_layer",
        "dense_17": "horizontal_conv1.conv_layer",
        "dense_18": "upsampling_unet0.0.conv_layer",
        "dense_19": "horizontal_conv0.conv_layer",
        "conv1d_7": "final_joined_convs.0.conv_layer",
        "dense_20": "human_head",
    }
    for conv_tf, conv_pt in conv_lookup.items():
        if "separable" in conv_tf:
            weights[f"{conv_pt}.0.weight"] = layers[conv_tf][0].permute(
                (1, 2, 0)
            )
            weights[f"{conv_pt}.1.weight"] = layers[conv_tf][1].permute(
                (2, 1, 0)
            )
            weights[f"{conv_pt}.1.bias"] = layers[conv_tf][2]
        else:
            try:
                weights[f"{conv_pt}.weight"] = layers[conv_tf][0].permute(
                    (2, 1, 0)
                )
            except:
                weights[f"{conv_pt}.weight"] = (
                    layers[conv_tf][0].unsqueeze(0).permute((2, 1, 0))
                )
            weights[f"{conv_pt}.bias"] = layers[conv_tf][1]
    return weights


def normalize_the_norms(layers):
    weights = dict()
    norm_lookup = {
        "batch_normalization": "res_tower.0.norm",
        "batch_normalization_1": "res_tower.2.norm",
        "batch_normalization_2": "res_tower.4.norm",
        "batch_normalization_3": "res_tower.6.norm",
        "batch_normalization_4": "res_tower.8.norm",
        "batch_normalization_5": "unet1.1.norm",
        "batch_normalization_6": "upsampling_unet1.0.norm",
        "batch_normalization_7": "horizontal_conv1.norm",
        "batch_normalization_8": "upsampling_unet0.0.norm",
        "batch_normalization_9": "horizontal_conv0.norm",
        "batch_normalization_10": "final_joined_convs.0.norm",
    }
    for norm_tf, norm_pt in norm_lookup.items():
        weights[f"{norm_pt}.weight"] = layers[norm_tf][0]
        weights[f"{norm_pt}.bias"] = layers[norm_tf][1]
        weights[f"{norm_pt}.running_mean"] = layers[norm_tf][2]
        weights[f"{norm_pt}.running_var"] = layers[norm_tf][3]
    return weights




In [3]:
import glob

# Specify the path pattern
path_pattern = "/nfs/turbo/umms-welchjd/mkarikom/borzoi_tf_weights/model0_best.h5*"

# Get list of all files matching the pattern
model_file_list = glob.glob(path_pattern)



In [4]:
params_file = "/nfs/turbo/umms-welchjd/mkarikom/borzoi_tf_weights/params_pred.json"

We transfer each checkpoint individually, by constructing the keras model from the checkpoint and then saving all weights to a dictionary.<br> We then translate the dictionary to a pytorch state_dict where keys match the current architecture. There are easier ways to do this (from the weights.h5 without having TF installed)...

In [5]:
%%capture
#Take
with open(params_file) as params_open :
    params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']
    
params_model['verbose'] = True

seqnn_model = seqnn.SeqNN(params_model)


In [6]:
# Get model weights from keras model

torch_weights_path = "/nfs/turbo/umms-welchjd/mkarikom/borzoi_torch_weights"

for model_file in model_file_list:
    seqnn_model.restore(model_file, 0)
    layer_weight_dict = dict()
    for layer in seqnn_model.model.layers: 
        cfg = layer.get_config()
        weights = layer.get_weights()
        if len(weights) != 0:
            layer_weight_dict[cfg['name']] = weights
            
    assert "batch_normalization" in layer_weight_dict.keys()
    # If this fails, it probably means you have to restart the kernel as each call of seqnn.SeqNN(params_model) increases the layernames numbers

    sorted(layer_weight_dict.keys())

    layer_weight_dict['conv_dna.conv_layer'] = layer_weight_dict['conv1d']
    for key, layer in layer_weight_dict.items():
        temp_list = []
        for weights in layer:
            temp_list.append(torch.as_tensor(weights))
        layer_weight_dict[key] = temp_list
        
        
    res_transformers = transform_the_transformer(layer_weight_dict)
    res_convs = convert_the_convs(layer_weight_dict)
    res_norms = normalize_the_norms(layer_weight_dict)


    z = {**res_transformers, **res_convs, **res_norms}

    torch.save(z, f"{torch_weights_path}/{model_file.split('/')[-1]}.pth")