In [1]:
import numpy as np
import pandas as pd
import torch
import utils
import preprocess
import data
import os
from tqdm import tqdm

In [5]:
record_pattern = 'sei_seq{}_nip_feature{}'.format(1024, 467)

# prediction
import sei_model as sei

prediction_dir = '../cross/taes'
prediction_list = [ i for i in os.listdir(prediction_dir) if i[-3:] == '.fa' ]
prediction_list = [ i for i in prediction_list if i+'_pm_sei_seq1024_nip_feature467.npy' not in os.listdir(prediction_dir) ]

## Note: This is indepent param
model_dir = '../model/{}.model'.format(record_pattern)

## Restart varibles which will be cleared
batch_size=512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu = torch.device('cpu')

nfeature = 467

In [6]:
prediction_list

['taes_1k_128s_filtered_chr4D2.fa',
 'taes_1k_128s_filtered_chr4D3.fa',
 'taes_1k_128s_filtered_chr5A0.fa',
 'taes_1k_128s_filtered_chr5A1.fa',
 'taes_1k_128s_filtered_chr5A2.fa',
 'taes_1k_128s_filtered_chr5A3.fa',
 'taes_1k_128s_filtered_chr5B0.fa',
 'taes_1k_128s_filtered_chr5B1.fa',
 'taes_1k_128s_filtered_chr5B2.fa',
 'taes_1k_128s_filtered_chr2A1.fa',
 'taes_1k_128s_filtered_chr3A3.fa',
 'taes_1k_128s_filtered_chr4B1.fa',
 'taes_1k_128s_filtered_chr5B3.fa',
 'taes_1k_128s_filtered_chr6D1.fa',
 'taes_1k_128s_filtered_chr5D0.fa',
 'taes_1k_128s_filtered_chr5D1.fa',
 'taes_1k_128s_filtered_chr5D2.fa',
 'taes_1k_128s_filtered_chr5D3.fa',
 'taes_1k_128s_filtered_chr6A0.fa',
 'taes_1k_128s_filtered_chr6A1.fa',
 'taes_1k_128s_filtered_chr6A2.fa',
 'taes_1k_128s_filtered_chr6A3.fa',
 'taes_1k_128s_filtered_chr6B0.fa',
 'taes_1k_128s_filtered_chr6B1.fa',
 'taes_1k_128s_filtered_chr6B2.fa',
 'taes_1k_128s_filtered_chr6B3.fa',
 'taes_1k_128s_filtered_chr6D0.fa',
 'taes_1k_128s_filtered_chr6

In [3]:
import collections

def sc_projection(chromatin_profile_preds, clustervfeat):
    return np.dot(chromatin_profile_preds, clustervfeat.T)

ordered_profiles = np.load('../prediction/{}.cluster_ordered_profiles.npy'.format(record_pattern))


leiden_wgt = np.load('../visualization/{}.leiden.npy'.format(record_pattern))
cluster_def = np.load('../visualization/{}.cluster_def.npy'.format(record_pattern), allow_pickle=True).item()
ordered_category = np.array([ cluster_def['category']['cluster'+str(cluster)] for cluster in sorted(np.unique(leiden_wgt)) if 'cluster'+str(cluster) in cluster_def['category'] ])


def statis_scores(ordered_category, total_pre_scores):
    asc_ind = {}
    for i, rec in enumerate(ordered_category):
        if rec in asc_ind:
            asc_ind[rec].append(i)
        else:
            asc_ind[rec] = [i]
    
    result = []
    max_result = []
    for rec in total_pre_scores:
        tmp = {}
        max_key = ''
        max_value = 0
        for key in asc_ind:
            tmp[key] = np.sum(rec[asc_ind[key]])
            if tmp[key] > max_value:
                max_key = key
                max_value = tmp[key]
        result.append(tmp)
        max_result.append(max_key)
    
    return result, collections.Counter(max_result)

In [None]:
model = torch.load(model_dir)
model.eval()

predict_dict = {}
for prediction_fa in prediction_list:
    total_pre = torch.tensor([[0]*nfeature],dtype=torch.float)
    
    pre_seqs = utils.load_data('/'.join([prediction_dir,prediction_fa]))
    print('Total {} seqs from {}'.format(len(pre_seqs), prediction_fa))
    
    pre_nuc_pre = preprocess.NucPreprocess(pre_seqs)
    pre_X_all = pre_nuc_pre.onehot_for_nuc()
    print('Encoding done for {}'.format(prediction_fa))
    
    pre_dataset = data.NucDataset(x=pre_X_all, y=[0]*len(pre_X_all))
    pre_loader = torch.utils.data.DataLoader(dataset=pre_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    with torch.no_grad():
        print('Start prediction, total: ', len(pre_X_all))
        for i, (inputs, _) in tqdm(enumerate(pre_loader)):
            inputs = inputs.to(device, dtype=torch.float)
            inputs = inputs.permute(0,2,1)
            outputs = model(inputs)
            outputs = outputs.squeeze()
            total_pre = torch.concat((total_pre, outputs.to(cpu, dtype=torch.float)))
    
    total_pre = total_pre[1:].numpy()
    np.save('{}/{}_pm_{}.npy'.format(prediction_dir, prediction_fa, record_pattern),total_pre)
    

    # statis
    total_pre_scores = sc_projection(total_pre, ordered_profiles)
    np.save('{}/{}_cluster_{}.npy'.format(prediction_dir, prediction_fa, record_pattern), total_pre_scores)
    
    result, max_result = statis_scores(ordered_category, total_pre_scores)
    
    predict_dict[prediction_fa] = {'total_pre': total_pre, 'sta_res': result, 'sta_max': max_result}

In [133]:
for i in predict_dict:
    print(i, predict_dict[i]['sta_max'])

random_n100k_w1k_noN.fa Counter({'Heterochromatin': 67451, 'Enhancer': 15361, 'Repressed Polycomb': 7034, 'Transcription': 3409, 'Bivalent TSS': 164})
repeat_n100k_w1k_noN.fa Counter({'Heterochromatin': 73280, 'Enhancer': 15572, 'Repressed Polycomb': 5087, 'Transcription': 1974, 'Bivalent TSS': 90})
seedlings_H3K27me3.fa Counter({'Heterochromatin': 49088, 'Repressed Polycomb': 35523, 'Transcription': 19679, 'Enhancer': 15668, 'Bivalent TSS': 473})
seedlings_H3K36me3.fa Counter({'Enhancer': 29866, 'Transcription': 28678, 'Heterochromatin': 10230, 'Repressed Polycomb': 2221, 'Bivalent TSS': 26})
seedlings_H3K4me3.fa Counter({'Transcription': 21446, 'Repressed Polycomb': 9445, 'Heterochromatin': 6367, 'Enhancer': 3238, 'Bivalent TSS': 5})
seedlings_H3K9ac.fa Counter({'Transcription': 33236, 'Enhancer': 8427, 'Heterochromatin': 7630, 'Repressed Polycomb': 6252, 'Bivalent TSS': 428})


In [130]:
max_term = []
for rec in predict_dict['random_n100k_w1k_noN.fa']['sta_res']:
    max_term.append(sorted(rec.items(), key=lambda x:x[1])[-1][0])

with open('../cross/taes/random_n100k_w1k_noN_predMaxLabel.txt','w') as f:
    for rec in max_term:
        f.write(rec+'\n')