-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The code to train the net #3
Comments
Hi @meule The code to train the network is standard, as in all mxnet applications. You can find a crash course on mxnet, and examples, here. In particular, for training a network you can see this. One important thing is to create a loss that is suitable for the list of predictions. This is something that is working, by feeding a list of predictions from resuneta.nn.loss.loss import *
class CustomLoss(object):
def __init__(self,NClasses=6):
self.tnmt = Tanimoto_wth_dual()
self.skip = NClasses
def loss(self,_prediction,_label):
pred_segm = _prediction[0]
pred_bound = _prediction[1]
pred_dists = _prediction[2]
# HSV colorspace prediction
pred_color = _prediction[3]
# Here I split _label to all different labels, to apply different loss functions on each type of output.
# Note that I pack in the _label, the following information:
# First NClasses are segmentation
# Second set is boundary
# third set is distance transform
# Last three elements are the original image in HSV color space
label_segm = _label[:,:self.skip,:,:] # segmentation
label_bound = _label[:,self.skip:2*self.skip,:,:] # boundaries
label_dists = _label[:,2*self.skip:-3,:,:] # distance transform
# color image -HSV format - need to transform to HSV!!
label_color = _label[:,-3:,:,:] # color in HSV
# Getting all loss functions for each task, all using the SAME Tanimoto with Complement
loss_segm = 1.0 - self.tnmt(pred_segm, label_segm)
loss_bound = 1.0 - self.tnmt(pred_bound, label_bound)
loss_dists = 1.0 - self.tnmt(pred_dists, label_dists)
loss_color = 1.0- tnmt(pred_color,label_color)
# Devide by 4.0 to keep output in range [0,1]
return (loss_segm+loss_bound+loss_dists+loss_color)/4.0 The training routine, once you've initialized your network and trainer is "simple" (it gets much more involved if you want to add checkpointing, monitoring operations and distributed training): # Define network, dataset, data generator
from resuneta.models.resunet_d7_causal_mtskcolor_ddist import *
# modify according to your preferences
Nfilters_init = 32
NClasses = 6
net = ResUNet_d7(Nfilters_init,NClasses)
net.initialize()
net.hybridize()
# trainer/optimizer
# Some optimizer of your choice, recommend Adam.
trainer = gluon.trainer.Trainer(net.collect_params(), 'adam') # add appropriate parameters.
YourLoss = CustomLoss()
# define your custom dataset
dataset = ...#
datagen = gluon.data.DataLoader(dataset,batch_size=YourBatchSize,shuffle=True) # see [here](https://beta.mxnet.io/api/gluon/_autogen/mxnet.gluon.data.DataLoader.html)
for epoch in range(epochs): # Train for as many epochs you want
for img, mask in datagen: # assumes a single gpu
img = img.as_in_context(mx.gpu())
mask = mask.as_in_context(mx.gpu())
with autograd.record():
ListOfPredictions = net(img)
loss = YourLoss(ListOfPredictions,mask)
loss.backward()
trainer.step(SomeBatchSize)
# Add here any kind of monitoring you want For distributed training with mxnet there are a lot of options depending on the cluster you have at your disposal. You can find a starting point here (I highly recommend go with the horovod approach). Hope the above helps. |
Thank you so much for the repo! It's amazing work.
Unfortunately, I can't train the multitasking nets. Could you please share the code to train the net?
Thanks in advance
The text was updated successfully, but these errors were encountered: