# Borzoi weight conversion from TensorFlow to PyTorch

## Load packages and files

In [1]:
import os
import h5py
import numpy as np
import json
import torch

In [None]:
# Download weights files
# model_0 is human, model_1 is mouse
!wget https://storage.googleapis.com/seqnn-share/borzoi/f0/model0_best.h5 -O f0_model0_best.h5
!wget https://storage.googleapis.com/seqnn-share/borzoi/f0/model1_best.h5 -O f0_model1_best.h5
!wget https://storage.googleapis.com/seqnn-share/borzoi/f1/model0_best.h5 -O f1_model0_best.h5
!wget https://storage.googleapis.com/seqnn-share/borzoi/f1/model1_best.h5 -O f1_model1_best.h5
!wget https://storage.googleapis.com/seqnn-share/borzoi/f2/model0_best.h5 -O f2_model0_best.h5
!wget https://storage.googleapis.com/seqnn-share/borzoi/f2/model1_best.h5 -O f2_model1_best.h5
!wget https://storage.googleapis.com/seqnn-share/borzoi/f3/model0_best.h5 -O f3_model0_best.h5
!wget https://storage.googleapis.com/seqnn-share/borzoi/f3/model1_best.h5 -O f3_model1_best.h5

## Create porting functions

In [2]:
def make_trans_lookup(n_layers):
    def tf_count(x):
        # TensorFlow doesn't label first layer of a kind, second is _1, third is _2, ...
        if x == 0:
            return ""
        else:
            return f"_{x}"
    out_dict = dict()
    for layer_i in range(n_layers):
        out_dict[f"layer_normalization{tf_count(2*layer_i)}"] = f"distal.0.{layer_i}.transf.norm"
        out_dict[f"multihead_attention{tf_count(layer_i)}"] = f"distal.0.{layer_i}.transf.attention"
        out_dict[f"layer_normalization{tf_count(2*layer_i+1)}"] = f"distal.0.{layer_i}.ff.norm"
        out_dict[f"dense{tf_count(2*layer_i)}"] = f"distal.0.{layer_i}.ff.l1"
        out_dict[f"dense{tf_count(2*layer_i+1)}"] = f"distal.0.{layer_i}.ff.l2"
    return out_dict

def transform_the_transformer(layers, n_transformer_layers = 8):
    weights = dict()
    trans_lookup = make_trans_lookup(n_transformer_layers)
    for trans_tf, trans_pt in trans_lookup.items():
        if "layer_normalization" in trans_tf:
            weights[f"{trans_pt}.weight"] = torch.tensor(layers[trans_tf][trans_tf]['gamma:0'][...]) #layer[0]
            weights[f"{trans_pt}.bias"] = torch.tensor(layers[trans_tf][trans_tf]['beta:0'][...]) #layer[1]
        elif "multihead_attention" in trans_tf:
            weights[f"{trans_pt}.rel_content_bias"] = torch.tensor(layers[trans_tf]['r_w_bias:0'][...]) #layer[0]
            weights[f"{trans_pt}.rel_pos_bias"] = torch.tensor(layers[trans_tf]['r_r_bias:0'][...]) #layer[1]
            weights[f"{trans_pt}.to_q.weight"] = torch.tensor(layers[trans_tf][trans_tf]['q_layer']['kernel:0'][...].T) #layer[2].T
            weights[f"{trans_pt}.to_k.weight"] = torch.tensor(layers[trans_tf][trans_tf]['k_layer']['kernel:0'][...].T) #layer[3].T
            weights[f"{trans_pt}.to_v.weight"] = torch.tensor(layers[trans_tf][trans_tf]['v_layer']['kernel:0'][...].T) #layer[4].T
            weights[f"{trans_pt}.to_out.weight"] = torch.tensor(layers[trans_tf][trans_tf]['embedding_layer']['kernel:0'][...].T) #layer[5].T
            weights[f"{trans_pt}.to_out.bias"] = torch.tensor(layers[trans_tf][trans_tf]['embedding_layer']['bias:0'][...]) #layer[6]
            weights[f"{trans_pt}.to_rel_k.weight"] = torch.tensor(layers[trans_tf][trans_tf]['r_k_layer']['kernel:0'][...].T) #layer[7].T
        elif "dense" in trans_tf:
            weights[f"{trans_pt}.weight"] = torch.tensor(layers[trans_tf][trans_tf]['kernel:0'][...].T) #layer[0].T
            weights[f"{trans_pt}.bias"] = torch.tensor(layers[trans_tf][trans_tf]['bias:0'][...]) #layer[1]
    return weights

