In [131]:
import math
import sys
from typing import Optional, Union
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from pprint import pprint
from pydantic import BaseModel
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from transformers import PreTrainedTokenizer, GPT2Tokenizer


from mllm.config.model import VocabEncoderCfg, EmbDecoderCfg
from mllm.model.modules import VocabEncoder, VocabDecoder



In [155]:

class ScaledDotProductAttention(nn.Module):
    temperature: float
    inp_len: int
    dropout_rate: float
    # dropout: nn.Module

    def __init__(self, temperature: float, inp_len: int = 0,
                 dropout_rate: float = 0.1):
        super().__init__()
        self.temperature = temperature
        self.inp_len = inp_len
        self.dropout_rate = dropout_rate

        self.dropout = nn.Dropout(self.dropout_rate)

    def forward(self, q, k, v, mask=None):
        attn = q / self.temperature
        attn = torch.matmul(attn, k.transpose(2, 3))

        if mask is not None:
            # print_dtype_shape(attn)
            # print_dtype_shape(mask)
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)

        return output, attn


class MultiHeadAttention(nn.Module):
    n_heads: int
    d_model: int
    d_k: int
    d_v: int
    dropout_rate: float

    def __init__(self, n_heads: int, d_model: int, d_k: int, d_v: int, dropout_rate: float = 0.1):
        super().__init__()

        self.n_heads = n_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.dropout_rate = dropout_rate

        self.w_qs = nn.Linear(d_model, n_heads * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_heads * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_heads * d_v, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

        temp = d_k ** 0.5
        temp = 1
        self.attention = ScaledDotProductAttention(
            temperature=temp, inp_len=10000,
            dropout_rate=dropout_rate,
        )

        self.dropout = nn.Dropout(dropout_rate)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, q, k, v, mask=None):

        d_k, d_v, n_heads = self.d_k, self.d_v, self.n_heads
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        residual = q

        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: b x lq x n x dv
        q = self.w_qs(q)
        k = self.w_ks(k)
        v = self.w_vs(v)

        q = q.view(sz_b, len_q, n_heads, d_k)
        k = k.view(sz_b, len_k, n_heads, d_k)
        v = v.view(sz_b, len_v, n_heads, d_v)

        # Transpose for attention dot product: b x n x lq x dv
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)   # For head axis broadcasting.

        q, attn = self.attention(q, k, v, mask=mask)

        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        q = self.dropout(self.fc(q))
        q += residual

        q = self.layer_norm(q)

        return q, attn


class PositionalEncoding(nn.Module):

    def __init__(self, d_hid, n_position=200):
        super(PositionalEncoding, self).__init__()

        # Not a parameter
        self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))

    def _get_sinusoid_encoding_table(self, n_position, d_hid):
        ''' Sinusoid position encoding table '''
        # TODO: make it with torch instead of numpy

        def get_position_angle_vec(position):
            return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

        sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
        sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
        sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

        return torch.FloatTensor(sinusoid_table).unsqueeze(0)

    def forward(self, x):
        return x + self.pos_table[:, :x.size(1)].clone().detach()


class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout_rate: float = 0.1):
        super().__init__()
        bias = True
        self.w_1 = nn.Linear(d_in, d_hid, bias=bias) # position-wise
        self.w_2 = nn.Linear(d_hid, d_in, bias=bias) # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):

        residual = x

        x = self.w_2(F.relu(self.w_1(x)))
        # x = self.w_2(F.leaky_relu(self.w_1(x)))
        # x = self.w_2(F.sigmoid(self.w_1(x)))
        x = self.dropout(x)
        x += residual

        x = self.layer_norm(x)

        return x


class EncoderLayer(nn.Module):
    ''' Compose with two layers '''

    def __init__(self, n_heads, d_model, d_inner, d_k, d_v, dropout_rate: float = 0.1):
        super(EncoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, dropout_rate=dropout_rate)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout_rate=dropout_rate)

    def forward(self, enc_input: Optional[Tensor], enc_input_kv: Optional[Tensor] = None, slf_attn_mask: Optional[Tensor] = None):
        if enc_input_kv is None:
            enc_input_kv = enc_input
        enc_output, enc_slf_attn = self.slf_attn(
            enc_input, enc_input_kv, enc_input_kv, mask=slf_attn_mask)
        enc_output = self.pos_ffn(enc_output)
        return enc_output, enc_slf_attn


In [54]:
d_model = 256
d_word_vec = d_model
n_heads = 8
d_v = d_k = d_model // n_heads
dropout_rate = 0.0
d_inner = 1024
inp_len = 20_000


