In [1]:
import torch
import torch.nn as nn

from ddi_kt_2024.embed.other_embed import sinusoidal_positional_embedding

In [47]:
class BertWithPostionOnlyModel(nn.Module):
    """
    Only with bert + position encoding
    The stucture: [bert_embedding, pos_ent, zero_ent, pos_tag]
    """
    def __init__(self,
                dropout_rate: float = 0.5,
                word_embedding_size: int = 768,
                position_number: int = 512,
                position_embedding_size: int = 128,
                position_embedding_type: str = "normal",
                tag_number: int = 51,
                tag_embedding_size: int = 64,
                token_embedding_size : int = 256,
                conv1_out_channels: int = 256,
                conv2_out_channels: int = 256,
                conv3_out_channels: int = 256,
                conv1_length: int = 1,
                conv2_length: int = 2,
                conv3_length: int = 3,
                target_class: int = 5
                ):
        super(BertWithPostionOnlyModel, self).__init__()
        self.word_embedding_size = word_embedding_size
        self.position_embedding_size = position_embedding_size
        self.device ="cuda"
        self.tag_embedding = nn.Embedding(tag_number, tag_embedding_size, padding_idx=0)
        self.position_embedding_type = position_embedding_type
        if position_embedding_type == "normal":
            self.pos_embedding = nn.Linear(position_number, position_embedding_size, bias=False)
        elif position_embedding_type == "sinusoidal":
            self.pos_embedding = self.sinusoidal_positional_encoding
        elif position_embedding_type == "rotary":
            self.pos_embedding = self.rotary_positional_embedding
        else:
            raise ValueError("Wrong type pos embed")

        self.dropout = nn.Dropout(dropout_rate)

        self.normalize_tokens = nn.Linear(in_features = word_embedding_size+tag_embedding_size+position_embedding_size,
            out_features=token_embedding_size,
            bias=False)

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1,
                      out_channels=conv1_out_channels,
                      kernel_size=(conv1_length, token_embedding_size),
                      stride=1,
                      bias=False),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=1,
                      out_channels=conv2_out_channels,
                      kernel_size=(conv2_length, token_embedding_size),
                      stride=1,
                      bias=False),
            nn.ReLU()
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=1,
                      out_channels=conv3_out_channels,
                      kernel_size=(conv3_length, token_embedding_size),
                      stride=1,
                      bias=False),
            nn.ReLU()
        )
        self.dense_to_tag = nn.Linear(in_features = conv1_out_channels + conv2_out_channels + conv3_out_channels,out_features=target_class,
                        bias=False)

        self.relu = nn.ReLU()

        self.softmax = nn.Softmax(dim=1)

    def sinusoidal_positional_encoding(self, position):
        d_model = int((self.position_embedding_size - 1) / 2)
        position = position.unsqueeze(dim=2)
        print(position.shape)
        angle_rads = torch.arange(d_model) // 2 * torch.pi / torch.pow(10000, 2 * (torch.arange(d_model) // 2) / d_model)
        angle_rads = angle_rads.to(self.device)
        angle_rads = angle_rads.unsqueeze(dim=0).unsqueeze(dim=0).expand((position.shape[0], 1, angle_rads.shape[0]))
        print(position.shape)
        print(angle_rads.shape)
        angle_rads = torch.bmm(position, angle_rads)
        pos_encoding = torch.zeros((angle_rads.shape[0], angle_rads.shape[1], angle_rads.shape[2])).to(self.device)
        pos_encoding[:, :, 0::2] = torch.sin(angle_rads[:, :, 0::2])
        pos_encoding[:, :, 1::2] = torch.cos(angle_rads[:, :, 1::2])
        return pos_encoding

    def rotary_positional_embedding(self, position):
        d_model = int((self.position_embedding_size - 1) / 2)
        position = position.unsqueeze(dim=2)
        freqs = torch.exp(torch.linspace(0., -1., int(d_model // 2)+1) * torch.log(torch.tensor(10000.))).to(self.device)
        freqs = freqs.unsqueeze(dim=0).unsqueeze(dim=0).expand((position.shape[0], 1, freqs.shape[0]))
        angles = position * freqs
        rotary_matrix = torch.stack([torch.sin(angles), torch.cos(angles)], axis=-1).to(self.device)
        print(rotary_matrix.shape)
        return rotary_matrix.reshape((position.shape[0], position.shape[1], d_model))

    def forward(self, x):
        x = x.float()

        if self.position_embedding_type == "normal": # Linear
            pos_embedding = self.pos_embedding(x[:,:,self.word_embedding_size: self.word_embedding_size+4])
        elif self.position_embedding_type == "sinusoidal":
            position_embedding_ent = x[:, :, self.word_embedding_size: self.word_embedding_size+4].float()
            pos3 = self.sinusoidal_positional_encoding(position_embedding_ent[:, :, 0])
            pos4 = self.sinusoidal_positional_encoding(position_embedding_ent[:, :, 1])
            pos_embedding = torch.cat((pos3, pos4, position_embedding_ent[:, :, 2:]), dim=2) 
        else: # rotary
            position_embedding_ent = x[:, :, self.word_embedding_size: self.word_embedding_size+4].float()
            pos3 = self.rotary_positional_embedding(position_embedding_ent[:, :, 0])
            pos4 = self.rotary_positional_embedding(position_embedding_ent[:, :, 1])
            pos_embedding = torch.cat((pos3, pos4, position_embedding_ent[:, :, 2:]), dim=2) 
        print(f"Re: {pos_embedding.shape}")
        tag_embedding = self.tag_embedding(x[:,:,-1].long())
        x = self.normalize_tokens(torch.cat((x[:,:,:self.word_embedding_size], pos_embedding, tag_embedding), dim =2))
        
        x = x.unsqueeze(1)

        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)

        x1 = torch.max(x1.squeeze(dim=3), dim=2)[0]
        x2 = torch.max(x2.squeeze(dim=3), dim=2)[0]
        x3 = torch.max(x3.squeeze(dim=3), dim=2)[0]
        
        x = torch.cat((x1, x2, x3), dim=1)
        x = self.dense_to_tag(x)
        x = self.softmax(x)
        return x


In [48]:
model = BertWithPostionOnlyModel(position_embedding_type='rotary')

In [49]:
example_tensor = torch.cat((torch.randn(16,30,768), torch.randint(0,30,(16,30,5))), dim=-1)
example_tensor.shape

torch.Size([16, 30, 773])

In [50]:
model.to('cuda')
model.forward(example_tensor.to('cuda'))

torch.Size([16, 30, 32, 2])


RuntimeError: shape '[16, 30, 63]' is invalid for input of size 30720

In [51]:
16*30*63

30240