# Cross-entropy with torch 

Case 구현 및 검증 

- case 1 : data array with 1 dimension 
- case 2 : data array with n dimension 
- case 3 : (h,w) data array with (d) dimension eg. image data 

In [1]:
'''
>>> # Example of target with class indices
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
>>>
>>> # Example of target with class probabilities
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randn(3, 5).softmax(dim=1)
>>> output = loss(input, target)
>>> output.backward()


pytorch document example 이다. 

두 예제의 가장 큰 차이점은 target 의 유형 (indices or class probabilities) 이다. 
따라서 indices(long) => class probabilities(one-hot) 으로 변환을 해서 쓰거나 혹은 쓰지 않아도 된다. 
segmentation 의 경우에, iou or dice 를 계산해주어야 할 일이 많으므로, one-hot encoded 를 matirx 사이에 넣어주는 경우도 흔하다. 

small tips for usage : 

- torch 의 CE loss 는, softmax 내장이기 때문에, 모델의 마지막 layer 가 softmax 일 필요는 없다. 
- LogSoftmax + NLLLoss combined into one function => torch.nn.functional.CrossEntropy  
- (e.g. [2, 3, 1, 1] vs [[0 0 1 0], [0 0 0 1], [0 1 0 0], [0 1 0 0] for CrossEntropyLoss and BCELoss) 


- 단, BCE 를 사용할 경우에, softmax 를 따로 적용해 주어야 한다. 
  혹은 BCEWithLogitsLoss 를 사용하면 된다. 

- nn.CrossEntropyLoss 의 weight 를 잡아주면은, focal_loss 처럼 활용가능하다
- ref : https://medium.com/unpackai/cross-entropy-loss-in-ml-d9f22fc11fe0

'''

'\n# Example of target with class indices\nloss = nn.CrossEntropyLoss()\ninput = torch.randn(3, 5, requires_grad=True)\ntarget = torch.empty(3, dtype=torch.long).random_(5)\noutput = loss(input, target)\noutput.backward()\n\n# Example of target with class probabilities\ninput = torch.randn(3, 5, requires_grad=True)\ntarget = torch.randn(3, 5).softmax(dim=1)\noutput = loss(input, target)\noutput.backward()\n\n\npytorch document example 이다. \n\n두 예제의 가장 큰 차이점은 target 의 유형 (indices or class probabilities) 이다. \n따라서 indices(long) => class probabilities(one-hot) 으로 변환을 해서 쓰거나 혹은 쓰지 않아도 된다. \nsegmentation 의 경우에, iou or dice 를 계산해주어야 할 일이 많으므로, one-hot encoded 를 matirx 사이에 넣어주는 경우도 흔하다. \n\nsmall tips for usage : \n\n- torch 의 CE loss 는, softmax 내장이기 때문에, 모델의 마지막 layer 가 softmax 일 필요는 없다. \n- LogSoftmax + NLLLoss combined into one function => torch.nn.functional.CrossEntropy  \n- (e.g. [2, 3, 1, 1] vs [[0 0 1 0], [0 0 0 1], [0 1 0 0], [0 1 0 0] for CrossEntropyLoss and BCELoss) \n\n\n- 단, BCE

In [2]:
import torch.nn.functional as F 
import torch.nn as nn
import torch 

# Case 1

In [3]:
# Case 1 ------------------------------------------------
# conditions) 
# target 이 indices
# 데이터 dimension 이 1 
 
input = torch.tensor([-0.7,0.2,0.2])
target = torch.tensor([1,0,0], dtype=float)

loss = nn.CrossEntropyLoss(reduction="mean")
output = loss(input, target)
print(target)
print(output.item())

tensor([1., 0., 0.], dtype=torch.float64)
1.7782022953033447


In [4]:
def ce_implementation(input, target, indice):
    
    input = torch.softmax(input, dim=-1)
    if indice == True:
        target = F.one_hot(target, num_classes=5)
    
    CE_res = -1 * torch.sum( (target * torch.log(input)) )
    return CE_res

In [5]:
output_ = ce_implementation(input, target, indice=False)
print(output_)

tensor(1.7782, dtype=torch.float64)


# Case 2

In [6]:
# Case 2 ------------------------------------------------
# conditions) 
# target 이 indices
# 데이터 dimension 이 n (>2)
 
input = torch.tensor([[0.7,0.2,0.2,0.1],
                      [0.2,0.7,0.2,0.1],
                      [0.2,0.2,0.8,0.1]
                      ])
target = torch.tensor([2,3,0], dtype=torch.long)

loss = nn.CrossEntropyLoss(reduction="mean")
output = loss(input, target)
print(target)
print(output.item())

tensor([2, 3, 0])
1.5616998672485352


In [7]:
input.shape

torch.Size([3, 4])

In [8]:
target.shape

torch.Size([3])

In [9]:
def ce_imple_n_dimension(input, target):

    input_ = torch.softmax(input,dim=-1)
    target_ = F.one_hot(target)
    r = -1 * torch.sum( target_ * torch.log(input_) ) / input_.shape[0] 
    return r  

In [10]:
output_ = ce_imple_n_dimension(input, target)

In [11]:
output_