In [15]:
tkz = GPT2Tokenizer.from_pretrained('gpt2', model_max_length=1000)
tkz.add_special_tokens({'pad_token': '<|pad_token|>'})
n_vocab = len(tkz)
pad_idx = tkz.pad_token_id
print(pad_idx)

cfg_vocab_enc = VocabEncoderCfg(
    n_vocab=n_vocab, d_word_vec=d_word_vec, d_model=d_model, pad_idx=pad_idx, inp_len=inp_len, dropout_rate=dropout_rate,
)


50257


In [16]:
vocab_encoder = VocabEncoder(**cfg_vocab_enc.dict())
vocab_encoder


VocabEncoder(
  (src_word_emb): Embedding(50258, 256, padding_idx=50257)
  (position_enc): PositionalEncoding()
  (dropout): Dropout(p=0.0, inplace=False)
  (layer_norm): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
)

In [85]:
enc_layer_1 = EncoderLayer(
    n_heads=n_heads, d_model=d_model, d_inner=d_inner, d_k=d_k, d_v=d_v,
    dropout_rate=dropout_rate,
)
enc_layer_2 = EncoderLayer(
    n_heads=n_heads, d_model=d_model, d_inner=d_inner, d_k=d_k, d_v=d_v,
    dropout_rate=dropout_rate,
)

In [35]:
doc = """
    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in [Attention is
    all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.

    To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
    to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
    `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
    """
txts = [s.strip() for s in doc.split('\n') if s.strip()]
for txt in txts:
    print(len(txt))

108
113
110
64
113
108
113


In [157]:
seq_len = 50

def print_dtype_shape(x: Union[np.ndarray, Tensor], name: str = ''):
    print(name, x.shape, x.dtype)

def to_tensor(*xs: np.ndarray, device='cpu') -> tuple[Tensor, ...]:
    return tuple(torch.from_numpy(x).to(device) for x in xs)

def to_tokens(tkz: PreTrainedTokenizer, txts: list[str], seq_len: int, pad_tok: int) -> tuple[np.ndarray, np.ndarray]:
    n = len(txts)
    tokens = np.full((n, seq_len), pad_tok, dtype=np.int32)
    masks = np.full((n, seq_len), 0, dtype=np.int32)
    for i in range(n):
        txt = txts[i]
        toks = tkz(txt)['input_ids'][:seq_len]
        tokens[i, :len(toks)] = toks
        masks[i, :len(toks)] = 1
    masks1 = masks[:, :, np.newaxis]
    masks2 = masks[:, np.newaxis, :]
    masks = np.matmul(masks1, masks2)
    return tokens, masks

tokens, masks = to_tokens(tkz, txts, seq_len, pad_idx)
tokens, masks = to_tensor(tokens, masks)
print_dtype_shape(tokens)
print_dtype_shape(masks)


 torch.Size([7, 50]) torch.int32
 torch.Size([7, 50, 50]) torch.int32


In [79]:
out_enc = vocab_encoder(tokens)
print_dtype_shape(out_enc)

torch.Size([7, 50, 256]) torch.float32


In [90]:
out_enc_1, _ = enc_layer_1(out_enc, slf_attn_mask=masks)
print_dtype_shape(out_enc_1)
# out_enc_2, _ = enc_layer_2(out_enc_1, slf_attn_mask=masks)
# print_dtype_shape(out_enc_2)

torch.Size([7, 50, 256]) torch.float32


