# Lora

In [4]:
import torch

In [5]:

import torch.nn as nn
import torch.nn.functional as F
import math

class LoRAlinear(nn.Module):
    def __init__(self, in_features, out_features, merge, rank = 16, lora_alpha = 16, dropout = 0.5) -> None:
        super(LoRAlinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.merge = merge
        self.rank = rank
        self.lora_alpha = lora_alpha
        self.dropout_rate = dropout
        
        self.linear = nn.Linear(in_features, out_features)
        if rank > 0:
            self.lora_b = nn.Parameter(torch.zeros(out_features,rank))
            self.lora_a = nn.Parameter(torch.zeros(rank,in_features))
            self.scale = self.lora_alpha/self.rank # why
            self.linear.weight.requires_grad = False
        
        if self.dropout_rate > 0:
             self.dropout = nn.Dropout(self.dropout_rate)
        else:
            self.dropout = nn.Identity()
            
    def initial_weights(self):
        nn.init.kaiming_uniform_(self.lora_a, a = math.sqrt(5))
        nn.init.zeros_(self.lora_b)
        
    def forward(self, x):
        if self.rank > 0 and self.merge:
            output = F.linear(x, self.linear.weight + self.lora_a @ self.lora_b * self.scale, self.linear.bias)
            output = self.dropout(output)
            return output
        else:
            return self.dropout(self.linear(x))

## 什么是Focal Loss
分类问题我们通常会使用交叉熵损失函数，但是交叉熵损失函数对于类别不均衡的问题并不是很友好，因为它会对少数类别的样本给予更多的权重，而对于多数类别的样本给予较少的权重。为了解决这个问题，Lin等人提出了Focal Loss，Focal Loss是一种专门用于处理类别不均衡问题的损失函数，它通过调整损失函数的权重，使得模型更加关注难以分类的样本，从而提高模型的泛化能力。

In [7]:
class FocalLoss(nn.Module):
    def __init__(self, alpha = None, gamma = 2.0, reduction = 'mean'):
        """_summary_
        initialize FocalLoss Class
        Args:
            alpha (_type_, optional): class weight, float array with same size as categories. Defaults to None.
            gamma (float, optional): reducing the contribution of "easy" classes to the Loss. Defaults to 2.0.
            reduction (str, optional): return mode of Loss. Defaults to 'mean'.
        """
        super(FocalLoss,self).__init__()
        if alpha is None:
            self.alpha = torch.tensor([1.0])
        else:
            self.alpha = torch.tensor(alpha)
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, inputs, labels):
        """
        compute Focal Loss
        Args:
            inputs (_type_): logits from model, shape (N,C) where N is the number of samples and C is the number of categories
            targets (_type_): True class labels with shape (N,)
        """
        BCE_loss = F.cross_entropy(inputs, labels, reduction='none')
        pt = torch.exp(-BCE_loss)
        
        alpha_t = self.alpha.to(inputs.device)
        alpha_t = alpha_t.gather(0,labels.data.view(-1))
        loss = alpha_t * (1-pt) ** self.gamma * BCE_loss
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

criterion = FocalLoss(alpha=[0.25, 0.5, 1.0], gamma = 2.0, reduction = 'mean')
inputs = torch.randn(10, 3, requires_grad=True)
labels = torch.empty(10, dtype=torch.long).random_(3) # 随机生成十个类别标签，范围0-2

loss = criterion(inputs, labels)
print(loss)


tensor(0.5506, grad_fn=<MeanBackward0>)