tensor(1.5617)

# case 3
- 2d array with n dimension 
- 2d array with N categories 

In [36]:
# Case 3 ------------------------------------------------
# conditions) 
# target 이 indices
# 데이터 dimension 이 n (>2)
n_class = 2
pred = torch.randn(4,n_class,20,20)
target = torch.randint(0,n_class,(4,20,20))
# => 여기에서, 
'''
pred.shape
torch.Size([1, 5, 20, 20])

target.shape
torch.Size([1, 20, 20])

target shape 에서, (1,1,20,20) 이렇게 맞춰주지 말고 
1 channel 을 무시해줘야 한다. 

'''

'\npred.shape\ntorch.Size([1, 5, 20, 20])\n\ntarget.shape\ntorch.Size([1, 20, 20])\n\ntarget shape 에서, (1,1,20,20) 이렇게 맞춰주지 말고 \n1 channel 을 무시해줘야 한다. \n\n'

In [33]:
pred.shape

torch.Size([4, 1, 20, 20])

In [34]:
target.shape

torch.Size([4, 20, 20])

In [35]:
loss = nn.CrossEntropyLoss(reduction="mean")
output = loss(pred, target)
print(output)

tensor(0.)


In [16]:
# 까지꺼 한번 해보자 ㅋㅋ 

pred1 = torch.softmax(pred, dim=1)
print(pred1.shape)
pred2 = pred1.permute(0,2,3,1)
print(pred2.shape)

torch.Size([4, 5, 20, 20])
torch.Size([4, 20, 20, 5])


In [17]:
pred2.shape

torch.Size([4, 20, 20, 5])

In [18]:
target1 = F.one_hot(target)
print(target1.shape)

torch.Size([4, 20, 20, 5])


In [19]:
r = -1 * torch.sum( target1 * torch.log(pred2) ) / (20*20*4)
print(r) 

# 위의 cross entropy 값과 같음 검증 완료 

tensor(1.9370)


# Case 4

- BCE 

In [73]:
# Case 4 ------------------------------------------------
# conditions) 
# target 이 indices
# 데이터 dimension 이 n (>2)
n_class = 1
pred = torch.randn(4,n_class,20,20)
target = torch.randint(0,n_class+1,(4,1,20,20))
# => 여기에서, 
'''
pred.shape
torch.Size([1, 5, 20, 20])

target.shape
torch.Size([1, 20, 20])

target shape 에서, (1,1,20,20) 이렇게 맞춰주지 말고 
1 channel 을 무시해줘야 한다. 

'''

'\npred.shape\ntorch.Size([1, 5, 20, 20])\n\ntarget.shape\ntorch.Size([1, 20, 20])\n\ntarget shape 에서, (1,1,20,20) 이렇게 맞춰주지 말고 \n1 channel 을 무시해줘야 한다. \n\n'

In [74]:
pred.shape

torch.Size([4, 1, 20, 20])

In [75]:
target.shape

torch.Size([4, 1, 20, 20])

In [79]:
sig_pred = nn.functional.sigmoid(pred)



In [80]:
pred = pred.float()
target = target.float()

bce_loss = nn.BCELoss()
loss_ = bce_loss(sig_pred,target)
print(loss_)

tensor(0.8060)


In [None]:
nn.BCELoss

In [60]:
loss_

tensor(0.7831)

In [20]:
# Example of target with class probabilities
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)
output.backward()

In [21]:
class UNet_metric():
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.CE_loss = nn.CrossEntropyLoss(reduction="mean") # "mean" or "sum"

    def __call__(self, pred, target):
        # cross-entropy
        loss1 = self.CE_loss(pred, target)
        
        # dice-coefficient
        onehot_pred = F.one_hot(torch.argmax(pred, dim=1), num_classes=self.num_classes).permute(0, 3, 1, 2)
        onehot_target = F.one_hot(target, num_classes=self.num_classes).permute(0, 3, 1, 2)
        loss2 = self._get_dice_loss(onehot_pred, onehot_target)
        
        # total loss
        loss = loss1 + loss2

        # dice score
        dice_coefficient = self._get_batch_dice_coefficient(onehot_pred, onehot_target)
        return loss, dice_coefficient

    def _get_dice_coeffient(self, pred, target):
        set_inter = torch.dot(pred.reshape(-1).float(), target.reshape(-1).float())
        set_sum = pred.sum() + target.sum()
        if set_sum.item() == 0:
            set_sum = 2 * set_inter
        dice_coeff = (2 * set_inter) / (set_sum + 1e-9)
        return dice_coeff

    def _get_multiclass_dice_coefficient(self, pred, target):
        dice = 0
        for class_index in range(1, self.num_classes):
            dice += self._get_dice_coeffient(pred[class_index], target[class_index])
        return dice / (self.num_classes - 1)

    def _get_batch_dice_coefficient(self, pred, target):
        num_batch = pred.shape[0]
        dice = 0
        for batch_index in range(num_batch):
            dice += self._get_multiclass_dice_coefficient(pred[batch_index], target[batch_index])
        return dice / num_batch

    def _get_dice_loss(self, pred, target):
        return 1 - self._get_batch_dice_coefficient(pred, target)