In [12]:
import torch
import torch.nn as nn
import torchinfo
import math

# RotaryPositional Embedding

https://nn.labml.ai/transformers/rope/index.html

In [13]:
class RotaryPositionalEmbedding(torch.nn.Module):
    """
    Given a tensor of shape [BATCH, SEQ_LEN, NUM_HEADS, HEAD_DIM]
    applies Rotary Positional Encoding.
    offset allows to apply rotary to sequnce part by part by telling how much tokens preecede the input in the sequence.
    """

    def __init__(
        self,
        dimension: int,
        max_seq_len: int,
        theta: float,
    ):
        super().__init__()

        assert dimension % 2 == 0
        self.dimension = dimension
        self.max_seq_len = max_seq_len

        ## Theta := theta^( -(2i / dimension) ) where i = 0, 1, 2, ..., dimension / 2
        self.theta = (
            1.0 / (theta ** (torch.arange(0, self.dimension, 2).float() / dimension))
        )[None, :]  # [1, dimension / 2]

        rot_seq = max_seq_len
        m_theta = torch.arange(rot_seq)[:, None].float()  # [max_seq_len, 1]
        m_theta = (m_theta @ self.theta)[
            :, :, None, None
        ]  # [max_seq_len, dimension / 2, 1, 1]

        m_sin = m_theta.sin()
        m_cos = m_theta.cos()

        row0 = torch.cat((m_cos, -m_sin), dim=-1)  # [max_seq_len, dimension / 2, 1, 2]
        row1 = torch.cat((m_sin, m_cos), dim=-1)  # [max_seq_len, dimension / 2, 1, 2]

        self.rotation_matrix = torch.cat((row0, row1), dim=-2)[None, :, None, :, :, :]
        """
        [1, max_seq_len, 1, dimension / 2, 2, 2]
        """

    def forward(self, x, offset: int = 0):
        assert (
            len(x.shape) == 4
        )  # torch tensor of shape [BATCH, SEQ_LEN, NUM_HEADS, HEAD_DIM]
        assert offset >= 0

        ## reshape
        BATCH, SEQ_LEN, NUM_HEADS, HEAD_DIM = x.shape
        y = x.reshape(BATCH, SEQ_LEN, NUM_HEADS, HEAD_DIM // 2, 2, 1)

        ## rotate
        start, end = offset, offset + SEQ_LEN
        print("self.rotation_matrix", self.rotation_matrix.shape)
        print("y", y.shape)
        y = self.rotation_matrix[:, start:end].to(x.device) @ y

        ## reshape
        y = y.reshape(BATCH, SEQ_LEN, NUM_HEADS, HEAD_DIM)

        assert y.shape == x.shape
        return y


model = RotaryPositionalEmbedding(dimension=4, max_seq_len=100, theta=10_000)
torchinfo.summary(model, input_size=(2, 5, 5, 4))

self.rotation_matrix torch.Size([1, 100, 1, 2, 2, 2])
y torch.Size([2, 5, 5, 2, 2, 1])


Layer (type:depth-idx)                   Output Shape              Param #
RotaryPositionalEmbedding                [2, 5, 5, 4]              --
Total params: 0
Trainable params: 0
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00

# Long Term Memory Module

In [14]:
class LongTermMemory(nn.Module):
    """
    This is a key-value neural network memory store.

    Retrieve
    - query -> output

    Update
    - key, output -> update

    Args:
        nn (_type_): _description_
    """

    def __init__(self, vector_size: int):
        super(LongTermMemory, self).__init__()

        self.neural_net = nn.Sequential(
            nn.Linear(vector_size, vector_size),
            nn.ReLU(),
            nn.Linear(vector_size, vector_size),
            nn.ReLU(),
            nn.Linear(vector_size, vector_size),
        )

        self.w_q = nn.Linear(vector_size, vector_size)
        self.w_k = nn.Linear(vector_size, vector_size)
        self.w_v = nn.Linear(vector_size, vector_size)

        self.optimizer = torch.optim.SGD(
            self.parameters(), lr=0.0001, momentum=0.9, weight_decay=0.0001
        )
        ## 결국 SGD에서 momentum과 weight decay를 쓰는건,
        ## past surprise 반영과
        ## forgetting mechanism 사용하는것과 동일

    def retrieve_memory(self, seq: torch.Tensor) -> torch.Tensor:
        q = self.w_q(seq)
        output = self.neural_net(q)
        return output

    def update_memory(self, seq: torch.Tensor) -> torch.Tensor:
        with torch.enable_grad():
            seq = seq.detach()

            self.optimizer.zero_grad()

            k = self.w_k(seq)
            v = self.w_v(seq)

            print("k", k)
            print("v", v)

            output = self.neural_net(k)
            surprise = torch.norm(output - v, p=2, dim=-1)
            surprise = surprise.sum()
            surprise.backward()

            self.optimizer.step()

        return self.retrieve_memory(seq)

    def forward(self, seq: torch.Tensor, is_update: bool = True) -> torch.Tensor:
        if is_update:
            return self.update_memory(seq)
        else:
            return self.retrieve_memory(seq)


# def test_structure():
memory = LongTermMemory(
    vector_size=5,
)
# torchinfo.summary(memory, input_data=torch.randn(2, 10, 5))
seq = torch.ones(2, 1, 5)
for i in range(10):
    memory.update_memory(seq)

print(seq)
# test_structure()

k tensor([[[-0.8638, -0.6591, -0.0420,  0.7137, -1.1103]],

        [[-0.8638, -0.6591, -0.0420,  0.7137, -1.1103]]],
       grad_fn=<ViewBackward0>)
v tensor([[[-0.9871,  0.7721,  0.2904,  0.1649, -0.4141]],

        [[-0.9871,  0.7721,  0.2904,  0.1649, -0.4141]]],
       grad_fn=<ViewBackward0>)
k tensor([[[-0.8638, -0.6590, -0.0420,  0.7138, -1.1104]],

        [[-0.8638, -0.6590, -0.0420,  0.7138, -1.1104]]],
       grad_fn=<ViewBackward0>)
v tensor([[[-0.9861,  0.7716,  0.2902,  0.1647, -0.4138]],

        [[-0.9861,  0.7716,  0.2902,  0.1647, -0.4138]]],
       grad_fn=<ViewBackward0>)
k tensor([[[-0.8638, -0.6587, -0.0421,  0.7139, -1.1106]],

        [[-0.8638, -0.6587, -0.0421,  0.7139, -1.1106]]],
       grad_fn=<ViewBackward0>)
v tensor([[[-0.9841,  0.7707,  0.2898,  0.1644, -0.4134]],

        [[-0.9841,  0.7707,  0.2898,  0.1644, -0.4134]]],
       grad_fn=<ViewBackward0>)
k tensor([[[-0.8639, -0.6583, -0.0421,  0.7141, -1.1108]],

        [[-0.8639, -0.6583, -0.0421,  0.

In [17]:
class DecoderBlock(nn.Module):
    def __init__(
        self,
        vector_size: int,
        num_heads: int,
        persistent_memory_length: int,
        long_term_memory_length: int,
    ):
        super(DecoderBlock, self).__init__()

        self.persistent_memory = nn.Parameter(
            torch.randn(1, persistent_memory_length, vector_size)
        )

        self.long_term_memory_length = long_term_memory_length
        self.long_term_memory = LongTermMemory(vector_size=vector_size)

        self.w_q = nn.Linear(vector_size, vector_size)
        self.w_k = nn.Linear(vector_size, vector_size)
        self.w_v = nn.Linear(vector_size, vector_size)

        self.head_size = vector_size // num_heads
        self.num_heads = num_heads

        self.positional_encoding = RotaryPositionalEmbedding(
            dimension=self.head_size, max_seq_len=5000, theta=10_000
        )

        self.linear = nn.Linear(vector_size, vector_size)

        self.ffn = nn.Sequential(
            nn.Linear(2 * vector_size, vector_size),
            nn.ReLU(),
            nn.Linear(vector_size, vector_size),
        )

    def forward(self, seq: torch.Tensor) -> torch.Tensor:
        ## Get input size
        batch, seq_len, vector_size = seq.size()

        ## retrieve memories
        ## persistent memory: [batch, persistent_memory_length, vector_size]
        ## long term memory: [batch, long_term_memory_length, vector_size]
        persistent_memory = self.persistent_memory.repeat(batch, 1, 1)
        print("persistent_memory", persistent_memory.shape)
        long_term_memory = self.long_term_memory(seq, is_update=False)

        memory_length = persistent_memory.size(1) + long_term_memory.size(1)

        ## concat memories and input
        seq = torch.cat([persistent_memory, long_term_memory, seq], dim=-2)

        ## q, k, v
        q = self.w_q(seq)
        k = self.w_k(seq)
        v = self.w_v(seq)

        ## multi head attention
        q = q.view(batch, seq_len + memory_length, self.num_heads, self.head_size)
        k = k.view(batch, seq_len + memory_length, self.num_heads, self.head_size)
        v = v.view(batch, seq_len + memory_length, self.num_heads, self.head_size)

        ## positional encoding
        q = self.positional_encoding(q)
        k = self.positional_encoding(k)
        v = self.positional_encoding(v)

        ## transpose for attention
        ## => [batch, heads, seq, vec]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        ## calculate attention
        attention = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_size)

        mask = torch.zeros_like(attention)
        _memory_length = persistent_memory.size(1) + long_term_memory.size(1)
        mask[:, :, :_memory_length, :_memory_length] = 1
        mask[:, :, _memory_length:, _memory_length:] = torch.triu(
            torch.ones_like(mask[:, :, _memory_length:, _memory_length:]), diagonal=1
        )

        attention = attention.masked_fill(mask == 1, -float("inf"))
        attention = torch.softmax(attention, dim=-1)
        attention = attention @ v
        attention = attention.transpose(1, 2)
        attention = attention.contiguous().view(
            batch, _memory_length + seq_len, vector_size
        )
        attention = self.linear(attention)

        ## update long term memory
        y = self.long_term_memory(attention, is_update=False)

        ## ffn
        y = torch.cat([y, attention], dim=-1)
        print("y", y.shape)
        y = self.ffn(y)
        print("y", y.shape)

        ## reduce
        y = y[:, _memory_length:, :]

        ## retrieve memory
        return y


model = DecoderBlock(
    vector_size=8, num_heads=2, persistent_memory_length=5, long_term_memory_length=10
)

torchinfo.summary(model, input_size=(2, 10, 8))

persistent_memory torch.Size([2, 5, 8])
self.rotation_matrix torch.Size([1, 5000, 1, 2, 2, 2])
y torch.Size([2, 25, 2, 2, 2, 1])
self.rotation_matrix torch.Size([1, 5000, 1, 2, 2, 2])
y torch.Size([2, 25, 2, 2, 2, 1])
self.rotation_matrix torch.Size([1, 5000, 1, 2, 2, 2])
y torch.Size([2, 25, 2, 2, 2, 1])
y torch.Size([2, 25, 16])
y torch.Size([2, 25, 8])


Layer (type:depth-idx)                   Output Shape              Param #
DecoderBlock                             [2, 10, 8]                40
├─LongTermMemory: 1-1                    [2, 10, 8]                144
│    └─Linear: 2-1                       [2, 10, 8]                72
│    └─Sequential: 2-2                   [2, 10, 8]                --
│    │    └─Linear: 3-1                  [2, 10, 8]                72
│    │    └─ReLU: 3-2                    [2, 10, 8]                --
│    │    └─Linear: 3-3                  [2, 10, 8]                72
│    │    └─ReLU: 3-4                    [2, 10, 8]                --
│    │    └─Linear: 3-5                  [2, 10, 8]                72
├─Linear: 1-2                            [2, 25, 8]                72
├─Linear: 1-3                            [2, 25, 8]                72
├─Linear: 1-4                            [2, 25, 8]                72
├─RotaryPositionalEmbedding: 1-5         [2, 25, 2, 4]             --
├─RotaryPositi