In [None]:
import torch
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
import numpy as np
import json
from borzoi_pytorch import Borzoi
from borzoi_pytorch.pytorch_borzoi_helpers import predict_tracks

#### We check if all tracks of the WT sequence of the eQTL example are predicted with the PyTorch model as with the original TensorFlow/Keras model

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
folds = 4
model_folds = []

# Load params
with open('../borzoi_pytorch/pytorch_borzoi_arch.json') as params_file:
    params = json.load(params_file)
# Adjust cropping size to match the Borzoi notebook's almost-full-length output
params['final'][2][0]['target_length'] = 16352

# Create models and load in weights
for fold in range(folds):
    borzoi = Borzoi(params)
    missing_keys, unexpected_keys = borzoi.load_state_dict(torch.load(f"../weights/borzoi_fold_0.pt"), strict = False)
    borzoi.to(device)
    borzoi.eval()
    model_folds.append(borzoi)

In [None]:
slices = [7522, 7523, 7524, 7525, 7526, 7527, 7528, 7529, 7530, 7531, 7532,
            7533, 7534, 7535, 7536, 7537, 7538, 7539, 7540, 7541, 7542, 7543,
            7544, 7545, 7546, 7547, 7548, 7549, 7550, 7551, 7552, 7553, 7554,
            7555, 7556, 7557, 7558, 7559, 7560, 7561, 7562, 7563, 7564, 7565,
            7566, 7567, 7568, 7569, 7570, 7571, 7572, 7573, 7574, 7575, 7576,
            7577, 7578, 7579, 7580, 7581, 7582, 7583, 7584, 7585, 7586, 7587,
            7588, 7589, 7590, 7591, 7592, 7593, 7594, 7595, 7596, 7597, 7598,
            7599, 7600, 7601, 7602, 7603, 7604, 7605, 7606, 7607, 7608, 7609,
            7610] # slices from the first eQTL example

In [None]:
# Load test sequence and pre-saved test outputs
sequence_one_hot_wt = torch.as_tensor(np.load('../wt_seq.npy')).to(device)
wt_pred_across_folds_tf = np.load('../wt_pred_across_folds.npy')

In [None]:
# Predict tracks
wt_pred_across_folds_pt = predict_tracks(model_folds, sequence_one_hot_wt, 0, slices)
# Reshape to match saved TensorFlow Borzoi output
wt_pred_across_folds_pt = wt_pred_across_folds_pt.transpose(0, 2, 1)[None, ...]

In [None]:
wt_pred_across_folds_pt.shape, wt_pred_across_folds_tf.shape

In [None]:
np.allclose(wt_pred_across_folds_pt, wt_pred_across_folds_tf, rtol=0, atol = 0.00001)

#### Up to numerical precision, the Borzoi-ensemble ported to PyTorch gets the same results as TF-based Borzoi

In [None]:
wt_pred_across_folds_pt.max(), wt_pred_across_folds_tf.max()

In [None]:
wt_pred_across_folds_pt.min(), wt_pred_across_folds_tf.min()