In [1]:
import sys
sys.path.append("../../")

from torch.utils.data import DataLoader
from enformer_pytorch import Enformer, from_pretrained

from MPRA_predict.utils import *
from MPRA_predict.datasets import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_pred(model, test_data_loader, device):
    model = model.to(device)
    y_pred = []
    model.eval()
    with torch.no_grad():
        for (x, y) in tqdm(test_data_loader):
            x = x['seq'].to(device, non_blocking=True)
            output = model(x)['human']
            y_pred.append(output.cpu().numpy())
    y_pred = np.concatenate(y_pred, axis=0)
    return y_pred


# def get_pred(model, test_data_loader, device):
#     model = model.to(device)
#     y_pred = []
#     with torch.no_grad():
#         model.eval()
#         for (x, y) in tqdm(test_data_loader):
#             x = x['seq'].to(device)
#             x_rc = onehots_reverse_complement(x).to(device)
#             pred_1 = model(x)['human']
#             pred_2 = model(x_rc)['human']
#             pred = (pred_1 + pred_2)/2
#             y_pred.extend(pred.cpu().detach().numpy())
#     y_pred = np.array(y_pred)
#     return y_pred

# 实验1：random_seed的影响

# 结论1：Enformer不受随机数种子的影响

# 实验2：target_length的影响

In [8]:
set_seed(0)
trained_model_path = 'Enformer'
device = 'cuda:0'

model = from_pretrained(trained_model_path, target_length=896)
# model = from_pretrained(trained_model_path)
dataset = SeqLabelDataset(
    data_path='../../data/SirajMPRA/SirajMPRA_100.csv',
    seq_column='seq', padding=True, padded_len=196_608, N_fill_value=0)
test_data_loader = DataLoader(dataset, batch_size=6, shuffle=False, num_workers=4, pin_memory=True)
y_pred = get_pred(model, test_data_loader, device)

torch.cuda.empty_cache()
np.save(f'data/Enformer_Siraj_pred_100_len896.npy', y_pred)

set all labels to 0


100%|██████████| 17/17 [00:10<00:00,  1.58it/s]


In [9]:
y_pred_2 = np.load(f'data/Enformer_Siraj_pred_100_len2.npy')
print(y_pred_2.shape)
y_pred_4 = np.load(f'data/Enformer_Siraj_pred_100_len4.npy')
print(y_pred_4.shape)
y_pred_896 = np.load(f'data/Enformer_Siraj_pred_100_len896.npy')
print(y_pred_896.shape)

(100, 2, 5313)
(100, 4, 5313)
(100, 896, 5313)


In [10]:
print(pearsonr(y_pred_2[:, 0].reshape(-1), y_pred_4[:, 1].reshape(-1)))
print(pearsonr(y_pred_2[:, 0].reshape(-1), y_pred_896[:, 447].reshape(-1)))
print(pearsonr(y_pred_4[:, 0].reshape(-1), y_pred_896[:, 446].reshape(-1)))

PearsonRResult(statistic=0.9999999999999999, pvalue=0.0)
PearsonRResult(statistic=0.9999999999999037, pvalue=0.0)
PearsonRResult(statistic=0.9999999999999103, pvalue=0.0)


# 结论2：target_length基本不影响速度和结果, 轻微影响显存

target_length=2代表正中间两个位置

# 实验3：padding长度的影响

In [9]:
set_seed(0)
trained_model_path = 'Enformer'
device = 'cuda:0'

model = from_pretrained(trained_model_path, target_length=2)
dataset = SeqLabelDataset(
    data_path='../../data/SirajMPRA/SirajMPRA_100.csv',
    seq_column='seq', padding=True, padded_len=256, N_fill_value=0)
test_data_loader = DataLoader(dataset, batch_size=6, shuffle=False, num_workers=4, pin_memory=True)
y_pred = get_pred(model, test_data_loader, device)

torch.cuda.empty_cache()
np.save(f'data/Enformer_Siraj_pred_100_pad256.npy', y_pred)

set all labels to 0


100%|██████████| 17/17 [00:00<00:00, 91.53it/s]


In [17]:
y_pred_0 = np.load(f'data/Enformer_Siraj_pred_100_pad200.npy')
print(y_pred_0.shape)
y_pred_1 = np.load(f'data/Enformer_Siraj_pred_100_pad256.npy')
print(y_pred_1.shape)
y_pred_2 = np.load(f'data/Enformer_Siraj_pred_100_pad196k.npy')
print(y_pred_2.shape)
y_pred_3 = np.load(f'data/Enformer_Siraj_pred_100_pad196k_N025.npy')
print(y_pred_3.shape)

(100, 2, 5313)
(100, 2, 5313)
(100, 2, 5313)
(100, 2, 5313)


In [18]:
combined_y_pred = np.stack([y_pred_0, y_pred_1, y_pred_2, y_pred_3]).reshape(4, -1)
print(np.corrcoef(combined_y_pred))

[[1.         0.92467371 0.47678416 0.39599049]
 [0.92467371 1.         0.40123296 0.37183952]
 [0.47678416 0.40123296 1.         0.61093956]
 [0.39599049 0.37183952 0.61093956 1.        ]]


In [19]:
print(y_pred_0.mean(), y_pred_1.mean(), y_pred_2.mean(), y_pred_3.mean())

1.0454046 1.0853074 0.9762632 0.71978724


# 结论3：pad显著影响结果

pad=200和pad=256结果差不太多(r~0.9)

但是和pad=196608差别很大(r~0.4)

平均信号强度基本不变

N=0和N=0.25结果差别很大(r~0.6)

# 实验4： 比较genome padding和N padding

In [21]:
# from MPRA_predict.datasets import SeqInterval
# seq_interval = SeqInterval(genome_path='../../../../genome/hg38.fa')

In [None]:
dataset = BedDataset(
    data_path='../../data/SirajMPRA/SirajMPRA_100.csv', 
    genome_path='../../../../genome/hg38.fa')