In [1]:
import pyfaidx
import json
import pysam
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
from tqdm import tqdm
from pytorch_borzoi import Borzoi
from pytorch_borzoi_helpers import predict_tracks

#### We check if all tracks of the WT sequence of the EQTL example are predicted with the Pytorch model as with the original Calico TF keras model

In [2]:
device = torch.device("cuda")
folds = 4
model_folds = []
for fold in tqdm(range(folds)):
    borzoi = Borzoi(checkpoint_path = f'pytorch_weights/borzoi_fold{fold}.pt')
    borzoi.to(device)
    borzoi.eval()
    model_folds.append(borzoi)

100%|██████████| 4/4 [00:08<00:00,  2.23s/it]


In [3]:
slices = [7522, 7523, 7524, 7525, 7526, 7527, 7528, 7529, 7530, 7531, 7532,
            7533, 7534, 7535, 7536, 7537, 7538, 7539, 7540, 7541, 7542, 7543,
            7544, 7545, 7546, 7547, 7548, 7549, 7550, 7551, 7552, 7553, 7554,
            7555, 7556, 7557, 7558, 7559, 7560, 7561, 7562, 7563, 7564, 7565,
            7566, 7567, 7568, 7569, 7570, 7571, 7572, 7573, 7574, 7575, 7576,
            7577, 7578, 7579, 7580, 7581, 7582, 7583, 7584, 7585, 7586, 7587,
            7588, 7589, 7590, 7591, 7592, 7593, 7594, 7595, 7596, 7597, 7598,
            7599, 7600, 7601, 7602, 7603, 7604, 7605, 7606, 7607, 7608, 7609,
            7610] # slices from the first EQTL example

In [4]:
import numpy as np
sequence_one_hot_wt = torch.as_tensor(np.load('../wt_seq.npy')).to(device)
wt_pred_across_folds_pt = predict_tracks(model_folds,sequence_one_hot_wt.permute(1,0), slices)

In [5]:
wt_pred_across_folds_tf = np.load('../wt_pred_across_folds.npy')

In [6]:
wt_pred_across_folds_pt.shape, wt_pred_across_folds_tf.shape

((1, 4, 16352, 89), (1, 4, 16352, 89))

In [7]:
np.allclose(wt_pred_across_folds_pt,wt_pred_across_folds_tf,rtol=0, atol = 0.00001)

True

#### Up to numerical precision, the Borzoi-ensemble ported to Pytorch gets the same results as the TF-Borzoi

In [8]:
wt_pred_across_folds_pt.max(), wt_pred_across_folds_tf.max()

(4.6815963, 4.6815977)

In [9]:
wt_pred_across_folds_pt.min(), wt_pred_across_folds_tf.min()

(4.0989912e-07, 4.099e-07)