-
Notifications
You must be signed in to change notification settings - Fork 66
/
train_densenet_albumentations.py
378 lines (286 loc) · 12.2 KB
/
train_densenet_albumentations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:light
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.4'
# jupytext_version: 1.1.7
# kernelspec:
# display_name: Python 3
# language: python
# name: python3
# ---
# +
#v3.classification
#17/8/2019
#modified augmentation approach to use albumentations:
#https://github.com/albu/albumentations
#https://albumentations.readthedocs.io/
# + {}
dataname="lymphoma"
gpuid=0
# --- densenet params
#these parameters get fed directly into the densenet class, and more description of them can be discovered there
num_classes=3 #number of classes in the data mask that we'll aim to predict
in_channels= 3 #input channel of the data, RGB = 3
growth_rate=32
block_config=(2, 2, 2, 2)
num_init_features=64
bn_size=4
drop_rate=0
# --- training params
batch_size=128
patch_size=224 #currently, this needs to be 224 due to densenet architecture
num_epochs = 100
phases = ["train","val"] #how many phases did we create databases for?
validation_phases= ["val"] #when should we do valiation? note that validation is *very* time consuming, so as opposed to doing for both training and validation, we do it only for vlaidation at the end of the epoch
#additionally, using simply [], will skip validation entirely, drastically speeding things up
# + {}
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.models import DenseNet
from albumentations import *
from albumentations.pytorch import ToTensor
import PIL
import matplotlib.pyplot as plt
import cv2
import numpy as np
import sys, glob
from tensorboardX import SummaryWriter
import time
import math
import tables
import random
from sklearn.metrics import confusion_matrix
# -
#helper function for pretty printing of current time and remaining time
def asMinutes(s):
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
def timeSince(since, percent):
now = time.time()
s = now - since
es = s / (percent+.00001)
rs = es - s
return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
#specify if we should use a GPU (cuda) or only the CPU
print(torch.cuda.get_device_properties(gpuid))
torch.cuda.set_device(gpuid)
device = torch.device(f'cuda:{gpuid}' if torch.cuda.is_available() else 'cpu')
# +
#build the model according to the paramters specified above and copy it to the GPU. finally print out the number of trainable parameters
model = DenseNet(growth_rate=growth_rate, block_config=block_config,
num_init_features=num_init_features,
bn_size=bn_size,
drop_rate=drop_rate,
num_classes=num_classes).to(device)
#model = DenseNet(growth_rate=32, block_config=(6, 12, 24, 16), #these represent the default parameters
# num_init_features=64, bn_size=4, drop_rate=0, num_classes=3)
print(f"total params: \t{sum([np.prod(p.size()) for p in model.parameters()])}")
# -
#this defines our dataset class which will be used by the dataloader
class Dataset(object):
def __init__(self, fname ,img_transform=None):
#nothing special here, just internalizing the constructor parameters
self.fname=fname
self.img_transform=img_transform
with tables.open_file(self.fname,'r') as db:
self.classsizes=db.root.classsizes[:]
self.nitems=db.root.imgs.shape[0]
self.imgs = None
self.labels = None
def __getitem__(self, index):
#opening should be done in __init__ but seems to be
#an issue with multithreading so doing here. need to do it everytime, otherwise hdf5 crashes
with tables.open_file(self.fname,'r') as db:
self.imgs=db.root.imgs
self.labels=db.root.labels
#get the requested image and mask from the pytable
img = self.imgs[index,:,:,:]
label = self.labels[index]
img_new = img
if self.img_transform:
img_new = self.img_transform(image=img)['image']
return img_new, label, img
def __len__(self):
return self.nitems
# +
#https://github.com/albu/albumentations/blob/master/notebooks/migrating_from_torchvision_to_albumentations.ipynb
img_transform = Compose([
VerticalFlip(p=.5),
HorizontalFlip(p=.5),
HueSaturationValue(hue_shift_limit=(-25,0),sat_shift_limit=0,val_shift_limit=0,p=1),
Rotate(p=1, border_mode=cv2.BORDER_CONSTANT,value=0),
#ElasticTransform(always_apply=True, approximate=True, alpha=150, sigma=8,alpha_affine=50),
RandomSizedCrop((patch_size,patch_size), patch_size,patch_size),
ToTensor()
])
dataset={}
dataLoader={}
for phase in phases: #now for each of the phases, we're creating the dataloader
#interestingly, given the batch size, i've not seen any improvements from using a num_workers>0
dataset[phase]=Dataset(f"./{dataname}_{phase}.pytable", img_transform=img_transform)
dataLoader[phase]=DataLoader(dataset[phase], batch_size=batch_size,
shuffle=True, num_workers=8,pin_memory=True)
print(f"{phase} dataset size:\t{len(dataset[phase])}")
# +
#visualize a single example to verify that it is correct
(img, label, img_old)=dataset["train"][7]
fig, ax = plt.subplots(1,2, figsize=(10,4)) # 1 row, 2 columns
#build output showing patch after augmentation and original patch
ax[0].imshow(np.moveaxis(img.numpy(),0,-1))
ax[1].imshow(img_old)
print(label)
# -
optim = torch.optim.Adam(model.parameters()) #adam is going to be the most robust, though perhaps not the best performing, typically a good place to start
# optim = torch.optim.SGD(model.parameters(),
# lr=.1,
# momentum=0.9,
# weight_decay=0.0005)
# +
#we have the ability to weight individual classes, in this case we'll do so based on their presense in the trainingset
#to avoid biasing any particular class
nclasses = dataset["train"].classsizes.shape[0]
class_weight=dataset["train"].classsizes
class_weight = torch.from_numpy(1-class_weight/class_weight.sum()).type('torch.FloatTensor').to(device)
print(class_weight) #show final used weights, make sure that they're reasonable before continouing
criterion = nn.CrossEntropyLoss(weight = class_weight)
# +
#def trainnetwork():
writer=SummaryWriter() #open the tensorboard visualiser
best_loss_on_test = np.Infinity
start_time = time.time()
for epoch in range(num_epochs):
#zero out epoch based performance variables
all_acc = {key: 0 for key in phases}
all_loss = {key: torch.zeros(0).to(device) for key in phases} #keep this on GPU for greatly improved performance
cmatrix = {key: np.zeros((num_classes,num_classes)) for key in phases}
for phase in phases: #iterate through both training and validation states
if phase == 'train':
model.train() # Set model to training mode
else: #when in eval mode, we don't want parameters to be updated
model.eval() # Set model to evaluate mode
for ii , (X, label, img_orig) in enumerate(dataLoader[phase]): #for each of the batches
X = X.to(device) # [Nbatch, 3, H, W]
label = label.type('torch.LongTensor').to(device) # [Nbatch, 1] with class indices (0, 1, 2,...num_classes)
with torch.set_grad_enabled(phase == 'train'): #dynamically set gradient computation, in case of validation, this isn't needed
#disabling is good practice and improves inference time
prediction = model(X) # [N, Nclass]
loss = criterion(prediction, label)
if phase=="train": #in case we're in train mode, need to do back propogation
optim.zero_grad()
loss.backward()
optim.step()
train_loss = loss
all_loss[phase]=torch.cat((all_loss[phase],loss.detach().view(1,-1)))
if phase in validation_phases: #if this phase is part of validation, compute confusion matrix
p=prediction.detach().cpu().numpy()
cpredflat=np.argmax(p,axis=1).flatten()
yflat=label.cpu().numpy().flatten()
cmatrix[phase]=cmatrix[phase]+confusion_matrix(yflat,cpredflat, labels=range(nclasses))
all_acc[phase]=(cmatrix[phase]/cmatrix[phase].sum()).trace()
all_loss[phase] = all_loss[phase].cpu().numpy().mean()
#save metrics to tensorboard
writer.add_scalar(f'{phase}/loss', all_loss[phase], epoch)
if phase in validation_phases:
writer.add_scalar(f'{phase}/acc', all_acc[phase], epoch)
for r in range(nclasses):
for c in range(nclasses): #essentially write out confusion matrix
writer.add_scalar(f'{phase}/{r}{c}', cmatrix[phase][r][c],epoch)
print('%s ([%d/%d] %d%%), train loss: %.4f test loss: %.4f' % (timeSince(start_time, (epoch+1) / num_epochs),
epoch+1, num_epochs ,(epoch+1) / num_epochs * 100, all_loss["train"], all_loss["val"]),end="")
#if current loss is the best we've seen, save model state with all variables
#necessary for recreation
if all_loss["val"] < best_loss_on_test:
best_loss_on_test = all_loss["val"]
print(" **")
state = {'epoch': epoch + 1,
'model_dict': model.state_dict(),
'optim_dict': optim.state_dict(),
'best_loss_on_test': all_loss,
'in_channels': in_channels,
'growth_rate':growth_rate,
'block_config':block_config,
'num_init_features':num_init_features,
'bn_size':bn_size,
'drop_rate':drop_rate,
'num_classes':num_classes}
torch.save(state, f"{dataname}_densenet_best_model.pth")
else:
print("")
# +
# #%load_ext line_profiler
# #%lprun -f trainnetwork trainnetwork()
# +
#At this stage, training is done...below are snippets to help with other tasks: output generation + visualization
# -
#----- generate output
#load best model
checkpoint = torch.load(f"{dataname}_densenet_best_model.pth")
model.load_state_dict(checkpoint["model_dict"])
#grab a single image from validation set
(img, label, img_old)=dataset["val"][2]
#generate its output
# #%%timeit
output=model(img[None,::].to(device))
output=output.detach().squeeze().cpu().numpy()
output.shape
print(output)
print(f"True class:{label}")
print(f"Predicted class:{np.argmax(output)}")
# +
#look at input
fig, ax = plt.subplots(1,2, figsize=(10,4)) # 1 row, 2 columns
ax[0].imshow(np.moveaxis(img.numpy(),0,-1))
ax[1].imshow(img_old)
# +
#------- visualize kernels and activations
# -
#helper function for visualization
def plot_kernels(tensor, num_cols=8 ,cmap="gray"):
if not len(tensor.shape)==4:
raise Exception("assumes a 4D tensor")
# if not tensor.shape[1]==3:
# raise Exception("last dim needs to be 3 to plot")
num_kernels = tensor.shape[0] * tensor.shape[1]
num_rows = 1+ num_kernels // num_cols
fig = plt.figure(figsize=(num_cols,num_rows))
i=0
t=tensor.data.numpy()
for t1 in t:
for t2 in t1:
i+=1
ax1 = fig.add_subplot(num_rows,num_cols,i)
ax1.imshow(t2 , cmap=cmap)
ax1.axis('off')
ax1.set_xticklabels([])
ax1.set_yticklabels([])
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()
class LayerActivations():
features=None
def __init__(self,layer):
self.hook = layer.register_forward_hook(self.hook_fn)
def hook_fn(self,module,input,output):
self.features = output.cpu()
def remove(self):
self.hook.remove()
# +
# --- visualize kernels
# -
w=model.features.denseblock2.denselayer1.conv2
plot_kernels(w.weight.detach().cpu()[0:5,0:5,:,:],5)
# +
# ---- visualize activiations
# -
dr=LayerActivations(model.features.denseblock2.denselayer1.conv2)
(img, label, img_old)=dataset["val"][7]
plt.imshow(np.moveaxis(img.numpy(),0,-1))
output=model(img[None,::].to(device))
plot_kernels(dr.features,8,cmap="rainbow")
# # ---- Improvements:
# 1 replace Adam with SGD with appropriate learning rate reduction