In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [3]:
import sys
sys.path.append("../2_train_models")
from file_configs import FoldFilesConfig

from data_loading import extract_peaks, extract_observed_profiles
#from performance_metrics import compute_performance_metrics
#from plot_utils import get_continuous_cmap

import matplotlib.pyplot as plt
#import matplotlib.colors as colors
#import matplotlib.cm as cmx
import numpy as np
#import pandas as pd
import torch
#from collections import defaultdict

In [4]:
cell_types = ["K562", "A673", "CACO2", "CALU3", "HUVEC", "MCF10A"]

timestamps = {"K562" : "2023-05-29_15-51-40",
              "A673" : "2023-06-11_20-11-32",
              "CACO2" : "2023-06-12_21-46-40",
              "CALU3" : "2023-06-14_00-43-44",
              "HUVEC" : "2023-06-16_21-59-35",
              "MCF10A" : "2023-06-15_06-07-40"}

model_type = "strand_merged_umap"
data_type = "procap"

in_window = 2114
out_window = 1000

In [8]:
proj_dir = FoldFilesConfig(cell_types[0], model_type, "1", timestamps[cell_types[0]], data_type).proj_dir

union_peaks_path = proj_dir + "data/" + data_type + "/processed/union_peaks_fold1_val.bed.gz"

Timestamp: 2023-05-29_15-51-40


In [9]:
! ls $union_peaks_path

/mnt/lab_data2/kcochran/procapnet/data/procap/processed/union_peaks_fold1_val.bed.gz


In [12]:
def _predict(model, onehot_seqs, batch_size=64, logits = False):
    with torch.no_grad():
        starts = np.arange(0, onehot_seqs.shape[0], batch_size)
        ends = starts + batch_size

        y_profiles, y_counts = [], []
        for start, end in zip(starts, ends):
            X_batch = onehot_seqs[start:end].cuda()

            y_profiles_, y_counts_ = model(X_batch)
            if not logits:  # apply softmax
                y_profiles_ = model.log_softmax(y_profiles_)
            y_profiles.append(y_profiles_.cpu().detach().numpy())
            y_counts.append(y_counts_.cpu().detach().numpy())

        y_profiles = np.concatenate(y_profiles)
        y_counts = np.concatenate(y_counts)
        return y_profiles, y_counts

    
def predict_union_peaks_all_cells(cell_types, timestamps, model_type, data_type):
    pred_logcounts = dict()
    pred_profiles = dict()
    true_counts = dict()
    true_profiles = dict()
    for cell_type in cell_types:
        print(cell_type)
        config = FoldFilesConfig(cell_type, model_type, "1", timestamps[cell_type], data_type)
        model = torch.load(config.model_save_path).cuda()
        
        onehot_seqs, true_profs = extract_peaks(config.genome_path,
                                       config.chrom_sizes,
                                       config.plus_bw_path,
                                       config.minus_bw_path,
                                       union_peaks_path,
                                       in_window=in_window,
                                       out_window=out_window,
                                       max_jitter=0, verbose=True)

        pred_profs, pred_logcts = _predict(model, torch.tensor(onehot_seqs, dtype=torch.float))
        pred_logcounts[cell_type] = pred_logcts
        pred_profiles[cell_type] = pred_profs
        true_counts[cell_type] = true_profs.sum(axis=-1)
        true_profiles[cell_type] = true_profs
    return pred_logcounts, true_counts, pred_profiles, true_profiles


union_pred_logcounts, union_true_counts, union_pred_profiles, union_true_profiles = predict_union_peaks_all_cells(cell_types, timestamps,
                                                     model_type, data_type)

K562
Timestamp: 2023-05-29_15-51-40
Loading genome sequence from /mnt/lab_data2/kcochran/procapnet/genomes/hg38.withrDNA.fasta


Reading FASTA: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:09<00:00,  2.53it/s]
Loading Peaks: 14269it [00:25, 566.47it/s]


