# cut cross entropy

author: [xiaodongguaAIGC](https://github.com/dhcode-cpp)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(42)   

<torch._C.Generator at 0x10927eb90>

## Cross-Entropy

In [2]:
X = torch.tensor([1.0, 3.0, -1.2, 1.1, -0.5, -0.8]) # logits
Y = torch.tensor([0.4, 0.2, 0.1,  0.1, 0.05, 0.15])

### Cross Entropy Pytorch

$CE(X,Y) = \sum_i -\log(\text{softmax}(X))$

In [3]:
loss = nn.CrossEntropyLoss()
l = loss(X,Y)
print(l)

### Cross Entropy with log softmax

when we calculate cross entropy, we need logprob

In [4]:
l = -Y * F.log_softmax(X, dim = 0 )
l = l.sum()
print(l)

In [5]:
p = F.softmax(X, dim = 0 )
l = -Y * p.log()
l = l.sum()
print(l)

## Softmax

### forward

In [6]:
p = torch.exp(X) / (torch.exp(X)).sum()
print(p)
l = - Y * p.log()
l = l.sum()
print(l)

### gradient

In [7]:
N = 6 
dX = (p - Y) / 6
print(dX)

In [8]:
dpdX = torch.diag(p) - torch.outer(p, p)
dX = -(Y / p) @ dpdX   # first term is (- Y log p)' = (-Y/p)
print(dX)

## Log Softmax

### forward

In [9]:
LSE = torch.log( torch.sum( torch.exp( X ) ) )
p = X - LSE # X - LSE
print(p)
l = - Y * p # delete log
l = l.sum()
print(l)

### gradient

首先，L = -∑y_i * log(softmax(x_i))

= -∑y_i * p_i (因为p_i = LogSoftmax(x_i))

求导链式法则：

∂L/∂x_k = ∑_i (∂L/∂p_i * ∂p_i/∂x_k)

其中：

∂L/∂p_i = -y_i

∂p_i/∂x_k = δ_{ik} - softmax(x_k) 

代入得到：

∂L/∂x_k = ∑_i (-y_i * (δ_{ik} - softmax(x_k)))

= -y_k + ∑_i y_i * softmax(x_k)

= -y_k + softmax(x_k) * ∑_i y_i

= softmax(x_k) - y_k (因为∑y_i = 1)


In [10]:
l = - Y * p / N
# p.exp() is softmax result probability
dX = p.exp() - Y # lse is easy revearse p 
print(dX)

## Cut Cross Entropy

think about llm model has a big vocabulary, so the classifer project d-dimension to |V|-dimension, and |V| usually is big like 128,000 or 256,000.

classifier proj is big GEMM when the model params < 2B

In [11]:
V = 128
D = 16
BS = 8
block_size = 4
n_b = BS // block_size
d_b = D // block_size
v_b = V // block_size

print(n_b)
print(d_b)
print(v_b)

In [12]:
C = torch.randn(D, V)
E = torch.randn(BS, D)
Y_raw = torch.randn(BS, V)
Y_t = Y_raw / 0.0001 # be sharp
Y = F.softmax(Y_t, dim = 1)
# # print(YY)
# _, label = torch.max(Y, dim = 1)
# print(label)
print(Y.shape) # label

### standard CE

In [13]:
# ce
logits = E @ C 
print(logits.shape)
print(Y.shape)
l = loss(logits, Y)
print(l)

In [14]:
# softmax -> ce
P = F.softmax(logits, dim = 1)
print(P)
l = -Y * P.log()
print(l.sum() / (BS*V))

In [15]:
# logsoftmax -> ce
P = F.log_softmax(logits, dim = 1)
l = -Y * P
print(l.sum() / (BS*V))

### Cut Cross Entropy 

#### *Algorithm 1: forward GEMM*

our target is get  


logprob = X @ C - log \sum_j X[i,:] C[:,j] # row sub

the X[i,1] C[1,j] + X[i,2] C[2,j] + X[i,d] C[d,j] + 

output = -Y * logprob

In [16]:
X = E @ C

In [17]:
print(X.shape)

In [18]:
X_blocks = torch.split(X, 64, dim = 1)
print(X_blocks[0].shape)
print(len(X_blocks))

In [19]:
E_blocks = torch.split(E, n_b, dim = 0)
C_blocks = torch.split(C, v_b, dim = 1)

print(len(E_blocks), len(C_blocks))
print(E_blocks[0].shape)
print(C_blocks[0].shape)

In [20]:
stand_mm = E @ C
print(stand_mm[:5,:5])

In [21]:
# CCE : Algorithm 1

O = torch.zeros(BS, V)
O_blocks = torch.split(O, n_b, dim = 0)
O_result = []

for e, o in zip(E_blocks, O_blocks): # i
    e_d_blocks = torch.split(e, d_b, dim = 1)
    c_blocks = torch.split(C, d_b, dim = 0)
    for ed, cd in zip(e_d_blocks, c_blocks): # 
        a = ed @ cd
        o = o + a
    O_result.append(o)

O_mat = torch.cat(O_result, dim = 0)
print(O_mat.shape)

print(O_mat[:5,:5])

#### *Algorithm 2: forward logsoftmax*

Original LSE

In [22]:
A = E @ C 
LSE = torch.log( torch.sum( torch.exp(A), dim = 1, keepdim=True ) )
print(LSE)

In [23]:
# CCE : Algorithm 2

LSE = torch.ones(BS, 1) * -100000.0 # -inf

# LSE part
for n, e in enumerate(E_blocks):
    for v, c in enumerate(C_blocks):
        A = torch.zeros(n_b, v_b)
        e_d_blocks = torch.split(e, d_b, dim = 1)
        c_d_blocks = torch.split(c, d_b, dim = 0)
        for ed, cd in zip(e_d_blocks, c_d_blocks):
            A += ed @ cd
        LSE_nv = torch.log( torch.sum( torch.exp(A), dim = 1, keepdim=True ) )
        LSE[n_b * n : n_b * (n+1)] = torch.log( torch.exp(LSE[n_b * n : n_b * (n+1)]) 
                                                    + torch.exp(LSE_nv))
print(LSE.shape)
print(LSE)

#### Backward

#### standard backward

In [24]:
# basic backward
# P = softmax(E @ C, dim = 1)
# loss = - Y * torch.log(P)

Y = torch.randn(BS, V)
Y = torch.softmax(Y, dim = 1)

X = E @ C
P = torch.softmax(X, dim = 1)

loss = (- Y * torch.log(P)).mean()
print(loss)

In [25]:
# Stand backward
N = BS * V

dP = (P - Y) / N
dPdX = torch.ones(V, V)
for i in range(BS):
    dPdX += torch.diag(P[i,:]) - torch.outer(P[i,:], P[i,:]) 

dX = dP @ dPdX 
print(dX.shape)

dE = dP @ C.t()
dC = E.t() @ dP 

print(dE[:5,:5])
print(dC[:5,:5])


#### LSE Backward

In [26]:
X = E @ C
LSE = torch.log(torch.sum( torch.exp(X), dim = 1, keepdim = True))
print(LSE)
P = X - LSE # 逐行减去LSE
loss = (- Y * P) # (-Y * X + Y * LSE)
print(loss.mean() )


In [27]:
dY_dX = -Y / N # P = "X" - LSE
dY_dLSE = torch.sum(Y / N, dim = 1, keepdim = True) # P = X - "LSE"
dLSE_dX = F.softmax(X, dim = 1)
dY_dX_LSE = dY_dLSE * dLSE_dX

dY_dX_total =  dY_dX + dY_dX_LSE

dE = dY_dX_total @ C.t()
dC = E.t() @ dY_dX_total 

print(dE[:5,:5])
print(dC[:5,:5])

In [28]:
dE_LSE = dLSE_dX @ C.t()
dC_LSE = E.t() @ dLSE_dX 
print(dE_LSE[:5,:5])
print(dC_LSE[:5,:5])

#### Pytorch AutoGrad LSE backward

 仅检验手动计算的梯度是否正确

In [29]:
# # 注意要清除梯度，得到的才准确
# # E.grad.zero_()
# # C.grad.zero_()
# # Y.grad.zero_()

# E.requires_grad = True
# C.requires_grad = True
# Y.requires_grad = True


# X = E @ C
# X.retain_grad()

# LSE = torch.log(torch.sum( torch.exp(X), dim = 1, keepdim = True))
# LSE.retain_grad()

# P = X - LSE # 逐行减去LSE
# P.retain_grad()

# loss = (- Y * P) # (-Y * X + Y * LSE)
# loss.retain_grad()
# # print(loss)

# l = loss.mean()
# l.backward() 

In [30]:
# print(loss.grad[:5,:5])

# print(LSE.grad[:5,:5])
# print(LSE.grad.shape)

# print(X.grad[:5,:5])
# print(E.grad[:5,:5])
# print(C.grad[:5,:5])

#### Algorithm 3: CCE LSE Backward

In [31]:
# CCE : Algorithm 3

is_filter = False
epsilon = 1e-6
# 
dE = torch.zeros_like(E)
dC = torch.zeros_like(C)

dE_blocks = torch.split(dE, n_b, dim = 0)
dC_blocks = torch.split(dC, v_b, dim = 1)

print(Y.shape)
d_LSE = dY_dLSE
print(d_LSE)

# LSE part
for n, e in enumerate(E_blocks):
    for v, c in enumerate(C_blocks):
        
        A = torch.zeros(n_b, v_b)
        e_d_blocks = torch.split(e, d_b, dim = 1)
        c_d_blocks = torch.split(c, d_b, dim = 0)
        for ed, cd in zip(e_d_blocks, c_d_blocks):
            A += ed @ cd # A_nv

        S = torch.exp(A - LSE[n_b * n : n_b * (n+1)]) # softmax

        # filter S
        if is_filter:
            if  (S < epsilon).all():
                print('skip', n, v)
                continue
            
        
        d_LSE_block = d_LSE[n_b * n : n_b * (n+1), :]
        d_index = 0
        for ed, cd in zip(e_d_blocks, c_d_blocks):
            dE_blocks[n][ :, d_index * d_b: (d_index+1)*d_b ] +=  (d_LSE_block * S) @ cd.t()
            dC_blocks[v][ d_index * d_b: (d_index+1)*d_b, : ] +=  ed.t() @ (d_LSE_block * S)
            d_index = d_index + 1


# dLSE gradient part
dE_part_LSE = torch.cat(dE_blocks, dim = 0)
dC_part_LSE = torch.cat(dC_blocks, dim = 1)

print(dE_part_LSE.shape)
print(dC_part_LSE.shape)

print(dE_part_LSE[:5,:5])
print(dC_part_LSE[:5,:5])

In [32]:
# dX part dE, dC
dE_part_X = dY_dX @ C.t()
dC_part_X = E.t() @ dY_dX

# dX Total
dE_total = dE_part_X + dE_part_LSE
dC_total = dC_part_X + dC_part_LSE

print(dE_total[:5, :5])
print(dC_total[:5, :5])

#### Algorithm 4: CCE Backward total

In [33]:
# print(Y)
# Y_t = Y / 0.001 # be sharp
# YY = F.softmax(Y_t, dim = 1)
# print(YY)
_, label = torch.max(Y, dim = 1)
print(label)

print(Y[0, label[0]])

##### max likelihood version CE

In [34]:
print(P.shape)

In [35]:
print(label.unsqueeze(0))
P_CEL = torch.gather(P, dim = 1, index = label.unsqueeze(1))
loss = -P_CEL
l = loss.mean()
print(loss)
print(l)

In [36]:
loss = nn.CrossEntropyLoss()
l_torch = loss(X, label)
print(l_torch)

In [37]:
print(d_LSE)
print(l_torch/BS)

In [53]:
# CCE : Algorithm 4

# M = torch.zeros(BS, V)
# for i in range(BS):
#     M[i, label[i]] = 1
M_blocks = torch.split(Y, n_b, dim = 0)

dE = torch.zeros_like(E)
dC = torch.zeros_like(C)

dE_blocks = torch.split(dE, n_b, dim = 0)
dC_blocks = torch.split(dC, v_b, dim = 1)

d_LCE = torch.ones(BS, 1) / (BS*V)


for n, e in enumerate(E_blocks):
    
    M_v_blcoks = torch.split(M_blocks[n], v_b, dim = 1)
                                 
    for v, c in enumerate(C_blocks):
        
        A = torch.zeros(n_b, v_b)
        e_d_blocks = torch.split(e, d_b, dim = 1)
        c_d_blocks = torch.split(c, d_b, dim = 0)
        for ed, cd in zip(e_d_blocks, c_d_blocks):
            A += ed @ cd # A_nv

        S = torch.exp(A - LSE[n_b * n : n_b * (n+1)]) 
        G = -M_v_blcoks[v] + S                          # Total Gradient
                
        # filter S
        if is_filter:
            if  (G.abs() < epsilon).all():
                print('skip',n,v)
                continue
        
        d_LCE_block = d_LCE[n_b * n : n_b * (n+1), :]
        d_index = 0
        for ed, cd in zip(e_d_blocks, c_d_blocks):
            dE_blocks[n][ :, d_index * d_b: (d_index+1)*d_b ] +=  (d_LCE_block * G) @ cd.t()
            dC_blocks[v][ d_index * d_b: (d_index+1)*d_b, : ] +=  ed.t() @ (d_LCE_block * G)
            d_index = d_index + 1


# dLSE gradient part
dE = torch.cat(dE_blocks, dim = 0)
dC = torch.cat(dC_blocks, dim = 1)

print(dE[:5,:5])
print(dC[:5,:5])