### Import

In [None]:
import os
import sys
import yaml
import numpy as np
from PIL import Image
from tqdm import trange, tqdm
from collections import namedtuple

import torch
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data.dataloader as loader
import torch.nn.functional as F
import matplotlib.pyplot as plt

from train_dataset import DataServoStereo
import train_model as model



### Training on GPU or CPU?

In [None]:
with open("cfg/train_real_images.yaml", "r") as f:
    config = yaml.safe_load(f)

gpu_enabled = config["gpu"]

if gpu_enabled:
    print("Training on GPU")
else:
    print("Training on CPU")

### Load configuration

In [None]:
arg = yaml.load(open("cfg/train_real_images.yaml", 'r'), yaml.Loader)
arg = namedtuple('Arg', arg.keys())(**arg)


### Start GPU

In [None]:
if gpu_enabled == True:
    cudnn.enabled = True
    cudnn.benchmark = True
    cudnn.deterministic = True

In [None]:
torch.cuda.empty_cache() 

In [None]:
torch.manual_seed(0)
np.random.seed(0)

### Initialise neural network

In [None]:
kper = model.KeyPointGaussian(arg.sigma_kp[0], (arg.num_keypoint, *arg.im_size[1]))
if gpu_enabled ==True:
    enc = model.Encoder(arg.num_input, arg.num_keypoint, arg.growth_rate[0], arg.blk_cfg_enc, arg.drop_rate, kper).cuda()
else:   
    enc = model.Encoder(arg.num_input, arg.num_keypoint, arg.growth_rate[0], arg.blk_cfg_enc, arg.drop_rate, kper)


In [None]:
optim = torch.optim.Adam([{'params': enc.parameters(),
                           'weight_decay': arg.wd[0]}],
                         lr=arg.lr, amsgrad=True)

### Function to adjust the learning rate