== In Extract Peaks ==
Peak filepath: /mnt/lab_data2/kcochran/procapnet/data/procap/processed/union_peaks_fold1_val.bed.gz
Sequence length (with jitter): 2114
Profile length (with jitter): 1000
Max jitter applied: 0
Num. Examples: 14269
Mask loaded? False
A673
Timestamp: 2023-06-11_20-11-32
Loading genome sequence from /mnt/lab_data2/kcochran/procapnet/genomes/hg38.withrDNA.fasta


Reading FASTA: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:09<00:00,  2.51it/s]
Loading Peaks: 14269it [00:24, 577.09it/s]


== In Extract Peaks ==
Peak filepath: /mnt/lab_data2/kcochran/procapnet/data/procap/processed/union_peaks_fold1_val.bed.gz
Sequence length (with jitter): 2114
Profile length (with jitter): 1000
Max jitter applied: 0
Num. Examples: 14269
Mask loaded? False
CACO2
Timestamp: 2023-06-12_21-46-40
Loading genome sequence from /mnt/lab_data2/kcochran/procapnet/genomes/hg38.withrDNA.fasta


Reading FASTA: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:09<00:00,  2.57it/s]
Loading Peaks: 14269it [00:24, 574.88it/s]


== In Extract Peaks ==
Peak filepath: /mnt/lab_data2/kcochran/procapnet/data/procap/processed/union_peaks_fold1_val.bed.gz
Sequence length (with jitter): 2114
Profile length (with jitter): 1000
Max jitter applied: 0
Num. Examples: 14269
Mask loaded? False
CALU3
Timestamp: 2023-06-14_00-43-44
Loading genome sequence from /mnt/lab_data2/kcochran/procapnet/genomes/hg38.withrDNA.fasta


Reading FASTA: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:09<00:00,  2.53it/s]
Loading Peaks: 14269it [00:25, 570.27it/s]


== In Extract Peaks ==
Peak filepath: /mnt/lab_data2/kcochran/procapnet/data/procap/processed/union_peaks_fold1_val.bed.gz
Sequence length (with jitter): 2114
Profile length (with jitter): 1000
Max jitter applied: 0
Num. Examples: 14269
Mask loaded? False
HUVEC
Timestamp: 2023-06-16_21-59-35
Loading genome sequence from /mnt/lab_data2/kcochran/procapnet/genomes/hg38.withrDNA.fasta


Reading FASTA: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:09<00:00,  2.55it/s]
Loading Peaks: 14269it [00:24, 578.37it/s]


== In Extract Peaks ==
Peak filepath: /mnt/lab_data2/kcochran/procapnet/data/procap/processed/union_peaks_fold1_val.bed.gz
Sequence length (with jitter): 2114
Profile length (with jitter): 1000
Max jitter applied: 0
Num. Examples: 14269
Mask loaded? False
MCF10A
Timestamp: 2023-06-15_06-07-40
Loading genome sequence from /mnt/lab_data2/kcochran/procapnet/genomes/hg38.withrDNA.fasta


Reading FASTA: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:09<00:00,  2.56it/s]
Loading Peaks: 14269it [00:24, 573.97it/s]


== In Extract Peaks ==
Peak filepath: /mnt/lab_data2/kcochran/procapnet/data/procap/processed/union_peaks_fold1_val.bed.gz
Sequence length (with jitter): 2114
Profile length (with jitter): 1000
Max jitter applied: 0
Num. Examples: 14269
Mask loaded? False


In [13]:
union_pred_logcounts["K562"].shape

(14269, 1)

In [14]:
dest_path = "/mnt/lab_data2/kcochran/procap_data_for_melody/predicted/"

for cell_type in cell_types:
    pred_lgcts = union_pred_logcounts[cell_type]
    pred_profs = union_pred_profiles[cell_type]
    
    np.save(dest_path + cell_type + "/pred_logcounts.npy", pred_lgcts)
    np.save(dest_path + cell_type + "/pred_profiles.npy", pred_profs)