# Cross Entropy
> PyTorch와 Numpy로 구현하는 Cross Entropy

- toc: true 
- badges: true
- comments: true
- categories: [Implementation, AI-math]
- image: images/cross_entropy_fig1.png

# Cross Entropy란?

# Code Implementation

## Library 호출

In [1]:
import math
import numbers
from typing import Optional, Tuple, Sequence, Union, Any

import numpy as np

import torch
import torch.nn as nn

## Numpy로 구현하기

In [235]:
class NumpyCrossEntropy:

    @staticmethod
    def log_softmax(ndarray: np.ndarray, axis: int = -1) -> np.ndarray:
        c = np.amax(ndarray, axis=axis, keepdims=True)
        s = ndarray - c
        nominator = np.exp(s)
        denominator = nominator.sum(axis=axis, keepdims=True)
        probs = nominator / denominator
        return np.log(probs)

    @staticmethod
    def negative_log_likelihood(
        y_pred: np.ndarray, 
        y: np.ndarray,
        axis: int = -1,
        reduce: str = "mean",
    ) -> np.ndarray:
        assert y_pred.ndim == 2 and y.ndim == 1
        axis_x = np.arange(y_pred.shape[0]) if axis != 0 else y
        axis_y = y if axis != 0 else np.arange(y_pred.shape[0])
        log_likelihood = y_pred[axis_x, axis_y]
        nll = -log_likelihood
        if reduce == "mean":
            return np.mean(nll)
        elif reduce == "sum":
            return np.sum(nll)
        return nll

    def cross_entropy(
        self, 
        y_pred: np.ndarray, 
        y: np.ndarray,
        axis: int = -1,
        reduce: str = "mean",
    ) -> np.ndarray:
        assert axis in [0, 1, -1]
        assert reduce in ["mean", "sum", "none"]
        log_probs = self.log_softmax(y_pred)
        ce_loss = self.negative_log_likelihood(log_probs, y, axis, reduce)
        return ce_loss
    

nce = NumpyCrossEntropy()

## PyTorch로 구현하기

In [589]:
class LogSoftmax(torch.autograd.Function):

    @staticmethod
    def forward(ctx: Any, tensor: Any, dim: int = -1) -> Any:
        # softmax(x) = softmax(x+c)
        c = torch.amax(tensor, dim=dim, keepdims=True)
        s = tensor - c
        # Calculate softmax
        nominator = torch.exp(s)
        denominator = nominator.sum(dim=dim, keepdims=True)
        probs = nominator / denominator
        # Calculate log
        log_probs = torch.log(probs)
        ctx.save_for_backward(probs, torch.tensor(dim))
        return log_probs

    @staticmethod
    def backward(ctx: Any, grad_outputs: Any) -> Any:
        # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/SoftMax.cpp#L219
        probs, dim, = ctx.saved_tensors
        grad_outputs -= probs * grad_outputs.sum(dim=dim.item(), keepdims=True)
        return grad_outputs, None


class NegativeLogLikelihoodLoss(torch.autograd.Function):

    @staticmethod
    def forward(ctx: Any, y_pred: Any, y: Any, dim: int = -1, 
                reduce: str = "mean", ignore_index: int = -1) -> Any:
        bsz, n_classes = torch.tensor(y_pred.size())
        mask = y.ne(ignore_index)
        ctx.save_for_backward(
            bsz, n_classes, y, 
            torch.tensor(dim), mask,
            torch.tensor({"mean": 0, "sum": 1, "none": 2}.get(reduce, -1)),
            torch.tensor(ignore_index)
        )
        dim_x = torch.arange(bsz) if dim != 0 else y
        dim_y = y if dim != 0 else torch.arange(bsz)
        log_likelihood = y_pred[dim_x, dim_y] # Calculate Log Likelihood
        nll = -log_likelihood # Calculate Negative Log Likelihood
        # Calculate Loss
        if reduce == "mean":
            return torch.mean(nll[mask])
        elif reduce == "sum":
            return torch.sum(nll[mask])
        nll[~mask] = 0.
        return nll

    @staticmethod
    def backward(ctx: Any, grad_outputs: Any) -> Any:
        bsz, n_classes, y, dim, mask, reduce, ignore_index, = ctx.saved_tensors
        if reduce.item() != 2: # reduce case
            grad_outputs = grad_outputs.expand(bsz)
        if reduce.item() == 0: # mean case
            grad_outputs = grad_outputs / mask.sum()
        negative_mean_grad = -grad_outputs # backward negative
        # backward log likelihood (indexing)
        if dim.item() != 0:
            ll_grad = torch.zeros(bsz, n_classes, device=grad_outputs.device)
            ll_grad[torch.arange(bsz), y] = 1
            ll_grad[torch.arange(bsz), ignore_index.item()] = 0
        else:
            ll_grad = torch.zeros(n_classes, bsz, device=grad_outputs.device)
            ll_grad[y, torch.arange(bsz)] = 1
            ll_grad[ignore_index.item(), torch.arange(bsz)] = 0
        grad_outputs = torch.diag(negative_mean_grad) @ ll_grad
        return grad_outputs, None, None, None, None
    

_TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]]


class CrossEntropyLoss(nn.Module):
    
    log_softmax = LogSoftmax.apply
    negative_log_likelihood = NegativeLogLikelihoodLoss.apply
    
    def __init__(self, reduce: str = "mean", ignore_index: int = -1):
        self.reduce = reduce
        self.ignore_index = ignore_index

    def forward(
        self, 
        y_pred: _TensorOrTensors, 
        y: _TensorOrTensors,
        dim: int = -1,
    ) -> _TensorOrTensors:
        log_probs = self.log_softmax(y_pred, dim)
        ce_loss = self.negative_log_likelihood(
            log_probs, y, dim, self.reduce, self.ignore_index)
        probs = torch.exp(log_probs)
        self.save_for_backward(probs, y, y_pred.size(0), y_pred.size(-1))
        return ce_loss

    def save_for_backward(self, *args):
        self.saved_tensors = args

    @torch.no_grad()
    def backward(self, grad_outputs: _TensorOrTensors) -> _TensorOrTensors:
        probs, y, bsz, num_classes, = self.saved_tensors
        y = torch.nn.functional.one_hot(y, num_classes=num_classes)
        ce_grad = probs - y
        if self.reduce == "mean":
            ce_grad = ce_grad / bsz
        return grad_outputs * ce_grad
    
    
class PyTorchCrossEntropy:
    
    def __init__(self):
        self.log_softmax = LogSoftmax.apply
        self.negative_log_likelihood = NegativeLogLikelihoodLoss.apply
        self.cross_entropy = CrossEntropyLoss()
        

tce = PyTorchCrossEntropy()

## Figure 그리기

In [237]:
import matplotlib.pyplot as plt

# 결과값 비교

In [596]:
import random
from functools import partial


batch_size = 8
vocab_size = 3000

rtol = 1e-4
atol = 1e-6
isclose = partial(torch.isclose, rtol=rtol, atol=atol)

In [597]:
y_pred = [[random.normalvariate(mu=0., sigma=1.) for _ in range(vocab_size)] for _ in range(batch_size)]
y_pred_torch = torch.FloatTensor(y_pred)
y_pred_torch.requires_grad = True
y_pred_numpy = y_pred_torch.detach().numpy()

y = [random.randint(0, vocab_size) for _ in range(batch_size)]
y_torch = torch.LongTensor(y)
y_numpy = y_torch.numpy()

In [275]:
nn.functional.nll_loss(
    nn.functional.log_softmax(y_pred_torch, dim=-1), 
    y_torch, 
    reduction="none",
)

tensor([8.4246, 8.8899, 9.0440, 8.9059, 6.7580, 7.8764, 7.6204, 7.6686],
       grad_fn=<NllLossBackward>)

In [272]:
tce.negative_log_likelihood(
    nn.functional.log_softmax(y_pred_torch, dim=-1), 
    y_torch, 
    -1,
    "sum",
)

TypeError: forward() takes 3 positional arguments but 5 were given

## forward pass

In [241]:
nce.cross_entropy(y_pred_numpy, y_numpy, reduce="none")

array([9.205082 , 6.5644517, 7.9004273, 9.04045  , 7.5050826, 8.987215 ,
       7.9438257, 7.5462923], dtype=float32)

In [232]:
ce_result = nn.CrossEntropyLoss()(y_pred_torch, y_torch)
ce_numpy = nce.cross_entropy(y_pred_numpy, y_numpy)
ce_torch = tce.cross_entropy(y_pred_torch, y_torch)

try:
    isclose(ce_result, ce_torch).item()
    isclose(ce_result, torch.tensor(ce_numpy)).item()
    success = True
except:
    success = False

print("Do both output the same tensors?", "🔥" if success else "💩")
if not success:
    raise Exeption("Something went wrong")

Do both output the same tensors? 🔥


## backward pass

In [8]:
# backward (under debugging)
ce_grad = torch.autograd.grad(ce_result, y_pred_torch, retain_graph=True)[0]
my_ce_grad1 = torch.autograd.grad(ce_torch, y_pred_torch, retain_graph=True)[0]
my_ce_grad2 = tce.cross_entropy.backward(torch.ones_like(y_pred_torch))
# my_ce_grad2 = tce.cross_entropy.backward(y_pred_torch)

try:
    isclose(ce_grad, my_ce_grad1).all()
    isclose(ce_grad, my_ce_grad2).all()
    success = True
except:
    success = False

print("Do both output the same tensors?", "🔥" if success else "💩")
if not success:
    raise Exeption("Something went wrong")

Do both output the same tensors? 🔥


## TODO
- reduce 기능 추가
- forward와 backward 순서대로 수식으로 설명하면서 코드로 연결시키기
- plot 그리기