From e1c7adf5ce8c0fb1b3b3ebd728f9bbf7bc1355f4 Mon Sep 17 00:00:00 2001 From: Maxat Kulmanov Date: Wed, 16 Aug 2017 08:37:55 +0300 Subject: [PATCH] . --- blast.py | 13 +++++---- evaluation.py | 10 +++++-- nn_hierarchical_network.py | 10 ++++--- plots.py | 56 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 75 insertions(+), 14 deletions(-) diff --git a/blast.py b/blast.py index 4d2c3f7..c0b8189 100755 --- a/blast.py +++ b/blast.py @@ -81,15 +81,18 @@ def compute_performance(func): def convert(function): - df = pd.read_pickle(DATA_ROOT + 'swissprot_exp.pkl') - f1 = open(DATA_ROOT + 'swissprot_exp.fa', 'w') + df = pd.read_pickle('data/' + 'sequence_embeddings.pkl') + f1 = open(DATA_ROOT + 'embeddings.fa', 'w') # f2 = open(DATA_ROOT + 'test-missing.fa', 'w') - + seqs = set() for i, row in df.iterrows(): # missing = np.sum(row['embeddings']) == 0 # if not missing: - f1.write('>' + row['proteins'] + '\n') - f1.write(to_fasta(str(row['sequences']))) + seq = row['sequences'] + if seq not in seqs: + seqs.add(seq) + f1.write('>' + row['accessions'] + '\n') + f1.write(to_fasta(str(seq))) #else: # f2.write('>' + row['proteins'] + '\n') # f2.write(to_fasta(str(row['sequences']))) diff --git a/evaluation.py b/evaluation.py index 846e282..00b570c 100755 --- a/evaluation.py +++ b/evaluation.py @@ -42,6 +42,7 @@ def main(function): if target_id not in preds_dict: preds_dict[target_id] = list() preds_dict[target_id].append((it[1], float(it[3]))) + print(len(preds_dict)) target_ids = list() predictions = list() for key, val in preds_dict.iteritems(): @@ -103,10 +104,13 @@ def reshape(values): for g_id in go_set: if g_id not in scores_dict or scores_dict[g_id] < score: scores_dict[g_id] = score - all_preds = set(scores_dict) + all_preds = set(scores_dict) # | all_gos all_preds.discard(GO_ID) for go_id in all_preds: - scores.append(scores_dict[go_id]) + if go_id in scores_dict: + scores.append(scores_dict[go_id]) + else: + scores.append(0) if go_id in all_gos: labels.append(1) else: @@ -161,7 +165,7 @@ def compute_performance(preds, labels, gos): all_gos.discard(GO_ID) for val in preds[i]: go_id, score = val - if score > threshold and go_id in go: + if score > threshold and go_id in all_functions: all_preds |= get_anchestors(go, go_id) all_preds.discard(GO_ID) predictions.append(all_preds) diff --git a/nn_hierarchical_network.py b/nn_hierarchical_network.py index 77d7714..4d59e4d 100644 --- a/nn_hierarchical_network.py +++ b/nn_hierarchical_network.py @@ -118,10 +118,12 @@ def load_data(org=None): df = pd.read_pickle(DATA_ROOT + 'train' + '-' + FUNCTION + '.pkl') n = len(df) index = df.index.values - valid_n = int(n * 0.8) - train_df = df.loc[index[:valid_n]] - valid_df = df.loc[index[valid_n:]] + # valid_n = int(n * 0.8) + # train_df = df.loc[index[:valid_n]] + # valid_df = df.loc[index[valid_n:]] + train_df = df test_df = pd.read_pickle(DATA_ROOT + 'test' + '-' + FUNCTION + '.pkl') + valid_df = test_df # test_df = pd.read_pickle(DATA_ROOT + 'targets.pkl') if org is not None: logging.info('Unfiltered test size: %d' % len(test_df)) @@ -346,7 +348,7 @@ def model(params, batch_size=128, nb_epoch=6, is_train=True): logging.info("Validation data size: %d" % len(val_data[0])) logging.info("Test data size: %d" % len(test_data[0])) - model_path = (DATA_ROOT + 'models/model_' + FUNCTION + '.h5') + model_path = (DATA_ROOT + 'models/model_all_' + FUNCTION + '.h5') # '-' + str(params['embedding_dims']) + # '-' + str(params['nb_filter']) + '.h5') checkpointer = ModelCheckpoint( diff --git a/plots.py b/plots.py index f89a5d4..bac1d7b 100755 --- a/plots.py +++ b/plots.py @@ -6,6 +6,8 @@ import pandas as pd import click as ck from sklearn.metrics import roc_curve, auc +import matplotlib as mpl +mpl.use('Agg') from matplotlib import pyplot as plt import math from utils import ( @@ -14,17 +16,67 @@ MOLECULAR_FUNCTION, CELLULAR_COMPONENT, get_ipro, + EXP_CODES ) -DATA_ROOT = 'data/swissexp/' +DATA_ROOT = 'data/swiss/' @ck.command() def main(): # x, y = get_data('cc.res') # plot(x, y) - ipro_table() + # ipro_table() + plot_sequence_stats() +def read_fasta(filename): + data = list() + c = 0 + with open(filename, 'r') as f: + seq = '' + for line in f: + line = line.strip() + if line.startswith('>'): + if seq != '': + data.append(seq) + line = line[1:].split()[0].split('|') + line = line[1] + '\t' + line[2] + seq = line + '\t' + else: + seq += line + data.append(seq) + return data + +def plot_sequence_stats(): + df = pd.read_pickle('data/swissprot.pkl') + index = list() + for i, row in df.iterrows(): + ok = False + for it in row['annots']: + it = it.split('|') + if it[1] in EXP_CODES: + ok = True + if ok: + index.append(i) + df = df.iloc[index] + print(len(df)) + lens = map(len, df['sequences']) + c = 0 + for i in lens: + if i <= 1002: + c += 1 + print(c) + h = np.histogram(lens, bins=( + 0, 500, 1000, 1500, 2000, 40000)) + plt.bar(range(5), + h[0], width=1, facecolor='green') + titles = ['<=500', '<=1000', '<=1500', '<=2000', '>2000'] + plt.xticks(np.arange(0.5, 5.5, 1), titles) + plt.xlabel('Sequence length') + plt.ylabel('Sequence number') + plt.title(r'Sequence length distribution') + plt.savefig('sequence-dist.eps') + print(np.max(lens)) def table(): bp = get_data('bp.res')