### Import

In [1]:
import os
import sys
import yaml
import numpy as np
from PIL import Image
from tqdm import trange, tqdm
from collections import namedtuple

import torch
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data.dataloader as loader
import torch.nn.functional as F

from train_dataset import DataServoStereo
import train_model as model



### Training on GPU or CPU?

In [2]:
with open("cfg/train_real_images.yaml", "r") as f:
    config = yaml.safe_load(f)

gpu_enabled = config["gpu"]

if gpu_enabled:
    print("Training on GPU")
else:
    print("Training on CPU")

Training on CPU


### Load configuration

In [3]:
arg = yaml.load(open("cfg/train_real_images.yaml", 'r'), yaml.Loader)
arg = namedtuple('Arg', arg.keys())(**arg)


### Start GPU

In [4]:
if gpu_enabled == True:
    cudnn.enabled = True
    cudnn.benchmark = True
    cudnn.deterministic = True

In [5]:
torch.manual_seed(0)
np.random.seed(0)

### Initialise neural network

In [6]:
kper = model.KeyPointGaussian(arg.sigma_kp[0], (arg.num_keypoint, *arg.im_size[1]))
if gpu_enabled ==True:
    enc = model.Encoder(arg.num_input, arg.num_keypoint, arg.growth_rate[0], arg.blk_cfg_enc, arg.drop_rate, kper).cuda()
else:   
    enc = model.Encoder(arg.num_input, arg.num_keypoint, arg.growth_rate[0], arg.blk_cfg_enc, arg.drop_rate, kper)


In [7]:
optim = torch.optim.Adam([{'params': enc.parameters(),
                           'weight_decay': arg.wd[0]}],
                         lr=arg.lr, amsgrad=True)

### Function to adjust the learning rate

