# Engram

git: [dhcode-cpp/Engram-pytorch](https://github.com/dhcode-cpp/Engram-pytorch)

Blog: [【手撕Engram】DeepSeek 的 Conditional Memory 能取代 Attention 吗？](https://zhuanlan.zhihu.com/p/1994713080131772751)

In [1]:
from typing import List
from dataclasses import dataclass, field
import math

## third-class
import torch
import torch.nn as nn

In [2]:
@dataclass
class EngramModelConfig:
    dim: int = 512
    head_dim: int = 64
    max_n_gram: int = 3
    max_memory_vocab_size: int = 1007
    head_hash: int = 8
    n_hc: int = 4 # ignore
    num_layer: int = 4
    vocab_size: int = 100
    kernel_size: int = 4 # 序列卷积核

config = EngramModelConfig()

bsz = 2
seq_len = 100

x = torch.randint(config.vocab_size, 
                  size=(bsz, seq_len))

H = torch.randn(bsz, seq_len, config.dim)

# Engram Framework

In [3]:
class EngramSimple(nn.Module):
    def __init__(self, 
                 memory_vocab_size, 
                 dim, 
                 kernel_size=2):
        super().__init__()
        self.memory_vocab_size = memory_vocab_size
        self.memory_embd = nn.Embedding(memory_vocab_size, dim)
        self.Wk = nn.Linear(dim, dim)
        self.Wv = nn.Linear(dim, dim)
        
        self.kernel_size = kernel_size
        self.w_conv1d = nn.Parameter(torch.randn(kernel_size))
        
    def forward(self, x, h):
        """
            h: bsz, seq_len, dim
            x: bsz, seq_len
        """
        B, T, D = h.shape

        # 1. Multi-heads-hash memory
        hash_id = self.multi_head_hash(x)
        h_memory = self.get_memory(hash_id) # B, T, D

        # 2. Scale-Dot-Product Fusion
        q, k, v = h, self.Wk(h_memory), self.Wv(h_memory)
        gate = (q * k).sum(dim=2, keepdim=True)
        v_ = gate * v # B, T, D

        # 3. short-conv1d
        out = self.short_conv1d(v_)
        return out

    def multi_head_hash(self, x):
        B, T = x.shape
        hash_id = torch.randint(self.memory_vocab_size, (B,T)) # 随机hash
        return hash_id

    def get_memory(self, hash_id):
        h_memory = self.memory_embd(hash_id)
        return h_memory

    def short_conv1d(self, v):
        """简化conv1d, 相邻时刻相加"""
        v0 = v * self.w_conv1d[0]
        v1 = v * self.w_conv1d[1]
        v0[:, :, 1:] += v1[:,:,:-1]
        return v0

model = EngramSimple(memory_vocab_size=1000, 
                 dim=128, 
                 kernel_size=2,)
x = torch.randint(12948, (1,10))
h = torch.randn(1, 10, 128)
y=model(x,h)
print(y.shape)

torch.Size([1, 10, 128])


## hash

In [4]:
class MultiHeadsHash:
    def __init__(self, max_memory_vocab_size, layer_id):
        
        self.mods = torch.tensor([12582917, 25165843, 50331653, 100663319, 201326611, 402653189, 805306457, 1610612741]) # 素数
        self.mods *= (layer_id+1) # 每层哈希结果不同
        self.max_memory_vocab_size = max_memory_vocab_size
        self.layer_id = layer_id

    def hash(self, x, mod, n_gram):
        x_ = x.clone()
        for i in range(1, n_gram):
            x_[:, i:] *= x[:, :-i]
        hash_id = x_ % mod 
        hash_id = hash_id % self.max_memory_vocab_size
        return hash_id

    def multi_head_hash(self, x, mods, n_gram):
        hash_ids = []
        for mod in mods:
            hash_id = self.hash(x, mod, n_gram)
            hash_ids.append(hash_id)
        hash_ids = torch.stack(hash_ids, dim=-1) # bsz, seq_len, hash_head
        return hash_ids

    def get_all_hash_ids(self, x, max_n_gram):
        ngram_hash_ids = []
        for N in range(1, max_n_gram):
            hash_ids = self.multi_head_hash(x, self.mods, N)
            ngram_hash_ids.append(hash_ids)
        return ngram_hash_ids # [ [bsz, seq_len, hash_head], [bsz, seq_len, hash_head] ]

In [5]:
mhash = MultiHeadsHash(config.max_memory_vocab_size, layer_id = 3)
hash_ids = mhash.get_all_hash_ids(x, config.max_n_gram)

print(x.shape)
print(hash_ids[0].shape) # 2-gram, 8-hash-head
print(hash_ids[1].shape) # 3-gram, 8-hash-head

torch.Size([1, 10])
torch.Size([1, 10, 8])
torch.Size([1, 10, 8])


## Memory

In [6]:
class ConditionalMemory(nn.Module):
    def __init__(self,  
                 config
                ):
        super().__init__()

        self.head_dim = config.head_dim
        max_memory_vocab_size=config.max_memory_vocab_size
        self.head_hash = config.head_hash

        # self.memory_embds = [ nn.Embedding(max_memory_vocab_size, head_dim) for i in range(head_hash) ]
        self.memory_embds = nn.Embedding(max_memory_vocab_size * self.head_hash, self.head_dim)
        self.offset = torch.arange(self.head_hash) * max_memory_vocab_size 
        self.offset = self.offset[None, None, 1]

    def forward(self, x, ngram_hash_ids):
        bsz, seq_len = x.shape
        n = len(ngram_hash_ids)
        
        x += self.offset
        ngram_memory = []
        for hash_ids in ngram_hash_ids:
            memory = self.memory_embds(hash_ids)
            ngram_memory.append(memory)
        h_memory = torch.cat(ngram_memory, dim = -1)
        h_memory = h_memory.reshape(bsz, seq_len, n*self.head_hash*self.head_dim) # 提前 cat
        return h_memory

In [7]:
memory = ConditionalMemory(config)
h_memory = memory(x, hash_ids)
print(h_memory.shape)

torch.Size([1, 10, 1024])


# Engram

In [8]:
class EngramWithoutHC(nn.Module):
    def __init__(self, 
                 config,
                 layer_id=1,
                ):
        super().__init__()

        D = config.dim
        memory_dim = config.head_dim * config.head_hash * (config.max_n_gram-1)
        self.max_n_gram = config.max_n_gram

        # proj
        self.Wk = nn.Linear(memory_dim, D) 
        self.Wv = nn.Linear(memory_dim, D)
        self.norm1 = nn.RMSNorm(D)
        self.norm2 = nn.RMSNorm(D)

        # op
        self.hash = MultiHeadsHash(max_memory_vocab_size=config.max_memory_vocab_size, layer_id=layer_id)
        self.memory = ConditionalMemory(config)
        # self.conv = ShortConv1D(config.dim, 
        #                         config.n_hc, 
        #                         config.kernel_size)

    def forward(self, h, x):
        """
            h: bsz, seq_len, dim
                hidden states
               
            x: bsz, seq_len
                input ids
        """
        
        _, _, D = h.shape
        
        ngram_hash_id = self.hash.get_all_hash_ids(x, self.max_n_gram)
        h_memory = self.memory(x, ngram_hash_id)

        # proj
        q = self.norm1(h)
        k = self.norm2(self.Wk(h_memory))
        # score
        gate = (q * k).sum(dim=-1, keepdim=True) / math.sqrt(D) # bsz, seq_len, 1
        gate = torch.sigmoid(gate)

        # value
        v_ = gate * self.Wv(h_memory)

        # Skip Conv1D 
        # out = self.conv(v_) + v_
        out = v_

        return out

In [9]:
engram = EngramWithoutHC(config, layer_id=1)

x = torch.randint(config.vocab_size, 
                  size=(bsz, seq_len))
h = torch.randn(bsz, seq_len, config.dim)

y = engram(h, x)
print(y.shape)

torch.Size([2, 100, 512])


# Conv1D

## Conv1D for 1-dimension

In [10]:
def fun_conv1d(x, w):
    kernel_size = len(w)
    x_len = len(x)
    x_padding = [0] * (kernel_size-1) + x
    
    x_conv1d = []
    for i in range(x_len):
        x_tmp = 0
        for j in range(kernel_size):
            x_tmp += x_padding[i+j] * w[j]
        print(x_padding[i: i+kernel_size], '*' ,w, '->', x_tmp)
        x_conv1d.append(x_tmp)
        
    return x_conv1d

x = [1, 2, 3, 4]
w = [1, 10, 100]


print('x:', x, 'w:', w)
y = fun_conv1d(x, w)
print(y)

x: [1, 2, 3, 4] w: [1, 10, 100]
[0, 0, 1] * [1, 10, 100] -> 100
[0, 1, 2] * [1, 10, 100] -> 210
[1, 2, 3] * [1, 10, 100] -> 321
[2, 3, 4] * [1, 10, 100] -> 432
[100, 210, 321, 432]


## Conv1D for 1-dimension Pytorch

In [11]:
kernel_size = 3
conv = nn.Conv1d(
            in_channels=1,
            out_channels=1,
            kernel_size=kernel_size,
            groups=1,
            bias=False,
            padding=(kernel_size - 1) * 1, # 3
            dilation=1,
        )
x_len = len(x)
x_tensor = torch.tensor([x], dtype=torch.float32).unsqueeze(dim = 1)
print(x_tensor.shape) # B, C, T
conv.weight.data = torch.tensor([[[1,10,100]]],dtype=torch.float32)
y = conv(x_tensor)
print(y[0,0,:x_len])

torch.Size([1, 1, 4])
tensor([100., 210., 321., 432.], grad_fn=<SliceBackward0>)


## Conv1D for hidden states

In [12]:
B = 1
T = 10
C = 128 # channel, dim(D)
X = torch.randn(B, T, C)
print(X.shape)
X = X.transpose(1,2) # B, C, T

conv_3C = nn.Conv1d(in_channels=C, out_channels=C, kernel_size=kernel_size,
            groups=C, padding=(kernel_size - 1) * 1, dilation=1, bias=False,)
print('shape_groupC:', conv_3C.weight.data.shape)
Y = conv_3C(X)
Y = Y.transpose(1,2)[:, :T]
print(Y.shape)

torch.Size([1, 10, 128])
shape_groupC: torch.Size([128, 1, 3])
torch.Size([1, 10, 128])


## ShortConv1DWithoutHC

In [13]:
class ShortConv1DWithoutHC(nn.Module):
    def __init__(self,
                 dim,
                 kernel_size,):
        super().__init__()
        
        dilation=1
        self.total_dim = dim 
        self.conv = nn.Conv1d(
            in_channels=self.total_dim,
            out_channels=self.total_dim,
            kernel_size=kernel_size,
            groups=self.total_dim,
            bias=False,
            padding=(kernel_size - 1) * dilation, # 3
            dilation=dilation,
        )

        self.norm = nn.RMSNorm( dim )
        self.act_fn = nn.SiLU()

    def forward(self, x):
        B, T, C = x.shape 

        x_norm = self.norm(x)
        x_norm = x_norm.transpose(1, 2) # B, C, T
        y = self.conv(x_norm)
        y = y[..., :T]

        y = self.act_fn(y) # swiglu
        y = y.transpose(1,2)

        return y

In [14]:
conv = ShortConv1DWithoutHC(
    dim=config.dim * (config.max_n_gram-1), 
    kernel_size=config.kernel_size,
)


V_ = torch.randn(2, 10, config.dim*(config.max_n_gram-1))
V_conv1d = conv(V_)

print(V_.shape)
print(V_conv1d.shape)

torch.Size([2, 10, 1024])
torch.Size([2, 10, 1024])


## Engram With Hyper Connection

Read `./Engram.ipynb`