In [13]:
import torch 
import torch.nn as nn
import torch.nn.functional as F


In [None]:
class BiAffine(nn.Module):
    """双仿射函数
    BiAffine(in1_features, in2_features, out_features, bias=True)
    """
    def __init__(self,input_size1,input_size2,output_size,bias=(True,True)):
        """
        参数含义
        input_size1:输入向量的维度
        input_size2:输入向量的维度
        output_size:输出向量的维度
        bias:是否使用偏置项
        """
        super(BiAffine,self).__init__()
        self.input_size1 = input_size1
        self.input_size2 = input_size2
        self.output_size = output_size
        self.bias = bias
        # 定义参数
        self.u=nn.Parameter(torch.randn(output_size,input_size1,input_size2))
        self.w_left=None
        self.w_right=None
        self.b=None

        if self.bias[0]:
            self.w_left=nn.Parameter(torch.randn(output_size,input_size1))

        if self.bias[1]:
            self.w_right=nn.Parameter(torch.randn(output_size,input_size2))

        if self.bias[0] and self.bias[1]:
            self.b=nn.Parameter(torch.randn(output_size))
    def reset_parameters(self):
        """初始化参数"""
        nn.init.xavier_uniform_(self.u)
        if self.w_left is not None:
            nn.init.xavier_uniform_(self.w_left)
        if self.w_right is not None:
            nn.init.xavier_uniform_(self.w_right)
        if self.b is not None:
            nn.init.zeros_(self.b)

    def forward(self,x,y):
        """
        前向传播的参数含义
        x:输入向量 x[batch_size,seq_length1,input_size1]
        y:输入向量 y[batch_size,seq_length2,input_size2]
        output:[batch_size,seq_length1,seq_length2,output_size]
        """
        batch_size,seq_length1,input_size1=x.shape 
        seq_length2=y.shape[1]

        # 计算双仿射函数
        xuy=torch.einsum("bxi,oij,byj->bxyo",x,self.u,y)
        output=xuy

        #use w_left
        if self.w_left is not None:
            wx=torch.matmul(x,self.w_left.t())
            output+=wx.unsqueeze(2).expand(-1,-1,seq_length2,-1)
        if self.w_right is not None:
            wy=torch.matmul(y,self.w_right.t())
            output+=wy.unsqueeze(1).expand(-1,seq_length1,-1,-1)

        if self.b is not None:
            output+=self.b

        return output



In [15]:
def test_biaffine():
    torch.manual_seed(42)

    input_size1=4
    input_size2=3
    output_size=2
    biaffine=BiAffine(input_size1=input_size1,input_size2=input_size2,output_size=output_size)

    batch_size=2
    seq_len1=3
    seq_len2=2

    x=torch.rand(batch_size,seq_len1,input_size1)
    y=torch.rand(batch_size,seq_len2,input_size2)

    print(f"shaperx: {x.shape}")
    print(f"shapery: {y.shape}")
    z=biaffine(x,y)
    print(f"shapez: {z.shape}")
    print('='*10)
    print(f"input1: {x}")
    print(f"input2: {y}")
    print(f"output: {z}")


In [16]:
test_biaffine()

shaperx: torch.Size([2, 3, 4])
shapery: torch.Size([2, 2, 3])
shapez: torch.Size([2, 3, 2, 2])
input1: tensor([[[0.9811, 0.0874, 0.0041, 0.1088],
         [0.1637, 0.7025, 0.6790, 0.9155],
         [0.2418, 0.1591, 0.7653, 0.2979]],

        [[0.8035, 0.3813, 0.7860, 0.1115],
         [0.2477, 0.6524, 0.6057, 0.3725],
         [0.7980, 0.8399, 0.1374, 0.2331]]])
input2: tensor([[[0.9578, 0.3313, 0.3227],
         [0.0162, 0.2137, 0.6249]],

        [[0.4340, 0.1371, 0.5117],
         [0.1585, 0.0758, 0.2247]]])
output: tensor([[[[ 2.8857, -1.4196],
          [ 1.4030, -1.4968]],

         [[-1.7830, -0.2003],
          [-0.0663, -0.8519]],

         [[-0.7544,  0.2553],
          [-0.4543, -0.6689]]],


        [[[ 0.7906, -0.9671],
          [ 0.4306, -0.5885]],

         [[-0.7895, -0.4223],
          [-0.3803, -0.5328]],

         [[ 0.7402, -1.4283],
          [ 0.8075, -0.9614]]]], grad_fn=<AsStridedBackward0>)
