# Loss

> Loss functions

In [1]:
#| default_exp loss

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export

import torch
import torch.nn as nn
from Noise2Model.utils import StandardNormal

In [4]:

from torch import randn as torch_randn
from fastai.vision.all import test_eq


In [5]:
#| export

loss_class_dict = {}

def regist_loss(loss):
    loss_name = loss.__name__.lower()
    assert not loss_name in loss_class_dict, 'there is already registered dataset: %s in trainer_dict.' % loss_name
    loss_class_dict[loss_name] = loss

    return loss

def get_loss_class(loss_name:str):
    loss_name = loss_name.lower()
    return loss_class_dict[loss_name]




## Base functions

In [6]:
#| export

class Loss(nn.Module):
    def __init__(self, loss_string, tmp_info=[]):
        super().__init__()
        loss_string     = loss_string.replace(' ', '').lower()

        # parse loss string
        self.loss_list = []
        for single_loss in loss_string.split('+'):
            weight, name = single_loss.split('*')
            ratio = True if 'r' in weight else False
            weight = float(weight.replace('r', ''))

            if name in loss_class_dict:
                self.loss_list.append({ 'name': name,
                                        'weight': float(weight),
                                        'func': loss_class_dict[name](),
                                        'ratio': ratio})
            else:
                raise RuntimeError('undefined loss term: {}'.format(name))
            
        # parse temporal information string
        self.tmp_info_list = []
        for name in tmp_info:
            lname=name.lower()
            if lname in loss_class_dict:
                self.tmp_info_list.append({ 'name': lname,
                                            'func': loss_class_dict[name]()})
            else:
                raise RuntimeError('undefined loss term: {}'.format(lname))


    def forward(self, input_data, model_output, data, module, loss_name=None, change_name=None, ratio=1.0):
        '''
        forward all loss and return as dict format.
        Args
            input_data   : input of the network (also in the data)
            model_output : output of the network
            data         : entire batch of data
            module       : dictionary of modules (for another network forward)
            loss_name    : (optional) choose specific loss with name
            change_name  : (optional) replace name of chosen loss
            ratio        : (optional) percentage of learning procedure for increase weight during training
        Return
            losses       : dictionary of loss
        '''
        loss_arg = (input_data, model_output, data, module)

        # calculate only specific loss 'loss_name' and change its name to 'change_name'
        if loss_name is not None:
            for single_loss in self.loss_list:
                if loss_name == single_loss['name']:
                    loss = single_loss['weight'] * single_loss['func'](*loss_arg)
                    if single_loss['ratio']: loss *= ratio
                    if change_name is not None:
                        return {change_name: loss}
                    return {single_loss['name']: loss}
            raise RuntimeError('there is no such loss in training losses: {}'.format(loss_name))

        # normal case: calculate all training losses at one time
        losses = {}
        for single_loss in self.loss_list:
            losses[single_loss['name']] = single_loss['weight'] * single_loss['func'](*loss_arg)
            if single_loss['ratio']: losses[single_loss['name']] *= ratio 

        # calculate temporal information
        tmp_info = {}
        for single_tmp_info in self.tmp_info_list:
            # don't need gradient
            with torch.no_grad():
                tmp_info[single_tmp_info['name']] = single_tmp_info['func'](*loss_arg)

        return losses, tmp_info

In [7]:
#| export

def _mse(x, y, level=2):
    assert x.shape == y.shape
    err = (x - y) ** level
    err = torch.abs(err)
    return err.mean()


@regist_loss
class L1Loss(nn.Module):
    def forward(self, input_data, model_output, data, module):
        fx = model_output['recon']
        y = data['clean']
        return _mse(fx, y, level=1)


@regist_loss
class L2Loss(nn.Module):
    def forward(self, input_data, model_output, data, module):
        fx = model_output['recon']
        y = data['clean']
        return _mse(fx, y, level=2)

## NLL

In [8]:
#| export

@regist_loss
class NLLLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.dist = StandardNormal()

    def forward(self, input_data, model_output, data, module):
        z = model_output['z']
        ldj = model_output['ldj']

        log_z = self.dist.log_prob(z)
        objectives = ldj + log_z
        return torch.mean(-objectives)

@regist_loss
class std_z(nn.Module):
    def __init__(self):
        super().__init__()
        self.dist = StandardNormal()

    def forward(self, input_data, model_output, data, module):
        z = model_output['z']
        var_z = torch.var(z, dim=[1,2,3])
        sd_z = torch.mean(torch.sqrt(var_z))
        return sd_z

## GAN

In [9]:
#| export

@regist_loss
class GANLoss(nn.Module):
    def __init__(self, lambda_gp=10., lambda_gen=1.0):
        super(GANLoss, self).__init__()
        self.lambda_gen = lambda_gen
        self.lambda_gp = lambda_gp
        
    def get_target_tensor(self, prediction, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def forward(self, input_data, model_output, data, module):
        training_mode = model_output['training_mode']
        if training_mode == 'generator':
            loss = -model_output['critic_fake'].mean() * self.lambda_gen
        elif training_mode == 'critic':
            if model_output['critic_noise']:
                #REMARKS: If this part is changed, _forward_fn must also be changed.
                fake_noise = (model_output['fake']-data['clean']).requires_grad_(True)
                real_noise = (model_output['real']-data['clean']).requires_grad_(True)
                gp_loss = self._gradient_penalty(
                    module['critic'],
                    torch.cat([real_noise, data['clean']],dim=1).requires_grad_(True),
                    torch.cat([fake_noise, data['clean']],dim=1).requires_grad_(True)
                )
            else:
                gp_loss = self._gradient_penalty(module['critic'], model_output['real'], model_output['fake'])
            loss = model_output['critic_fake'].mean() - model_output['critic_real'].mean() \
                + self.lambda_gp * gp_loss
        else:
            assert False, f'Invalid training mode: {training_mode}'

        return loss

    def _gradient_penalty(self, D, real_samples, fake_samples):
        alpha = torch.randn(real_samples.size(0), 1, 1, 1, device=real_samples.device)
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        d_interpolates = D(interpolates)
        fake = torch.ones([real_samples.shape[0], 1], device=real_samples.device)
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty
    


In [10]:
#| export

@regist_loss
class real_sub_fake(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input_data, model_output, data, module):
        return model_output['critic_real'].mean() - model_output['critic_fake'].mean()

In [11]:
#| hide
import nbdev; nbdev.nbdev_export()