-
Notifications
You must be signed in to change notification settings - Fork 4
/
general_functions.py
329 lines (269 loc) · 13.5 KB
/
general_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import functools
from torch.utils.data import DataLoader
from torch.utils.data import ConcatDataset
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from dataloader import cityscapes
from core.deeplabv3_plus import DeepLabv3_plus
from core.pspnet import PSPNet
from core.unet import UNet
from core.unet_paper import UNet_paper
from core.unet_pytorch import UNet_torch
from core.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
from core.sync_batchnorm.replicate import patch_replication_callback
def make_data_loader(args, split='train'):
"""
Builds the model based on the provided arguments
Parameters:
args (argparse) -- the network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
gain (float) -- scaling factor for normal, xavier and orthogonal.
"""
if split == 'train':
set = cityscapes.CityScapes(args.mode, 'train', args.blank, args.resize, args.base_size, args.time_dilation, args.reconstruct, args.shuffle)
loader = DataLoader(set, batch_size=args.batch_size, num_workers=8, shuffle=True, pin_memory=True)
elif split == 'val':
set = cityscapes.CityScapes(args.mode, 'val', args.blank, args.resize, args.base_size, args.time_dilation, args.reconstruct, args.shuffle)
loader = DataLoader(set, batch_size=1, num_workers=8, shuffle=False, pin_memory=True)
elif split == 'test':
set = cityscapes.CityScapes(args.mode, 'test', args.blank, args.resize, args.base_size, args.time_dilation, args.reconstruct, args.shuffle)
loader = DataLoader(set, batch_size=1, num_workers=8, shuffle=False, pin_memory=True)
elif split == 'trainval':
train_set = cityscapes.CityScapes(args.mode, 'train', args.blank, args.resize, args.base_size, args.time_dilation, args.reconstruct, args.shuffle)
val_set = cityscapes.CityScapes(args.mode, 'val', args.blank, args.resize, args.base_size, args.time_dilation, args.reconstruct, args.shuffle)
trainval_set = ConcatDataset([train_set, val_set])
loader = DataLoader(trainval_set, batch_size=args.batch_size, num_workers=8, shuffle=True, pin_memory=True)
elif split == 'demoVideo':
set = cityscapes.CityScapes(args.mode, split, args.blank, args.resize, args.base_size, args.time_dilation, args.reconstruct, args.shuffle)
loader = DataLoader(set, batch_size=1, num_workers=8, shuffle=False, pin_memory=True)
return loader
def get_model(args, num_classes=19):
"""
Builds the model based on the provided arguments and returns the initialized model
Parameters:
args (argparse) -- command line arguments
num_classes (int) -- number of possible classes
"""
norm_layer = get_norm_layer(args.norm_layer)
if 'deeplab' in args.model:
model = DeepLabv3_plus(args, num_classes=num_classes, norm_layer=norm_layer)
print("Built ", args.model)
elif 'unet' in args.model:
if args.model == 'unet':
model = UNet(num_classes=num_classes, args=args, norm_layer=norm_layer)
print("Built UNet")
elif args.model == 'unet_paper':
model = UNet_paper(num_classes=num_classes, args=args, norm_layer=norm_layer)
print("Built UNet paper")
elif args.model == 'unet_pytorch':
model = UNet_torch(num_classes=num_classes, args=args)
print("Built UNet pytorch")
model = init_model(model, args.init_type)
elif 'pspnet' in args.model:
model = PSPNet(num_classes=num_classes, args=args)
print("Built PSPNet")
else:
raise NotImplementedError
if args.gpu_ids:
model = torch.nn.DataParallel(model, device_ids=args.gpu_ids)
patch_replication_callback(model)
if args.cuda:
model = model.cuda()
return model
def get_optimizer(model, args):
"""
Builds the optimizer for the model based on the provided arguments and returns the optimizer
Parameters:
model -- the network to be optimized
args -- command line arguments
"""
if args.gpu_ids:
train_params = model.module.get_train_parameters(args.lr)
else:
train_params = model.get_train_parameters(args.lr)
if args.optim == 'sgd':
optimizer = optim.SGD(train_params, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=True)
elif args.optim == 'adam':
optimizer = optim.Adam(train_params, lr=args.lr, weight_decay=args.weight_decay)
elif args.optim == 'amsgrad':
optimizer = optim.Adam(train_params, lr=args.lr, weight_decay=args.weight_decay, amsgrad=True)
return optimizer
def get_norm_layer(norm_type='instance'):
"""Returns a normalization layer
Parameters:
norm_type (str) -- the name of the normalization layer: batch | instance | none
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
"""
if norm_type == 'batch':
norm_layer = nn.BatchNorm2d
elif norm_type == 'syncbn':
norm_layer = SynchronizedBatchNorm2d
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == 'none':
norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def init_model(net, init_type='normal', init_gain=0.02):
"""Initialize the network weights
Parameters:
net (network) -- the network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
gain (float) -- scaling factor for normal, xavier and orthogonal.
Return an initialized network.
"""
net = net.cuda()
init_weights(net, init_type, init_gain=init_gain)
return net
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
nn.init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='leaky_relu')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif hasattr(m, '_all_weights') and (classname.find('LSTM') != -1 or classname.find('GRU') != -1):
for names in m._all_weights:
for name in filter(lambda n: "weight" in n, names):
weight = getattr(m, name)
nn.init.xavier_normal_(weight.data, gain=init_gain)
for name in filter(lambda n: "bias" in n, names):
bias = getattr(m, name)
nn.init.constant_(bias.data, 0.0)
if classname.find('LSTM') != -1:
n = bias.size(0)
start, end = n // 4, n // 2
nn.init.constant_(bias.data[start:end], 1.)
elif classname.find('BatchNorm2d') != -1 or classname.find('SynchronizedBatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
nn.init.normal_(m.weight.data, 1.0, init_gain)
nn.init.constant_(m.bias.data, 0.0)
print('Initialized network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func>
def tensor2submit_image(input_image):
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
image_numpy = image_tensor.cpu().float().numpy() # convert it into a numpy array
image = cityscapes.colorize_mask_submit(image_numpy)
return image
def tensor2im(input_image, imtype=np.uint8, return_tensor=True):
""""Converts a Tensor array into a numpy image array.
Parameters:
input_image (tensor) -- the input image tensor array
imtype (type) -- the desired type of the converted numpy array
"""
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor.cpu().float().numpy() # convert it into a numpy array
if image_numpy.ndim == 2:
image_numpy = cityscapes.visualize(cityscapes.colorize_mask(image_numpy).convert("RGB")).numpy()
elif image_numpy.ndim == 3:
image_numpy = (image_numpy - np.min(image_numpy))/(np.max(image_numpy)-np.min(image_numpy))
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (image_numpy + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return torch.from_numpy(image_numpy.astype(imtype)) if return_tensor else np.transpose(image_numpy, (1,2,0))
def plot_grad_flow(model):
"""
Plots the gradients flowing through different layers in the net during training.
Can be used for checking for possible gradient vanishing / exploding problems.
"""
ave_grads = []
max_grads = []
layers = []
for n, p in model.named_parameters():
if (p.requires_grad) and ("bias" not in n and "norm" not in n):
name = n[:n.find('.network')] + n[n.find('.conv'):n.find('.weight')] if 'network' in n else n[:n.find('.weight')]
layers.append(name)
ave_grads.append(p.grad.abs().mean())
max_grads.append(p.grad.abs().max())
plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.5, lw=1, color="r")
plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.5, lw=1, color="b")
plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
plt.xlim(left=0, right=len(ave_grads))
plt.ylim(bottom=-0.001, top=0.01) # zoom in on the lower gradient regions
plt.xlabel("Layers")
plt.ylabel("average gradient")
plt.title("Gradient flow")
plt.grid(True)
plt.legend([Line2D([0], [0], color="r", alpha=0.5, lw=4),
Line2D([0], [0], color="b", alpha=0.5, lw=4)], ['max-gradient', 'mean-gradient'])
plt.tight_layout()
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def calc_width(net):
net_params = filter(lambda p: p[1].requires_grad, net.named_parameters())
weight_count = 0
for (name, param) in net_params:
weight_count += np.prod(param.size())
return weight_count
def print_training_info(args):
print('Segmentation', args.segmentation)
if args.shuffle:
print('Shuffling', args.shuffle)
if 'unet' in args.model:
print('Ngf', args.ngf)
print('Num downs', args.num_downs)
print('Down type', args.down_type)
if args.remove_skip:
print('Remove skip connections', args.remove_skip)
print('Mode', args.mode)
if 'sequence' in args.mode:
print('Sequence model', args.sequence_model)
print('Number stacked sequence models', args.sequence_stacked_models)
if args.sequence_model == 'lstm':
print('LSTM Bidirectional', args.lstm_bidirectional)
print('LSTM initial state', args.lstm_initial_state)
if 'tcn' in args.sequence_model:
print('TCN num levels', args.num_levels_tcn)
print('TCN kernel level size', args.tcn_kernel_size)
print('Time dilation', args.time_dilation)
print('Optimizer', args.optim)
print('Learning rate', args.lr)
if args.clip > 0:
print('Gradient clip', args.clip)
if args.reconstruct:
print('Reconstruct', args.reconstruct)
print('Reconstruct coeff', args.reconstruct_loss_coeff)
print('Reconstruct loss function', args.reconstruct_loss_type)
print('Reconstruct remove skip connections', args.reconstruct_remove_skip)
print('Resize', args.resize)
print('Blank', args.blank)
print('Batch size', args.batch_size)
print('Norm layer', args.norm_layer)
print('Using cuda', torch.cuda.is_available())
if args.use_class_weights:
print('Using weighted ' + args.loss_type + ' loss')
else:
print('Using ' + args.loss_type + ' loss')
if args.loss_type == 'focal':
print('Gamma', args.gamma)
print('Alpha', args.alpha)
print('Starting Epoch:', args.start_epoch)
print('Total Epoches:', args.epochs)