def convert_the_convs(layers):
    weights = dict()
    conv_lookup = {
        "conv1d": "local_list.0.0",
        "conv1d_1": "local_list.0.2.conv_layer",
        "conv1d_2": "local_list.0.3.conv_layer",
        "conv1d_3": "local_list.0.4.conv_layer",
        "conv1d_4": "local_list.0.5.conv_layer",
        "conv1d_5": "local_list.0.6.conv_layer",
        "conv1d_6": "local_list.1.1.conv_layer",
        "separable_conv1d" : "final_list.0.0.conv_sep.conv_layer",
        "separable_conv1d_1": "final_list.1.0.conv_sep.conv_layer",
        "dense_16": "final_list.0.0.conv_input.conv_layer",
        "dense_17": "final_list.0.0.conv_horizontal.conv_layer",
        "dense_18": "final_list.1.0.conv_input.conv_layer",
        "dense_19": "final_list.1.0.conv_horizontal.conv_layer",
        "conv1d_7": "final_list.2.1.conv_layer",
        "dense_20": "head_human.0"
    }
    for conv_tf, conv_pt in conv_lookup.items():
        if 'separable' in conv_tf:
            weights[f"{conv_pt}.0.weight"] = torch.tensor(layers[conv_tf][conv_tf]['depthwise_kernel:0'][...]).permute((1,2,0)) #layers[conv_tf][0].permute((1,2,0))
            weights[f"{conv_pt}.1.weight"] = torch.tensor(layers[conv_tf][conv_tf]['pointwise_kernel:0'][...]).permute((2,1,0)) #layers[conv_tf][1].permute((2,1,0))
            weights[f"{conv_pt}.1.bias"] = torch.tensor(layers[conv_tf][conv_tf]['bias:0'][...]) #layers[conv_tf][2]
        else:
            try:
                weights[f"{conv_pt}.weight"] = torch.tensor(layers[conv_tf][conv_tf]['kernel:0'][...]).permute((2,1,0)) #layers[conv_tf][0].permute((2,1,0))
            except:
                weights[f"{conv_pt}.weight"] = torch.tensor(layers[conv_tf][conv_tf]['kernel:0'][...]).unsqueeze(0).permute((2,1,0)) #layers[conv_tf][0].unsqueeze(0).permute((2,1,0))
            weights[f"{conv_pt}.bias"] = torch.tensor(layers[conv_tf][conv_tf]['bias:0'][...]) #layers[conv_tf][1]
    return weights

def match_the_mouse(layers):
    weights = dict()
    head_lookup = {"dense_21": "head_mouse.0"}
    for conv_tf, conv_pt in head_lookup.items():
        try:
            weights[f"{conv_pt}.weight"] = torch.tensor(layers[conv_tf][conv_tf]['kernel:0'][...]).permute((2,1,0)) #layers[conv_tf][0].permute((2,1,0))
        except:
            weights[f"{conv_pt}.weight"] = torch.tensor(layers[conv_tf][conv_tf]['kernel:0'][...]).unsqueeze(0).permute((2,1,0)) #layers[conv_tf][0].unsqueeze(0).permute((2,1,0))
        weights[f"{conv_pt}.bias"] = torch.tensor(layers[conv_tf][conv_tf]['bias:0'][...]) #layers[conv_tf][1]
    return weights

def normalize_the_norms(layers):
    weights = dict()
    norm_lookup = {
        "sync_batch_normalization": "local_list.0.2.norm",
        "sync_batch_normalization_1": "local_list.0.3.norm",
        "sync_batch_normalization_2": "local_list.0.4.norm",
        "sync_batch_normalization_3": "local_list.0.5.norm",
        "sync_batch_normalization_4": "local_list.0.6.norm",
        "sync_batch_normalization_5": "local_list.1.1.norm",
        "sync_batch_normalization_6": "final_list.0.0.conv_input.norm",
        "sync_batch_normalization_7": "final_list.0.0.conv_horizontal.norm",
        "sync_batch_normalization_8": "final_list.1.0.conv_input.norm",
        "sync_batch_normalization_9": "final_list.1.0.conv_horizontal.norm",
        "sync_batch_normalization_10": "final_list.2.1.norm"
    }
    for norm_tf, norm_pt in norm_lookup.items():
        weights[f"{norm_pt}.weight"] = torch.tensor(layers[norm_tf][norm_tf]['gamma:0'][...]) #layers[norm_tf][0]
        weights[f"{norm_pt}.bias"] = torch.tensor(layers[norm_tf][norm_tf]['beta:0'][...]) #layers[norm_tf][1]
        weights[f"{norm_pt}.running_mean"] = torch.tensor(layers[norm_tf][norm_tf]['moving_mean:0'][...]) #layers[norm_tf][2]
        weights[f"{norm_pt}.running_var"] = torch.tensor(layers[norm_tf][norm_tf]['moving_variance:0'][...]) #layers[norm_tf][3]
    return weights

## Port weights

We transfer each checkpoint individually, by reading the keras weights.h5 and saving all weights to a dictionary of tensors.   
Its keys match the model, so it works as a pytorch state_dict dictionary for `Borzoi.load_state_dict(dict)` when loaded with `torch.load('borzoi_fold_xxx.pt')`.

In [3]:
# Select folds to port
weights_paths = {f"../weights/borzoi_fold_{fold}.pt": (f"f{fold}_model0_best.h5", f"f{fold}_model1_best.h5") for fold in range(4)}

In [4]:
# Transfer them to a pytorch state dict
for pt_path, (human_tf_path, mouse_tf_path) in weights_paths.items():
    # Convert weights
    with h5py.File(human_tf_path) as human_weights, h5py.File(mouse_tf_path) as mouse_weights:
        res_convs = convert_the_convs(human_weights['model_weights'])
        res_mouse = match_the_mouse(mouse_weights['model_weights'])
        res_norms = normalize_the_norms(human_weights['model_weights'])
        res_transformers = transform_the_transformer(human_weights['model_weights'])
    # Save weights to 
    z = {**res_convs, **res_mouse, **res_norms, **res_transformers}    
    torch.save(z, pt_path)
    print(f"Saved to {pt_path}")

Saved to ../weights/borzoi_fold_0.pt
Saved to ../weights/borzoi_fold_1.pt
Saved to ../weights/borzoi_fold_2.pt
Saved to ../weights/borzoi_fold_3.pt
