In [1]:
import torch


class Atten(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.q = torch.nn.Linear(768, 768, bias=True)
        self.k = torch.nn.Linear(768, 768, bias=False)
        self.v = torch.nn.Linear(768, 768, bias=True)
        self.out = torch.nn.Linear(768, 768, bias=True)

    def forward(self, x, mask):
        b, lens, _ = x.size()

        q = self.q(x) * 0.125
        k = self.k(x)
        v = self.v(x)

        #[2, 50, 768] -> [2, 50, 12, 64] -> [2, 12, 50, 64] -> [24, 50, 64]
        q = q.reshape(b, lens, 12, 64).transpose(1,
                                                 2).reshape(b * 12, lens, 64)
        k = k.reshape(b, lens, 12, 64).transpose(1,
                                                 2).reshape(b * 12, lens, 64)
        v = v.reshape(b, lens, 12, 64).transpose(1,
                                                 2).reshape(b * 12, lens, 64)

        #[24, 50, 64] * [24, 64, 50] -> [24, 50, 50] -> [2, 12, 50, 50]
        atten = q.bmm(k.transpose(1, 2)).reshape(b, 12, lens, lens) + mask

        #[2, 12, 50, 50] -> [24, 50, 50] * [24, 50, 64] -> [24, 50, 64]
        atten = atten.reshape(b * 12, lens, lens).softmax(dim=-1).bmm(v)

        #[24, 50, 64] -> [2, 12, 50, 64] -> [2, 50, 12, 64] -> [2, 50, 768]
        atten = atten.reshape(b, 12, lens,
                              64).transpose(1, 2).reshape(b, lens, 768)

        return self.out(atten)


# Atten()(torch.randn(2, 50, 768), torch.ones(2, 1, 50, 50).long()).shape

  from .autonotebook import tqdm as notebook_tqdm


torch.Size([2, 50, 768])

In [2]:
class CrossAtten(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.q = torch.nn.Linear(768, 768, bias=True)
        self.k = torch.nn.Linear(768, 768, bias=False)
        self.v = torch.nn.Linear(768, 768, bias=True)

        self.out = torch.nn.Linear(768, 768, bias=True)

    def forward(self, x, kv):
        b, lens, _ = x.size()

        q = self.q(x) * 0.125
        k = self.k(kv)
        v = self.v(kv)

        q = q.reshape(b, lens, 12, 64).transpose(1,
                                                 2).reshape(b * 12, lens, 64)
        k = k.reshape(b, 1500, 12, 64).transpose(1,
                                                 2).reshape(b * 12, 1500, 64)
        v = v.reshape(b, 1500, 12, 64).transpose(1,
                                                 2).reshape(b * 12, 1500, 64)

        atten = q.bmm(k.transpose(1, 2)).softmax(dim=-1).bmm(v)
        atten = atten.reshape(b, 12, lens,
                              64).transpose(1, 2).reshape(b, lens, 768)

        return self.out(atten)


# CrossAtten()(torch.randn(2, 50, 768), torch.randn(2, 1500, 768)).shape

torch.Size([2, 50, 768])

In [3]:
class Layer(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.norm1 = torch.nn.LayerNorm(768)
        self.atten = Atten()

        self.norm2 = torch.nn.LayerNorm(768)
        self.cross_atten = CrossAtten()

        self.s = torch.nn.Sequential(
            torch.nn.LayerNorm(768),
            torch.nn.Linear(768, 3072),
            torch.torch.nn.GELU(),
            torch.nn.Linear(3072, 768),
        )

    def forward(self, x, mask, kv):
        x = self.atten(self.norm1(x), mask=mask) + x
        x = self.cross_atten(self.norm2(x), kv=kv) + x

        return self.s(x) + x


# Layer()(torch.randn(2, 50, 768), torch.ones(2, 1, 50, 50).long(),
#         torch.randn(2, 1500, 768)).shape

torch.Size([2, 50, 768])

In [4]:
def get_mask(b, lens):
    mask = torch.full((lens, lens), -float('inf'))

    t = torch.arange(lens)
    t = t < (t + 1).reshape(lens, 1)
    mask.masked_fill_(t, 0.0)

    return mask.reshape(1, 1, lens, lens).repeat(b, 1, 1, 1)


# get_mask(2, 5)

tensor([[[[0., -inf, -inf, -inf, -inf],
          [0., 0., -inf, -inf, -inf],
          [0., 0., 0., -inf, -inf],
          [0., 0., 0., 0., -inf],
          [0., 0., 0., 0., 0.]]],


        [[[0., -inf, -inf, -inf, -inf],
          [0., 0., -inf, -inf, -inf],
          [0., 0., 0., -inf, -inf],
          [0., 0., 0., 0., -inf],
          [0., 0., 0., 0., 0.]]]])

In [5]:
class Decoder(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.embed = torch.nn.Embedding(51865, 768, 50257)
        self.embed_pos = torch.nn.Embedding(448, 768)

        self.layer = torch.nn.ModuleList([Layer() for _ in range(12)])
        self.norm = torch.nn.LayerNorm(768)

    def forward(self, x, kv):
        mask = get_mask(*x.shape).to(x.device)
        
        x = self.embed(x) + self.embed_pos.weight[:x.shape[1]]

        for i in self.layer:
            x = i(x, mask=mask, kv=kv)

        return self.norm(x)


# Decoder()(torch.ones(2, 50).long(), kv=torch.randn(2, 1500, 768)).shape

torch.Size([2, 50, 768])

In [6]:
def load_decoder(pretrained):
    decoder = Decoder()

    decoder.embed.load_state_dict(pretrained.embed_tokens.state_dict())
    decoder.embed_pos.load_state_dict(pretrained.embed_positions.state_dict())
    decoder.norm.load_state_dict(pretrained.layer_norm.state_dict())

    for i in range(12):

        decoder.layer[i].norm1.load_state_dict(
            pretrained.layers[i].self_attn_layer_norm.state_dict())

        decoder.layer[i].atten.q.load_state_dict(
            pretrained.layers[i].self_attn.q_proj.state_dict())
        decoder.layer[i].atten.k.load_state_dict(
            pretrained.layers[i].self_attn.k_proj.state_dict())
        decoder.layer[i].atten.v.load_state_dict(
            pretrained.layers[i].self_attn.v_proj.state_dict())
        decoder.layer[i].atten.out.load_state_dict(
            pretrained.layers[i].self_attn.out_proj.state_dict())

        decoder.layer[i].norm2.load_state_dict(
            pretrained.layers[i].encoder_attn_layer_norm.state_dict())

        decoder.layer[i].cross_atten.q.load_state_dict(
            pretrained.layers[i].encoder_attn.q_proj.state_dict())
        decoder.layer[i].cross_atten.k.load_state_dict(
            pretrained.layers[i].encoder_attn.k_proj.state_dict())
        decoder.layer[i].cross_atten.v.load_state_dict(
            pretrained.layers[i].encoder_attn.v_proj.state_dict())
        decoder.layer[i].cross_atten.out.load_state_dict(
            pretrained.layers[i].encoder_attn.out_proj.state_dict())

        decoder.layer[i].s[0].load_state_dict(
            pretrained.layers[i].final_layer_norm.state_dict())

        decoder.layer[i].s[1].load_state_dict(
            pretrained.layers[i].fc1.state_dict())
        decoder.layer[i].s[3].load_state_dict(
            pretrained.layers[i].fc2.state_dict())

    return decoder


# from transformers import WhisperForConditionalGeneration

# pretrained = WhisperForConditionalGeneration.from_pretrained(
#     'openai/whisper-small').model.decoder
# decoder = load_decoder(pretrained)

# x = torch.ones(2, 50).long()
# kv = torch.randn(2, 1500, 768)

# out1 = decoder(x, kv)
# out2 = pretrained(input_ids=x, attention_mask=None,
#                   encoder_hidden_states=kv).last_hidden_state

# (out1 == out2).all()

tensor(True)