In [1]:
import os, sys
import time
import argparse
import multiprocessing
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
from scipy.sparse import coo_matrix

import models.deephic as deephic

from utils.io import spreadM, together

from all_parser import *

In [2]:
def dataloader(data, batch_size=64):
    inputs = torch.tensor(data['data'], dtype=torch.float)
    inds = torch.tensor(data['inds'], dtype=torch.long)

    dataset = TensorDataset(inputs, inds)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return loader


#提取信息
def data_info(data):
    indices = data['inds']
    compacts = data['compacts'][()]
    sizes = data['sizes'][()]
    return indices, compacts, sizes

get_digit = lambda x: int(''.join(list(filter(str.isdigit, x))))

def filename_parser(filename):
    info_str = filename.split('.')[0].split('_')[2:-1]
    chunk = get_digit(info_str[0])
    stride = get_digit(info_str[1])
    bound = get_digit(info_str[2])
    scale = 1 if info_str[3] == 'nonpool' else get_digit(info_str[3])
    return chunk, stride, bound, scale

#TODO:好好研究一下这个
def deephic_predictor(deephic_loader, ckpt_file, scale, res_num, device):
    #加载模型
    #只用Generator
    deepmodel = deephic.Generator(scale_factor=scale, in_channel=1, resblock_num=res_num).to(device)
    if not os.path.isfile(ckpt_file):
        ckpt_file = f'save/{ckpt_file}'
    
    #加载训练好的权重
    deepmodel.load_state_dict(torch.load(ckpt_file, map_location=device))


    print(f'Loading DeepHiC checkpoint file from "{ckpt_file}"')
    result_data = []
    result_inds = []
    deepmodel.eval()

    with torch.no_grad():
        for batch in tqdm(deephic_loader, desc='DeepHiC Predicting: '):
            lr, inds = batch
            lr = lr.to(device) #shape:[64, 1, 40, 40]
            out = deepmodel(lr) #[64, 1, 40, 40]
            result_data.append(out.to('cpu').numpy())
            result_inds.append(inds.numpy())
    
    #拼接
    result_data = np.concatenate(result_data, axis=0)#(76090, 1, 40, 40)
    result_inds = np.concatenate(result_inds, axis=0)#(76090,4)
    print("result_data shape",result_data.shape)
    print("result_inds shape",result_inds.shape)


    deep_hics = together(result_data, result_inds, tag='Reconstructing: ')
    #返回的是字典
    return deep_hics



def save_data(deep_hic, compact, size, file):
    #TODO:关键是这里
    deephic = spreadM(deep_hic, compact, size, convert_int=False, verbose=True)
    np.savez_compressed(file, deephic=deephic, compact=compact)
    print('Saving file:', file)



In [4]:
# python data_predict.py -lr 40kb -ckpt save/generator_nonpool_deephic.pytorch -c GM12878
cell_line = "GM12878"
low_res = "40kb"

#这里是可以改的
ckpt_file = "save/deephic_raw_16.pth"
res_num = 5
# cuda = args.cuda
print('WARNING: Predict process needs large memory, thus ensure that your machine have ~150G memory.')
if multiprocessing.cpu_count() > 23:
    pool_num = 23
else:
    exit()



In [5]:
in_dir = os.path.join(root_dir, 'data')
out_dir = os.path.join(root_dir, 'predict', cell_line)
mkdir(out_dir)

files = [f for f in os.listdir(in_dir) if f.find(low_res) >= 0]
deephic_file = [f for f in files if f.find(cell_line.lower()+'.npz') >= 0][0]

In [6]:
deephic_file

'deephic_10kb40kb_c40_s40_b201_nonpool_gm12878.npz'

In [7]:
chunk, stride, bound, scale = filename_parser(deephic_file)

device = torch.device(f'cuda:{cuda}' if (torch.cuda.is_available() and cuda>-1 and cuda<torch.cuda.device_count()) else 'cpu')
print(f'Using device: {device}')

Using device: cpu


In [8]:
start = time.time()
print(f'Loading data[DeepHiC]: {deephic_file}')
deephic_data = np.load(os.path.join(in_dir, deephic_file), allow_pickle=True)
deephic_loader = dataloader(deephic_data)            

Loading data[DeepHiC]: deephic_10kb40kb_c40_s40_b201_nonpool_gm12878.npz


In [9]:
indices, compacts, sizes = data_info(deephic_data)

In [14]:
#deephic_loader里是40*40的，返回的是完整的了
deep_hics = deephic_predictor(deephic_loader, ckpt_file, scale, res_num, device)

Loading DeepHiC checkpoint file from "save/deephic_raw_16.pth"


DeepHiC Predicting: 100%|██████████| 1189/1189 [05:42<00:00,  3.47it/s]


result_data shape (76090, 1, 40, 40)
result_inds shape (76090, 4)
Reconstructing:  data contain [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 'X'] chromosomes


In [19]:


pool = multiprocessing.Pool(processes=pool_num)
print(f'Start a multiprocess pool with process_num = {pool_num} for saving predicted data')

def save_data_n(key):
    file = os.path.join(out_dir, f'predict_chr{key}_{low_res}.npz')
    save_data(deep_hics[key], compacts[key], sizes[key], file)


for key in compacts.keys():
    pool.apply_async(save_data_n, (key,))

    
pool.close()
pool.join()
print(f'All data saved. Running cost is {(time.time()-start)/60:.1f} min.')

Process ForkPoolWorker-1:
Process ForkPoolWorker-3:
Process ForkPoolWorker-7:
Process ForkPoolWorker-12:
Process ForkPoolWorker-9:
Process ForkPoolWorker-11:
Process ForkPoolWorker-5:
Process ForkPoolWorker-4:
Process ForkPoolWorker-10:
Process ForkPoolWorker-2:
Process ForkPoolWorker-8:
Process ForkPoolWorker-6:
Process ForkPoolWorker-13:
Process ForkPoolWorker-14:
Process ForkPoolWorker-15:
Process ForkPoolWorker-20:
Process ForkPoolWorker-17:
Process ForkPoolWorker-18:
Process ForkPoolWorker-22:
Process ForkPoolWorker-16:
Process ForkPoolWorker-21:
Process ForkPoolWorker-19:
Process ForkPoolWorker-23:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Tra

Start a multiprocess pool with process_num = 23 for saving predicted data


In [None]:
data = np.load("/share/home/mliu/sc_sv/imputation/DeepHiC/data/RaoHiC/predict/GM12878/predict_chr1_40kb.npz", allow_pickle=True)