In [None]:
import torch
from torch import nn
import einops
class MHA(nn.Module):
    def __init__(self,hidden_size,num_heads):
        super().__init__()
        self.q_l=nn.Linear(hidden_size,hidden_size)
        self.k_l=nn.Linear(hidden_size,hidden_size)# group
        self.v_l=nn.Linear(hidden_size,hidden_size)# group
        self.o_l=nn.Linear(hidden_size,hidden_size)
        self.head_dim=hidden_size//num_heads# 头的维度
    
    def forward(self, hs, mask=None):
        bz=hs.shape[0]# 第0维度
        q=self.q_l(hs)
        k=self.k_l(hs)
        v=self.v_l(hs)

        q=einops.rearange(q,"b seq_len (head head_dim) -> b head seq_len  head_dim")## d 代表每个头的维度
        k=einops.rearange(k,"b seq_len (head head_dim) -> b head seq_len  head_dim")## d 代表每个头的维度 # group
        v=einops.rearange(v,"b seq_len (head head_dim) -> b head seq_len  head_dim")## d 代表每个头的维度 # group

        # intervalue_repeat
        # intervalue_repeat
        # intervalue_repeat       
        
        attention_score=torch.matmul(q,k.transpose(-1,-2)/torch.sqrt(torch.tensor(self.head_dim)))
        # b head seq_len seq_len
        if mask !=None:
            attention_score=attention_score.masked_fill(mask==0,float("-inf"))
        # 归一化#对最后一维度的score进行归一化
        attention_prob=torch.softmax(attention_score,dim=-1)
        # seq_len seq_len @ seq_len head_dim
        out=torch.matmul(attention_prob,v)
        out=einops.rearrenge(out,"b head seq_len head_dim -> b seq_len (head head_dim)")
        out_final=self.o_l(out)
        return out_final
        

In [1]:
import torch
import torch.nn as nn
from einops import rearrange

class MQA(nn.Module):
    def __init__(self, hz, num_heads):
        super().__init__()
        self.q_l=nn.Linear(hz,hz)
        self.k_l=nn.Linear(hz,num_heads)
        self.v_l=nn.Linear(hz,num_heads)
        self.o_l=nn.Linear(hz,hz)
        self.head_dim=hz//num_heads

    def forward(self, hs, mask=None):
        q=self.q_l(hs)
        k=self.k_l(hs)
        v=self.v_l(hs)

        q=einops.rearrange(q,"b seq_len (head head_dim) -> b head seq_len head_dim")
        k=einops.rearrange(k,"b seq_len (head head_dim) -> b 1 seq_len head_dim")
        v=einops.rearrange(v,"b seq_len (head head_dim) -> b 1 seq_len head_dim")
        k=k.expand(-1,self.num_heads,-1,-1)
        v=v.expand(-1,self.num_heads,-1,-1)
        #  b h s d @ b 1 s d = b h s s
        attention_score=torch.malmul(q,k.transpose(-1,-2))/torch.sqrt(torch.tensor(self.head_dim))
        # if mask
        if mask !=None:
            attention_score=attention_score.masked_fill(mask==0,float("-inf"))
        attention_prob=torch.softmax(attention_score,dim=-1)
        out= torch.matmul(attention_prob,v)
        out=einops.rearrenge(out,"b head seq_len head_dim -> b seq_len (head head_dim)")
        out_final=self.o_l(out)
        return out_final

In [None]:
import torch
from torch import nn
from einops import rearrange

class MHA(nn.Module):
    def __init__(self,hz,num_heads):
        super.__init__()
        self.head_dim=hz//num_heads
        self.q_l=nn.Linear(hz,hz)
        self.k_l=nn.Linear(hz,hz)
        self.v_l=nn.Linear(hz,hz)
        self.o_l=nn.Linear(hz,hz)

    def forward(self,hidden_state, mask=None):
        q=self.q_l(hidden_state)
        k=self.k_l(hidden_state)
        v=self.v_l(hidden_state)
        # o=self.o_l(hidden_state)

        q=einops.rearrange(q,"b seq_len (head head_dim) -> b head seq_len head_dim")
        k=einops.rearrange(k,"b seq_len (head head_dim) -> b head seq_len head_dim")
        v=einops.rearrange(v,"b seq_len (head head_dim) -> b head seq_len head_dim")

        attention_score=torch.malmut(q,k.transport(-1,-2))/torch.sqrt(self.head_dim)
        # b h s dim @ b h dim s = b h s s
        if mask==None:
            attention_score=attention_score.mask_filled(mask==0,float('-inf'))
        attention_prob=torch.softmax(attention_score,dim=-1)
        out=torch.matmul(attention_prob,v)
        out= einops.rearrenge(out,"b head seq_len head_dim -> b seq_len (head head_dim)")
        out= self.o_l(out)
        return out