In [8]:
def adjust_lr(ep, ep_train, bn=True):
    # Check the value of the argument lr_anne and set the learning rate accordingly
    if arg.lr_anne == 'step':
        # Use a step function to adjust the learning rate
        a_lr = 0.4 ** ((ep > (0.3 * ep_train)) +
                       (ep > (0.6 * ep_train)) +
                       (ep > (0.9 * ep_train)))
    elif arg.lr_anne == 'cosine':
        # Use a cosine function to adjust the learning rate
        a_lr = (np.cos(np.pi * ep / ep_train) + 1) / 2
    elif arg.lr_anne == 'repeat':
        # Use a repeated cosine function to adjust the learning rate
        partition = [0, 0.15, 0.30, 0.45, 0.6, 0.8, 1.0]
        par = int(np.digitize(ep * 1. / ep_train, partition))
        T = (partition[par] - partition[par - 1]) * ep_train
        t = ep - partition[par - 1] * ep_train
        a_lr = 0.5 * (1 + np.cos(np.pi * t / T))
        a_lr *= 1 - partition[par - 1]
    else:
        # Use a constant learning rate
        a_lr = 1

    # Set the learning rate for all parameter groups in the optimizer
    for param_group in optim.param_groups:
        param_group['lr'] = max(a_lr, 0.01) * arg.lr

    # If bn is True, adjust the momentum of batch normalization layers
    if bn:
        def fn(m):
            if isinstance(m, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
                # Set the momentum of batch normalization layers to the current learning rate
                m.momentum = min(max(a_lr, 0.01), 0.9)
        enc.apply(fn)

### definition of the training_function 

In [9]:
def train(ep, loader_train):
    
    # iterate over the training data loader
    #for i, (inL0,outS_Tensor,outS) in enumerate(loader_train):
    for i, (img,plug_mask_tensor,plug_mask) in enumerate(loader_train):

        # enable GPU if enabled in arguments
        if gpu_enabled == True:
            img = img.cuda()
            plug_mask_tensor = plug_mask_tensor.cuda()

        # calculate the iteration count and total iterations for the current epoch
        ith = ep * len(loader_train.dataset) // arg.batch_size + i, \
            arg.ep_train * len(loader_train.dataset) // arg.batch_size
        
        # update learning rate based on the scheduler and current iteration count
        adjust_lr(*ith)

        # update kp sigma
        kper.sigma = min(2.0 * ith[0] / ith[1], 1) * (arg.sigma_kp[1] - arg.sigma_kp[0]) + arg.sigma_kp[0]

        # generate key points for the input image
        keypL0 = enc(img)

        # calculate the concentration loss, which concentrates feature points around the edges of the object
        # (not on the object itself due to the lack of object detection)
        lossC = None
        if arg.concentrate != 0:
            lossC = []
            for idx_i in range(0, arg.num_keypoint - 1):
                for idx_j in range(idx_i + 1, arg.num_keypoint):
                    distL = torch.norm(torch.cat(
                        ((keypL0[0][:, idx_i] - keypL0[0][:, idx_j]).unsqueeze(1),
                        (keypL0[0][:, idx_i + arg.num_keypoint] - keypL0[0][:, idx_j + arg.num_keypoint]).unsqueeze(1)),
                        dim=1), dim=1)
                    lossC.append(distL.mul(arg.concentrate).exp().mul(keypL0[0][:, idx_i + 2 * arg.num_keypoint] *
                                                                    keypL0[0][:, idx_j + 2 * arg.num_keypoint]).mean())
            lossC = sum(lossC) / len(lossC)
        
        # calculate the inside loss, which forces the key points to be within the object boundaries
        lossI = None
        if arg.inside != 0:
            inoutL = plug_mask_tensor.eq(0).float()
            inoutL = F.interpolate(inoutL.unsqueeze(1), size=keypL0[1].size()[2:], align_corners=False, mode='bilinear')
            lossI = arg.inside * (inoutL.mul(keypL0[1]).mean()) 

        # set the gradients of all optimizer variables to zero
        optim.zero_grad()

        # calculate and backpropagate the total loss
        sum([l for l in [lossC,lossI] if l is not None]).backward()

        # update the optimizer variables
        optim.step()

        # print the loss for the current epoch
        if i == 0:
            if arg.concentrate == 0:
                tqdm.write('ep: {}, loss_I: {:.5f}  '.format(ep,lossI.item()))
            elif arg.inside == 0:
                tqdm.write('ep: {}, loss_C: {:.5f}  '.format(ep,lossC.item()))
            else:
                tqdm.write('ep: {}, loss_C loss_I: {:.5f} {:.5f} '.format(ep,lossC.item(), lossI.item()))


### Function to save the model

In [10]:
def save_checkpoint(base_dir):
    state = {'enc_state_dict': enc.state_dict()}
    torch.save(state, os.path.join(base_dir, 'ckpt.pth'))
    print('checkpoint saved.')

### Main-function for the training

In [11]:
def main_train():
    if arg.task in ['full']:
        # create directory to save data
        if not os.path.exists(arg.dir_base):
            os.makedirs(arg.dir_base)
        # copy the configuration file to the created directory
        os.system('cp {} {}'.format("cfg/train_real_images.yaml" ,os.path.join(arg.dir_base, 'servo.yaml')))

        # check if grayscale or RGB images are used for training and load the corresponding dataset
        if arg.num_input == 1:
            print("Training with grayscale images")
            ds_train = DataServoStereo(arg,grey=True)
        else:
            print("Training with RGB images")
            ds_train = DataServoStereo(arg,grey=False)

        # set parameters for the data loader
        data_param = {'pin_memory': False, 'shuffle': True, 'batch_size': arg.batch_size, 'drop_last': True,
                      'num_workers': 8, 'worker_init_fn': lambda _: np.random.seed(ord(os.urandom(1)))}
        
        # create data loader for training dataset
        loader_train = loader.DataLoader(ds_train, **data_param)

        # set the encoder model to training mode
        enc.train()
        print('training...')
        # train for each epoch
        for ep in trange(arg.ep_train):
            train(ep, loader_train)

        # save the trained model checkpoint
        save_checkpoint(arg.dir_base)


In [12]:
# Execute 
main_train()

Training with grayscale images
160 training data loaded
training...


  0%|          | 0/24 [00:08<?, ?it/s]

ep: 0, loss_C loss_I: 0.35768 0.00359 


  4%|▍         | 1/24 [09:26<3:33:18, 556.44s/it]

ep: 1, loss_C loss_I: 0.00443 0.00298 


  8%|▊         | 2/24 [18:16<3:18:36, 541.64s/it]

ep: 2, loss_C loss_I: 0.00101 0.00266 


 12%|█▎        | 3/24 [27:00<3:06:58, 534.22s/it]

ep: 3, loss_C loss_I: 0.00039 0.00241 


 17%|█▋        | 4/24 [35:55<2:58:08, 534.42s/it]

ep: 4, loss_C loss_I: 0.00086 0.00220 


 21%|██        | 5/24 [45:10<2:51:21, 541.16s/it]

ep: 5, loss_C loss_I: 0.00042 0.00195 


 25%|██▌       | 6/24 [54:02<2:41:30, 538.39s/it]

ep: 6, loss_C loss_I: 0.00011 0.00184 


 29%|██▉       | 7/24 [1:02:45<2:31:07, 533.41s/it]

ep: 7, loss_C loss_I: 0.00019 0.00166 


 33%|███▎      | 8/24 [1:11:19<2:20:35, 527.21s/it]

ep: 8, loss_C loss_I: 0.00005 0.00160 


 38%|███▊      | 9/24 [1:19:51<2:10:40, 522.69s/it]

ep: 9, loss_C loss_I: 0.00004 0.00149 


 42%|████▏     | 10/24 [1:28:27<2:01:24, 520.29s/it]

ep: 10, loss_C loss_I: 0.00010 0.00136 


 46%|████▌     | 11/24 [1:36:59<1:52:17, 518.29s/it]

ep: 11, loss_C loss_I: 0.00009 0.00134 


 50%|█████     | 12/24 [1:45:33<1:43:23, 516.95s/it]

ep: 12, loss_C loss_I: 0.00019 0.00126 


 54%|█████▍    | 13/24 [1:54:07<1:34:34, 515.85s/it]

ep: 13, loss_C loss_I: 0.00024 0.00128 


 58%|█████▊    | 14/24 [2:02:41<1:25:52, 515.23s/it]

ep: 14, loss_C loss_I: 0.00010 0.00126 


 62%|██████▎   | 15/24 [2:11:20<1:17:25, 516.14s/it]

ep: 15, loss_C loss_I: 0.00048 0.00126 


 67%|██████▋   | 16/24 [2:19:54<1:08:46, 515.81s/it]

ep: 16, loss_C loss_I: 0.00006 0.00126 


 71%|███████   | 17/24 [2:28:35<1:00:20, 517.24s/it]

ep: 17, loss_C loss_I: 0.00005 0.00126 


 75%|███████▌  | 18/24 [2:37:08<51:36, 516.02s/it]  

ep: 18, loss_C loss_I: 0.00005 0.00126 


 79%|███████▉  | 19/24 [2:45:52<43:12, 518.48s/it]

ep: 19, loss_C loss_I: 0.00008 0.00125 


 83%|████████▎ | 20/24 [2:54:24<34:26, 516.59s/it]

ep: 20, loss_C loss_I: 0.00006 0.00125 


 88%|████████▊ | 21/24 [3:02:56<25:45, 515.24s/it]

ep: 21, loss_C loss_I: 0.00004 0.00125 


 92%|█████████▏| 22/24 [3:11:29<17:08, 514.25s/it]

ep: 22, loss_C loss_I: 0.00003 0.00124 


 96%|█████████▌| 23/24 [3:20:02<08:33, 513.97s/it]

ep: 23, loss_C loss_I: 0.00002 0.00124 


100%|██████████| 24/24 [3:28:26<00:00, 521.10s/it]

checkpoint saved.



