In [1]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

from torch.utils.data import DataLoader
from enformer_pytorch import Enformer, from_pretrained

from MPRA_predict.utils import *
from MPRA_predict.datasets import SeqLabelDataset

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
def get_pred(model, test_data_loader, device):
    model = model.to(device)
    y_pred = []
    model.eval()
    with torch.no_grad():
        for (x, y) in tqdm(test_data_loader):
            x = x['seq'].to(device, non_blocking=True)
            output = model(x)['human']
            y_pred.append(output.cpu().numpy())
    y_pred = np.concatenate(y_pred, axis=0)
    return y_pred


# def get_pred(model, test_data_loader, device):
#     model = model.to(device)
#     y_pred = []
#     with torch.no_grad():
#         model.eval()
#         for (x, y) in tqdm(test_data_loader):
#             x = x['seq'].to(device)
#             x_rc = onehots_reverse_complement(x).to(device)
#             pred_1 = model(x)['human']
#             pred_2 = model(x_rc)['human']
#             pred = (pred_1 + pred_2)/2
#             y_pred.extend(pred.cpu().detach().numpy())
#     y_pred = np.array(y_pred)
#     return y_pred

In [17]:
set_seed(0)
trained_model_path = 'Enformer'
device = 'cuda:0'

model = from_pretrained(trained_model_path, target_length=2)
# model = from_pretrained(trained_model_path)
dataset = SeqLabelDataset(data_path='/home/hxcai/cell_type_specific_CRE/data/SirajMPRA/SirajMPRA_100.csv',
                          seq_column='seq', padding=True, padded_len=196_608, N_fill_value=0)
test_data_loader = DataLoader(dataset, batch_size=6, shuffle=False, num_workers=4, pin_memory=True)
y_pred = get_pred(model, test_data_loader, device)

torch.cuda.empty_cache()

set all labels to 0


100%|██████████| 17/17 [00:10<00:00,  1.67it/s]


In [7]:
np.save(f'data/Enformer_Siraj_pred_100_len896.npy', y_pred)

In [8]:
y_pred_2 = np.load(f'data/Enformer_Siraj_pred_100_len2.npy')
print(y_pred_2.shape)
y_pred_4 = np.load(f'data/Enformer_Siraj_pred_100_len4.npy')
print(y_pred_4.shape)
y_pred_896 = np.load(f'data/Enformer_Siraj_pred_100_len896.npy')
print(y_pred_896.shape)

(100, 2, 5313)
(100, 4, 5313)
(100, 896, 5313)


In [23]:
print(pearsonr(y_pred_2[:, 0].reshape(-1), y_pred_4[:, 1].reshape(-1)))
print(pearsonr(y_pred_2[:, 0].reshape(-1), y_pred_896[:, 447].reshape(-1)))
print(pearsonr(y_pred_4[:, 0].reshape(-1), y_pred_896[:, 446].reshape(-1)))

PearsonRResult(statistic=1.0, pvalue=0.0)
PearsonRResult(statistic=0.9999999999434269, pvalue=0.0)
PearsonRResult(statistic=0.9999999999064728, pvalue=0.0)


# 结论1：Enformer不受到随机数种子的影响

# 结论2：target_length参数基本不影响速度和结果, 轻微影响显存, target_length=2代表正中间两个位置, 和[447:449]的结果一样（r=0.9999999999613916）