In [None]:
class ReduceLayer(nn.Module):
    d_model: int
    step: int
    reducer: nn.Linear

    def __init__(self, d_model: int, step: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.step = step
        self.reducer = nn.Linear(in_features=d_model * step, out_features=d_model, bias=False)
    
    def forward(self, inp: Tensor) -> Tensor:
        batch_size, seq_len, d_model = inp.shape
        assert d_model == self.d_model, f'self.d_model = {self.d_model}. inp d_model = {d_model}'
        len_mod = seq_len % self.step
        # print_dtype_shape(inp, 'rdc_inp')
        if len_mod > 0:
            n_seq_add = self.step - len_mod
            inp = F.pad(inp, (0, 0, n_seq_add, 0), value=0)
            seq_len += n_seq_add
            # print_dtype_shape(inp, 'rdc_inp_pad')
        inp = inp.reshape(batch_size, seq_len // self.step, self.d_model * self.step)
        # print_dtype_shape(inp, 'rds_reshape')
        out = self.reducer(inp)
        # print_dtype_shape(out, 'rdc_reduce')
        return out

step = 3
enc = out_enc_1
reduce_layer = ReduceLayer(d_model=d_model, step=step)

# reduce_layer = torch.nn.Linear(in_features=d_model * 2, out_features=d_model, bias=False)
# print_dtype_shape(enc)
# if enc.shape[-1] % 2 == 1:
#     enc = F.pad(enc, (0, 0, 1, 0), value=0)
# enc = enc.reshape((enc.shape[0], enc.shape[1] // 2, enc.shape[2] * 2))
# print_dtype_shape(enc)
enc = reduce_layer(enc)
print_dtype_shape(enc)



 torch.Size([7, 17, 256]) torch.float32


In [169]:
class EncPyrCfg(BaseModel):
    vocab_encoder: VocabEncoderCfg
    pad_idx: int
    d_model: int
    n_heads: int
    d_k: int
    d_v: int
    d_inner: int
    inp_len: int
    step: int
    n_layers: int
    dropout_rate: float


class DecPyrCfg(BaseModel):
    d_model: int
    n_heads: int
    d_k: int
    d_v: int
    d_inner: int
    inp_len: int
    step: int
    n_layers: int
    dropout_rate: float
    with_vocab_decoder: bool
    n_vocab: int


class EncdecHgCfg(BaseModel):
    enc_pyr: EncPyrCfg
    dec_pyr: DecPyrCfg


def create_encdec_hg_cfg(
        n_vocab: int, pad_idx: int, d_model: int = 256, n_heads: int = 8, d_inner: int = 1024, inp_len: int = 256, step: int = 2, dropout_rate: float = 0.0, with_vacab_decoder: bool = True) -> EncdecHgCfg:
    d_word_vec = d_model
    d_k = d_v = d_model // n_heads
    n_layers = math.ceil(math.log(inp_len, step))
    cfg_vocab_enc = VocabEncoderCfg(
        n_vocab=n_vocab, d_word_vec=d_word_vec, d_model=d_model, pad_idx=pad_idx, inp_len=inp_len, dropout_rate=dropout_rate,
    )
    cfg_enc_pyr = EncPyrCfg(
        vocab_encoder=cfg_vocab_enc, pad_idx=pad_idx, d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_inner=d_inner, inp_len=inp_len, step=step, n_layers=n_layers, dropout_rate=dropout_rate,
    )
    cfg_dec_pyr = DecPyrCfg(
        d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_inner=d_inner, inp_len=inp_len, step=step, n_layers=n_layers, dropout_rate=dropout_rate, with_vocab_decoder=with_vacab_decoder, n_vocab=n_vocab,
    )
    cfg_encdec_hg = EncdecHgCfg(enc_pyr=cfg_enc_pyr, dec_pyr=cfg_dec_pyr)
    return cfg_encdec_hg

cfg_encdec_hg = create_encdec_hg_cfg(n_vocab=len(tkz), pad_idx=tkz.pad_token_id)
pprint(cfg_encdec_hg.dict())


{'dec_pyr': {'d_inner': 1024,
             'd_k': 32,
             'd_model': 256,
             'd_v': 32,
             'dropout_rate': 0.0,
             'inp_len': 256,
             'n_heads': 8,
             'n_layers': 8,
             'n_vocab': 50258,
             'step': 2,
             'with_vocab_decoder': True},
 'enc_pyr': {'d_inner': 1024,
             'd_k': 32,
             'd_model': 256,
             'd_v': 32,
             'dropout_rate': 0.0,
             'inp_len': 256,
             'n_heads': 8,
             'n_layers': 8,
             'pad_idx': 50257,
             'step': 2,
             'vocab_encoder': {'d_model': 256,
                               'd_word_vec': 256,
                               'dropout_rate': 0.0,
                               'inp_len': 256,
                               'n_vocab': 50258,
                               'pad_idx': 50257}}}


In [162]:
class EncoderPyramid(nn.Module):
    cfg: EncPyrCfg
    vocab_encoder: VocabEncoder
    enc_layers: nn.ModuleList
    rdc_layers: nn.ModuleList
    inp_chunk_len: int

    def __init__(self, cfg: EncPyrCfg):
        super().__init__()
        self.cfg = cfg
        self.vocab_encoder = VocabEncoder(**cfg.vocab_encoder.dict())
        self.enc_layers = nn.ModuleList([
            EncoderLayer(
                n_heads=cfg.n_heads, d_model=cfg.d_model, d_inner=cfg.d_inner, d_k=cfg.d_k, d_v=cfg.d_v,
                dropout_rate=cfg.dropout_rate,
            ) for _ in range(cfg.n_layers)
        ])
        self.rdc_layers = nn.ModuleList([
            ReduceLayer(d_model=cfg.d_model, step=cfg.step) for _ in range(cfg.n_layers)
        ])

    # Tensor of integer tokens: [batch_size, seq_len]
    def forward(self, inp: Tensor) -> Tensor:
        batch_size, seq_len = inp.shape
        mask = (inp == self.cfg.pad_idx).to(torch.bool)
        mask = np.matmul(mask.unsqueeze(-1), mask.unsqueeze(-2))
        assert seq_len == self.cfg.inp_len, f'seq_len = {seq_len}. inp_len = {inp_len}'
        # [batch_size, seq_len, d_model]
        out = self.vocab_encoder(inp)
        # print_dtype_shape(out, 'vocab_enc')
        for enc_layer, rdc_layer in zip(self.enc_layers, self.rdc_layers):
            out, _ = enc_layer(out, slf_attn_mask=mask)
            inds = slice(0, out.shape[1], 2)
            # print_dtype_shape(mask, 'mask 1')
            mask = mask[:, inds, inds]
            # print_dtype_shape(mask, 'mask 2')
            out = rdc_layer(out)
        return out

enc_pyr = EncoderPyramid(cfg_encdec_hg.enc_pyr)
pad_len = cfg_encdec_hg.enc_pyr.inp_len - tokens.shape[1]
print_dtype_shape(tokens, 'tokens')
inp_toks = F.pad(tokens, (0, pad_len), value=pad_idx)
print_dtype_shape(inp_toks, 'inp_toks')
out = enc_pyr(inp_toks)
print_dtype_shape(out, 'out')


tokens torch.Size([7, 50]) torch.int32
inp_toks torch.Size([7, 256]) torch.int32
out torch.Size([7, 1, 256]) torch.float32


In [168]:
class EnhanceLayer(nn.Module):
    d_model: int
    step: int
    enhancer: nn.Linear

    def __init__(self, d_model: int, step: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.step = step
        self.enhancer = nn.Linear(in_features=d_model, out_features=d_model * step, bias=False)
    
    def forward(self, inp: Tensor) -> Tensor:
        batch_size, seq_len, d_model = inp.shape
        assert d_model == self.d_model, f'self.d_model = {self.d_model}. inp d_model = {d_model}'
        print_dtype_shape(inp, 'enh_inp')
        out = self.enhancer(inp)
        print_dtype_shape(out, 'enh_out')
        out = out.reshape(batch_size, seq_len * self.step, self.d_model)
        print_dtype_shape(out, 'enh_out_reshape')
        return out

step = 2
enh_layer = EnhanceLayer(cfg_encdec_hg.dec_pyr.d_model, step)
enh_out = out
enh_out = enh_layer(enh_out)
enh_out = enh_layer(enh_out)
enh_out = enh_layer(enh_out)


enh_inp torch.Size([7, 1, 256]) torch.float32
enh_out torch.Size([7, 1, 512]) torch.float32
enh_out_reshape torch.Size([7, 2, 256]) torch.float32
enh_inp torch.Size([7, 2, 256]) torch.float32
enh_out torch.Size([7, 2, 512]) torch.float32
enh_out_reshape torch.Size([7, 4, 256]) torch.float32
enh_inp torch.Size([7, 4, 256]) torch.float32
enh_out torch.Size([7, 4, 512]) torch.float32
enh_out_reshape torch.Size([7, 8, 256]) torch.float32


In [None]:
class DecoderPyramid(nn.Module):
    cfg: DecPyrCfg
    enc_layers: nn.ModuleList
    enh_layers: nn.ModuleList
    inp_chunk_len: int
    # vocab_decoder: Optional[VocabDecoder]

    def __init__(self, cfg: DecPyrCfg):
        super().__init__()
        self.cfg = cfg
        self.enc_layers = nn.ModuleList([
            EncoderLayer(
                n_heads=cfg.n_heads, d_model=cfg.d_model, d_inner=cfg.d_inner, d_k=cfg.d_k, d_v=cfg.d_v,
                dropout_rate=cfg.dropout_rate,
            ) for _ in range(cfg.n_layers)
        ])
        self.enh_layers = nn.ModuleList([
            EnhanceLayer(d_model=cfg.d_model, step=cfg.step) for _ in range(cfg.n_layers)
        ])
        self.vocab_decoder = None
        if self.cfg.with_vocab_decoder:
            self.vocab_decoder = VocabDecoder(d_model=self.cfg.d_model, n_vocab=self.cfg.n_vocab)

    # Tensor with embeddings: [batch_size, 1, d_model]
    def forward(self, inp: Tensor) -> Tensor:
        out = inp
        for enc_layer, enh_layer in zip(self.enc_layers, self.enh_layers):
            out = enh_layer(out)
            out, _ = enc_layer(out)

        if self.vocab_decoder is not None:
            # [batch_size, seq_len, d_model]
            out = self.vocab_decoder(out)
            # [batch_size, seq_len, n_vocab]

        return out




torch.Size([7, 50, 256]) torch.float32
