Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The code to train the net #3

Closed
meule opened this issue Mar 29, 2020 · 1 comment
Closed

The code to train the net #3

meule opened this issue Mar 29, 2020 · 1 comment

Comments

@meule
Copy link

meule commented Mar 29, 2020

Thank you so much for the repo! It's amazing work.
Unfortunately, I can't train the multitasking nets. Could you please share the code to train the net?

Thanks in advance

@feevos
Copy link
Owner

feevos commented Mar 30, 2020

Hi @meule

The code to train the network is standard, as in all mxnet applications. You can find a crash course on mxnet, and examples, here. In particular, for training a network you can see this.

One important thing is to create a loss that is suitable for the list of predictions.

This is something that is working, by feeding a list of predictions _prediction (in the same context) and an nd.array of labels _label. Note that in order for this loss to work you need to feed it appropriate input _label and that depends on how you define your dataset - which is a custom operation. Some people may decide to feed as input multiple arguments. I just stack on the same _label all requested outputs.

from resuneta.nn.loss.loss import * 

class CustomLoss(object):
    def __init__(self,NClasses=6):
        self.tnmt = Tanimoto_wth_dual()
        self.skip = NClasses 


    def loss(self,_prediction,_label):
        pred_segm  = _prediction[0]
        pred_bound = _prediction[1]
        pred_dists = _prediction[2]

        # HSV colorspace prediction
        pred_color = _prediction[3]

        # Here I split _label to all different labels, to apply different loss functions on each type of output. 
        # Note that I pack in the _label, the following information:
        # First NClasses are segmentation
        # Second set is boundary
        # third set is distance transform
        # Last three elements are the original image in HSV color space 
        label_segm  = _label[:,:self.skip,:,:] # segmentation 
        label_bound = _label[:,self.skip:2*self.skip,:,:] # boundaries 
        label_dists = _label[:,2*self.skip:-3,:,:]  # distance transform 

        # color image -HSV format - need to transform to HSV!!
        label_color = _label[:,-3:,:,:] # color in HSV 

         # Getting all loss functions for each task, all using the SAME Tanimoto with Complement 
        loss_segm = 1.0 - self.tnmt(pred_segm,   label_segm)
        loss_bound = 1.0 - self.tnmt(pred_bound, label_bound)
        loss_dists = 1.0 - self.tnmt(pred_dists, label_dists)

        loss_color = 1.0- tnmt(pred_color,label_color)

        # Devide by 4.0 to keep output in range [0,1] 
        return (loss_segm+loss_bound+loss_dists+loss_color)/4.0 

The training routine, once you've initialized your network and trainer is "simple" (it gets much more involved if you want to add checkpointing, monitoring operations and distributed training):

# Define network, dataset, data generator
from resuneta.models.resunet_d7_causal_mtskcolor_ddist import *
# modify according to your preferences
Nfilters_init = 32
NClasses = 6
net = ResUNet_d7(Nfilters_init,NClasses)
net.initialize()
net.hybridize()

# trainer/optimizer
 # Some optimizer of your choice, recommend Adam. 
trainer = gluon.trainer.Trainer(net.collect_params(), 'adam') # add appropriate parameters. 

YourLoss = CustomLoss()

# define your custom dataset
dataset = ...# 
datagen = gluon.data.DataLoader(dataset,batch_size=YourBatchSize,shuffle=True) # see [here](https://beta.mxnet.io/api/gluon/_autogen/mxnet.gluon.data.DataLoader.html)

for epoch in range(epochs): # Train for as many epochs you want 
    for img, mask in datagen: # assumes a single gpu 
        img  = img.as_in_context(mx.gpu())
        mask = mask.as_in_context(mx.gpu())
        with autograd.record():
            ListOfPredictions = net(img)
            loss = YourLoss(ListOfPredictions,mask)
        loss.backward()
        trainer.step(SomeBatchSize)
        # Add here any kind of monitoring you want

For distributed training with mxnet there are a lot of options depending on the cluster you have at your disposal. You can find a starting point here (I highly recommend go with the horovod approach).

Hope the above helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants