# Engram

git: [dhcode-cpp/Engram-pytorch](https://github.com/dhcode-cpp/Engram-pytorch)

Blog: [【手撕Engram】DeepSeek 的 Conditional Memory 能取代 Attention 吗？](https://zhuanlan.zhihu.com/p/1994713080131772751)

In [1]:
from typing import List
from dataclasses import dataclass, field
import math

## third-class
import torch
import torch.nn as nn

In [2]:
@dataclass
class EngramModelConfig:
    dim: int = 512
    head_dim: int = 64
    max_n_gram: int = 3
    max_memory_vocab_size: int = 1007
    head_hash: int = 8
    n_hc: int = 4
    num_layer: int = 4
    vocab_size: int = 100
    kernel_size: int = 4 # 序列卷积核

config = EngramModelConfig()

bsz = 2
seq_len = 100

x = torch.randint(config.vocab_size, 
                  size=(bsz, seq_len))

H = torch.randn(bsz, seq_len, config.dim)

## hash

In [3]:
class MultiHeadsHash:
    def __init__(self, max_memory_vocab_size, layer_id):
        
        self.mods = torch.tensor([12582917, 25165843, 50331653, 100663319, 201326611, 402653189, 805306457, 1610612741]) # 素数
        self.mods *= (layer_id+1) # 每层哈希结果不同
        self.max_memory_vocab_size = max_memory_vocab_size
        self.layer_id = layer_id

    def hash(self, x, mod, n_gram):
        x_ = x.clone()
        for i in range(1, n_gram):
            x_[:, i:] *= x[:, :-i]
        hash_id = x_ % mod 
        hash_id = hash_id % self.max_memory_vocab_size
        return hash_id

    def multi_head_hash(self, x, mods, n_gram):
        hash_ids = []
        for mod in mods:
            hash_id = self.hash(x, mod, n_gram)
            hash_ids.append(hash_id)
        hash_ids = torch.stack(hash_ids, dim=-1) # bsz, seq_len, hash_head
        return hash_ids

    def get_all_hash_ids(self, x, max_n_gram):
        ngram_hash_ids = []
        for N in range(1, max_n_gram):
            hash_ids = self.multi_head_hash(x, self.mods, N)
            ngram_hash_ids.append(hash_ids)
        return ngram_hash_ids # [ [bsz, seq_len, hash_head], [bsz, seq_len, hash_head] ]

In [4]:

mhash = MultiHeadsHash(config.max_memory_vocab_size, layer_id = 3)
hash_ids = mhash.get_all_hash_ids(x, config.max_n_gram)

print(x.shape)
print(hash_ids[0].shape) # 2-gram, 8-hash-head
print(hash_ids[1].shape) # 3-gram, 8-hash-head

torch.Size([2, 100])
torch.Size([2, 100, 8])
torch.Size([2, 100, 8])


## Memory

In [5]:
class ConditionalMemory(nn.Module):
    def __init__(self,  
                 config
                ):
        super().__init__()

        self.head_dim = config.head_dim
        max_memory_vocab_size=config.max_memory_vocab_size
        self.head_hash = config.head_hash

        # self.memory_embds = [ nn.Embedding(max_memory_vocab_size, head_dim) for i in range(head_hash) ]
        self.memory_embds = nn.Embedding(max_memory_vocab_size * self.head_hash, self.head_dim)
        self.offset = torch.arange(self.head_hash) * max_memory_vocab_size 
        self.offset = self.offset[None, None, 1]

    def forward(self, x, ngram_hash_ids):
        bsz, seq_len = x.shape
        n = len(ngram_hash_ids)
        
        x += self.offset
        ngram_memory = []
        for hash_ids in ngram_hash_ids:
            memory = self.memory_embds(hash_ids)
            ngram_memory.append(memory)
        h_memory = torch.cat(ngram_memory, dim = -1)
        h_memory = h_memory.reshape(bsz, seq_len, n*self.head_hash*self.head_dim) # 提前 cat
        return h_memory

In [6]:
memory = ConditionalMemory(config)
h_memory = memory(x, hash_ids)
print(h_memory.shape)

torch.Size([2, 100, 1024])


## Conv1D

In [7]:
class ShortConv1D(nn.Module):
    def __init__(self,
                 dim,
                 n_hc,
                 kernel_size,):
        super().__init__()
        
        dilation=1
        self.total_dim = dim * n_hc # cat n_hc dim
        self.conv = nn.Conv1d(
            in_channels=self.total_dim,
            out_channels=self.total_dim,
            kernel_size=kernel_size,
            groups=self.total_dim,
            bias=False,
            padding=(kernel_size - 1) * dilation, # 3
            dilation=dilation,
        )

        self.norms = nn.ModuleList([nn.RMSNorm( dim ) for _ in range(n_hc)])
        self.act_fn = nn.SiLU()

    def forward(self, x):
        B, L, n_hc, D = x.shape 

        x_norm = []
        for hc_idx, norm in enumerate(self.norms):
            x_norm.append(norm(x[:,:, hc_idx]))

        x_norm = torch.stack(x_norm, dim=2) # B, L, n_hc, D 

        x_norm = x_norm.reshape(B, L, n_hc*D) # B, L, (n_hc, D) -> B, L, (C)
        x_norm = x_norm.transpose(1, 2) # B, C, L
        y = self.conv(x_norm)
        y = y[..., :L]

        y = self.act_fn(y) # swiglu
        y = y.transpose(1,2).reshape(B, L, n_hc, D)

        return y

In [8]:
conv = ShortConv1D(
    dim=config.dim * (config.max_n_gram-1), 
    kernel_size=config.kernel_size,
    n_hc=config.n_hc,
)


H_hc = torch.randn(2, 10, config.n_hc, config.dim*(config.max_n_gram-1))
H_conv = conv(H_hc)

print(H_hc.shape)
print(H_conv.shape)

torch.Size([2, 10, 4, 1024])
torch.Size([2, 10, 4, 1024])


## Engram

In [9]:
class Engram(nn.Module):
    def __init__(self, 
                 config,
                 layer_id=1,
                ):
        super().__init__()

        self.n_hc = config.n_hc
        D = config.dim
        
        memory_dim = config.head_dim * config.head_hash * (config.max_n_gram-1)
        
        self.Wks = nn.ModuleList([ nn.Linear(memory_dim, D) for i in range(self.n_hc) ])
        self.Wv = nn.Linear(memory_dim, D)
        self.norm1 = [ nn.RMSNorm(D) for i in range(self.n_hc) ]
        self.norm2 = [ nn.RMSNorm(D) for i in range(self.n_hc) ]

        self.max_n_gram = config.max_n_gram

        self.memory = ConditionalMemory(config)
        
        self.hash = MultiHeadsHash(max_memory_vocab_size=config.max_memory_vocab_size, layer_id=layer_id)
        
        self.conv = ShortConv1D(config.dim, 
                                config.n_hc, 
                                config.kernel_size)

    def forward(self, h, x):
        """
            h: bsz, seq_len, n_hc, dim
                hidden states
               
            x: bsz, seq_len
                input ids
        """
        
        _,_, _, D = h.shape
        
        ngram_hash_id = self.hash.get_all_hash_ids(x, self.max_n_gram)
        h_memory = self.memory(x, ngram_hash_id)

        gates = [] # Hype Connection beta
        for hc_idx in range(self.n_hc):
            # proj
            q = self.norm1[hc_idx](h[:,:, hc_idx, :])
            k = self.norm2[hc_idx](self.Wks[hc_idx](h_memory))
            # score
            gate = q * k / math.sqrt(D) # bsz, seq_len, 1
            gate = torch.sigmoid(gate)
            # gate = gate.unsqueeze(dim=1)
            gates.append(gate)

        # value
        gates = torch.stack(gates, dim = 2) # bsz, seq_len, n_hc, 1
        v = self.Wv(h_memory).unsqueeze(2) # bsz, seq_len, n_hc, dim
        v_ = gates * v # bsz, seq_len, n_hc, dim

        # Conv1D 
        out = self.conv(v_) + v_

        return out


In [10]:
engram = Engram(config)

bsz = 1
seq_len = 10
H = torch.randn(bsz, seq_len, config.n_hc, config.dim)
x = torch.randint(config.vocab_size, (bsz, seq_len))
y = engram(H, x)

# Model

In [11]:
class DecoderBlock(nn.Module):
    def __init__(self, 
                 with_engram,
                 layer_id,
                 config,
                ):
        super().__init__()
        D = config.dim
        self.with_engram = with_engram
        self.ffn = nn.Linear(D, D)
        self.attn = nn.Linear(D, D)
        if with_engram:
            self.engram = Engram(config, layer_id=layer_id)
        
    def forward(self, H, x):
        if self.with_engram:
            H = self.engram(H, x) + H
        H = self.attn(H) + H
        H = self.ffn(H) + H
        return H

In [12]:
class LanguageModelWithEngram(nn.Module):
    def __init__(self, 
                 config,
                ):
        super().__init__()

        self.n_hc = config.n_hc
        self.embd = nn.Embedding(config.vocab_size, config.dim)
        self.decoder_block = nn.ModuleList([
            DecoderBlock(
                 config = config,
                 with_engram = (layer_id+1) % 2, # 0, [1], 2, [3]
                 layer_id = layer_id,
            ) for layer_id in range(config.num_layer)])
        self.lm_head = nn.Linear(config.dim, config.vocab_size)
        
    def forward(self, x):
        H = self.embd(x)
        
        # hc branch expand
        H = H.unsqueeze(2).expand(-1, -1, self.n_hc, -1) # B, L, n_HC, D

        for block in self.decoder_block:
            H = block(H, x)

        # hc branch sum
        H = H.sum(dim = 2)

        logits = self.lm_head(H)

        return logits

## Test

In [13]:
model = LanguageModelWithEngram(config)
x = torch.randint(config.vocab_size, (bsz, seq_len))
logits = model(x)
print(logits.shape)

torch.Size([1, 10, 100])


## Reference

Conditional Memory via Scalable Lookup: A New Axis of Sparsity for Large Language Models