In [None]:
def adjust_lr(ep, ep_train, bn=True):
    # Check the value of the argument lr_anne and set the learning rate accordingly
    if arg.lr_anne == 'step':
        # Use a step function to adjust the learning rate
        a_lr = 0.4 ** ((ep > (0.3 * ep_train)) +
                       (ep > (0.6 * ep_train)) +
                       (ep > (0.9 * ep_train)))
    elif arg.lr_anne == 'cosine':
        # Use a cosine function to adjust the learning rate
        a_lr = (np.cos(np.pi * ep / ep_train) + 1) / 2
    elif arg.lr_anne == 'repeat':
        # Use a repeated cosine function to adjust the learning rate
        partition = [0, 0.15, 0.30, 0.45, 0.6, 0.8, 1.0]
        par = int(np.digitize(ep * 1. / ep_train, partition))
        T = (partition[par] - partition[par - 1]) * ep_train
        t = ep - partition[par - 1] * ep_train
        a_lr = 0.5 * (1 + np.cos(np.pi * t / T))
        a_lr *= 1 - partition[par - 1]
    else:
        # Use a constant learning rate
        a_lr = 1

    # Set the learning rate for all parameter groups in the optimizer
    for param_group in optim.param_groups:
        param_group['lr'] = max(a_lr, 0.01) * arg.lr

    # If bn is True, adjust the momentum of batch normalization layers
    if bn:
        def fn(m):
            if isinstance(m, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
                # Set the momentum of batch normalization layers to the current learning rate
                m.momentum = min(max(a_lr, 0.01), 0.9)
        enc.apply(fn)

### definition of the training_function 

In [None]:
def train(ep, loader_train):
    
    # iterate over the training data loader
    #for i, (inL0,outS_Tensor,outS) in enumerate(loader_train):
    lossC_epoch = []
    lossI_epoch = []

    for i, (img,plug_mask_tensor,plug_mask) in enumerate(loader_train):

        # enable GPU if enabled in arguments
        if gpu_enabled == True:
            img = img.cuda()
            plug_mask_tensor = plug_mask_tensor.cuda()

        # calculate the iteration count and total iterations for the current epoch
        ith = ep * len(loader_train.dataset) // arg.batch_size + i, \
            arg.ep_train * len(loader_train.dataset) // arg.batch_size
        
        # update learning rate based on the scheduler and current iteration count
        adjust_lr(*ith)

        # update kp sigma
        kper.sigma = min(2.0 * ith[0] / ith[1], 1) * (arg.sigma_kp[1] - arg.sigma_kp[0]) + arg.sigma_kp[0]

        # generate key points for the input image
        keypL0 = enc(img)

        # calculate the concentration loss, which concentrates feature points around 2.5the edges of the object
        # (not on the object itself due to the lack of object detection)
        lossC = None
        if arg.concentrate != 0:
            lossC = []
            for idx_i in range(0, arg.num_keypoint - 1):
                for idx_j in range(idx_i + 1, arg.num_keypoint):
                    distL = torch.norm(torch.cat(
                        ((keypL0[0][:, idx_i] - keypL0[0][:, idx_j]).unsqueeze(1),
                        (keypL0[0][:, idx_i + arg.num_keypoint] - keypL0[0][:, idx_j + arg.num_keypoint]).unsqueeze(1)),
                        dim=1), dim=1)
                    lossC.append(distL.mul(arg.concentrate).exp().mul(keypL0[0][:, idx_i + 2 * arg.num_keypoint] *
                                                                    keypL0[0][:, idx_j + 2 * arg.num_keypoint]).mean())
            lossC = sum(lossC) / len(lossC)
        
        # calculate the inside loss, which forces the key points to be within the object boundaries
        lossI = None
        if arg.inside != 0:
            inoutL = plug_mask_tensor.eq(0).float()
            inoutL = F.interpolate(inoutL.unsqueeze(1), size=keypL0[1].size()[2:], align_corners=False, mode='bilinear')
            lossI = arg.inside * (inoutL.mul(keypL0[1]).mean()) 

        # set the gradients of all optimizer variables to zero
        optim.zero_grad()

        # calculate and backpropagate the total loss
        sum([l for l in [lossC,lossI] if l is not None]).backward()

        # update the optimizer variables
        optim.step()

        # print the loss for the current epoch
        if i == 0:
            if arg.concentrate == 0:
                tqdm.write('ep: {}, loss_I: {:.5f}  '.format(ep,lossI.item()))
                lossC_epoch.append(lossC.item())

            elif arg.inside == 0:
                tqdm.write('ep: {}, loss_C: {:.5f}  '.format(ep,lossC.item()))
                lossI_epoch.append(lossI.item())
                
            else:
                tqdm.write('ep: {}, loss_C loss_I: {:.5f} {:.5f} '.format(ep,lossC.item(), lossI.item()))
                lossC_epoch.append(lossC.item())
                lossI_epoch.append(lossI.item())

    return lossC_epoch, lossI_epoch


### Function to save the model

In [None]:
def save_checkpoint(base_dir):
    state = {'enc_state_dict': enc.state_dict()}
    torch.save(state, os.path.join(base_dir, 'ckpt.pth'))
    print('checkpoint saved.')

### Main-function for the training

In [None]:
def main_train():
    loss = [[],[]]

    if arg.task in ['full']:
        # create directory to save data
        if not os.path.exists(arg.dir_base):
            os.makedirs(arg.dir_base)
        # copy the configuration file to the created directory
        os.system('cp {} {}'.format("cfg/train_real_images.yaml" ,os.path.join(arg.dir_base, 'servo.yaml')))

        # check if grayscale or RGB images are used for training and load the corresponding dataset
        if arg.num_input == 1:
            print("Training with grayscale images")
            ds_train = DataServoStereo(arg,grey=True)
        else:
            print("Training with RGB images")
            ds_train = DataServoStereo(arg,grey=False)

        # set parameters for the data loader
        data_param = {'pin_memory': False, 'shuffle': True, 'batch_size': arg.batch_size, 'drop_last': True,
                      'num_workers': 8, 'worker_init_fn': lambda _: np.random.seed(ord(os.urandom(1)))}
        
        # create data loader for training dataset
        loader_train = loader.DataLoader(ds_train, **data_param)

        # set the encoder model to training mode
        enc.train()
        print('training...')
        # train for each epoch
        for ep in trange(arg.ep_train):
            lossC_epoch, lossI_epoch = train(ep, loader_train)
        
            loss[0].extend(lossC_epoch)
            loss[1].extend(lossI_epoch)

        # save the trained model checkpoint
        save_checkpoint(arg.dir_base)

        return loss


In [None]:
# Execute 
loss = main_train()

In [None]:
plt.figure()
plt.plot(loss[0], label='Concentration Loss')
plt.plot(loss[1], label='Inside Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
