-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph_training_set.py
76 lines (43 loc) · 1.92 KB
/
graph_training_set.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
import tensorflow as tf
import input_data
from cnn_backwork import build_model
from view_results import proc_all_results
from utils import get_data_label
def get_permuted_train_data(data):
l = [np.random.permutation(get_data_label(data, label)) for label in range(10)]
return l
def prop_forward(data, model_options, use_worst):
params_pf, x_pf, y_pf, layers_pf = build_model(model_options)
softmax_layer_pf = layers_pf[-1]
num_keep = model_options['num_examples']
if not use_worst:
data = [d[:num_keep] for d in data]
new_data = []
with tf.Session():
tf.initialize_all_variables().run()
res = []
for label, data_label in enumerate(data):
probs = softmax_layer_pf.eval(feed_dict={x_pf: data_label})
if use_worst:
order = np.argsort(probs[:, label])[:num_keep]
probs = probs[order]
data_label = data_label[order]
new_data.append(data_label)
res.append(probs)
return res, new_data
def graph_subset_train_data(data, model_options, save=False, use_worst=False):
image_dim_size = model_options['image_dim_size']
data_subset = get_permuted_train_data(data)
probs, data_subset = prop_forward(data_subset, model_options, use_worst)
data_subset = [d.reshape(-1, image_dim_size, image_dim_size) for d in data_subset]
res = list(zip(data_subset, probs))
fname = model_options.get('fp_save', None)
proc_all_results(model_options['num_examples'], save_filename=fname, res=res)
if __name__ == '__main__':
from cnn_backwork import model_options
model_options['num_examples'] = 8
model_options['fp_params'] = 'params/params_norm3.pkl'
model_options['fp_save'] = 'worst_train_data'
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
graph_subset_train_data(mnist, model_options, save=True, use_worst=True)