In [None]:
def compute_loss(self,model,inputs,return_outputs=False,num_item_in_batch=None):
    prompt_ids,prompt_mask=inputs['input_ids'],inputs['prompt_mask']
    completion_ids,completion_mask=inputs['completion_ids'],inputs["completion_mask"]
    input_ids=torch.cat([prompt_ids,completion_ids],dim=1)
    attention_mask= torch.cat([prompt_mask,completion_mask])
    logits_to_keep =completion_ids.size(1) #只需要计算completion 的token的loss
    per_token_logps= self._get_per_token_logps(model,input_ids,attention_mask, logits_to_keep)
    ref_per_token_logps=inputs["ref_per_token_logps"]
    #Loss = E[min(ratio * advantage, clip(ratio, 1-ε, 1+ε) * advantage)]
    per_token_kl= torch.exp(ref_per_token_logps-per_token_logps)-(ref_per_token_logps-per_token_logps)-1
    advantages=inputs["advantages"]
    # x-x.detach
    # log_ratio = per_token_logps - per_token_logps.detach()  # log(π_new / π_old)
    # ratio = torch.exp(log_ratio)  # π_new / π_old
    per_token_loss=torch.exp(per_token_logps-per_token_logps.detach())* advantages.unsqueeze(-1)
    per_token_loss=-(per_token_loss-self.beta*per_token_kl)
    loss=((per_token_loss * completion_mask).sum(dim=1)/completion_mask.sum(dim=1)).mean()


In [None]:
from einops import rearrenge
def compute_loss_trl(self, model,inputs,group,rewards_per_func=[], return_outputs=False,num_item_in_batch=None):
    rewards=rewards_per_func.sum(dim=1)

    per_token_logps=self._get_per_token_logps(model,inputs,attention_mask,logits_to_keep)
    ref_per_token_logps=inputs["ref_per_token_logps"]
    # compute grouped-wise rewards
    mean_group_rewards=rearrange(rewards,"(b g) -> b g",g=self.num_generation).mean(dim=1)# 对组做平均
    std_group_rewards=rearrange(rewards,'(b g) -> b g',g=self.num_generation).std(dim=1)
    ## 已经对组进行了处理，但是需要进行复制广播

    # =# Normalize the rewards to compute the advantages​
    mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)​# 还是标量，进行增广
    std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)# 标量增广
    # pi/pi_old - log(pi/pi_old)-# 指数操作
    per_token_kl=torch.exp(ref_per_token_logps-per_token_logps)-(ref_per_token_logps-per_token_logps)-1
    advantages=(rewards-mean_group_rewards)/(std_group_rewards+1e-4)
    coff_1=torch.exp(per_token_logps-per_token_logps.detach())
    coff_2=torch.clamp(coff_1,1-self.epsilon,1+self.epsilon)# 裁剪
    advantage=rearrenge(advantages,'(b g)->(b g) 1')# 增广
    # coff_1=  per_token_logps-per_token_logps.detach
    
    per_token_loss= min(coff_1*advantage,coff_2*advantage)
    # (b g) sep_len
    per_token_loss=-(per_token_loss-self.beta*per_token_kl)
    # 对每个token 去均值loss
    loss=((per_token_loss*completion_mask).sum(dim=1)/self.completion_mask.sum(dim=1)).mean()
    return loss 


In [None]:
from einops import rearrange


```python
from einops import rearrenge
def compute_loss(self, input,group,rewards_func):
    rewards=rewards_func.sum(dim=1)# b g seq_len func
    
    mean_group_rewards=rearrenge(rewards,"(b g)")
    str_group_rewards=rearrenge()
```

In [1]:
import einops
from einops import rearrange
def compute_loss(self,model,inputs,mask,rewards_func):
    completion_mask=torch.cat([inputs["prompt_mask"],inputs["completion_mask"]],dim=1)
    rewards= rewards_func.sum(dim=1)# b* g 标量
    per_token_logps=self._get_per_token_logps(model,inputs,mask,logit_to_keep=True)
    ref_per_token_logps=inputs["ref_per_token_logps"]

    mean_group_rewards=rearrange(rewards,"(b g) -> b g",g=self.num_generation)\
        .mean(dim=1) \
        .repeat_interleave(self.num_generation,dim=0)
    std_group_rewards=rearrange(rewards,"(b g) -> b g",g=self.num_generation)\
        .std(dim=1) \
        .repeat_interleave(self.num_generation,dim=0)
    advantages=(rewards-mean_group_rewards)/(std_group_rewards+1e-4)   
    per_token_kl=torch.exp(ref_per_token_logps-per_token_logps)-(ref_per_token_logps-per_token_logps)-1
    # (per_token_logps-per_token_logps.detach())
    # -((pi/pi_old,clamp(pi/pi_old)-self.beta*per_token_kl))
    pi_pi_old=torch.exp(per_token_logps-per_token_logps.detach())
    per_token_loss=torch.min(pi_pi_old*advantages.unsqueeze(1),torch.clamp(pi_pi_old,1-self.epsilon,1+self.epsilon)*advantages.unsqueeze(1))
    per_token_loss=-(per_token_loss-self.beta*per_token_kl)
    loss=(per_token_loss*completion_mask).sum(dim=1)/completion_mask.sum(dim=1).mean()
    return loss