In [1]:
import sys
sys.path.append("../../")

from enformer_pytorch import from_pretrained
from MPRA_predict.utils import *
from MPRA_predict.datasets import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_pred(model, test_data_loader, device):
    model = model.to(device)
    y_pred = []
    model.eval()
    with torch.no_grad():
        for batch in tqdm(test_data_loader):
            if isinstance(batch, (list, tuple)):
                x = batch[0]
            elif isinstance(batch, dict):
                x = batch['seq']

            x = x.to(device)
            output = model(x)['human']
            y_pred.append(output.cpu().numpy())
    y_pred = np.concatenate(y_pred, axis=0)
    return y_pred


def get_pred_rc(model, test_data_loader, device):
    model = model.to(device)
    y_pred = []
    model.eval()
    with torch.no_grad():
        for batch in tqdm(test_data_loader):
            if isinstance(batch, (list, tuple)):
                x = batch[0]
            elif isinstance(batch, dict):
                x = batch['seq']

            x = x.to(device)
            x_rc = onehots_rc(x).to(device)
            output = (model(x)['human'] + model(x_rc)['human']) / 2
            y_pred.append(output.cpu().numpy())
    y_pred = np.concatenate(y_pred, axis=0)
    return y_pred

# 实验0：检查bed_dataset和seqlabeldataset是否相同

In [4]:
dataset_1 = SeqLabelDataset(
    data_path='../../data/SirajMPRA/SirajMPRA_100.csv',
    seq_column='seq', 
    padding=True, 
    padded_length=300, 
    N_fill_value=0)


dataset_2 = BedDataset(
    bed_path='../../data/SirajMPRA/SirajMPRA_100.csv', 
    genome_path='../../../../genome/hg19.fa',
    padding=True, 
    padded_length=300, 
    N_fill_value=0,
    )

for i in range(100):
    if (dataset_1[2]['seq'] != dataset_2[2]['seq']).all():
        print(i)

In [None]:
for padded_length in [200, 256, 2**10, 2**12, 2**14, 2**16, 131072, 196608, 393216]:

    set_seed(0)
    trained_model_path = 'Enformer'
    device = 'cuda:0'

    model = from_pretrained(trained_model_path, target_length=2)
    dataset = SeqLabelDataset(
        data_path='../../data/SirajMPRA/SirajMPRA_100.csv',
        seq_column='seq', padding=True, padded_length=padded_length, N_fill_value=0)
    test_data_loader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)
    y_pred = get_pred(model, test_data_loader, device)

    torch.cuda.empty_cache()
    np.save(f'data/Enformer_Siraj_pred_padded_len={padded_length}.npy', y_pred)

In [4]:
set_seed(0)
trained_model_path = 'Enformer'
device = 'cuda:0'

model = from_pretrained(trained_model_path, target_length=2)
dataset = BedDataset(
    bed_path='../../data/SirajMPRA/SirajMPRA_100.csv', 
    genome_path='../../../../genome/hg19.fa',
    genome_window_size=196608)

test_data_loader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)
y_pred = get_pred(model, test_data_loader, device)

torch.cuda.empty_cache()
np.save(f'data/Enformer_Siraj_pred_genome_window_size=196608.npy', y_pred)

100%|██████████| 25/25 [00:10<00:00,  2.44it/s]


In [9]:
y_pred = np.load(f'data/Enformer_Siraj_pred_bed_dataset.npy')
print(y_pred.shape)

y_pred_2 = np.load(f'data/Enformer_Siraj_pred_padded_len=196608.npy')
print(y_pred_2.shape)

(y_pred == y_pred_2).all()

(100, 2, 5313)
(100, 2, 5313)


False

In [10]:
y_pred

array([[[0.49641743, 0.53111815, 0.6526666 , ..., 0.03282204,
         0.24221255, 0.39656678],
        [0.35675162, 0.35826644, 0.5343172 , ..., 0.07466529,
         0.63177264, 0.998656  ]],

       [[0.49641743, 0.53111815, 0.6526666 , ..., 0.03282204,
         0.24221255, 0.39656678],
        [0.35675162, 0.35826644, 0.5343172 , ..., 0.07466529,
         0.63177264, 0.998656  ]],

       [[0.05341086, 0.04149767, 0.03522886, ..., 0.00368416,
         0.0114732 , 0.00979407],
        [0.05142899, 0.04179503, 0.0392975 , ..., 0.00330846,
         0.01074063, 0.00943512]],

       ...,

       [[0.02212196, 0.02886742, 0.03534057, ..., 0.00226836,
         0.00922884, 0.00435526],
        [0.08546597, 0.09543043, 0.11159252, ..., 0.00503291,
         0.01137636, 0.00578577]],

       [[0.12694344, 0.13454068, 0.16969986, ..., 0.00334302,
         0.00774447, 0.00645422],
        [0.17051901, 0.18340878, 0.23938677, ..., 0.00396433,
         0.01234872, 0.00817071]],

       [[0.126943