/
ufig_launcher.py
129 lines (98 loc) · 4.34 KB
/
ufig_launcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# tf_unet is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# tf_unet is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with tf_unet. If not, see <http://www.gnu.org/licenses/>.
'''
Created on Jul 28, 2016
author: jakeret
Trains a tf_unet network to segment stars and galaxies in a wide field image.
Requires data from a UFIG simulation.
'''
from __future__ import print_function, division, absolute_import, unicode_literals
import click
import numpy as np
from scipy.ndimage import gaussian_filter
import h5py
from tf_unet import unet
from tf_unet import util
from tf_unet.image_util import BaseDataProvider
@click.command()
@click.option('--data_root', default="./ufig_images/1.h5")
@click.option('--output_path', default="./unet_trained_ufig")
@click.option('--training_iters', default=20)
@click.option('--epochs', default=10)
@click.option('--restore', default=False)
@click.option('--layers', default=3)
@click.option('--features_root', default=16)
def launch(data_root, output_path, training_iters, epochs, restore, layers, features_root):
data_provider = DataProvider(572, data_root)
data, label = data_provider(1)
weights = None#(1/3) / (label.sum(axis=2).sum(axis=1).sum(axis=0) / data.size)
net = unet.Unet(channels=data_provider.channels,
n_class=data_provider.n_class,
layers=layers,
features_root=features_root,
cost_kwargs=dict(regularizer=0.001,
class_weights=weights),
)
path = output_path if restore else util.create_training_path(output_path)
trainer = unet.Trainer(net, optimizer="adam", opt_kwargs=dict(beta1=0.91))
path = trainer.train(data_provider, path,
training_iters=training_iters,
epochs=epochs,
dropout=0.5,
display_step=2,
restore=restore)
prediction = net.predict(path, data)
print("Testing error rate: {:.2f}%".format(unet.error_rate(prediction, util.crop_to_shape(label, prediction.shape))))
class DataProvider(BaseDataProvider):
"""
Extends the BaseDataProvider to randomly select the next
chunk of the image and randomly applies transformations to the data
"""
channels = 1
n_class = 3
def __init__(self, nx, path, a_min=0, a_max=20, sigma=1):
super(DataProvider, self).__init__(a_min, a_max)
self.nx = nx
self.path = path
self.sigma = sigma
self._load_data()
def _load_data(self):
with h5py.File(self.path, "r") as fp:
self.image = gaussian_filter(fp["image"].value, self.sigma)
self.gal_map = fp["segmaps/galaxy"].value
self.star_map = fp["segmaps/star"].value
def _transpose_3d(self, a):
return np.stack([a[..., i].T for i in range(a.shape[2])], axis=2)
def _post_process(self, data, labels):
op = np.random.randint(0, 4)
if op == 0:
if np.random.randint(0, 2) == 0:
data, labels = self._transpose_3d(data[:,:,np.newaxis]), self._transpose_3d(labels)
else:
data, labels = np.rot90(data, op), np.rot90(labels, op)
return data, labels
def _next_data(self):
ix = np.random.randint(0, self.image.shape[0] - self.nx)
iy = np.random.randint(0, self.image.shape[1] - self.nx)
slx = slice(ix, ix+self.nx)
sly = slice(iy, iy+self.nx)
data = self.image[slx, sly]
gal_seg = self.gal_map[slx, sly]
star_seg = self.star_map[slx, sly]
labels = np.zeros((self.nx, self.nx, self.n_class), dtype=np.float32)
labels[..., 1] = np.clip(gal_seg, 0, 1)
labels[..., 2] = np.clip(star_seg, 0, 1)
labels[..., 0] = (1+np.clip(labels[...,1] + labels[...,2], 0, 1))%2
return data, labels
if __name__ == '__main__':
launch()